gccrs: Create and use CompilePatternLet visitor for compiling let statments
[official-gcc.git] / gcc / rust / backend / rust-compile-pattern.cc
blobe13d6caf7e6ab6bc09873090bfbac4d9e5876799
1 // Copyright (C) 2020-2023 Free Software Foundation, Inc.
3 // This file is part of GCC.
5 // GCC is free software; you can redistribute it and/or modify it under
6 // the terms of the GNU General Public License as published by the Free
7 // Software Foundation; either version 3, or (at your option) any later
8 // version.
10 // GCC is distributed in the hope that it will be useful, but WITHOUT ANY
11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 // for more details.
15 // You should have received a copy of the GNU General Public License
16 // along with GCC; see the file COPYING3. If not see
17 // <http://www.gnu.org/licenses/>.
19 #include "rust-compile-pattern.h"
20 #include "rust-compile-expr.h"
21 #include "rust-compile-resolve-path.h"
22 #include "rust-constexpr.h"
24 namespace Rust {
25 namespace Compile {
27 void
28 CompilePatternCaseLabelExpr::visit (HIR::PathInExpression &pattern)
30 // lookup the type
31 TyTy::BaseType *lookup = nullptr;
32 bool ok
33 = ctx->get_tyctx ()->lookup_type (pattern.get_mappings ().get_hirid (),
34 &lookup);
35 rust_assert (ok);
37 // this must be an enum
38 rust_assert (lookup->get_kind () == TyTy::TypeKind::ADT);
39 TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (lookup);
40 rust_assert (adt->is_enum ());
42 // lookup the variant
43 HirId variant_id;
44 ok = ctx->get_tyctx ()->lookup_variant_definition (
45 pattern.get_mappings ().get_hirid (), &variant_id);
46 rust_assert (ok);
48 TyTy::VariantDef *variant = nullptr;
49 ok = adt->lookup_variant_by_id (variant_id, &variant);
50 rust_assert (ok);
52 HIR::Expr *discrim_expr = variant->get_discriminant ();
53 tree discrim_expr_node = CompileExpr::Compile (discrim_expr, ctx);
54 tree folded_discrim_expr = fold_expr (discrim_expr_node);
55 tree case_low = folded_discrim_expr;
57 case_label_expr
58 = build_case_label (case_low, NULL_TREE, associated_case_label);
61 void
62 CompilePatternCaseLabelExpr::visit (HIR::StructPattern &pattern)
64 CompilePatternCaseLabelExpr::visit (pattern.get_path ());
67 void
68 CompilePatternCaseLabelExpr::visit (HIR::TupleStructPattern &pattern)
70 CompilePatternCaseLabelExpr::visit (pattern.get_path ());
73 void
74 CompilePatternCaseLabelExpr::visit (HIR::WildcardPattern &pattern)
76 // operand 0 being NULL_TREE signifies this is the default case label see:
77 // tree.def for documentation for CASE_LABEL_EXPR
78 case_label_expr
79 = build_case_label (NULL_TREE, NULL_TREE, associated_case_label);
82 void
83 CompilePatternCaseLabelExpr::visit (HIR::LiteralPattern &pattern)
85 // Compile the literal
86 HIR::LiteralExpr *litexpr
87 = new HIR::LiteralExpr (pattern.get_pattern_mappings (),
88 pattern.get_literal (), pattern.get_locus (),
89 std::vector<AST::Attribute> ());
91 // Note: Floating point literals are currently accepted but will likely be
92 // forbidden in LiteralPatterns in a future version of Rust.
93 // See: https://github.com/rust-lang/rust/issues/41620
94 // For now, we cannot compile them anyway as CASE_LABEL_EXPR does not support
95 // floating point types.
96 if (pattern.get_literal ().get_lit_type () == HIR::Literal::LitType::FLOAT)
98 rust_sorry_at (pattern.get_locus (), "floating-point literal in pattern");
101 tree lit = CompileExpr::Compile (litexpr, ctx);
103 case_label_expr = build_case_label (lit, NULL_TREE, associated_case_label);
106 static tree
107 compile_range_pattern_bound (HIR::RangePatternBound *bound,
108 Analysis::NodeMapping mappings, Location locus,
109 Context *ctx)
111 tree result = NULL_TREE;
112 switch (bound->get_bound_type ())
114 case HIR::RangePatternBound::RangePatternBoundType::LITERAL: {
115 HIR::RangePatternBoundLiteral &ref
116 = *static_cast<HIR::RangePatternBoundLiteral *> (bound);
118 HIR::LiteralExpr *litexpr
119 = new HIR::LiteralExpr (mappings, ref.get_literal (), locus,
120 std::vector<AST::Attribute> ());
122 result = CompileExpr::Compile (litexpr, ctx);
124 break;
126 case HIR::RangePatternBound::RangePatternBoundType::PATH: {
127 HIR::RangePatternBoundPath &ref
128 = *static_cast<HIR::RangePatternBoundPath *> (bound);
130 result = ResolvePathRef::Compile (ref.get_path (), ctx);
132 // If the path resolves to a const expression, fold it.
133 result = fold_expr (result);
135 break;
137 case HIR::RangePatternBound::RangePatternBoundType::QUALPATH: {
138 HIR::RangePatternBoundQualPath &ref
139 = *static_cast<HIR::RangePatternBoundQualPath *> (bound);
141 result = ResolvePathRef::Compile (ref.get_qualified_path (), ctx);
143 // If the path resolves to a const expression, fold it.
144 result = fold_expr (result);
148 return result;
151 void
152 CompilePatternCaseLabelExpr::visit (HIR::RangePattern &pattern)
154 tree upper = compile_range_pattern_bound (pattern.get_upper_bound ().get (),
155 pattern.get_pattern_mappings (),
156 pattern.get_locus (), ctx);
157 tree lower = compile_range_pattern_bound (pattern.get_lower_bound ().get (),
158 pattern.get_pattern_mappings (),
159 pattern.get_locus (), ctx);
161 case_label_expr = build_case_label (lower, upper, associated_case_label);
164 void
165 CompilePatternCaseLabelExpr::visit (HIR::GroupedPattern &pattern)
167 pattern.get_item ()->accept_vis (*this);
170 // setup the bindings
172 void
173 CompilePatternBindings::visit (HIR::TupleStructPattern &pattern)
175 // lookup the type
176 TyTy::BaseType *lookup = nullptr;
177 bool ok = ctx->get_tyctx ()->lookup_type (
178 pattern.get_path ().get_mappings ().get_hirid (), &lookup);
179 rust_assert (ok);
181 // this must be an enum
182 rust_assert (lookup->get_kind () == TyTy::TypeKind::ADT);
183 TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (lookup);
184 rust_assert (adt->number_of_variants () > 0);
186 int variant_index = 0;
187 TyTy::VariantDef *variant = adt->get_variants ().at (0);
188 if (adt->is_enum ())
190 HirId variant_id = UNKNOWN_HIRID;
191 bool ok = ctx->get_tyctx ()->lookup_variant_definition (
192 pattern.get_path ().get_mappings ().get_hirid (), &variant_id);
193 rust_assert (ok);
195 ok = adt->lookup_variant_by_id (variant_id, &variant, &variant_index);
196 rust_assert (ok);
199 rust_assert (variant->get_variant_type ()
200 == TyTy::VariantDef::VariantType::TUPLE);
202 std::unique_ptr<HIR::TupleStructItems> &items = pattern.get_items ();
203 switch (items->get_item_type ())
205 case HIR::TupleStructItems::RANGE: {
206 // TODO
207 gcc_unreachable ();
209 break;
211 case HIR::TupleStructItems::NO_RANGE: {
212 HIR::TupleStructItemsNoRange &items_no_range
213 = static_cast<HIR::TupleStructItemsNoRange &> (*items.get ());
215 rust_assert (items_no_range.get_patterns ().size ()
216 == variant->num_fields ());
218 if (adt->is_enum ())
220 // we are offsetting by + 1 here since the first field in the record
221 // is always the discriminator
222 size_t tuple_field_index = 1;
223 for (auto &pattern : items_no_range.get_patterns ())
225 tree variant_accessor
226 = ctx->get_backend ()->struct_field_expression (
227 match_scrutinee_expr, variant_index, pattern->get_locus ());
229 tree binding = ctx->get_backend ()->struct_field_expression (
230 variant_accessor, tuple_field_index++, pattern->get_locus ());
232 ctx->insert_pattern_binding (
233 pattern->get_pattern_mappings ().get_hirid (), binding);
236 else
238 size_t tuple_field_index = 0;
239 for (auto &pattern : items_no_range.get_patterns ())
241 tree variant_accessor = match_scrutinee_expr;
243 tree binding = ctx->get_backend ()->struct_field_expression (
244 variant_accessor, tuple_field_index++, pattern->get_locus ());
246 ctx->insert_pattern_binding (
247 pattern->get_pattern_mappings ().get_hirid (), binding);
251 break;
255 void
256 CompilePatternBindings::visit (HIR::StructPattern &pattern)
258 // lookup the type
259 TyTy::BaseType *lookup = nullptr;
260 bool ok = ctx->get_tyctx ()->lookup_type (
261 pattern.get_path ().get_mappings ().get_hirid (), &lookup);
262 rust_assert (ok);
264 // this must be an enum
265 rust_assert (lookup->get_kind () == TyTy::TypeKind::ADT);
266 TyTy::ADTType *adt = static_cast<TyTy::ADTType *> (lookup);
267 rust_assert (adt->number_of_variants () > 0);
269 int variant_index = 0;
270 TyTy::VariantDef *variant = adt->get_variants ().at (0);
271 if (adt->is_enum ())
273 HirId variant_id = UNKNOWN_HIRID;
274 bool ok = ctx->get_tyctx ()->lookup_variant_definition (
275 pattern.get_path ().get_mappings ().get_hirid (), &variant_id);
276 rust_assert (ok);
278 ok = adt->lookup_variant_by_id (variant_id, &variant, &variant_index);
279 rust_assert (ok);
282 rust_assert (variant->get_variant_type ()
283 == TyTy::VariantDef::VariantType::STRUCT);
285 auto &struct_pattern_elems = pattern.get_struct_pattern_elems ();
286 for (auto &field : struct_pattern_elems.get_struct_pattern_fields ())
288 switch (field->get_item_type ())
290 case HIR::StructPatternField::ItemType::TUPLE_PAT: {
291 // TODO
292 gcc_unreachable ();
294 break;
296 case HIR::StructPatternField::ItemType::IDENT_PAT: {
297 // TODO
298 gcc_unreachable ();
300 break;
302 case HIR::StructPatternField::ItemType::IDENT: {
303 HIR::StructPatternFieldIdent &ident
304 = static_cast<HIR::StructPatternFieldIdent &> (*field.get ());
306 size_t offs = 0;
308 = variant->lookup_field (ident.get_identifier (), nullptr, &offs);
309 rust_assert (ok);
311 tree binding = error_mark_node;
312 if (adt->is_enum ())
314 tree variant_accessor
315 = ctx->get_backend ()->struct_field_expression (
316 match_scrutinee_expr, variant_index, ident.get_locus ());
318 // we are offsetting by + 1 here since the first field in the
319 // record is always the discriminator
320 binding = ctx->get_backend ()->struct_field_expression (
321 variant_accessor, offs + 1, ident.get_locus ());
323 else
325 tree variant_accessor = match_scrutinee_expr;
326 binding = ctx->get_backend ()->struct_field_expression (
327 variant_accessor, offs, ident.get_locus ());
330 ctx->insert_pattern_binding (ident.get_mappings ().get_hirid (),
331 binding);
333 break;
338 void
339 CompilePatternBindings::visit (HIR::GroupedPattern &pattern)
341 pattern.get_item ()->accept_vis (*this);
344 void
345 CompilePatternLet::visit (HIR::IdentifierPattern &pattern)
347 Bvariable *var = nullptr;
348 rust_assert (
349 ctx->lookup_var_decl (pattern.get_pattern_mappings ().get_hirid (), &var));
351 auto fnctx = ctx->peek_fn ();
352 if (ty->is_unit ())
354 ctx->add_statement (init_expr);
356 tree stmt_type = TyTyResolveCompile::compile (ctx, ty);
358 auto unit_type_init_expr
359 = ctx->get_backend ()->constructor_expression (stmt_type, false, {}, -1,
360 rval_locus);
361 auto s = ctx->get_backend ()->init_statement (fnctx.fndecl, var,
362 unit_type_init_expr);
363 ctx->add_statement (s);
365 else
367 auto s
368 = ctx->get_backend ()->init_statement (fnctx.fndecl, var, init_expr);
369 ctx->add_statement (s);
373 void
374 CompilePatternLet::visit (HIR::WildcardPattern &pattern)
376 Bvariable *var = nullptr;
377 rust_assert (
378 ctx->lookup_var_decl (pattern.get_pattern_mappings ().get_hirid (), &var));
380 auto fnctx = ctx->peek_fn ();
381 if (ty->is_unit ())
383 ctx->add_statement (init_expr);
385 tree stmt_type = TyTyResolveCompile::compile (ctx, ty);
387 auto unit_type_init_expr
388 = ctx->get_backend ()->constructor_expression (stmt_type, false, {}, -1,
389 rval_locus);
390 auto s = ctx->get_backend ()->init_statement (fnctx.fndecl, var,
391 unit_type_init_expr);
392 ctx->add_statement (s);
394 else
396 auto s
397 = ctx->get_backend ()->init_statement (fnctx.fndecl, var, init_expr);
398 ctx->add_statement (s);
402 } // namespace Compile
403 } // namespace Rust