diff --git a/substrate/frame/support/procedural/src/derive_impl.rs b/substrate/frame/support/procedural/src/derive_impl.rs index 54755f1163a1..69117c026816 100644 --- a/substrate/frame/support/procedural/src/derive_impl.rs +++ b/substrate/frame/support/procedural/src/derive_impl.rs @@ -17,13 +17,13 @@ //! Implementation of the `derive_impl` attribute macro. -use derive_syn_parse::Parse; use macro_magic::mm_core::ForeignPath; use proc_macro2::TokenStream as TokenStream2; use quote::{quote, ToTokens}; use std::collections::HashSet; use syn::{ - parse2, parse_quote, spanned::Spanned, token, Ident, ImplItem, ItemImpl, Path, Result, Token, + parse2, parse_quote, spanned::Spanned, token, AngleBracketedGenericArguments, Ident, ImplItem, + ItemImpl, Path, PathArguments, PathSegment, Result, Token, }; mod keyword { @@ -56,18 +56,60 @@ fn is_runtime_type(item: &syn::ImplItemType) -> bool { false }) } - -#[derive(Parse, Debug)] pub struct DeriveImplAttrArgs { pub default_impl_path: Path, + pub generics: Option, _as: Option, - #[parse_if(_as.is_some())] pub disambiguation_path: Option, _comma: Option, - #[parse_if(_comma.is_some())] pub no_aggregated_types: Option, } +impl syn::parse::Parse for DeriveImplAttrArgs { + fn parse(input: syn::parse::ParseStream) -> Result { + let mut default_impl_path: Path = input.parse()?; + // Extract the generics if any + let (default_impl_path, generics) = match default_impl_path.clone().segments.last() { + Some(PathSegment { ident, arguments: PathArguments::AngleBracketed(args) }) => { + default_impl_path.segments.pop(); + default_impl_path + .segments + .push(PathSegment { ident: ident.clone(), arguments: PathArguments::None }); + (default_impl_path, Some(args.clone())) + }, + Some(PathSegment { arguments: PathArguments::None, .. }) => (default_impl_path, None), + _ => return Err(syn::Error::new(default_impl_path.span(), "Invalid default impl path")), + }; + + let lookahead = input.lookahead1(); + let (_as, disambiguation_path) = if lookahead.peek(Token![as]) { + let _as: Token![as] = input.parse()?; + let disambiguation_path: Path = input.parse()?; + (Some(_as), Some(disambiguation_path)) + } else { + (None, None) + }; + + let lookahead = input.lookahead1(); + let (_comma, no_aggregated_types) = if lookahead.peek(Token![,]) { + let _comma: Token![,] = input.parse()?; + let no_aggregated_types: keyword::no_aggregated_types = input.parse()?; + (Some(_comma), Some(no_aggregated_types)) + } else { + (None, None) + }; + + Ok(DeriveImplAttrArgs { + default_impl_path, + generics, + _as, + disambiguation_path, + _comma, + no_aggregated_types, + }) + } +} + impl ForeignPath for DeriveImplAttrArgs { fn foreign_path(&self) -> &Path { &self.default_impl_path @@ -77,6 +119,7 @@ impl ForeignPath for DeriveImplAttrArgs { impl ToTokens for DeriveImplAttrArgs { fn to_tokens(&self, tokens: &mut TokenStream2) { tokens.extend(self.default_impl_path.to_token_stream()); + tokens.extend(self.generics.to_token_stream()); tokens.extend(self._as.to_token_stream()); tokens.extend(self.disambiguation_path.to_token_stream()); tokens.extend(self._comma.to_token_stream()); @@ -117,6 +160,7 @@ fn combine_impls( default_impl_path: Path, disambiguation_path: Path, inject_runtime_types: bool, + generics: Option, ) -> ItemImpl { let (existing_local_keys, existing_unsupported_items): (HashSet, HashSet) = local_impl @@ -155,7 +199,7 @@ fn combine_impls( // modify and insert uncolliding type items let modified_item: ImplItem = parse_quote! { #( #cfg_attrs )* - type #ident = <#default_impl_path as #disambiguation_path>::#ident; + type #ident = <#default_impl_path #generics as #disambiguation_path>::#ident; }; return Some(modified_item) } @@ -216,6 +260,7 @@ pub fn derive_impl( local_tokens: TokenStream2, disambiguation_path: Option, no_aggregated_types: Option, + generics: Option, ) -> Result { let local_impl = parse2::(local_tokens)?; let foreign_impl = parse2::(foreign_tokens)?; @@ -234,6 +279,7 @@ pub fn derive_impl( default_impl_path, disambiguation_path, no_aggregated_types.is_none(), + generics, ); Ok(quote!(#combined_impl)) @@ -301,3 +347,16 @@ fn test_disambiguation_path() { compute_disambiguation_path(None, foreign_impl.clone(), parse_quote!(SomeType)); assert_eq!(disambiguation_path.unwrap(), parse_quote!(SomeTrait)); } + +#[test] +fn test_derive_impl_attr_args_parsing_with_generic() { + let args = parse2::(quote!( + some::path::TestDefaultConfig as some::path::DefaultConfig + )) + .unwrap(); + assert_eq!(args.default_impl_path, parse_quote!(some::path::TestDefaultConfig)); + assert_eq!(args.generics.unwrap().args[0], parse_quote!(Config)); + let args = parse2::(quote!(TestDefaultConfig)).unwrap(); + assert_eq!(args.default_impl_path, parse_quote!(TestDefaultConfig)); + assert_eq!(args.generics.unwrap().args[0], parse_quote!(Config2)); +} diff --git a/substrate/frame/support/procedural/src/lib.rs b/substrate/frame/support/procedural/src/lib.rs index 8554a5b830de..d40a571c9eab 100644 --- a/substrate/frame/support/procedural/src/lib.rs +++ b/substrate/frame/support/procedural/src/lib.rs @@ -683,6 +683,7 @@ pub fn derive_impl(attrs: TokenStream, input: TokenStream) -> TokenStream { input.into(), custom_attrs.disambiguation_path, custom_attrs.no_aggregated_types, + custom_attrs.generics, ) .unwrap_or_else(|r| r.into_compile_error()) .into()