Bug 1888590 - Mark some subtests on trusted-types-event-handlers.html as failing...
[gecko.git] / third_party / rust / async-trait / src / expand.rs
blob0a1ef1cd5a94778b05753879b7c9c1f3c68a8b30
1 use crate::bound::{has_bound, InferredBound, Supertraits};
2 use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3 use crate::parse::Item;
4 use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5 use crate::verbatim::VerbatimFn;
6 use proc_macro2::{Span, TokenStream};
7 use quote::{format_ident, quote, quote_spanned, ToTokens};
8 use std::collections::BTreeSet as Set;
9 use std::mem;
10 use syn::punctuated::Punctuated;
11 use syn::visit_mut::{self, VisitMut};
12 use syn::{
13     parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14     Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15     ReturnType, Signature, Token, TraitItem, Type, TypePath, WhereClause,
18 impl ToTokens for Item {
19     fn to_tokens(&self, tokens: &mut TokenStream) {
20         match self {
21             Item::Trait(item) => item.to_tokens(tokens),
22             Item::Impl(item) => item.to_tokens(tokens),
23         }
24     }
27 #[derive(Clone, Copy)]
28 enum Context<'a> {
29     Trait {
30         generics: &'a Generics,
31         supertraits: &'a Supertraits,
32     },
33     Impl {
34         impl_generics: &'a Generics,
35         associated_type_impl_traits: &'a Set<Ident>,
36     },
39 impl Context<'_> {
40     fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41         let generics = match self {
42             Context::Trait { generics, .. } => generics,
43             Context::Impl { impl_generics, .. } => impl_generics,
44         };
45         generics.params.iter().filter_map(move |param| {
46             if let GenericParam::Lifetime(param) = param {
47                 if used.contains(&param.lifetime) {
48                     return Some(param);
49                 }
50             }
51             None
52         })
53     }
56 pub fn expand(input: &mut Item, is_local: bool) {
57     match input {
58         Item::Trait(input) => {
59             let context = Context::Trait {
60                 generics: &input.generics,
61                 supertraits: &input.supertraits,
62             };
63             for inner in &mut input.items {
64                 if let TraitItem::Fn(method) = inner {
65                     let sig = &mut method.sig;
66                     if sig.asyncness.is_some() {
67                         let block = &mut method.default;
68                         let mut has_self = has_self_in_sig(sig);
69                         method.attrs.push(parse_quote!(#[must_use]));
70                         if let Some(block) = block {
71                             has_self |= has_self_in_block(block);
72                             transform_block(context, sig, block);
73                             method.attrs.push(lint_suppress_with_body());
74                         } else {
75                             method.attrs.push(lint_suppress_without_body());
76                         }
77                         let has_default = method.default.is_some();
78                         transform_sig(context, sig, has_self, has_default, is_local);
79                     }
80                 }
81             }
82         }
83         Item::Impl(input) => {
84             let mut associated_type_impl_traits = Set::new();
85             for inner in &input.items {
86                 if let ImplItem::Type(assoc) = inner {
87                     if let Type::ImplTrait(_) = assoc.ty {
88                         associated_type_impl_traits.insert(assoc.ident.clone());
89                     }
90                 }
91             }
93             let context = Context::Impl {
94                 impl_generics: &input.generics,
95                 associated_type_impl_traits: &associated_type_impl_traits,
96             };
97             for inner in &mut input.items {
98                 match inner {
99                     ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100                         let sig = &mut method.sig;
101                         let block = &mut method.block;
102                         let has_self = has_self_in_sig(sig) || has_self_in_block(block);
103                         transform_block(context, sig, block);
104                         transform_sig(context, sig, has_self, false, is_local);
105                         method.attrs.push(lint_suppress_with_body());
106                     }
107                     ImplItem::Verbatim(tokens) => {
108                         let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109                             Ok(method) if method.sig.asyncness.is_some() => method,
110                             _ => continue,
111                         };
112                         let sig = &mut method.sig;
113                         let has_self = has_self_in_sig(sig);
114                         transform_sig(context, sig, has_self, false, is_local);
115                         method.attrs.push(lint_suppress_with_body());
116                         *tokens = quote!(#method);
117                     }
118                     _ => {}
119                 }
120             }
121         }
122     }
125 fn lint_suppress_with_body() -> Attribute {
126     parse_quote! {
127         #[allow(
128             clippy::async_yields_async,
129             clippy::let_unit_value,
130             clippy::no_effect_underscore_binding,
131             clippy::shadow_same,
132             clippy::type_complexity,
133             clippy::type_repetition_in_bounds,
134             clippy::used_underscore_binding
135         )]
136     }
139 fn lint_suppress_without_body() -> Attribute {
140     parse_quote! {
141         #[allow(
142             clippy::type_complexity,
143             clippy::type_repetition_in_bounds
144         )]
145     }
148 // Input:
149 //     async fn f<T>(&self, x: &T) -> Ret;
151 // Output:
152 //     fn f<'life0, 'life1, 'async_trait, T>(
153 //         &'life0 self,
154 //         x: &'life1 T,
155 //     ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
156 //     where
157 //         'life0: 'async_trait,
158 //         'life1: 'async_trait,
159 //         T: 'async_trait,
160 //         Self: Sync + 'async_trait;
161 fn transform_sig(
162     context: Context,
163     sig: &mut Signature,
164     has_self: bool,
165     has_default: bool,
166     is_local: bool,
167 ) {
168     let default_span = sig.asyncness.take().unwrap().span;
169     sig.fn_token.span = default_span;
171     let (ret_arrow, ret) = match &sig.output {
172         ReturnType::Default => (Token![->](default_span), quote_spanned!(default_span=> ())),
173         ReturnType::Type(arrow, ret) => (*arrow, quote!(#ret)),
174     };
176     let mut lifetimes = CollectLifetimes::new();
177     for arg in sig.inputs.iter_mut() {
178         match arg {
179             FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
180             FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
181         }
182     }
184     for param in &mut sig.generics.params {
185         match param {
186             GenericParam::Type(param) => {
187                 let param_name = &param.ident;
188                 let span = match param.colon_token.take() {
189                     Some(colon_token) => colon_token.span,
190                     None => param_name.span(),
191                 };
192                 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
193                 where_clause_or_default(&mut sig.generics.where_clause)
194                     .predicates
195                     .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
196             }
197             GenericParam::Lifetime(param) => {
198                 let param_name = &param.lifetime;
199                 let span = match param.colon_token.take() {
200                     Some(colon_token) => colon_token.span,
201                     None => param_name.span(),
202                 };
203                 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
204                 where_clause_or_default(&mut sig.generics.where_clause)
205                     .predicates
206                     .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
207             }
208             GenericParam::Const(_) => {}
209         }
210     }
212     for param in context.lifetimes(&lifetimes.explicit) {
213         let param = &param.lifetime;
214         let span = param.span();
215         where_clause_or_default(&mut sig.generics.where_clause)
216             .predicates
217             .push(parse_quote_spanned!(span=> #param: 'async_trait));
218     }
220     if sig.generics.lt_token.is_none() {
221         sig.generics.lt_token = Some(Token![<](sig.ident.span()));
222     }
223     if sig.generics.gt_token.is_none() {
224         sig.generics.gt_token = Some(Token![>](sig.paren_token.span.join()));
225     }
227     for elided in lifetimes.elided {
228         sig.generics.params.push(parse_quote!(#elided));
229         where_clause_or_default(&mut sig.generics.where_clause)
230             .predicates
231             .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
232     }
234     sig.generics
235         .params
236         .push(parse_quote_spanned!(default_span=> 'async_trait));
238     if has_self {
239         let bounds: &[InferredBound] = if let Some(receiver) = sig.receiver() {
240             match receiver.ty.as_ref() {
241                 // self: &Self
242                 Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
243                 // self: Arc<Self>
244                 Type::Path(ty)
245                     if {
246                         let segment = ty.path.segments.last().unwrap();
247                         segment.ident == "Arc"
248                             && match &segment.arguments {
249                                 PathArguments::AngleBracketed(arguments) => {
250                                     arguments.args.len() == 1
251                                         && match &arguments.args[0] {
252                                             GenericArgument::Type(Type::Path(arg)) => {
253                                                 arg.path.is_ident("Self")
254                                             }
255                                             _ => false,
256                                         }
257                                 }
258                                 _ => false,
259                             }
260                     } =>
261                 {
262                     &[InferredBound::Sync, InferredBound::Send]
263                 }
264                 _ => &[InferredBound::Send],
265             }
266         } else {
267             &[InferredBound::Send]
268         };
270         let bounds = bounds.iter().filter_map(|bound| {
271             let assume_bound = match context {
272                 Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, bound),
273                 Context::Impl { .. } => true,
274             };
275             if assume_bound || is_local {
276                 None
277             } else {
278                 Some(bound.spanned_path(default_span))
279             }
280         });
282         where_clause_or_default(&mut sig.generics.where_clause)
283             .predicates
284             .push(parse_quote_spanned! {default_span=>
285                 Self: #(#bounds +)* 'async_trait
286             });
287     }
289     for (i, arg) in sig.inputs.iter_mut().enumerate() {
290         match arg {
291             FnArg::Receiver(receiver) => {
292                 if receiver.reference.is_none() {
293                     receiver.mutability = None;
294                 }
295             }
296             FnArg::Typed(arg) => {
297                 if match *arg.ty {
298                     Type::Reference(_) => false,
299                     _ => true,
300                 } {
301                     if let Pat::Ident(pat) = &mut *arg.pat {
302                         pat.by_ref = None;
303                         pat.mutability = None;
304                     } else {
305                         let positional = positional_arg(i, &arg.pat);
306                         let m = mut_pat(&mut arg.pat);
307                         arg.pat = parse_quote!(#m #positional);
308                     }
309                 }
310                 AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
311             }
312         }
313     }
315     let bounds = if is_local {
316         quote_spanned!(default_span=> 'async_trait)
317     } else {
318         quote_spanned!(default_span=> ::core::marker::Send + 'async_trait)
319     };
320     sig.output = parse_quote_spanned! {default_span=>
321         #ret_arrow ::core::pin::Pin<Box<
322             dyn ::core::future::Future<Output = #ret> + #bounds
323         >>
324     };
327 // Input:
328 //     async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
329 //         self + x + a + b
330 //     }
332 // Output:
333 //     Box::pin(async move {
334 //         let ___ret: Ret = {
335 //             let __self = self;
336 //             let x = x;
337 //             let (a, b) = __arg1;
339 //             __self + x + a + b
340 //         };
342 //         ___ret
343 //     })
344 fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
345     let mut self_span = None;
346     let decls = sig
347         .inputs
348         .iter()
349         .enumerate()
350         .map(|(i, arg)| match arg {
351             FnArg::Receiver(Receiver {
352                 self_token,
353                 mutability,
354                 ..
355             }) => {
356                 let ident = Ident::new("__self", self_token.span);
357                 self_span = Some(self_token.span);
358                 quote!(let #mutability #ident = #self_token;)
359             }
360             FnArg::Typed(arg) => {
361                 // If there is a #[cfg(...)] attribute that selectively enables
362                 // the parameter, forward it to the variable.
363                 //
364                 // This is currently not applied to the `self` parameter.
365                 let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
367                 if let Type::Reference(_) = *arg.ty {
368                     quote!()
369                 } else if let Pat::Ident(PatIdent {
370                     ident, mutability, ..
371                 }) = &*arg.pat
372                 {
373                     quote! {
374                         #(#attrs)*
375                         let #mutability #ident = #ident;
376                     }
377                 } else {
378                     let pat = &arg.pat;
379                     let ident = positional_arg(i, pat);
380                     if let Pat::Wild(_) = **pat {
381                         quote! {
382                             #(#attrs)*
383                             let #ident = #ident;
384                         }
385                     } else {
386                         quote! {
387                             #(#attrs)*
388                             let #pat = {
389                                 let #ident = #ident;
390                                 #ident
391                             };
392                         }
393                     }
394                 }
395             }
396         })
397         .collect::<Vec<_>>();
399     if let Some(span) = self_span {
400         let mut replace_self = ReplaceSelf(span);
401         replace_self.visit_block_mut(block);
402     }
404     let stmts = &block.stmts;
405     let let_ret = match &mut sig.output {
406         ReturnType::Default => quote_spanned! {block.brace_token.span=>
407             #(#decls)*
408             let _: () = { #(#stmts)* };
409         },
410         ReturnType::Type(_, ret) => {
411             if contains_associated_type_impl_trait(context, ret) {
412                 if decls.is_empty() {
413                     quote!(#(#stmts)*)
414                 } else {
415                     quote!(#(#decls)* { #(#stmts)* })
416                 }
417             } else {
418                 quote_spanned! {block.brace_token.span=>
419                     if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
420                         return __ret;
421                     }
422                     #(#decls)*
423                     let __ret: #ret = { #(#stmts)* };
424                     #[allow(unreachable_code)]
425                     __ret
426                 }
427             }
428         }
429     };
430     let box_pin = quote_spanned!(block.brace_token.span=>
431         Box::pin(async move { #let_ret })
432     );
433     block.stmts = parse_quote!(#box_pin);
436 fn positional_arg(i: usize, pat: &Pat) -> Ident {
437     let span: Span = syn::spanned::Spanned::span(pat);
438     #[cfg(not(no_span_mixed_site))]
439     let span = span.resolved_at(Span::mixed_site());
440     format_ident!("__arg{}", i, span = span)
443 fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
444     struct AssociatedTypeImplTraits<'a> {
445         set: &'a Set<Ident>,
446         contains: bool,
447     }
449     impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
450         fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
451             if ty.qself.is_none()
452                 && ty.path.segments.len() == 2
453                 && ty.path.segments[0].ident == "Self"
454                 && self.set.contains(&ty.path.segments[1].ident)
455             {
456                 self.contains = true;
457             }
458             visit_mut::visit_type_path_mut(self, ty);
459         }
460     }
462     match context {
463         Context::Trait { .. } => false,
464         Context::Impl {
465             associated_type_impl_traits,
466             ..
467         } => {
468             let mut visit = AssociatedTypeImplTraits {
469                 set: associated_type_impl_traits,
470                 contains: false,
471             };
472             visit.visit_type_mut(ret);
473             visit.contains
474         }
475     }
478 fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
479     clause.get_or_insert_with(|| WhereClause {
480         where_token: Default::default(),
481         predicates: Punctuated::new(),
482     })