Import parity-scale-codec-derive 3.1.4 upstream upstream/3.1.4
authorDongHun Kwak <dh0128.kwak@samsung.com>
Thu, 23 Mar 2023 06:07:00 +0000 (15:07 +0900)
committerDongHun Kwak <dh0128.kwak@samsung.com>
Thu, 23 Mar 2023 06:07:00 +0000 (15:07 +0900)
.cargo_vcs_info.json [new file with mode: 0644]
Cargo.toml [new file with mode: 0644]
Cargo.toml.orig [new file with mode: 0644]
src/decode.rs [new file with mode: 0644]
src/encode.rs [new file with mode: 0644]
src/lib.rs [new file with mode: 0644]
src/max_encoded_len.rs [new file with mode: 0644]
src/trait_bounds.rs [new file with mode: 0644]
src/utils.rs [new file with mode: 0644]

diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
new file mode 100644 (file)
index 0000000..5577d12
--- /dev/null
@@ -0,0 +1,6 @@
+{
+  "git": {
+    "sha1": "ef9f79438e9e51912b4445c6290de142b5d91e97"
+  },
+  "path_in_vcs": "derive"
+}
\ No newline at end of file
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644 (file)
index 0000000..0063905
--- /dev/null
@@ -0,0 +1,43 @@
+# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
+#
+# When uploading crates to the registry Cargo will automatically
+# "normalize" Cargo.toml files for maximal compatibility
+# with all versions of Cargo and also rewrite `path` dependencies
+# to registry (e.g., crates.io) dependencies.
+#
+# If you are reading this file be aware that the original Cargo.toml
+# will likely look very different (and much more reasonable).
+# See Cargo.toml.orig for the original contents.
+
+[package]
+edition = "2021"
+rust-version = "1.56.1"
+name = "parity-scale-codec-derive"
+version = "3.1.4"
+authors = ["Parity Technologies <admin@parity.io>"]
+description = "Serialization and deserialization derive macro for Parity SCALE Codec"
+license = "Apache-2.0"
+
+[lib]
+proc-macro = true
+
+[dependencies.proc-macro-crate]
+version = "1.0.0"
+
+[dependencies.proc-macro2]
+version = "1.0.6"
+
+[dependencies.quote]
+version = "1.0.2"
+
+[dependencies.syn]
+version = "1.0.98"
+features = [
+    "full",
+    "visit",
+]
+
+[dev-dependencies]
+
+[features]
+max-encoded-len = []
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
new file mode 100644 (file)
index 0000000..6fb552b
--- /dev/null
@@ -0,0 +1,26 @@
+[package]
+name = "parity-scale-codec-derive"
+description = "Serialization and deserialization derive macro for Parity SCALE Codec"
+version = "3.1.4"
+authors = ["Parity Technologies <admin@parity.io>"]
+license = "Apache-2.0"
+edition = "2021"
+rust-version = "1.56.1"
+
+[lib]
+proc-macro = true
+
+[dependencies]
+syn = { version = "1.0.98", features = ["full", "visit"] }
+quote = "1.0.2"
+proc-macro2 = "1.0.6"
+proc-macro-crate = "1.0.0"
+
+[dev-dependencies]
+parity-scale-codec = { path = "..", features = ["max-encoded-len"] }
+
+[features]
+# Enables the new `MaxEncodedLen` trait.
+# NOTE: This is still considered experimental and is exempt from the usual
+# SemVer guarantees. We do not guarantee no code breakage when using this.
+max-encoded-len = []
diff --git a/src/decode.rs b/src/decode.rs
new file mode 100644 (file)
index 0000000..df60390
--- /dev/null
@@ -0,0 +1,206 @@
+// Copyright 2017, 2018 Parity Technologies
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use proc_macro2::{Span, TokenStream, Ident};
+use syn::{
+       spanned::Spanned,
+       Data, Fields, Field, Error,
+};
+
+use crate::utils;
+
+/// Generate function block for function `Decode::decode`.
+///
+/// * data: data info of the type,
+/// * type_name: name of the type,
+/// * type_generics: the generics of the type in turbofish format, without bounds, e.g. `::<T, I>`
+/// * input: the variable name for the argument of function `decode`.
+pub fn quote(
+       data: &Data,
+       type_name: &Ident,
+       type_generics: &TokenStream,
+       input: &TokenStream,
+       crate_path: &syn::Path,
+) -> TokenStream {
+       match *data {
+               Data::Struct(ref data) => match data.fields {
+                       Fields::Named(_) | Fields::Unnamed(_) => create_instance(
+                               quote! { #type_name #type_generics },
+                               &type_name.to_string(),
+                               input,
+                               &data.fields,
+                               crate_path,
+                       ),
+                       Fields::Unit => {
+                               quote_spanned! { data.fields.span() =>
+                                       ::core::result::Result::Ok(#type_name)
+                               }
+                       },
+               },
+               Data::Enum(ref data) => {
+                       let data_variants = || data.variants.iter().filter(|variant| !utils::should_skip(&variant.attrs));
+
+                       if data_variants().count() > 256 {
+                               return Error::new(
+                                       data.variants.span(),
+                                       "Currently only enums with at most 256 variants are encodable."
+                               ).to_compile_error();
+                       }
+
+                       let recurse = data_variants().enumerate().map(|(i, v)| {
+                               let name = &v.ident;
+                               let index = utils::variant_index(v, i);
+
+                               let create = create_instance(
+                                       quote! { #type_name #type_generics :: #name },
+                                       &format!("{}::{}", type_name, name),
+                                       input,
+                                       &v.fields,
+                                       crate_path,
+                               );
+
+                               quote_spanned! { v.span() =>
+                                       __codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => {
+                                               #create
+                                       },
+                               }
+                       });
+
+                       let read_byte_err_msg = format!(
+                               "Could not decode `{}`, failed to read variant byte",
+                               type_name,
+                       );
+                       let invalid_variant_err_msg = format!(
+                               "Could not decode `{}`, variant doesn't exist",
+                               type_name,
+                       );
+                       quote! {
+                               match #input.read_byte()
+                                       .map_err(|e| e.chain(#read_byte_err_msg))?
+                               {
+                                       #( #recurse )*
+                                       _ => ::core::result::Result::Err(
+                                               <_ as ::core::convert::Into<_>>::into(#invalid_variant_err_msg)
+                                       ),
+                               }
+                       }
+
+               },
+               Data::Union(_) => Error::new(Span::call_site(), "Union types are not supported.").to_compile_error(),
+       }
+}
+
+fn create_decode_expr(field: &Field, name: &str, input: &TokenStream, crate_path: &syn::Path) -> TokenStream {
+       let encoded_as = utils::get_encoded_as_type(field);
+       let compact = utils::is_compact(field);
+       let skip = utils::should_skip(&field.attrs);
+
+       let res = quote!(__codec_res_edqy);
+
+       if encoded_as.is_some() as u8 + compact as u8 + skip as u8 > 1 {
+               return Error::new(
+                       field.span(),
+                       "`encoded_as`, `compact` and `skip` can only be used one at a time!"
+               ).to_compile_error();
+       }
+
+       let err_msg = format!("Could not decode `{}`", name);
+
+       if compact {
+               let field_type = &field.ty;
+               quote_spanned! { field.span() =>
+                       {
+                               let #res = <
+                                       <#field_type as #crate_path::HasCompact>::Type as #crate_path::Decode
+                               >::decode(#input);
+                               match #res {
+                                       ::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)),
+                                       ::core::result::Result::Ok(#res) => #res.into(),
+                               }
+                       }
+               }
+       } else if let Some(encoded_as) = encoded_as {
+               quote_spanned! { field.span() =>
+                       {
+                               let #res = <#encoded_as as #crate_path::Decode>::decode(#input);
+                               match #res {
+                                       ::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)),
+                                       ::core::result::Result::Ok(#res) => #res.into(),
+                               }
+                       }
+               }
+       } else if skip {
+               quote_spanned! { field.span() => ::core::default::Default::default() }
+       } else {
+               let field_type = &field.ty;
+               quote_spanned! { field.span() =>
+                       {
+                               let #res = <#field_type as #crate_path::Decode>::decode(#input);
+                               match #res {
+                                       ::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)),
+                                       ::core::result::Result::Ok(#res) => #res,
+                               }
+                       }
+               }
+       }
+}
+
+fn create_instance(
+       name: TokenStream,
+       name_str: &str,
+       input: &TokenStream,
+       fields: &Fields,
+       crate_path: &syn::Path,
+) -> TokenStream {
+       match *fields {
+               Fields::Named(ref fields) => {
+                       let recurse = fields.named.iter().map(|f| {
+                               let name_ident = &f.ident;
+                               let field_name = match name_ident {
+                                       Some(a) => format!("{}::{}", name_str, a),
+                                       None => format!("{}", name_str), // Should never happen, fields are named.
+                               };
+                               let decode = create_decode_expr(f, &field_name, input, crate_path);
+
+                               quote_spanned! { f.span() =>
+                                       #name_ident: #decode
+                               }
+                       });
+
+                       quote_spanned! { fields.span() =>
+                               ::core::result::Result::Ok(#name {
+                                       #( #recurse, )*
+                               })
+                       }
+               },
+               Fields::Unnamed(ref fields) => {
+                       let recurse = fields.unnamed.iter().enumerate().map(|(i, f) | {
+                               let field_name = format!("{}.{}", name_str, i);
+
+                               create_decode_expr(f, &field_name, input, crate_path)
+                       });
+
+                       quote_spanned! { fields.span() =>
+                               ::core::result::Result::Ok(#name (
+                                       #( #recurse, )*
+                               ))
+                       }
+               },
+               Fields::Unit => {
+                       quote_spanned! { fields.span() =>
+                               ::core::result::Result::Ok(#name)
+                       }
+               },
+       }
+}
diff --git a/src/encode.rs b/src/encode.rs
new file mode 100644 (file)
index 0000000..b2badc1
--- /dev/null
@@ -0,0 +1,329 @@
+// Copyright 2017, 2018 Parity Technologies
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::str::from_utf8;
+
+use proc_macro2::{Ident, Span, TokenStream};
+use syn::{
+       punctuated::Punctuated,
+       spanned::Spanned,
+       token::Comma,
+       Data, Field, Fields, Error,
+};
+
+use crate::utils;
+
+type FieldsList = Punctuated<Field, Comma>;
+
+// Encode a signle field by using using_encoded, must not have skip attribute
+fn encode_single_field(
+       field: &Field,
+       field_name: TokenStream,
+       crate_path: &syn::Path,
+) -> TokenStream {
+       let encoded_as = utils::get_encoded_as_type(field);
+       let compact = utils::is_compact(field);
+
+       if utils::should_skip(&field.attrs) {
+               return Error::new(
+                       Span::call_site(),
+                       "Internal error: cannot encode single field optimisation if skipped"
+               ).to_compile_error();
+       }
+
+       if encoded_as.is_some() && compact {
+               return Error::new(
+                       Span::call_site(),
+                       "`encoded_as` and `compact` can not be used at the same time!"
+               ).to_compile_error();
+       }
+
+       let final_field_variable = if compact {
+               let field_type = &field.ty;
+               quote_spanned! {
+                       field.span() => {
+                               <<#field_type as #crate_path::HasCompact>::Type as
+                               #crate_path::EncodeAsRef<'_, #field_type>>::RefType::from(#field_name)
+                       }
+               }
+       } else if let Some(encoded_as) = encoded_as {
+               let field_type = &field.ty;
+               quote_spanned! {
+                       field.span() => {
+                               <#encoded_as as
+                               #crate_path::EncodeAsRef<'_, #field_type>>::RefType::from(#field_name)
+                       }
+               }
+       } else {
+               quote_spanned! { field.span() =>
+                       #field_name
+               }
+       };
+
+       // This may have different hygiene than the field span
+       let i_self = quote! { self };
+
+       quote_spanned! { field.span() =>
+                       fn encode_to<__CodecOutputEdqy: #crate_path::Output + ?::core::marker::Sized>(
+                               &#i_self,
+                               __codec_dest_edqy: &mut __CodecOutputEdqy
+                       ) {
+                               #crate_path::Encode::encode_to(&#final_field_variable, __codec_dest_edqy)
+                       }
+
+                       fn encode(&#i_self) -> #crate_path::alloc::vec::Vec<::core::primitive::u8> {
+                               #crate_path::Encode::encode(&#final_field_variable)
+                       }
+
+                       fn using_encoded<R, F: ::core::ops::FnOnce(&[::core::primitive::u8]) -> R>(&#i_self, f: F) -> R {
+                               #crate_path::Encode::using_encoded(&#final_field_variable, f)
+                       }
+       }
+}
+
+fn encode_fields<F>(
+       dest: &TokenStream,
+       fields: &FieldsList,
+       field_name: F,
+       crate_path: &syn::Path,
+) -> TokenStream where
+       F: Fn(usize, &Option<Ident>) -> TokenStream,
+{
+       let recurse = fields.iter().enumerate().map(|(i, f)| {
+               let field = field_name(i, &f.ident);
+               let encoded_as = utils::get_encoded_as_type(f);
+               let compact = utils::is_compact(f);
+               let skip = utils::should_skip(&f.attrs);
+
+               if encoded_as.is_some() as u8 + compact as u8 + skip as u8 > 1 {
+                       return Error::new(
+                               f.span(),
+                               "`encoded_as`, `compact` and `skip` can only be used one at a time!"
+                       ).to_compile_error();
+               }
+
+               // Based on the seen attribute, we generate the code that encodes the field.
+               // We call `push` from the `Output` trait on `dest`.
+               if compact {
+                       let field_type = &f.ty;
+                       quote_spanned! {
+                               f.span() => {
+                                       #crate_path::Encode::encode_to(
+                                               &<
+                                                       <#field_type as #crate_path::HasCompact>::Type as
+                                                       #crate_path::EncodeAsRef<'_, #field_type>
+                                               >::RefType::from(#field),
+                                               #dest,
+                                       );
+                               }
+                       }
+               } else if let Some(encoded_as) = encoded_as {
+                       let field_type = &f.ty;
+                       quote_spanned! {
+                               f.span() => {
+                                       #crate_path::Encode::encode_to(
+                                               &<
+                                                       #encoded_as as
+                                                       #crate_path::EncodeAsRef<'_, #field_type>
+                                               >::RefType::from(#field),
+                                               #dest,
+                                       );
+                               }
+                       }
+               } else if skip {
+                       quote! {
+                               let _ = #field;
+                       }
+               } else {
+                       quote_spanned! { f.span() =>
+                               #crate_path::Encode::encode_to(#field, #dest);
+                       }
+               }
+       });
+
+       quote! {
+               #( #recurse )*
+       }
+}
+
+fn try_impl_encode_single_field_optimisation(data: &Data, crate_path: &syn::Path) -> Option<TokenStream> {
+       match *data {
+               Data::Struct(ref data) => {
+                       match data.fields {
+                               Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => {
+                                       let field = utils::filter_skip_named(fields).next().unwrap();
+                                       let name = &field.ident;
+                                       Some(encode_single_field(
+                                               field,
+                                               quote!(&self.#name),
+                                               crate_path,
+                                       ))
+                               },
+                               Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
+                                       let (id, field) = utils::filter_skip_unnamed(fields).next().unwrap();
+                                       let id = syn::Index::from(id);
+
+                                       Some(encode_single_field(
+                                               field,
+                                               quote!(&self.#id),
+                                               crate_path,
+                                       ))
+                               },
+                               _ => None,
+                       }
+               },
+               _ => None,
+       }
+}
+
+fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenStream {
+       let self_ = quote!(self);
+       let dest = &quote!(__codec_dest_edqy);
+       let encoding = match *data {
+               Data::Struct(ref data) => {
+                       match data.fields {
+                               Fields::Named(ref fields) => encode_fields(
+                                       dest,
+                                       &fields.named,
+                                       |_, name| quote!(&#self_.#name),
+                                       crate_path,
+                               ),
+                               Fields::Unnamed(ref fields) => encode_fields(
+                                       dest,
+                                       &fields.unnamed,
+                                       |i, _| {
+                                               let i = syn::Index::from(i);
+                                               quote!(&#self_.#i)
+                                       },
+                                       crate_path,
+                               ),
+                               Fields::Unit => quote!(),
+                       }
+               },
+               Data::Enum(ref data) => {
+                       let data_variants = || data.variants.iter().filter(|variant| !utils::should_skip(&variant.attrs));
+
+                       if data_variants().count() > 256 {
+                               return Error::new(
+                                       data.variants.span(),
+                                       "Currently only enums with at most 256 variants are encodable."
+                               ).to_compile_error();
+                       }
+
+                       // If the enum has no variants, we don't need to encode anything.
+                       if data_variants().count() == 0 {
+                               return quote!();
+                       }
+
+                       let recurse = data_variants().enumerate().map(|(i, f)| {
+                               let name = &f.ident;
+                               let index = utils::variant_index(f, i);
+
+                               match f.fields {
+                                       Fields::Named(ref fields) => {
+                                               let field_name = |_, ident: &Option<Ident>| quote!(#ident);
+                                               let names = fields.named
+                                                       .iter()
+                                                       .enumerate()
+                                                       .map(|(i, f)| field_name(i, &f.ident));
+
+                                               let encode_fields = encode_fields(
+                                                       dest,
+                                                       &fields.named,
+                                                       |a, b| field_name(a, b),
+                                                       crate_path,
+                                               );
+
+                                               quote_spanned! { f.span() =>
+                                                       #type_name :: #name { #( ref #names, )* } => {
+                                                               #dest.push_byte(#index as ::core::primitive::u8);
+                                                               #encode_fields
+                                                       }
+                                               }
+                                       },
+                                       Fields::Unnamed(ref fields) => {
+                                               let field_name = |i, _: &Option<Ident>| {
+                                                       let data = stringify(i as u8);
+                                                       let ident = from_utf8(&data).expect("We never go beyond ASCII");
+                                                       let ident = Ident::new(ident, Span::call_site());
+                                                       quote!(#ident)
+                                               };
+                                               let names = fields.unnamed
+                                                       .iter()
+                                                       .enumerate()
+                                                       .map(|(i, f)| field_name(i, &f.ident));
+
+                                               let encode_fields = encode_fields(
+                                                       dest,
+                                                       &fields.unnamed,
+                                                       |a, b| field_name(a, b),
+                                                       crate_path,
+                                               );
+
+                                               quote_spanned! { f.span() =>
+                                                       #type_name :: #name ( #( ref #names, )* ) => {
+                                                               #dest.push_byte(#index as ::core::primitive::u8);
+                                                               #encode_fields
+                                                       }
+                                               }
+                                       },
+                                       Fields::Unit => {
+                                               quote_spanned! { f.span() =>
+                                                       #type_name :: #name => {
+                                                               #dest.push_byte(#index as ::core::primitive::u8);
+                                                       }
+                                               }
+                                       },
+                               }
+                       });
+
+                       quote! {
+                               match *#self_ {
+                                       #( #recurse )*,
+                                       _ => (),
+                               }
+                       }
+               },
+               Data::Union(ref data) => Error::new(
+                       data.union_token.span(),
+                       "Union types are not supported."
+               ).to_compile_error(),
+       };
+       quote! {
+               fn encode_to<__CodecOutputEdqy: #crate_path::Output + ?::core::marker::Sized>(
+                       &#self_,
+                       #dest: &mut __CodecOutputEdqy
+               ) {
+                       #encoding
+               }
+       }
+}
+
+pub fn quote(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenStream {
+       if let Some(implementation) = try_impl_encode_single_field_optimisation(data, crate_path) {
+               implementation
+       } else {
+               impl_encode(data, type_name, crate_path)
+       }
+}
+
+pub fn stringify(id: u8) -> [u8; 2] {
+       const CHARS: &[u8] = b"abcdefghijklmnopqrstuvwxyz";
+       let len = CHARS.len() as u8;
+       let symbol = |id: u8| CHARS[(id % len) as usize];
+       let a = symbol(id);
+       let b = symbol(id / len);
+
+       [a, b]
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644 (file)
index 0000000..9d80304
--- /dev/null
@@ -0,0 +1,361 @@
+// Copyright 2017-2021 Parity Technologies
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Derives serialization and deserialization codec for complex structs for simple marshalling.
+
+#![recursion_limit = "128"]
+extern crate proc_macro;
+
+#[macro_use]
+extern crate syn;
+
+#[macro_use]
+extern crate quote;
+
+use crate::utils::{codec_crate_path, is_lint_attribute};
+use syn::{spanned::Spanned, Data, DeriveInput, Error, Field, Fields};
+
+mod decode;
+mod encode;
+mod max_encoded_len;
+mod trait_bounds;
+mod utils;
+
+/// Wraps the impl block in a "dummy const"
+fn wrap_with_dummy_const(
+       input: DeriveInput,
+       impl_block: proc_macro2::TokenStream,
+) -> proc_macro::TokenStream {
+       let attrs = input.attrs.into_iter().filter(is_lint_attribute);
+       let generated = quote! {
+               #[allow(deprecated)]
+               const _: () = {
+                       #(#attrs)*
+                       #impl_block
+               };
+       };
+
+       generated.into()
+}
+
+/// Derive `parity_scale_codec::Encode` and `parity_scale_codec::EncodeLike` for struct and enum.
+///
+/// # Top level attributes
+///
+/// By default the macro will add [`Encode`] and [`Decode`] bounds to all types, but the bounds can
+/// be specified manually with the top level attributes:
+/// * `#[codec(encode_bound(T: Encode))]`: a custom bound added to the `where`-clause when deriving
+///   the `Encode` trait, overriding the default.
+/// * `#[codec(decode_bound(T: Decode))]`: a custom bound added to the `where`-clause when deriving
+///   the `Decode` trait, overriding the default.
+///
+/// # Struct
+///
+/// A struct is encoded by encoding each of its fields successively.
+///
+/// Fields can have some attributes:
+/// * `#[codec(skip)]`: the field is not encoded. It must derive `Default` if Decode is derived.
+/// * `#[codec(compact)]`: the field is encoded in its compact representation i.e. the field must
+///   implement `parity_scale_codec::HasCompact` and will be encoded as `HasCompact::Type`.
+/// * `#[codec(encoded_as = "$EncodeAs")]`: the field is encoded as an alternative type. $EncodedAs
+///   type must implement `parity_scale_codec::EncodeAsRef<'_, $FieldType>` with $FieldType the type
+///   of the field with the attribute. This is intended to be used for types implementing
+///   `HasCompact` as shown in the example.
+///
+/// ```
+/// # use parity_scale_codec_derive::Encode;
+/// # use parity_scale_codec::{Encode as _, HasCompact};
+/// #[derive(Encode)]
+/// struct StructType {
+///            #[codec(skip)]
+///            a: u32,
+///            #[codec(compact)]
+///            b: u32,
+///            #[codec(encoded_as = "<u32 as HasCompact>::Type")]
+///            c: u32,
+/// }
+/// ```
+///
+/// # Enum
+///
+/// The variable is encoded with one byte for the variant and then the variant struct encoding.
+/// The variant number is:
+/// * if variant has attribute: `#[codec(index = "$n")]` then n
+/// * else if variant has discrimant (like 3 in `enum T { A = 3 }`) then the discrimant.
+/// * else its position in the variant set, excluding skipped variants, but including variant with
+/// discrimant or attribute. Warning this position does collision with discrimant or attribute
+/// index.
+///
+/// variant attributes:
+/// * `#[codec(skip)]`: the variant is not encoded.
+/// * `#[codec(index = "$n")]`: override variant index.
+///
+/// field attributes: same as struct fields attributes.
+///
+/// ```
+/// # use parity_scale_codec_derive::Encode;
+/// # use parity_scale_codec::Encode as _;
+/// #[derive(Encode)]
+/// enum EnumType {
+///    #[codec(index = 15)]
+///    A,
+///    #[codec(skip)]
+///    B,
+///    C = 3,
+///    D,
+/// }
+///
+/// assert_eq!(EnumType::A.encode(), vec![15]);
+/// assert_eq!(EnumType::B.encode(), vec![]);
+/// assert_eq!(EnumType::C.encode(), vec![3]);
+/// assert_eq!(EnumType::D.encode(), vec![2]);
+/// ```
+#[proc_macro_derive(Encode, attributes(codec))]
+pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+       let mut input: DeriveInput = match syn::parse(input) {
+               Ok(input) => input,
+               Err(e) => return e.to_compile_error().into(),
+       };
+
+       if let Err(e) = utils::check_attributes(&input) {
+               return e.to_compile_error().into()
+       }
+
+       let crate_path = match codec_crate_path(&input.attrs) {
+               Ok(crate_path) => crate_path,
+               Err(error) => return error.into_compile_error().into(),
+       };
+
+       if let Err(e) = trait_bounds::add(
+               &input.ident,
+               &mut input.generics,
+               &input.data,
+               utils::custom_encode_trait_bound(&input.attrs),
+               parse_quote!(#crate_path::Encode),
+               None,
+               utils::has_dumb_trait_bound(&input.attrs),
+               &crate_path,
+       ) {
+               return e.to_compile_error().into()
+       }
+
+       let name = &input.ident;
+       let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+
+       let encode_impl = encode::quote(&input.data, name, &crate_path);
+
+       let impl_block = quote! {
+               #[automatically_derived]
+               impl #impl_generics #crate_path::Encode for #name #ty_generics #where_clause {
+                       #encode_impl
+               }
+
+               #[automatically_derived]
+               impl #impl_generics #crate_path::EncodeLike for #name #ty_generics #where_clause {}
+       };
+
+       wrap_with_dummy_const(input, impl_block)
+}
+
+/// Derive `parity_scale_codec::Decode` and for struct and enum.
+///
+/// see derive `Encode` documentation.
+#[proc_macro_derive(Decode, attributes(codec))]
+pub fn decode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+       let mut input: DeriveInput = match syn::parse(input) {
+               Ok(input) => input,
+               Err(e) => return e.to_compile_error().into(),
+       };
+
+       if let Err(e) = utils::check_attributes(&input) {
+               return e.to_compile_error().into()
+       }
+
+       let crate_path = match codec_crate_path(&input.attrs) {
+               Ok(crate_path) => crate_path,
+               Err(error) => return error.into_compile_error().into(),
+       };
+
+       if let Err(e) = trait_bounds::add(
+               &input.ident,
+               &mut input.generics,
+               &input.data,
+               utils::custom_decode_trait_bound(&input.attrs),
+               parse_quote!(#crate_path::Decode),
+               Some(parse_quote!(Default)),
+               utils::has_dumb_trait_bound(&input.attrs),
+               &crate_path,
+       ) {
+               return e.to_compile_error().into()
+       }
+
+       let name = &input.ident;
+       let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+       let ty_gen_turbofish = ty_generics.as_turbofish();
+
+       let input_ = quote!(__codec_input_edqy);
+       let decoding =
+               decode::quote(&input.data, name, &quote!(#ty_gen_turbofish), &input_, &crate_path);
+
+       let impl_block = quote! {
+               #[automatically_derived]
+               impl #impl_generics #crate_path::Decode for #name #ty_generics #where_clause {
+                       fn decode<__CodecInputEdqy: #crate_path::Input>(
+                               #input_: &mut __CodecInputEdqy
+                       ) -> ::core::result::Result<Self, #crate_path::Error> {
+                               #decoding
+                       }
+               }
+       };
+
+       wrap_with_dummy_const(input, impl_block)
+}
+
+/// Derive `parity_scale_codec::Compact` and `parity_scale_codec::CompactAs` for struct with single
+/// field.
+///
+/// Attribute skip can be used to skip other fields.
+///
+/// # Example
+///
+/// ```
+/// # use parity_scale_codec_derive::CompactAs;
+/// # use parity_scale_codec::{Encode, HasCompact};
+/// # use std::marker::PhantomData;
+/// #[derive(CompactAs)]
+/// struct MyWrapper<T>(u32, #[codec(skip)] PhantomData<T>);
+/// ```
+#[proc_macro_derive(CompactAs, attributes(codec))]
+pub fn compact_as_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+       let mut input: DeriveInput = match syn::parse(input) {
+               Ok(input) => input,
+               Err(e) => return e.to_compile_error().into(),
+       };
+
+       if let Err(e) = utils::check_attributes(&input) {
+               return e.to_compile_error().into()
+       }
+
+       let crate_path = match codec_crate_path(&input.attrs) {
+               Ok(crate_path) => crate_path,
+               Err(error) => return error.into_compile_error().into(),
+       };
+
+       if let Err(e) = trait_bounds::add::<()>(
+               &input.ident,
+               &mut input.generics,
+               &input.data,
+               None,
+               parse_quote!(#crate_path::CompactAs),
+               None,
+               utils::has_dumb_trait_bound(&input.attrs),
+               &crate_path,
+       ) {
+               return e.to_compile_error().into()
+       }
+
+       let name = &input.ident;
+       let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+
+       fn val_or_default(field: &Field) -> proc_macro2::TokenStream {
+               let skip = utils::should_skip(&field.attrs);
+               if skip {
+                       quote_spanned!(field.span()=> Default::default())
+               } else {
+                       quote_spanned!(field.span()=> x)
+               }
+       }
+
+       let (inner_ty, inner_field, constructor) = match input.data {
+               Data::Struct(ref data) => match data.fields {
+                       Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => {
+                               let recurse = fields.named.iter().map(|f| {
+                                       let name_ident = &f.ident;
+                                       let val_or_default = val_or_default(&f);
+                                       quote_spanned!(f.span()=> #name_ident: #val_or_default)
+                               });
+                               let field = utils::filter_skip_named(fields).next().expect("Exactly one field");
+                               let field_name = &field.ident;
+                               let constructor = quote!( #name { #( #recurse, )* });
+                               (&field.ty, quote!(&self.#field_name), constructor)
+                       },
+                       Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
+                               let recurse = fields.unnamed.iter().enumerate().map(|(_, f)| {
+                                       let val_or_default = val_or_default(&f);
+                                       quote_spanned!(f.span()=> #val_or_default)
+                               });
+                               let (id, field) =
+                                       utils::filter_skip_unnamed(fields).next().expect("Exactly one field");
+                               let id = syn::Index::from(id);
+                               let constructor = quote!( #name(#( #recurse, )*));
+                               (&field.ty, quote!(&self.#id), constructor)
+                       },
+                       _ =>
+                               return Error::new(
+                                       data.fields.span(),
+                                       "Only structs with a single non-skipped field can derive CompactAs",
+                               )
+                               .to_compile_error()
+                               .into(),
+               },
+               Data::Enum(syn::DataEnum { enum_token: syn::token::Enum { span }, .. }) |
+               Data::Union(syn::DataUnion { union_token: syn::token::Union { span }, .. }) =>
+                       return Error::new(span, "Only structs can derive CompactAs").to_compile_error().into(),
+       };
+
+       let impl_block = quote! {
+               #[automatically_derived]
+               impl #impl_generics #crate_path::CompactAs for #name #ty_generics #where_clause {
+                       type As = #inner_ty;
+                       fn encode_as(&self) -> &#inner_ty {
+                               #inner_field
+                       }
+                       fn decode_from(x: #inner_ty)
+                               -> ::core::result::Result<#name #ty_generics, #crate_path::Error>
+                       {
+                               ::core::result::Result::Ok(#constructor)
+                       }
+               }
+
+               #[automatically_derived]
+               impl #impl_generics From<#crate_path::Compact<#name #ty_generics>>
+                       for #name #ty_generics #where_clause
+               {
+                       fn from(x: #crate_path::Compact<#name #ty_generics>) -> #name #ty_generics {
+                               x.0
+                       }
+               }
+       };
+
+       wrap_with_dummy_const(input, impl_block)
+}
+
+/// Derive `parity_scale_codec::MaxEncodedLen` for struct and enum.
+///
+/// # Top level attribute
+///
+/// By default the macro will try to bound the types needed to implement `MaxEncodedLen`, but the
+/// bounds can be specified manually with the top level attribute:
+/// ```
+/// # use parity_scale_codec_derive::Encode;
+/// # use parity_scale_codec::MaxEncodedLen;
+/// # #[derive(Encode, MaxEncodedLen)]
+/// #[codec(mel_bound(T: MaxEncodedLen))]
+/// # struct MyWrapper<T>(T);
+/// ```
+#[cfg(feature = "max-encoded-len")]
+#[proc_macro_derive(MaxEncodedLen, attributes(max_encoded_len_mod))]
+pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+       max_encoded_len::derive_max_encoded_len(input)
+}
diff --git a/src/max_encoded_len.rs b/src/max_encoded_len.rs
new file mode 100644 (file)
index 0000000..bf8b20f
--- /dev/null
@@ -0,0 +1,141 @@
+// Copyright (C) 2021 Parity Technologies (UK) Ltd.
+// SPDX-License-Identifier: Apache-2.0
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#![cfg(feature = "max-encoded-len")]
+
+use crate::{
+       trait_bounds,
+       utils::{codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip},
+};
+use quote::{quote, quote_spanned};
+use syn::{parse_quote, spanned::Spanned, Data, DeriveInput, Fields, Type};
+
+/// impl for `#[derive(MaxEncodedLen)]`
+pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+       let mut input: DeriveInput = match syn::parse(input) {
+               Ok(input) => input,
+               Err(e) => return e.to_compile_error().into(),
+       };
+
+       let crate_path = match codec_crate_path(&input.attrs) {
+               Ok(crate_path) => crate_path,
+               Err(error) => return error.into_compile_error().into(),
+       };
+
+       let name = &input.ident;
+       if let Err(e) = trait_bounds::add(
+               &input.ident,
+               &mut input.generics,
+               &input.data,
+               custom_mel_trait_bound(&input.attrs),
+               parse_quote!(#crate_path::MaxEncodedLen),
+               None,
+               has_dumb_trait_bound(&input.attrs),
+               &crate_path
+       ) {
+               return e.to_compile_error().into()
+       }
+       let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+
+       let data_expr = data_length_expr(&input.data);
+
+       quote::quote!(
+               const _: () = {
+                       impl #impl_generics #crate_path::MaxEncodedLen for #name #ty_generics #where_clause {
+                               fn max_encoded_len() -> ::core::primitive::usize {
+                                       #data_expr
+                               }
+                       }
+               };
+       )
+       .into()
+}
+
+/// generate an expression to sum up the max encoded length from several fields
+fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
+       let type_iter: Box<dyn Iterator<Item = &Type>> = match fields {
+               Fields::Named(ref fields) => Box::new(
+                       fields.named.iter().filter_map(|field| if should_skip(&field.attrs) {
+                               None
+                       } else {
+                               Some(&field.ty)
+                       })
+               ),
+               Fields::Unnamed(ref fields) => Box::new(
+                       fields.unnamed.iter().filter_map(|field| if should_skip(&field.attrs) {
+                               None
+                       } else {
+                               Some(&field.ty)
+                       })
+               ),
+               Fields::Unit => Box::new(std::iter::empty()),
+       };
+       // expands to an expression like
+       //
+       //   0
+       //     .saturating_add(<type of first field>::max_encoded_len())
+       //     .saturating_add(<type of second field>::max_encoded_len())
+       //
+       // We match the span of each field to the span of the corresponding
+       // `max_encoded_len` call. This way, if one field's type doesn't implement
+       // `MaxEncodedLen`, the compiler's error message will underline which field
+       // caused the issue.
+       let expansion = type_iter.map(|ty| {
+               quote_spanned! {
+                       ty.span() => .saturating_add(<#ty>::max_encoded_len())
+               }
+       });
+       quote! {
+               0_usize #( #expansion )*
+       }
+}
+
+// generate an expression to sum up the max encoded length of each field
+fn data_length_expr(data: &Data) -> proc_macro2::TokenStream {
+       match *data {
+               Data::Struct(ref data) => fields_length_expr(&data.fields),
+               Data::Enum(ref data) => {
+                       // We need an expression expanded for each variant like
+                       //
+                       //   0
+                       //     .max(<variant expression>)
+                       //     .max(<variant expression>)
+                       //     .saturating_add(1)
+                       //
+                       // The 1 derives from the discriminant; see
+                       // https://github.com/paritytech/parity-scale-codec/
+                       //   blob/f0341dabb01aa9ff0548558abb6dcc5c31c669a1/derive/src/encode.rs#L211-L216
+                       //
+                       // Each variant expression's sum is computed the way an equivalent struct's would be.
+
+                       let expansion = data.variants.iter().map(|variant| {
+                               let variant_expression = fields_length_expr(&variant.fields);
+                               quote! {
+                                       .max(#variant_expression)
+                               }
+                       });
+
+                       quote! {
+                               0_usize #( #expansion )* .saturating_add(1)
+                       }
+               },
+               Data::Union(ref data) => {
+                       // https://github.com/paritytech/parity-scale-codec/
+                       //   blob/f0341dabb01aa9ff0548558abb6dcc5c31c669a1/derive/src/encode.rs#L290-L293
+                       syn::Error::new(data.union_token.span(), "Union types are not supported.")
+                               .to_compile_error()
+               },
+       }
+}
diff --git a/src/trait_bounds.rs b/src/trait_bounds.rs
new file mode 100644 (file)
index 0000000..f808588
--- /dev/null
@@ -0,0 +1,248 @@
+// Copyright 2019 Parity Technologies
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::iter;
+
+use proc_macro2::Ident;
+use syn::{
+       spanned::Spanned,
+       visit::{self, Visit},
+       Generics, Result, Type, TypePath,
+};
+
+use crate::utils::{self, CustomTraitBound};
+
+/// Visits the ast and checks if one of the given idents is found.
+struct ContainIdents<'a> {
+       result: bool,
+       idents: &'a [Ident],
+}
+
+impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
+       fn visit_ident(&mut self, i: &'ast Ident) {
+               if self.idents.iter().any(|id| id == i) {
+                       self.result = true;
+               }
+       }
+}
+
+/// Checks if the given type contains one of the given idents.
+fn type_contain_idents(ty: &Type, idents: &[Ident]) -> bool {
+       let mut visitor = ContainIdents { result: false, idents };
+       visitor.visit_type(ty);
+       visitor.result
+}
+
+/// Visits the ast and checks if the a type path starts with the given ident.
+struct TypePathStartsWithIdent<'a> {
+       result: bool,
+       ident: &'a Ident,
+}
+
+impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> {
+       fn visit_type_path(&mut self, i: &'ast TypePath) {
+               if let Some(segment) = i.path.segments.first() {
+                       if &segment.ident == self.ident {
+                               self.result = true;
+                               return
+                       }
+               }
+
+               visit::visit_type_path(self, i);
+       }
+}
+
+/// Checks if the given type path or any containing type path starts with the given ident.
+fn type_path_or_sub_starts_with_ident(ty: &TypePath, ident: &Ident) -> bool {
+       let mut visitor = TypePathStartsWithIdent { result: false, ident };
+       visitor.visit_type_path(ty);
+       visitor.result
+}
+
+/// Checks if the given type or any containing type path starts with the given ident.
+fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool {
+       let mut visitor = TypePathStartsWithIdent { result: false, ident };
+       visitor.visit_type(ty);
+       visitor.result
+}
+
+/// Visits the ast and collects all type paths that do not start or contain the given ident.
+///
+/// Returns `T`, `N`, `A` for `Vec<(Recursive<T, N>, A)>` with `Recursive` as ident.
+struct FindTypePathsNotStartOrContainIdent<'a> {
+       result: Vec<TypePath>,
+       ident: &'a Ident,
+}
+
+impl<'a, 'ast> Visit<'ast> for FindTypePathsNotStartOrContainIdent<'a> {
+       fn visit_type_path(&mut self, i: &'ast TypePath) {
+               if type_path_or_sub_starts_with_ident(i, &self.ident) {
+                       visit::visit_type_path(self, i);
+               } else {
+                       self.result.push(i.clone());
+               }
+       }
+}
+
+/// Collects all type paths that do not start or contain the given ident in the given type.
+///
+/// Returns `T`, `N`, `A` for `Vec<(Recursive<T, N>, A)>` with `Recursive` as ident.
+fn find_type_paths_not_start_or_contain_ident(ty: &Type, ident: &Ident) -> Vec<TypePath> {
+       let mut visitor = FindTypePathsNotStartOrContainIdent { result: Vec::new(), ident };
+       visitor.visit_type(ty);
+       visitor.result
+}
+
+/// Add required trait bounds to all generic types.
+pub fn add<N>(
+       input_ident: &Ident,
+       generics: &mut Generics,
+       data: &syn::Data,
+       custom_trait_bound: Option<CustomTraitBound<N>>,
+       codec_bound: syn::Path,
+       codec_skip_bound: Option<syn::Path>,
+       dumb_trait_bounds: bool,
+       crate_path: &syn::Path,
+) -> Result<()> {
+       let skip_type_params = match custom_trait_bound {
+               Some(CustomTraitBound::SpecifiedBounds { bounds, .. }) => {
+                       generics.make_where_clause().predicates.extend(bounds);
+                       return Ok(())
+               },
+               Some(CustomTraitBound::SkipTypeParams { type_names, .. }) =>
+                       type_names.into_iter().collect::<Vec<_>>(),
+               None => Vec::new(),
+       };
+
+       let ty_params = generics
+               .type_params()
+               .filter_map(|tp| {
+                       skip_type_params.iter().all(|skip| skip != &tp.ident).then(|| tp.ident.clone())
+               })
+               .collect::<Vec<_>>();
+       if ty_params.is_empty() {
+               return Ok(())
+       }
+
+       let codec_types =
+               get_types_to_add_trait_bound(input_ident, data, &ty_params, dumb_trait_bounds)?;
+
+       let compact_types = collect_types(&data, utils::is_compact)?
+               .into_iter()
+               // Only add a bound if the type uses a generic
+               .filter(|ty| type_contain_idents(ty, &ty_params))
+               .collect::<Vec<_>>();
+
+       let skip_types = if codec_skip_bound.is_some() {
+               let needs_default_bound = |f: &syn::Field| utils::should_skip(&f.attrs);
+               collect_types(&data, needs_default_bound)?
+                       .into_iter()
+                       // Only add a bound if the type uses a generic
+                       .filter(|ty| type_contain_idents(ty, &ty_params))
+                       .collect::<Vec<_>>()
+       } else {
+               Vec::new()
+       };
+
+       if !codec_types.is_empty() || !compact_types.is_empty() || !skip_types.is_empty() {
+               let where_clause = generics.make_where_clause();
+
+               codec_types
+                       .into_iter()
+                       .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #codec_bound)));
+
+               let has_compact_bound: syn::Path = parse_quote!(#crate_path::HasCompact);
+               compact_types
+                       .into_iter()
+                       .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #has_compact_bound)));
+
+               skip_types.into_iter().for_each(|ty| {
+                       let codec_skip_bound = codec_skip_bound.as_ref();
+                       where_clause.predicates.push(parse_quote!(#ty : #codec_skip_bound))
+               });
+       }
+
+       Ok(())
+}
+
+/// Returns all types that must be added to the where clause with the respective trait bound.
+fn get_types_to_add_trait_bound(
+       input_ident: &Ident,
+       data: &syn::Data,
+       ty_params: &[Ident],
+       dumb_trait_bound: bool,
+) -> Result<Vec<Type>> {
+       if dumb_trait_bound {
+               Ok(ty_params.iter().map(|t| parse_quote!( #t )).collect())
+       } else {
+               let needs_codec_bound = |f: &syn::Field| {
+                       !utils::is_compact(f) &&
+                               utils::get_encoded_as_type(f).is_none() &&
+                               !utils::should_skip(&f.attrs)
+               };
+               let res = collect_types(&data, needs_codec_bound)?
+                       .into_iter()
+                       // Only add a bound if the type uses a generic
+                       .filter(|ty| type_contain_idents(ty, &ty_params))
+                       // If a struct contains itself as field type, we can not add this type into the where
+                       // clause. This is required to work a round the following compiler bug: https://github.com/rust-lang/rust/issues/47032
+                       .flat_map(|ty| {
+                               find_type_paths_not_start_or_contain_ident(&ty, input_ident)
+                                       .into_iter()
+                                       .map(|ty| Type::Path(ty.clone()))
+                                       // Remove again types that do not contain any of our generic parameters
+                                       .filter(|ty| type_contain_idents(ty, &ty_params))
+                                       // Add back the original type, as we don't want to loose it.
+                                       .chain(iter::once(ty))
+                       })
+                       // Remove all remaining types that start/contain the input ident to not have them in the
+                       // where clause.
+                       .filter(|ty| !type_or_sub_type_path_starts_with_ident(ty, input_ident))
+                       .collect();
+
+               Ok(res)
+       }
+}
+
+fn collect_types(data: &syn::Data, type_filter: fn(&syn::Field) -> bool) -> Result<Vec<syn::Type>> {
+       use syn::*;
+
+       let types = match *data {
+               Data::Struct(ref data) => match &data.fields {
+                       | Fields::Named(FieldsNamed { named: fields, .. }) |
+                       Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
+                               fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
+
+                       Fields::Unit => Vec::new(),
+               },
+
+               Data::Enum(ref data) => data
+                       .variants
+                       .iter()
+                       .filter(|variant| !utils::should_skip(&variant.attrs))
+                       .flat_map(|variant| match &variant.fields {
+                               | Fields::Named(FieldsNamed { named: fields, .. }) |
+                               Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
+                                       fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
+
+                               Fields::Unit => Vec::new(),
+                       })
+                       .collect(),
+
+               Data::Union(ref data) =>
+                       return Err(Error::new(data.union_token.span(), "Union types are not supported.")),
+       };
+
+       Ok(types)
+}
diff --git a/src/utils.rs b/src/utils.rs
new file mode 100644 (file)
index 0000000..585ce9a
--- /dev/null
@@ -0,0 +1,433 @@
+// Copyright 2018-2020 Parity Technologies
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Various internal utils.
+//!
+//! NOTE: attributes finder must be checked using check_attribute first,
+//! otherwise the macro can panic.
+
+use std::str::FromStr;
+
+use proc_macro2::TokenStream;
+use quote::quote;
+use syn::{
+       parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput,
+       Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant,
+};
+
+fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option<R>
+where
+       F: FnMut(M) -> Option<R> + Clone,
+       I: Iterator<Item = &'a Attribute>,
+       M: Parse,
+{
+       itr.find_map(|attr| {
+               attr.path.is_ident("codec").then(|| pred(attr.parse_args().ok()?)).flatten()
+       })
+}
+
+/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
+/// is found, fall back to the discriminant or just the variant index.
+pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
+       // first look for an attribute
+       let index = find_meta_item(v.attrs.iter(), |meta| {
+               if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
+                       if nv.path.is_ident("index") {
+                               if let Lit::Int(ref v) = nv.lit {
+                                       let byte = v
+                                               .base10_parse::<u8>()
+                                               .expect("Internal error, index attribute must have been checked");
+                                       return Some(byte)
+                               }
+                       }
+               }
+
+               None
+       });
+
+       // then fallback to discriminant or just index
+       index.map(|i| quote! { #i }).unwrap_or_else(|| {
+               v.discriminant
+                       .as_ref()
+                       .map(|&(_, ref expr)| quote! { #expr })
+                       .unwrap_or_else(|| quote! { #i })
+       })
+}
+
+/// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given
+/// `Field`.
+pub fn get_encoded_as_type(field: &Field) -> Option<TokenStream> {
+       find_meta_item(field.attrs.iter(), |meta| {
+               if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
+                       if nv.path.is_ident("encoded_as") {
+                               if let Lit::Str(ref s) = nv.lit {
+                                       return Some(
+                                               TokenStream::from_str(&s.value())
+                                                       .expect("Internal error, encoded_as attribute must have been checked"),
+                                       )
+                               }
+                       }
+               }
+
+               None
+       })
+}
+
+/// Look for a `#[codec(compact)]` outer attribute on the given `Field`.
+pub fn is_compact(field: &Field) -> bool {
+       find_meta_item(field.attrs.iter(), |meta| {
+               if let NestedMeta::Meta(Meta::Path(ref path)) = meta {
+                       if path.is_ident("compact") {
+                               return Some(())
+                       }
+               }
+
+               None
+       })
+       .is_some()
+}
+
+/// Look for a `#[codec(skip)]` in the given attributes.
+pub fn should_skip(attrs: &[Attribute]) -> bool {
+       find_meta_item(attrs.iter(), |meta| {
+               if let NestedMeta::Meta(Meta::Path(ref path)) = meta {
+                       if path.is_ident("skip") {
+                               return Some(path.span())
+                       }
+               }
+
+               None
+       })
+       .is_some()
+}
+
+/// Look for a `#[codec(dumb_trait_bound)]`in the given attributes.
+pub fn has_dumb_trait_bound(attrs: &[Attribute]) -> bool {
+       find_meta_item(attrs.iter(), |meta| {
+               if let NestedMeta::Meta(Meta::Path(ref path)) = meta {
+                       if path.is_ident("dumb_trait_bound") {
+                               return Some(())
+                       }
+               }
+
+               None
+       })
+       .is_some()
+}
+
+/// Generate the crate access for the crate using 2018 syntax.
+fn crate_access() -> syn::Result<proc_macro2::Ident> {
+       use proc_macro2::{Ident, Span};
+       use proc_macro_crate::{crate_name, FoundCrate};
+       const DEF_CRATE: &str = "parity-scale-codec";
+       match crate_name(DEF_CRATE) {
+               Ok(FoundCrate::Itself) => {
+                       let name = DEF_CRATE.to_string().replace("-", "_");
+                       Ok(syn::Ident::new(&name, Span::call_site()))
+               },
+               Ok(FoundCrate::Name(name)) => Ok(Ident::new(&name, Span::call_site())),
+               Err(e) => Err(syn::Error::new(Span::call_site(), e)),
+       }
+}
+
+/// This struct matches `crate = ...` where the ellipsis is a `Path`.
+struct CratePath {
+       _crate_token: Token![crate],
+       _eq_token: Token![=],
+       path: Path,
+}
+
+impl Parse for CratePath {
+       fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+               Ok(CratePath {
+                       _crate_token: input.parse()?,
+                       _eq_token: input.parse()?,
+                       path: input.parse()?,
+               })
+       }
+}
+
+impl From<CratePath> for Path {
+       fn from(CratePath { path, .. }: CratePath) -> Self {
+               path
+       }
+}
+
+/// Match `#[codec(crate = ...)]` and return the `...` if it is a `Path`.
+fn codec_crate_path_inner(attr: &Attribute) -> Option<Path> {
+       // match `#[codec ...]`
+       attr.path
+               .is_ident("codec")
+               .then(|| {
+                       // match `#[codec(crate = ...)]` and return the `...`
+                       attr.parse_args::<CratePath>().map(Into::into).ok()
+               })
+               .flatten()
+}
+
+/// Match `#[codec(crate = ...)]` and return the ellipsis as a `Path`.
+///
+/// If not found, returns the default crate access pattern.
+///
+/// If multiple items match the pattern, all but the first are ignored.
+pub fn codec_crate_path(attrs: &[Attribute]) -> syn::Result<Path> {
+       match attrs.iter().find_map(codec_crate_path_inner) {
+               Some(path) => Ok(path),
+               None => crate_access().map(|ident| parse_quote!(::#ident)),
+       }
+}
+
+/// Parse `name(T: Bound, N: Bound)` or `name(skip_type_params(T, N))` as a custom trait bound.
+pub enum CustomTraitBound<N> {
+       SpecifiedBounds {
+               _name: N,
+               _paren_token: token::Paren,
+               bounds: Punctuated<syn::WherePredicate, Token![,]>,
+       },
+       SkipTypeParams {
+               _name: N,
+               _paren_token_1: token::Paren,
+               _skip_type_params: skip_type_params,
+               _paren_token_2: token::Paren,
+               type_names: Punctuated<syn::Ident, Token![,]>,
+       },
+}
+
+impl<N: Parse> Parse for CustomTraitBound<N> {
+       fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+               let mut content;
+               let _name: N = input.parse()?;
+               let _paren_token = syn::parenthesized!(content in input);
+               if content.peek(skip_type_params) {
+                       Ok(Self::SkipTypeParams {
+                               _name,
+                               _paren_token_1: _paren_token,
+                               _skip_type_params: content.parse::<skip_type_params>()?,
+                               _paren_token_2: syn::parenthesized!(content in content),
+                               type_names: content.parse_terminated(syn::Ident::parse)?,
+                       })
+               } else {
+                       Ok(Self::SpecifiedBounds {
+                               _name,
+                               _paren_token,
+                               bounds: content.parse_terminated(syn::WherePredicate::parse)?,
+                       })
+               }
+       }
+}
+
+syn::custom_keyword!(encode_bound);
+syn::custom_keyword!(decode_bound);
+syn::custom_keyword!(mel_bound);
+syn::custom_keyword!(skip_type_params);
+
+/// Look for a `#[codec(decode_bound(T: Decode))]` in the given attributes.
+///
+/// If found, it should be used as trait bounds when deriving the `Decode` trait.
+pub fn custom_decode_trait_bound(attrs: &[Attribute]) -> Option<CustomTraitBound<decode_bound>> {
+       find_meta_item(attrs.iter(), Some)
+}
+
+/// Look for a `#[codec(encode_bound(T: Encode))]` in the given attributes.
+///
+/// If found, it should be used as trait bounds when deriving the `Encode` trait.
+pub fn custom_encode_trait_bound(attrs: &[Attribute]) -> Option<CustomTraitBound<encode_bound>> {
+       find_meta_item(attrs.iter(), Some)
+}
+
+/// Look for a `#[codec(mel_bound(T: MaxEncodedLen))]` in the given attributes.
+///
+/// If found, it should be used as the trait bounds when deriving the `MaxEncodedLen` trait.
+#[cfg(feature = "max-encoded-len")]
+pub fn custom_mel_trait_bound(attrs: &[Attribute]) -> Option<CustomTraitBound<mel_bound>> {
+       find_meta_item(attrs.iter(), Some)
+}
+
+/// Given a set of named fields, return an iterator of `Field` where all fields
+/// marked `#[codec(skip)]` are filtered out.
+pub fn filter_skip_named<'a>(fields: &'a syn::FieldsNamed) -> impl Iterator<Item = &Field> + 'a {
+       fields.named.iter().filter(|f| !should_skip(&f.attrs))
+}
+
+/// Given a set of unnamed fields, return an iterator of `(index, Field)` where all fields
+/// marked `#[codec(skip)]` are filtered out.
+pub fn filter_skip_unnamed<'a>(
+       fields: &'a syn::FieldsUnnamed,
+) -> impl Iterator<Item = (usize, &Field)> + 'a {
+       fields.unnamed.iter().enumerate().filter(|(_, f)| !should_skip(&f.attrs))
+}
+
+/// Ensure attributes are correctly applied. This *must* be called before using
+/// any of the attribute finder methods or the macro may panic if it encounters
+/// misapplied attributes.
+///
+/// The top level can have the following attributes:
+///
+/// * `#[codec(dumb_trait_bound)]`
+/// * `#[codec(encode_bound(T: Encode))]`
+/// * `#[codec(decode_bound(T: Decode))]`
+/// * `#[codec(mel_bound(T: MaxEncodedLen))]`
+/// * `#[codec(crate = path::to::crate)]
+///
+/// Fields can have the following attributes:
+///
+/// * `#[codec(skip)]`
+/// * `#[codec(compact)]`
+/// * `#[codec(encoded_as = "$EncodeAs")]` with $EncodedAs a valid TokenStream
+///
+/// Variants can have the following attributes:
+///
+/// * `#[codec(skip)]`
+/// * `#[codec(index = $int)]`
+pub fn check_attributes(input: &DeriveInput) -> syn::Result<()> {
+       for attr in &input.attrs {
+               check_top_attribute(attr)?;
+       }
+
+       match input.data {
+               Data::Struct(ref data) => match &data.fields {
+                       | Fields::Named(FieldsNamed { named: fields, .. }) |
+                       Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
+                               for field in fields {
+                                       for attr in &field.attrs {
+                                               check_field_attribute(attr)?;
+                                       }
+                               },
+                       Fields::Unit => (),
+               },
+               Data::Enum(ref data) =>
+                       for variant in data.variants.iter() {
+                               for attr in &variant.attrs {
+                                       check_variant_attribute(attr)?;
+                               }
+                               for field in &variant.fields {
+                                       for attr in &field.attrs {
+                                               check_field_attribute(attr)?;
+                                       }
+                               }
+                       },
+               Data::Union(_) => (),
+       }
+       Ok(())
+}
+
+// Check if the attribute is `#[allow(..)]`, `#[deny(..)]`, `#[forbid(..)]` or `#[warn(..)]`.
+pub fn is_lint_attribute(attr: &Attribute) -> bool {
+       attr.path.is_ident("allow") ||
+               attr.path.is_ident("deny") ||
+               attr.path.is_ident("forbid") ||
+               attr.path.is_ident("warn")
+}
+
+// Ensure a field is decorated only with the following attributes:
+// * `#[codec(skip)]`
+// * `#[codec(compact)]`
+// * `#[codec(encoded_as = "$EncodeAs")]` with $EncodedAs a valid TokenStream
+fn check_field_attribute(attr: &Attribute) -> syn::Result<()> {
+       let field_error = "Invalid attribute on field, only `#[codec(skip)]`, `#[codec(compact)]` and \
+               `#[codec(encoded_as = \"$EncodeAs\")]` are accepted.";
+
+       if attr.path.is_ident("codec") {
+               match attr.parse_meta()? {
+                       Meta::List(ref meta_list) if meta_list.nested.len() == 1 => {
+                               match meta_list.nested.first().expect("Just checked that there is one item; qed") {
+                                       NestedMeta::Meta(Meta::Path(path))
+                                               if path.get_ident().map_or(false, |i| i == "skip") =>
+                                               Ok(()),
+
+                                       NestedMeta::Meta(Meta::Path(path))
+                                               if path.get_ident().map_or(false, |i| i == "compact") =>
+                                               Ok(()),
+
+                                       NestedMeta::Meta(Meta::NameValue(MetaNameValue {
+                                               path,
+                                               lit: Lit::Str(lit_str),
+                                               ..
+                                       })) if path.get_ident().map_or(false, |i| i == "encoded_as") =>
+                                               TokenStream::from_str(&lit_str.value())
+                                                       .map(|_| ())
+                                                       .map_err(|_e| syn::Error::new(lit_str.span(), "Invalid token stream")),
+
+                                       elt @ _ => Err(syn::Error::new(elt.span(), field_error)),
+                               }
+                       },
+                       meta @ _ => Err(syn::Error::new(meta.span(), field_error)),
+               }
+       } else {
+               Ok(())
+       }
+}
+
+// Ensure a field is decorated only with the following attributes:
+// * `#[codec(skip)]`
+// * `#[codec(index = $int)]`
+fn check_variant_attribute(attr: &Attribute) -> syn::Result<()> {
+       let variant_error = "Invalid attribute on variant, only `#[codec(skip)]` and \
+               `#[codec(index = $u8)]` are accepted.";
+
+       if attr.path.is_ident("codec") {
+               match attr.parse_meta()? {
+                       Meta::List(ref meta_list) if meta_list.nested.len() == 1 => {
+                               match meta_list.nested.first().expect("Just checked that there is one item; qed") {
+                                       NestedMeta::Meta(Meta::Path(path))
+                                               if path.get_ident().map_or(false, |i| i == "skip") =>
+                                               Ok(()),
+
+                                       NestedMeta::Meta(Meta::NameValue(MetaNameValue {
+                                               path,
+                                               lit: Lit::Int(lit_int),
+                                               ..
+                                       })) if path.get_ident().map_or(false, |i| i == "index") => lit_int
+                                               .base10_parse::<u8>()
+                                               .map(|_| ())
+                                               .map_err(|_| syn::Error::new(lit_int.span(), "Index must be in 0..255")),
+
+                                       elt @ _ => Err(syn::Error::new(elt.span(), variant_error)),
+                               }
+                       },
+                       meta @ _ => Err(syn::Error::new(meta.span(), variant_error)),
+               }
+       } else {
+               Ok(())
+       }
+}
+
+// Only `#[codec(dumb_trait_bound)]` is accepted as top attribute
+fn check_top_attribute(attr: &Attribute) -> syn::Result<()> {
+       let top_error = "Invalid attribute: only `#[codec(dumb_trait_bound)]`, \
+               `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, \
+               `#[codec(decode_bound(T: Decode))]`, or `#[codec(mel_bound(T: MaxEncodedLen))]` \
+               are accepted as top attribute";
+       if attr.path.is_ident("codec") &&
+               attr.parse_args::<CustomTraitBound<encode_bound>>().is_err() &&
+               attr.parse_args::<CustomTraitBound<decode_bound>>().is_err() &&
+               attr.parse_args::<CustomTraitBound<mel_bound>>().is_err() &&
+               codec_crate_path_inner(attr).is_none()
+       {
+               match attr.parse_meta()? {
+                       Meta::List(ref meta_list) if meta_list.nested.len() == 1 => {
+                               match meta_list.nested.first().expect("Just checked that there is one item; qed") {
+                                       NestedMeta::Meta(Meta::Path(path))
+                                               if path.get_ident().map_or(false, |i| i == "dumb_trait_bound") =>
+                                               Ok(()),
+
+                                       elt @ _ => Err(syn::Error::new(elt.span(), top_error)),
+                               }
+                       },
+                       _ => Err(syn::Error::new(attr.span(), top_error)),
+               }
+       } else {
+               Ok(())
+       }
+}