gccrs: Clear the substitution callbacks when copying ArgumentMappings
[official-gcc.git] / gcc / rust / typecheck / rust-tyty-subst.cc
blob0e181efb359e64479ef83238351823e50a3b4c1d
1 // Copyright (C) 2020-2022 Free Software Foundation, Inc.
3 // This file is part of GCC.
5 // GCC is free software; you can redistribute it and/or modify it under
6 // the terms of the GNU General Public License as published by the Free
7 // Software Foundation; either version 3, or (at your option) any later
8 // version.
10 // GCC is distributed in the hope that it will be useful, but WITHOUT ANY
11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 // for more details.
15 // You should have received a copy of the GNU General Public License
16 // along with GCC; see the file COPYING3. If not see
17 // <http://www.gnu.org/licenses/>.
19 #include "rust-tyty-subst.h"
20 #include "rust-hir-full.h"
21 #include "rust-tyty.h"
22 #include "rust-hir-type-check.h"
23 #include "rust-substitution-mapper.h"
24 #include "rust-hir-type-check-type.h"
26 namespace Rust {
27 namespace TyTy {
29 SubstitutionParamMapping::SubstitutionParamMapping (
30 const HIR::TypeParam &generic, ParamType *param)
31 : generic (generic), param (param)
34 SubstitutionParamMapping::SubstitutionParamMapping (
35 const SubstitutionParamMapping &other)
36 : generic (other.generic), param (other.param)
39 std::string
40 SubstitutionParamMapping::as_string () const
42 if (param == nullptr)
43 return "nullptr";
45 return param->get_name ();
48 SubstitutionParamMapping
49 SubstitutionParamMapping::clone () const
51 return SubstitutionParamMapping (generic,
52 static_cast<ParamType *> (param->clone ()));
55 ParamType *
56 SubstitutionParamMapping::get_param_ty ()
58 return param;
61 const ParamType *
62 SubstitutionParamMapping::get_param_ty () const
64 return param;
67 const HIR::TypeParam &
68 SubstitutionParamMapping::get_generic_param ()
70 return generic;
73 bool
74 SubstitutionParamMapping::needs_substitution () const
76 return !(get_param_ty ()->is_concrete ());
79 Location
80 SubstitutionParamMapping::get_param_locus () const
82 return generic.get_locus ();
85 bool
86 SubstitutionParamMapping::param_has_default_ty () const
88 return generic.has_type ();
91 BaseType *
92 SubstitutionParamMapping::get_default_ty () const
94 TyVar var (generic.get_type_mappings ().get_hirid ());
95 return var.get_tyty ();
98 bool
99 SubstitutionParamMapping::need_substitution () const
101 if (!param->can_resolve ())
102 return true;
104 auto resolved = param->resolve ();
105 return !resolved->is_concrete ();
108 bool
109 SubstitutionParamMapping::fill_param_ty (
110 SubstitutionArgumentMappings &subst_mappings, Location locus)
112 SubstitutionArg arg = SubstitutionArg::error ();
113 bool ok = subst_mappings.get_argument_for_symbol (get_param_ty (), &arg);
114 if (!ok)
115 return true;
117 TyTy::BaseType &type = *arg.get_tyty ();
118 if (type.get_kind () == TyTy::TypeKind::INFER)
120 type.inherit_bounds (*param);
122 else
124 if (!param->bounds_compatible (type, locus, true))
125 return false;
128 if (type.get_kind () == TypeKind::PARAM)
130 // delete param;
131 param = static_cast<ParamType *> (type.clone ());
133 else
135 // check the substitution is compatible with bounds
136 if (!param->bounds_compatible (type, locus, true))
137 return false;
139 // recursively pass this down to all HRTB's
140 for (auto &bound : param->get_specified_bounds ())
141 bound.handle_substitions (subst_mappings);
143 param->set_ty_ref (type.get_ref ());
146 return true;
149 void
150 SubstitutionParamMapping::override_context ()
152 if (!param->can_resolve ())
153 return;
155 auto mappings = Analysis::Mappings::get ();
156 auto context = Resolver::TypeCheckContext::get ();
158 context->insert_type (Analysis::NodeMapping (mappings->get_current_crate (),
159 UNKNOWN_NODEID,
160 param->get_ref (),
161 UNKNOWN_LOCAL_DEFID),
162 param->resolve ());
165 SubstitutionArg::SubstitutionArg (const SubstitutionParamMapping *param,
166 BaseType *argument)
167 : param (param), argument (argument)
170 SubstitutionArg::SubstitutionArg (const SubstitutionArg &other)
171 : param (other.param), argument (other.argument)
174 SubstitutionArg &
175 SubstitutionArg::operator= (const SubstitutionArg &other)
177 param = other.param;
178 argument = other.argument;
179 return *this;
182 BaseType *
183 SubstitutionArg::get_tyty ()
185 return argument;
188 const BaseType *
189 SubstitutionArg::get_tyty () const
191 return argument;
194 const SubstitutionParamMapping *
195 SubstitutionArg::get_param_mapping () const
197 return param;
200 SubstitutionArg
201 SubstitutionArg::error ()
203 return SubstitutionArg (nullptr, nullptr);
206 bool
207 SubstitutionArg::is_error () const
209 return param == nullptr || argument == nullptr;
212 bool
213 SubstitutionArg::is_conrete () const
215 if (argument != nullptr)
216 return true;
218 if (argument->get_kind () == TyTy::TypeKind::PARAM)
219 return false;
221 return argument->is_concrete ();
224 std::string
225 SubstitutionArg::as_string () const
227 return param->as_string ()
228 + (argument != nullptr ? ":" + argument->as_string () : "");
231 // SubstitutionArgumentMappings
233 SubstitutionArgumentMappings::SubstitutionArgumentMappings (
234 std::vector<SubstitutionArg> mappings,
235 std::map<std::string, BaseType *> binding_args, Location locus,
236 ParamSubstCb param_subst_cb, bool trait_item_flag)
237 : mappings (mappings), binding_args (binding_args), locus (locus),
238 param_subst_cb (param_subst_cb), trait_item_flag (trait_item_flag)
241 SubstitutionArgumentMappings::SubstitutionArgumentMappings (
242 const SubstitutionArgumentMappings &other)
243 : mappings (other.mappings), binding_args (other.binding_args),
244 locus (other.locus), param_subst_cb (nullptr),
245 trait_item_flag (other.trait_item_flag)
248 SubstitutionArgumentMappings &
249 SubstitutionArgumentMappings::operator= (
250 const SubstitutionArgumentMappings &other)
252 mappings = other.mappings;
253 binding_args = other.binding_args;
254 locus = other.locus;
255 param_subst_cb = nullptr;
256 trait_item_flag = other.trait_item_flag;
258 return *this;
261 SubstitutionArgumentMappings
262 SubstitutionArgumentMappings::error ()
264 return SubstitutionArgumentMappings ({}, {}, Location (), nullptr, false);
267 bool
268 SubstitutionArgumentMappings::is_error () const
270 return mappings.size () == 0;
273 bool
274 SubstitutionArgumentMappings::get_argument_for_symbol (
275 const ParamType *param_to_find, SubstitutionArg *argument)
277 for (auto &mapping : mappings)
279 const SubstitutionParamMapping *param = mapping.get_param_mapping ();
280 const ParamType *p = param->get_param_ty ();
282 if (p->get_symbol ().compare (param_to_find->get_symbol ()) == 0)
284 *argument = mapping;
285 return true;
288 return false;
291 bool
292 SubstitutionArgumentMappings::get_argument_at (size_t index,
293 SubstitutionArg *argument)
295 if (index > mappings.size ())
296 return false;
298 *argument = mappings.at (index);
299 return true;
302 bool
303 SubstitutionArgumentMappings::is_concrete () const
305 for (auto &mapping : mappings)
307 if (!mapping.is_conrete ())
308 return false;
310 return true;
313 Location
314 SubstitutionArgumentMappings::get_locus () const
316 return locus;
319 size_t
320 SubstitutionArgumentMappings::size () const
322 return mappings.size ();
325 bool
326 SubstitutionArgumentMappings::is_empty () const
328 return size () == 0;
331 std::vector<SubstitutionArg> &
332 SubstitutionArgumentMappings::get_mappings ()
334 return mappings;
337 const std::vector<SubstitutionArg> &
338 SubstitutionArgumentMappings::get_mappings () const
340 return mappings;
343 std::map<std::string, BaseType *> &
344 SubstitutionArgumentMappings::get_binding_args ()
346 return binding_args;
349 const std::map<std::string, BaseType *> &
350 SubstitutionArgumentMappings::get_binding_args () const
352 return binding_args;
355 std::string
356 SubstitutionArgumentMappings::as_string () const
358 std::string buffer;
359 for (auto &mapping : mappings)
361 buffer += mapping.as_string () + ", ";
363 return "<" + buffer + ">";
366 void
367 SubstitutionArgumentMappings::on_param_subst (const ParamType &p,
368 const SubstitutionArg &a) const
370 if (param_subst_cb == nullptr)
371 return;
373 param_subst_cb (p, a);
376 ParamSubstCb
377 SubstitutionArgumentMappings::get_subst_cb () const
379 return param_subst_cb;
382 bool
383 SubstitutionArgumentMappings::trait_item_mode () const
385 return trait_item_flag;
388 // SubstitutionRef
390 SubstitutionRef::SubstitutionRef (
391 std::vector<SubstitutionParamMapping> substitutions,
392 SubstitutionArgumentMappings arguments)
393 : substitutions (substitutions), used_arguments (arguments)
396 bool
397 SubstitutionRef::has_substitutions () const
399 return substitutions.size () > 0;
402 std::string
403 SubstitutionRef::subst_as_string () const
405 std::string buffer;
406 for (size_t i = 0; i < substitutions.size (); i++)
408 const SubstitutionParamMapping &sub = substitutions.at (i);
409 buffer += sub.as_string ();
411 if ((i + 1) < substitutions.size ())
412 buffer += ", ";
415 return buffer.empty () ? "" : "<" + buffer + ">";
418 bool
419 SubstitutionRef::supports_associated_bindings () const
421 return get_num_associated_bindings () > 0;
424 size_t
425 SubstitutionRef::get_num_associated_bindings () const
427 return 0;
430 TypeBoundPredicateItem
431 SubstitutionRef::lookup_associated_type (const std::string &search)
433 return TypeBoundPredicateItem::error ();
436 size_t
437 SubstitutionRef::get_num_substitutions () const
439 return substitutions.size ();
442 std::vector<SubstitutionParamMapping> &
443 SubstitutionRef::get_substs ()
445 return substitutions;
448 const std::vector<SubstitutionParamMapping> &
449 SubstitutionRef::get_substs () const
451 return substitutions;
454 std::vector<SubstitutionParamMapping>
455 SubstitutionRef::clone_substs () const
457 std::vector<SubstitutionParamMapping> clone;
459 for (auto &sub : substitutions)
460 clone.push_back (sub.clone ());
462 return clone;
465 void
466 SubstitutionRef::override_context ()
468 for (auto &sub : substitutions)
470 sub.override_context ();
474 bool
475 SubstitutionRef::needs_substitution () const
477 for (auto &sub : substitutions)
479 if (sub.need_substitution ())
480 return true;
482 return false;
485 bool
486 SubstitutionRef::was_substituted () const
488 return !needs_substitution ();
491 SubstitutionArgumentMappings &
492 SubstitutionRef::get_substitution_arguments ()
494 return used_arguments;
497 const SubstitutionArgumentMappings &
498 SubstitutionRef::get_substitution_arguments () const
500 return used_arguments;
503 size_t
504 SubstitutionRef::num_required_substitutions () const
506 size_t n = 0;
507 for (auto &p : substitutions)
509 if (p.needs_substitution ())
510 n++;
512 return n;
515 size_t
516 SubstitutionRef::min_required_substitutions () const
518 size_t n = 0;
519 for (auto &p : substitutions)
521 if (p.needs_substitution () && !p.param_has_default_ty ())
522 n++;
524 return n;
527 SubstitutionArgumentMappings
528 SubstitutionRef::get_used_arguments () const
530 return used_arguments;
533 SubstitutionArgumentMappings
534 SubstitutionRef::get_mappings_from_generic_args (HIR::GenericArgs &args)
536 std::map<std::string, BaseType *> binding_arguments;
537 if (args.get_binding_args ().size () > 0)
539 if (supports_associated_bindings ())
541 if (args.get_binding_args ().size () > get_num_associated_bindings ())
543 RichLocation r (args.get_locus ());
545 rust_error_at (r,
546 "generic item takes at most %lu type binding "
547 "arguments but %lu were supplied",
548 (unsigned long) get_num_associated_bindings (),
549 (unsigned long) args.get_binding_args ().size ());
550 return SubstitutionArgumentMappings::error ();
553 for (auto &binding : args.get_binding_args ())
555 BaseType *resolved
556 = Resolver::TypeCheckType::Resolve (binding.get_type ().get ());
557 if (resolved == nullptr
558 || resolved->get_kind () == TyTy::TypeKind::ERROR)
560 rust_error_at (binding.get_locus (),
561 "failed to resolve type arguments");
562 return SubstitutionArgumentMappings::error ();
565 // resolve to relevant binding
566 auto binding_item
567 = lookup_associated_type (binding.get_identifier ());
568 if (binding_item.is_error ())
570 rust_error_at (binding.get_locus (),
571 "unknown associated type binding: %s",
572 binding.get_identifier ().c_str ());
573 return SubstitutionArgumentMappings::error ();
576 binding_arguments[binding.get_identifier ()] = resolved;
579 else
581 RichLocation r (args.get_locus ());
582 for (auto &binding : args.get_binding_args ())
583 r.add_range (binding.get_locus ());
585 rust_error_at (r, "associated type bindings are not allowed here");
586 return SubstitutionArgumentMappings::error ();
590 // for inherited arguments
591 size_t offs = used_arguments.size ();
592 if (args.get_type_args ().size () + offs > substitutions.size ())
594 RichLocation r (args.get_locus ());
595 r.add_range (substitutions.front ().get_param_locus ());
597 rust_error_at (
599 "generic item takes at most %lu type arguments but %lu were supplied",
600 (unsigned long) substitutions.size (),
601 (unsigned long) args.get_type_args ().size ());
602 return SubstitutionArgumentMappings::error ();
605 if (args.get_type_args ().size () + offs < min_required_substitutions ())
607 RichLocation r (args.get_locus ());
608 r.add_range (substitutions.front ().get_param_locus ());
610 rust_error_at (
612 "generic item takes at least %lu type arguments but %lu were supplied",
613 (unsigned long) (min_required_substitutions () - offs),
614 (unsigned long) args.get_type_args ().size ());
615 return SubstitutionArgumentMappings::error ();
618 std::vector<SubstitutionArg> mappings = used_arguments.get_mappings ();
619 for (auto &arg : args.get_type_args ())
621 BaseType *resolved = Resolver::TypeCheckType::Resolve (arg.get ());
622 if (resolved == nullptr || resolved->get_kind () == TyTy::TypeKind::ERROR)
624 rust_error_at (args.get_locus (), "failed to resolve type arguments");
625 return SubstitutionArgumentMappings::error ();
628 SubstitutionArg subst_arg (&substitutions.at (offs), resolved);
629 offs++;
630 mappings.push_back (std::move (subst_arg));
633 // we must need to fill out defaults
634 size_t left_over
635 = num_required_substitutions () - min_required_substitutions ();
636 if (left_over > 0)
638 for (size_t offs = mappings.size (); offs < substitutions.size (); offs++)
640 SubstitutionParamMapping &param = substitutions.at (offs);
641 rust_assert (param.param_has_default_ty ());
643 BaseType *resolved = param.get_default_ty ();
644 if (resolved->get_kind () == TypeKind::ERROR)
645 return SubstitutionArgumentMappings::error ();
647 // this resolved default might already contain default parameters
648 if (resolved->contains_type_parameters ())
650 SubstitutionArgumentMappings intermediate (mappings,
651 binding_arguments,
652 args.get_locus ());
653 resolved = Resolver::SubstMapperInternal::Resolve (resolved,
654 intermediate);
656 if (resolved->get_kind () == TypeKind::ERROR)
657 return SubstitutionArgumentMappings::error ();
660 SubstitutionArg subst_arg (&param, resolved);
661 mappings.push_back (std::move (subst_arg));
665 return SubstitutionArgumentMappings (mappings, binding_arguments,
666 args.get_locus ());
669 BaseType *
670 SubstitutionRef::infer_substitions (Location locus)
672 std::vector<SubstitutionArg> args;
673 std::map<std::string, BaseType *> argument_mappings;
674 for (auto &p : get_substs ())
676 if (p.needs_substitution ())
678 const std::string &symbol = p.get_param_ty ()->get_symbol ();
679 auto it = argument_mappings.find (symbol);
680 bool have_mapping = it != argument_mappings.end ();
682 if (have_mapping)
684 args.push_back (SubstitutionArg (&p, it->second));
686 else
688 TyVar infer_var = TyVar::get_implicit_infer_var (locus);
689 args.push_back (SubstitutionArg (&p, infer_var.get_tyty ()));
690 argument_mappings[symbol] = infer_var.get_tyty ();
693 else
695 args.push_back (SubstitutionArg (&p, p.get_param_ty ()->resolve ()));
699 // FIXME do we need to add inference variables to all the possible bindings?
700 // it might just lead to inference variable hell not 100% sure if rustc does
701 // this i think the language might needs this to be explicitly set
703 SubstitutionArgumentMappings infer_arguments (std::move (args),
704 {} /* binding_arguments */,
705 locus);
706 return handle_substitions (infer_arguments);
709 SubstitutionArgumentMappings
710 SubstitutionRef::adjust_mappings_for_this (
711 SubstitutionArgumentMappings &mappings)
713 std::vector<SubstitutionArg> resolved_mappings;
714 for (size_t i = 0; i < substitutions.size (); i++)
716 auto &subst = substitutions.at (i);
718 SubstitutionArg arg = SubstitutionArg::error ();
719 if (mappings.size () == substitutions.size ())
721 mappings.get_argument_at (i, &arg);
723 else
725 if (subst.needs_substitution ())
727 // get from passed in mappings
728 mappings.get_argument_for_symbol (subst.get_param_ty (), &arg);
730 else
732 // we should already have this somewhere
733 used_arguments.get_argument_for_symbol (subst.get_param_ty (),
734 &arg);
738 bool ok = !arg.is_error ();
739 if (ok)
741 SubstitutionArg adjusted (&subst, arg.get_tyty ());
742 resolved_mappings.push_back (std::move (adjusted));
746 if (resolved_mappings.empty ())
747 return SubstitutionArgumentMappings::error ();
749 return SubstitutionArgumentMappings (resolved_mappings,
750 mappings.get_binding_args (),
751 mappings.get_locus (),
752 mappings.get_subst_cb (),
753 mappings.trait_item_mode ());
756 bool
757 SubstitutionRef::are_mappings_bound (SubstitutionArgumentMappings &mappings)
759 std::vector<SubstitutionArg> resolved_mappings;
760 for (size_t i = 0; i < substitutions.size (); i++)
762 auto &subst = substitutions.at (i);
764 SubstitutionArg arg = SubstitutionArg::error ();
765 if (mappings.size () == substitutions.size ())
767 mappings.get_argument_at (i, &arg);
769 else
771 if (subst.needs_substitution ())
773 // get from passed in mappings
774 mappings.get_argument_for_symbol (subst.get_param_ty (), &arg);
776 else
778 // we should already have this somewhere
779 used_arguments.get_argument_for_symbol (subst.get_param_ty (),
780 &arg);
784 bool ok = !arg.is_error ();
785 if (ok)
787 SubstitutionArg adjusted (&subst, arg.get_tyty ());
788 resolved_mappings.push_back (std::move (adjusted));
792 return !resolved_mappings.empty ();
795 // this function assumes that the mappings being passed are for the same type as
796 // this new substitution reference so ordering matters here
797 SubstitutionArgumentMappings
798 SubstitutionRef::solve_mappings_from_receiver_for_self (
799 SubstitutionArgumentMappings &mappings) const
801 std::vector<SubstitutionArg> resolved_mappings;
803 rust_assert (mappings.size () == get_num_substitutions ());
804 for (size_t i = 0; i < get_num_substitutions (); i++)
806 const SubstitutionParamMapping &param_mapping = substitutions.at (i);
807 SubstitutionArg &arg = mappings.get_mappings ().at (i);
809 if (param_mapping.needs_substitution ())
811 SubstitutionArg adjusted (&param_mapping, arg.get_tyty ());
812 resolved_mappings.push_back (std::move (adjusted));
816 return SubstitutionArgumentMappings (resolved_mappings,
817 mappings.get_binding_args (),
818 mappings.get_locus ());
821 SubstitutionArgumentMappings
822 SubstitutionRef::solve_missing_mappings_from_this (SubstitutionRef &ref,
823 SubstitutionRef &to)
825 rust_assert (!ref.needs_substitution ());
826 rust_assert (needs_substitution ());
827 rust_assert (get_num_substitutions () == ref.get_num_substitutions ());
829 Location locus = used_arguments.get_locus ();
830 std::vector<SubstitutionArg> resolved_mappings;
832 std::map<HirId, std::pair<ParamType *, BaseType *>> substs;
833 for (size_t i = 0; i < get_num_substitutions (); i++)
835 SubstitutionParamMapping &a = substitutions.at (i);
836 SubstitutionParamMapping &b = ref.substitutions.at (i);
838 if (a.need_substitution ())
840 const BaseType *root = a.get_param_ty ()->resolve ()->get_root ();
841 rust_assert (root->get_kind () == TyTy::TypeKind::PARAM);
842 const ParamType *p = static_cast<const TyTy::ParamType *> (root);
844 substs[p->get_ty_ref ()] = {static_cast<ParamType *> (p->clone ()),
845 b.get_param_ty ()->resolve ()};
849 for (auto it = substs.begin (); it != substs.end (); it++)
851 HirId param_id = it->first;
852 BaseType *arg = it->second.second;
854 const SubstitutionParamMapping *associate_param = nullptr;
855 for (SubstitutionParamMapping &p : to.substitutions)
857 if (p.get_param_ty ()->get_ty_ref () == param_id)
859 associate_param = &p;
860 break;
864 rust_assert (associate_param != nullptr);
865 SubstitutionArg argument (associate_param, arg);
866 resolved_mappings.push_back (std::move (argument));
869 return SubstitutionArgumentMappings (resolved_mappings, {}, locus);
872 bool
873 SubstitutionRef::monomorphize ()
875 auto context = Resolver::TypeCheckContext::get ();
876 for (const auto &subst : get_substs ())
878 const TyTy::ParamType *pty = subst.get_param_ty ();
880 if (!pty->can_resolve ())
881 continue;
883 const TyTy::BaseType *binding = pty->resolve ();
884 if (binding->get_kind () == TyTy::TypeKind::PARAM)
885 continue;
887 for (const auto &bound : pty->get_specified_bounds ())
889 const Resolver::TraitReference *specified_bound_ref = bound.get ();
891 // setup any associated type mappings for the specified bonds and this
892 // type
893 auto candidates = Resolver::TypeBoundsProbe::Probe (binding);
895 Resolver::AssociatedImplTrait *associated_impl_trait = nullptr;
896 for (auto &probed_bound : candidates)
898 const Resolver::TraitReference *bound_trait_ref
899 = probed_bound.first;
900 const HIR::ImplBlock *associated_impl = probed_bound.second;
902 HirId impl_block_id
903 = associated_impl->get_mappings ().get_hirid ();
904 Resolver::AssociatedImplTrait *associated = nullptr;
905 bool found_impl_trait
906 = context->lookup_associated_trait_impl (impl_block_id,
907 &associated);
908 if (found_impl_trait)
910 bool found_trait
911 = specified_bound_ref->is_equal (*bound_trait_ref);
912 bool found_self
913 = associated->get_self ()->can_eq (binding, false);
914 if (found_trait && found_self)
916 associated_impl_trait = associated;
917 break;
922 if (associated_impl_trait != nullptr)
924 associated_impl_trait->setup_associated_types (binding, bound);
929 return true;
932 } // namespace TyTy
933 } // namespace Rust