From be421160fd5b14938d9036b9a99469efd7ab6ce2 Mon Sep 17 00:00:00 2001 From: DongHun Kwak Date: Thu, 23 Mar 2023 15:07:00 +0900 Subject: [PATCH] Import parity-scale-codec-derive 3.1.4 --- .cargo_vcs_info.json | 6 + Cargo.toml | 43 ++++ Cargo.toml.orig | 26 +++ src/decode.rs | 206 ++++++++++++++++++++ src/encode.rs | 329 +++++++++++++++++++++++++++++++ src/lib.rs | 361 ++++++++++++++++++++++++++++++++++ src/max_encoded_len.rs | 141 ++++++++++++++ src/trait_bounds.rs | 248 +++++++++++++++++++++++ src/utils.rs | 433 +++++++++++++++++++++++++++++++++++++++++ 9 files changed, 1793 insertions(+) create mode 100644 .cargo_vcs_info.json create mode 100644 Cargo.toml create mode 100644 Cargo.toml.orig create mode 100644 src/decode.rs create mode 100644 src/encode.rs create mode 100644 src/lib.rs create mode 100644 src/max_encoded_len.rs create mode 100644 src/trait_bounds.rs create mode 100644 src/utils.rs diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json new file mode 100644 index 0000000..5577d12 --- /dev/null +++ b/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "ef9f79438e9e51912b4445c6290de142b5d91e97" + }, + "path_in_vcs": "derive" +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0063905 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,43 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +rust-version = "1.56.1" +name = "parity-scale-codec-derive" +version = "3.1.4" +authors = ["Parity Technologies "] +description = "Serialization and deserialization derive macro for Parity SCALE Codec" +license = "Apache-2.0" + +[lib] +proc-macro = true + +[dependencies.proc-macro-crate] +version = "1.0.0" + +[dependencies.proc-macro2] +version = "1.0.6" + +[dependencies.quote] +version = "1.0.2" + +[dependencies.syn] +version = "1.0.98" +features = [ + "full", + "visit", +] + +[dev-dependencies] + +[features] +max-encoded-len = [] diff --git a/Cargo.toml.orig b/Cargo.toml.orig new file mode 100644 index 0000000..6fb552b --- /dev/null +++ b/Cargo.toml.orig @@ -0,0 +1,26 @@ +[package] +name = "parity-scale-codec-derive" +description = "Serialization and deserialization derive macro for Parity SCALE Codec" +version = "3.1.4" +authors = ["Parity Technologies "] +license = "Apache-2.0" +edition = "2021" +rust-version = "1.56.1" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "1.0.98", features = ["full", "visit"] } +quote = "1.0.2" +proc-macro2 = "1.0.6" +proc-macro-crate = "1.0.0" + +[dev-dependencies] +parity-scale-codec = { path = "..", features = ["max-encoded-len"] } + +[features] +# Enables the new `MaxEncodedLen` trait. +# NOTE: This is still considered experimental and is exempt from the usual +# SemVer guarantees. We do not guarantee no code breakage when using this. +max-encoded-len = [] diff --git a/src/decode.rs b/src/decode.rs new file mode 100644 index 0000000..df60390 --- /dev/null +++ b/src/decode.rs @@ -0,0 +1,206 @@ +// Copyright 2017, 2018 Parity Technologies +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use proc_macro2::{Span, TokenStream, Ident}; +use syn::{ + spanned::Spanned, + Data, Fields, Field, Error, +}; + +use crate::utils; + +/// Generate function block for function `Decode::decode`. +/// +/// * data: data info of the type, +/// * type_name: name of the type, +/// * type_generics: the generics of the type in turbofish format, without bounds, e.g. `::` +/// * input: the variable name for the argument of function `decode`. +pub fn quote( + data: &Data, + type_name: &Ident, + type_generics: &TokenStream, + input: &TokenStream, + crate_path: &syn::Path, +) -> TokenStream { + match *data { + Data::Struct(ref data) => match data.fields { + Fields::Named(_) | Fields::Unnamed(_) => create_instance( + quote! { #type_name #type_generics }, + &type_name.to_string(), + input, + &data.fields, + crate_path, + ), + Fields::Unit => { + quote_spanned! { data.fields.span() => + ::core::result::Result::Ok(#type_name) + } + }, + }, + Data::Enum(ref data) => { + let data_variants = || data.variants.iter().filter(|variant| !utils::should_skip(&variant.attrs)); + + if data_variants().count() > 256 { + return Error::new( + data.variants.span(), + "Currently only enums with at most 256 variants are encodable." + ).to_compile_error(); + } + + let recurse = data_variants().enumerate().map(|(i, v)| { + let name = &v.ident; + let index = utils::variant_index(v, i); + + let create = create_instance( + quote! { #type_name #type_generics :: #name }, + &format!("{}::{}", type_name, name), + input, + &v.fields, + crate_path, + ); + + quote_spanned! { v.span() => + __codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => { + #create + }, + } + }); + + let read_byte_err_msg = format!( + "Could not decode `{}`, failed to read variant byte", + type_name, + ); + let invalid_variant_err_msg = format!( + "Could not decode `{}`, variant doesn't exist", + type_name, + ); + quote! { + match #input.read_byte() + .map_err(|e| e.chain(#read_byte_err_msg))? + { + #( #recurse )* + _ => ::core::result::Result::Err( + <_ as ::core::convert::Into<_>>::into(#invalid_variant_err_msg) + ), + } + } + + }, + Data::Union(_) => Error::new(Span::call_site(), "Union types are not supported.").to_compile_error(), + } +} + +fn create_decode_expr(field: &Field, name: &str, input: &TokenStream, crate_path: &syn::Path) -> TokenStream { + let encoded_as = utils::get_encoded_as_type(field); + let compact = utils::is_compact(field); + let skip = utils::should_skip(&field.attrs); + + let res = quote!(__codec_res_edqy); + + if encoded_as.is_some() as u8 + compact as u8 + skip as u8 > 1 { + return Error::new( + field.span(), + "`encoded_as`, `compact` and `skip` can only be used one at a time!" + ).to_compile_error(); + } + + let err_msg = format!("Could not decode `{}`", name); + + if compact { + let field_type = &field.ty; + quote_spanned! { field.span() => + { + let #res = < + <#field_type as #crate_path::HasCompact>::Type as #crate_path::Decode + >::decode(#input); + match #res { + ::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)), + ::core::result::Result::Ok(#res) => #res.into(), + } + } + } + } else if let Some(encoded_as) = encoded_as { + quote_spanned! { field.span() => + { + let #res = <#encoded_as as #crate_path::Decode>::decode(#input); + match #res { + ::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)), + ::core::result::Result::Ok(#res) => #res.into(), + } + } + } + } else if skip { + quote_spanned! { field.span() => ::core::default::Default::default() } + } else { + let field_type = &field.ty; + quote_spanned! { field.span() => + { + let #res = <#field_type as #crate_path::Decode>::decode(#input); + match #res { + ::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)), + ::core::result::Result::Ok(#res) => #res, + } + } + } + } +} + +fn create_instance( + name: TokenStream, + name_str: &str, + input: &TokenStream, + fields: &Fields, + crate_path: &syn::Path, +) -> TokenStream { + match *fields { + Fields::Named(ref fields) => { + let recurse = fields.named.iter().map(|f| { + let name_ident = &f.ident; + let field_name = match name_ident { + Some(a) => format!("{}::{}", name_str, a), + None => format!("{}", name_str), // Should never happen, fields are named. + }; + let decode = create_decode_expr(f, &field_name, input, crate_path); + + quote_spanned! { f.span() => + #name_ident: #decode + } + }); + + quote_spanned! { fields.span() => + ::core::result::Result::Ok(#name { + #( #recurse, )* + }) + } + }, + Fields::Unnamed(ref fields) => { + let recurse = fields.unnamed.iter().enumerate().map(|(i, f) | { + let field_name = format!("{}.{}", name_str, i); + + create_decode_expr(f, &field_name, input, crate_path) + }); + + quote_spanned! { fields.span() => + ::core::result::Result::Ok(#name ( + #( #recurse, )* + )) + } + }, + Fields::Unit => { + quote_spanned! { fields.span() => + ::core::result::Result::Ok(#name) + } + }, + } +} diff --git a/src/encode.rs b/src/encode.rs new file mode 100644 index 0000000..b2badc1 --- /dev/null +++ b/src/encode.rs @@ -0,0 +1,329 @@ +// Copyright 2017, 2018 Parity Technologies +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::str::from_utf8; + +use proc_macro2::{Ident, Span, TokenStream}; +use syn::{ + punctuated::Punctuated, + spanned::Spanned, + token::Comma, + Data, Field, Fields, Error, +}; + +use crate::utils; + +type FieldsList = Punctuated; + +// Encode a signle field by using using_encoded, must not have skip attribute +fn encode_single_field( + field: &Field, + field_name: TokenStream, + crate_path: &syn::Path, +) -> TokenStream { + let encoded_as = utils::get_encoded_as_type(field); + let compact = utils::is_compact(field); + + if utils::should_skip(&field.attrs) { + return Error::new( + Span::call_site(), + "Internal error: cannot encode single field optimisation if skipped" + ).to_compile_error(); + } + + if encoded_as.is_some() && compact { + return Error::new( + Span::call_site(), + "`encoded_as` and `compact` can not be used at the same time!" + ).to_compile_error(); + } + + let final_field_variable = if compact { + let field_type = &field.ty; + quote_spanned! { + field.span() => { + <<#field_type as #crate_path::HasCompact>::Type as + #crate_path::EncodeAsRef<'_, #field_type>>::RefType::from(#field_name) + } + } + } else if let Some(encoded_as) = encoded_as { + let field_type = &field.ty; + quote_spanned! { + field.span() => { + <#encoded_as as + #crate_path::EncodeAsRef<'_, #field_type>>::RefType::from(#field_name) + } + } + } else { + quote_spanned! { field.span() => + #field_name + } + }; + + // This may have different hygiene than the field span + let i_self = quote! { self }; + + quote_spanned! { field.span() => + fn encode_to<__CodecOutputEdqy: #crate_path::Output + ?::core::marker::Sized>( + &#i_self, + __codec_dest_edqy: &mut __CodecOutputEdqy + ) { + #crate_path::Encode::encode_to(&#final_field_variable, __codec_dest_edqy) + } + + fn encode(&#i_self) -> #crate_path::alloc::vec::Vec<::core::primitive::u8> { + #crate_path::Encode::encode(&#final_field_variable) + } + + fn using_encoded R>(&#i_self, f: F) -> R { + #crate_path::Encode::using_encoded(&#final_field_variable, f) + } + } +} + +fn encode_fields( + dest: &TokenStream, + fields: &FieldsList, + field_name: F, + crate_path: &syn::Path, +) -> TokenStream where + F: Fn(usize, &Option) -> TokenStream, +{ + let recurse = fields.iter().enumerate().map(|(i, f)| { + let field = field_name(i, &f.ident); + let encoded_as = utils::get_encoded_as_type(f); + let compact = utils::is_compact(f); + let skip = utils::should_skip(&f.attrs); + + if encoded_as.is_some() as u8 + compact as u8 + skip as u8 > 1 { + return Error::new( + f.span(), + "`encoded_as`, `compact` and `skip` can only be used one at a time!" + ).to_compile_error(); + } + + // Based on the seen attribute, we generate the code that encodes the field. + // We call `push` from the `Output` trait on `dest`. + if compact { + let field_type = &f.ty; + quote_spanned! { + f.span() => { + #crate_path::Encode::encode_to( + &< + <#field_type as #crate_path::HasCompact>::Type as + #crate_path::EncodeAsRef<'_, #field_type> + >::RefType::from(#field), + #dest, + ); + } + } + } else if let Some(encoded_as) = encoded_as { + let field_type = &f.ty; + quote_spanned! { + f.span() => { + #crate_path::Encode::encode_to( + &< + #encoded_as as + #crate_path::EncodeAsRef<'_, #field_type> + >::RefType::from(#field), + #dest, + ); + } + } + } else if skip { + quote! { + let _ = #field; + } + } else { + quote_spanned! { f.span() => + #crate_path::Encode::encode_to(#field, #dest); + } + } + }); + + quote! { + #( #recurse )* + } +} + +fn try_impl_encode_single_field_optimisation(data: &Data, crate_path: &syn::Path) -> Option { + match *data { + Data::Struct(ref data) => { + match data.fields { + Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => { + let field = utils::filter_skip_named(fields).next().unwrap(); + let name = &field.ident; + Some(encode_single_field( + field, + quote!(&self.#name), + crate_path, + )) + }, + Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => { + let (id, field) = utils::filter_skip_unnamed(fields).next().unwrap(); + let id = syn::Index::from(id); + + Some(encode_single_field( + field, + quote!(&self.#id), + crate_path, + )) + }, + _ => None, + } + }, + _ => None, + } +} + +fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenStream { + let self_ = quote!(self); + let dest = "e!(__codec_dest_edqy); + let encoding = match *data { + Data::Struct(ref data) => { + match data.fields { + Fields::Named(ref fields) => encode_fields( + dest, + &fields.named, + |_, name| quote!(&#self_.#name), + crate_path, + ), + Fields::Unnamed(ref fields) => encode_fields( + dest, + &fields.unnamed, + |i, _| { + let i = syn::Index::from(i); + quote!(&#self_.#i) + }, + crate_path, + ), + Fields::Unit => quote!(), + } + }, + Data::Enum(ref data) => { + let data_variants = || data.variants.iter().filter(|variant| !utils::should_skip(&variant.attrs)); + + if data_variants().count() > 256 { + return Error::new( + data.variants.span(), + "Currently only enums with at most 256 variants are encodable." + ).to_compile_error(); + } + + // If the enum has no variants, we don't need to encode anything. + if data_variants().count() == 0 { + return quote!(); + } + + let recurse = data_variants().enumerate().map(|(i, f)| { + let name = &f.ident; + let index = utils::variant_index(f, i); + + match f.fields { + Fields::Named(ref fields) => { + let field_name = |_, ident: &Option| quote!(#ident); + let names = fields.named + .iter() + .enumerate() + .map(|(i, f)| field_name(i, &f.ident)); + + let encode_fields = encode_fields( + dest, + &fields.named, + |a, b| field_name(a, b), + crate_path, + ); + + quote_spanned! { f.span() => + #type_name :: #name { #( ref #names, )* } => { + #dest.push_byte(#index as ::core::primitive::u8); + #encode_fields + } + } + }, + Fields::Unnamed(ref fields) => { + let field_name = |i, _: &Option| { + let data = stringify(i as u8); + let ident = from_utf8(&data).expect("We never go beyond ASCII"); + let ident = Ident::new(ident, Span::call_site()); + quote!(#ident) + }; + let names = fields.unnamed + .iter() + .enumerate() + .map(|(i, f)| field_name(i, &f.ident)); + + let encode_fields = encode_fields( + dest, + &fields.unnamed, + |a, b| field_name(a, b), + crate_path, + ); + + quote_spanned! { f.span() => + #type_name :: #name ( #( ref #names, )* ) => { + #dest.push_byte(#index as ::core::primitive::u8); + #encode_fields + } + } + }, + Fields::Unit => { + quote_spanned! { f.span() => + #type_name :: #name => { + #dest.push_byte(#index as ::core::primitive::u8); + } + } + }, + } + }); + + quote! { + match *#self_ { + #( #recurse )*, + _ => (), + } + } + }, + Data::Union(ref data) => Error::new( + data.union_token.span(), + "Union types are not supported." + ).to_compile_error(), + }; + quote! { + fn encode_to<__CodecOutputEdqy: #crate_path::Output + ?::core::marker::Sized>( + &#self_, + #dest: &mut __CodecOutputEdqy + ) { + #encoding + } + } +} + +pub fn quote(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenStream { + if let Some(implementation) = try_impl_encode_single_field_optimisation(data, crate_path) { + implementation + } else { + impl_encode(data, type_name, crate_path) + } +} + +pub fn stringify(id: u8) -> [u8; 2] { + const CHARS: &[u8] = b"abcdefghijklmnopqrstuvwxyz"; + let len = CHARS.len() as u8; + let symbol = |id: u8| CHARS[(id % len) as usize]; + let a = symbol(id); + let b = symbol(id / len); + + [a, b] +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9d80304 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,361 @@ +// Copyright 2017-2021 Parity Technologies +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Derives serialization and deserialization codec for complex structs for simple marshalling. + +#![recursion_limit = "128"] +extern crate proc_macro; + +#[macro_use] +extern crate syn; + +#[macro_use] +extern crate quote; + +use crate::utils::{codec_crate_path, is_lint_attribute}; +use syn::{spanned::Spanned, Data, DeriveInput, Error, Field, Fields}; + +mod decode; +mod encode; +mod max_encoded_len; +mod trait_bounds; +mod utils; + +/// Wraps the impl block in a "dummy const" +fn wrap_with_dummy_const( + input: DeriveInput, + impl_block: proc_macro2::TokenStream, +) -> proc_macro::TokenStream { + let attrs = input.attrs.into_iter().filter(is_lint_attribute); + let generated = quote! { + #[allow(deprecated)] + const _: () = { + #(#attrs)* + #impl_block + }; + }; + + generated.into() +} + +/// Derive `parity_scale_codec::Encode` and `parity_scale_codec::EncodeLike` for struct and enum. +/// +/// # Top level attributes +/// +/// By default the macro will add [`Encode`] and [`Decode`] bounds to all types, but the bounds can +/// be specified manually with the top level attributes: +/// * `#[codec(encode_bound(T: Encode))]`: a custom bound added to the `where`-clause when deriving +/// the `Encode` trait, overriding the default. +/// * `#[codec(decode_bound(T: Decode))]`: a custom bound added to the `where`-clause when deriving +/// the `Decode` trait, overriding the default. +/// +/// # Struct +/// +/// A struct is encoded by encoding each of its fields successively. +/// +/// Fields can have some attributes: +/// * `#[codec(skip)]`: the field is not encoded. It must derive `Default` if Decode is derived. +/// * `#[codec(compact)]`: the field is encoded in its compact representation i.e. the field must +/// implement `parity_scale_codec::HasCompact` and will be encoded as `HasCompact::Type`. +/// * `#[codec(encoded_as = "$EncodeAs")]`: the field is encoded as an alternative type. $EncodedAs +/// type must implement `parity_scale_codec::EncodeAsRef<'_, $FieldType>` with $FieldType the type +/// of the field with the attribute. This is intended to be used for types implementing +/// `HasCompact` as shown in the example. +/// +/// ``` +/// # use parity_scale_codec_derive::Encode; +/// # use parity_scale_codec::{Encode as _, HasCompact}; +/// #[derive(Encode)] +/// struct StructType { +/// #[codec(skip)] +/// a: u32, +/// #[codec(compact)] +/// b: u32, +/// #[codec(encoded_as = "::Type")] +/// c: u32, +/// } +/// ``` +/// +/// # Enum +/// +/// The variable is encoded with one byte for the variant and then the variant struct encoding. +/// The variant number is: +/// * if variant has attribute: `#[codec(index = "$n")]` then n +/// * else if variant has discrimant (like 3 in `enum T { A = 3 }`) then the discrimant. +/// * else its position in the variant set, excluding skipped variants, but including variant with +/// discrimant or attribute. Warning this position does collision with discrimant or attribute +/// index. +/// +/// variant attributes: +/// * `#[codec(skip)]`: the variant is not encoded. +/// * `#[codec(index = "$n")]`: override variant index. +/// +/// field attributes: same as struct fields attributes. +/// +/// ``` +/// # use parity_scale_codec_derive::Encode; +/// # use parity_scale_codec::Encode as _; +/// #[derive(Encode)] +/// enum EnumType { +/// #[codec(index = 15)] +/// A, +/// #[codec(skip)] +/// B, +/// C = 3, +/// D, +/// } +/// +/// assert_eq!(EnumType::A.encode(), vec![15]); +/// assert_eq!(EnumType::B.encode(), vec![]); +/// assert_eq!(EnumType::C.encode(), vec![3]); +/// assert_eq!(EnumType::D.encode(), vec![2]); +/// ``` +#[proc_macro_derive(Encode, attributes(codec))] +pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let mut input: DeriveInput = match syn::parse(input) { + Ok(input) => input, + Err(e) => return e.to_compile_error().into(), + }; + + if let Err(e) = utils::check_attributes(&input) { + return e.to_compile_error().into() + } + + let crate_path = match codec_crate_path(&input.attrs) { + Ok(crate_path) => crate_path, + Err(error) => return error.into_compile_error().into(), + }; + + if let Err(e) = trait_bounds::add( + &input.ident, + &mut input.generics, + &input.data, + utils::custom_encode_trait_bound(&input.attrs), + parse_quote!(#crate_path::Encode), + None, + utils::has_dumb_trait_bound(&input.attrs), + &crate_path, + ) { + return e.to_compile_error().into() + } + + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let encode_impl = encode::quote(&input.data, name, &crate_path); + + let impl_block = quote! { + #[automatically_derived] + impl #impl_generics #crate_path::Encode for #name #ty_generics #where_clause { + #encode_impl + } + + #[automatically_derived] + impl #impl_generics #crate_path::EncodeLike for #name #ty_generics #where_clause {} + }; + + wrap_with_dummy_const(input, impl_block) +} + +/// Derive `parity_scale_codec::Decode` and for struct and enum. +/// +/// see derive `Encode` documentation. +#[proc_macro_derive(Decode, attributes(codec))] +pub fn decode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let mut input: DeriveInput = match syn::parse(input) { + Ok(input) => input, + Err(e) => return e.to_compile_error().into(), + }; + + if let Err(e) = utils::check_attributes(&input) { + return e.to_compile_error().into() + } + + let crate_path = match codec_crate_path(&input.attrs) { + Ok(crate_path) => crate_path, + Err(error) => return error.into_compile_error().into(), + }; + + if let Err(e) = trait_bounds::add( + &input.ident, + &mut input.generics, + &input.data, + utils::custom_decode_trait_bound(&input.attrs), + parse_quote!(#crate_path::Decode), + Some(parse_quote!(Default)), + utils::has_dumb_trait_bound(&input.attrs), + &crate_path, + ) { + return e.to_compile_error().into() + } + + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let ty_gen_turbofish = ty_generics.as_turbofish(); + + let input_ = quote!(__codec_input_edqy); + let decoding = + decode::quote(&input.data, name, "e!(#ty_gen_turbofish), &input_, &crate_path); + + let impl_block = quote! { + #[automatically_derived] + impl #impl_generics #crate_path::Decode for #name #ty_generics #where_clause { + fn decode<__CodecInputEdqy: #crate_path::Input>( + #input_: &mut __CodecInputEdqy + ) -> ::core::result::Result { + #decoding + } + } + }; + + wrap_with_dummy_const(input, impl_block) +} + +/// Derive `parity_scale_codec::Compact` and `parity_scale_codec::CompactAs` for struct with single +/// field. +/// +/// Attribute skip can be used to skip other fields. +/// +/// # Example +/// +/// ``` +/// # use parity_scale_codec_derive::CompactAs; +/// # use parity_scale_codec::{Encode, HasCompact}; +/// # use std::marker::PhantomData; +/// #[derive(CompactAs)] +/// struct MyWrapper(u32, #[codec(skip)] PhantomData); +/// ``` +#[proc_macro_derive(CompactAs, attributes(codec))] +pub fn compact_as_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let mut input: DeriveInput = match syn::parse(input) { + Ok(input) => input, + Err(e) => return e.to_compile_error().into(), + }; + + if let Err(e) = utils::check_attributes(&input) { + return e.to_compile_error().into() + } + + let crate_path = match codec_crate_path(&input.attrs) { + Ok(crate_path) => crate_path, + Err(error) => return error.into_compile_error().into(), + }; + + if let Err(e) = trait_bounds::add::<()>( + &input.ident, + &mut input.generics, + &input.data, + None, + parse_quote!(#crate_path::CompactAs), + None, + utils::has_dumb_trait_bound(&input.attrs), + &crate_path, + ) { + return e.to_compile_error().into() + } + + let name = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + fn val_or_default(field: &Field) -> proc_macro2::TokenStream { + let skip = utils::should_skip(&field.attrs); + if skip { + quote_spanned!(field.span()=> Default::default()) + } else { + quote_spanned!(field.span()=> x) + } + } + + let (inner_ty, inner_field, constructor) = match input.data { + Data::Struct(ref data) => match data.fields { + Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => { + let recurse = fields.named.iter().map(|f| { + let name_ident = &f.ident; + let val_or_default = val_or_default(&f); + quote_spanned!(f.span()=> #name_ident: #val_or_default) + }); + let field = utils::filter_skip_named(fields).next().expect("Exactly one field"); + let field_name = &field.ident; + let constructor = quote!( #name { #( #recurse, )* }); + (&field.ty, quote!(&self.#field_name), constructor) + }, + Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => { + let recurse = fields.unnamed.iter().enumerate().map(|(_, f)| { + let val_or_default = val_or_default(&f); + quote_spanned!(f.span()=> #val_or_default) + }); + let (id, field) = + utils::filter_skip_unnamed(fields).next().expect("Exactly one field"); + let id = syn::Index::from(id); + let constructor = quote!( #name(#( #recurse, )*)); + (&field.ty, quote!(&self.#id), constructor) + }, + _ => + return Error::new( + data.fields.span(), + "Only structs with a single non-skipped field can derive CompactAs", + ) + .to_compile_error() + .into(), + }, + Data::Enum(syn::DataEnum { enum_token: syn::token::Enum { span }, .. }) | + Data::Union(syn::DataUnion { union_token: syn::token::Union { span }, .. }) => + return Error::new(span, "Only structs can derive CompactAs").to_compile_error().into(), + }; + + let impl_block = quote! { + #[automatically_derived] + impl #impl_generics #crate_path::CompactAs for #name #ty_generics #where_clause { + type As = #inner_ty; + fn encode_as(&self) -> &#inner_ty { + #inner_field + } + fn decode_from(x: #inner_ty) + -> ::core::result::Result<#name #ty_generics, #crate_path::Error> + { + ::core::result::Result::Ok(#constructor) + } + } + + #[automatically_derived] + impl #impl_generics From<#crate_path::Compact<#name #ty_generics>> + for #name #ty_generics #where_clause + { + fn from(x: #crate_path::Compact<#name #ty_generics>) -> #name #ty_generics { + x.0 + } + } + }; + + wrap_with_dummy_const(input, impl_block) +} + +/// Derive `parity_scale_codec::MaxEncodedLen` for struct and enum. +/// +/// # Top level attribute +/// +/// By default the macro will try to bound the types needed to implement `MaxEncodedLen`, but the +/// bounds can be specified manually with the top level attribute: +/// ``` +/// # use parity_scale_codec_derive::Encode; +/// # use parity_scale_codec::MaxEncodedLen; +/// # #[derive(Encode, MaxEncodedLen)] +/// #[codec(mel_bound(T: MaxEncodedLen))] +/// # struct MyWrapper(T); +/// ``` +#[cfg(feature = "max-encoded-len")] +#[proc_macro_derive(MaxEncodedLen, attributes(max_encoded_len_mod))] +pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + max_encoded_len::derive_max_encoded_len(input) +} diff --git a/src/max_encoded_len.rs b/src/max_encoded_len.rs new file mode 100644 index 0000000..bf8b20f --- /dev/null +++ b/src/max_encoded_len.rs @@ -0,0 +1,141 @@ +// Copyright (C) 2021 Parity Technologies (UK) Ltd. +// SPDX-License-Identifier: Apache-2.0 + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![cfg(feature = "max-encoded-len")] + +use crate::{ + trait_bounds, + utils::{codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip}, +}; +use quote::{quote, quote_spanned}; +use syn::{parse_quote, spanned::Spanned, Data, DeriveInput, Fields, Type}; + +/// impl for `#[derive(MaxEncodedLen)]` +pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let mut input: DeriveInput = match syn::parse(input) { + Ok(input) => input, + Err(e) => return e.to_compile_error().into(), + }; + + let crate_path = match codec_crate_path(&input.attrs) { + Ok(crate_path) => crate_path, + Err(error) => return error.into_compile_error().into(), + }; + + let name = &input.ident; + if let Err(e) = trait_bounds::add( + &input.ident, + &mut input.generics, + &input.data, + custom_mel_trait_bound(&input.attrs), + parse_quote!(#crate_path::MaxEncodedLen), + None, + has_dumb_trait_bound(&input.attrs), + &crate_path + ) { + return e.to_compile_error().into() + } + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let data_expr = data_length_expr(&input.data); + + quote::quote!( + const _: () = { + impl #impl_generics #crate_path::MaxEncodedLen for #name #ty_generics #where_clause { + fn max_encoded_len() -> ::core::primitive::usize { + #data_expr + } + } + }; + ) + .into() +} + +/// generate an expression to sum up the max encoded length from several fields +fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream { + let type_iter: Box> = match fields { + Fields::Named(ref fields) => Box::new( + fields.named.iter().filter_map(|field| if should_skip(&field.attrs) { + None + } else { + Some(&field.ty) + }) + ), + Fields::Unnamed(ref fields) => Box::new( + fields.unnamed.iter().filter_map(|field| if should_skip(&field.attrs) { + None + } else { + Some(&field.ty) + }) + ), + Fields::Unit => Box::new(std::iter::empty()), + }; + // expands to an expression like + // + // 0 + // .saturating_add(::max_encoded_len()) + // .saturating_add(::max_encoded_len()) + // + // We match the span of each field to the span of the corresponding + // `max_encoded_len` call. This way, if one field's type doesn't implement + // `MaxEncodedLen`, the compiler's error message will underline which field + // caused the issue. + let expansion = type_iter.map(|ty| { + quote_spanned! { + ty.span() => .saturating_add(<#ty>::max_encoded_len()) + } + }); + quote! { + 0_usize #( #expansion )* + } +} + +// generate an expression to sum up the max encoded length of each field +fn data_length_expr(data: &Data) -> proc_macro2::TokenStream { + match *data { + Data::Struct(ref data) => fields_length_expr(&data.fields), + Data::Enum(ref data) => { + // We need an expression expanded for each variant like + // + // 0 + // .max() + // .max() + // .saturating_add(1) + // + // The 1 derives from the discriminant; see + // https://github.com/paritytech/parity-scale-codec/ + // blob/f0341dabb01aa9ff0548558abb6dcc5c31c669a1/derive/src/encode.rs#L211-L216 + // + // Each variant expression's sum is computed the way an equivalent struct's would be. + + let expansion = data.variants.iter().map(|variant| { + let variant_expression = fields_length_expr(&variant.fields); + quote! { + .max(#variant_expression) + } + }); + + quote! { + 0_usize #( #expansion )* .saturating_add(1) + } + }, + Data::Union(ref data) => { + // https://github.com/paritytech/parity-scale-codec/ + // blob/f0341dabb01aa9ff0548558abb6dcc5c31c669a1/derive/src/encode.rs#L290-L293 + syn::Error::new(data.union_token.span(), "Union types are not supported.") + .to_compile_error() + }, + } +} diff --git a/src/trait_bounds.rs b/src/trait_bounds.rs new file mode 100644 index 0000000..f808588 --- /dev/null +++ b/src/trait_bounds.rs @@ -0,0 +1,248 @@ +// Copyright 2019 Parity Technologies +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::iter; + +use proc_macro2::Ident; +use syn::{ + spanned::Spanned, + visit::{self, Visit}, + Generics, Result, Type, TypePath, +}; + +use crate::utils::{self, CustomTraitBound}; + +/// Visits the ast and checks if one of the given idents is found. +struct ContainIdents<'a> { + result: bool, + idents: &'a [Ident], +} + +impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> { + fn visit_ident(&mut self, i: &'ast Ident) { + if self.idents.iter().any(|id| id == i) { + self.result = true; + } + } +} + +/// Checks if the given type contains one of the given idents. +fn type_contain_idents(ty: &Type, idents: &[Ident]) -> bool { + let mut visitor = ContainIdents { result: false, idents }; + visitor.visit_type(ty); + visitor.result +} + +/// Visits the ast and checks if the a type path starts with the given ident. +struct TypePathStartsWithIdent<'a> { + result: bool, + ident: &'a Ident, +} + +impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> { + fn visit_type_path(&mut self, i: &'ast TypePath) { + if let Some(segment) = i.path.segments.first() { + if &segment.ident == self.ident { + self.result = true; + return + } + } + + visit::visit_type_path(self, i); + } +} + +/// Checks if the given type path or any containing type path starts with the given ident. +fn type_path_or_sub_starts_with_ident(ty: &TypePath, ident: &Ident) -> bool { + let mut visitor = TypePathStartsWithIdent { result: false, ident }; + visitor.visit_type_path(ty); + visitor.result +} + +/// Checks if the given type or any containing type path starts with the given ident. +fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool { + let mut visitor = TypePathStartsWithIdent { result: false, ident }; + visitor.visit_type(ty); + visitor.result +} + +/// Visits the ast and collects all type paths that do not start or contain the given ident. +/// +/// Returns `T`, `N`, `A` for `Vec<(Recursive, A)>` with `Recursive` as ident. +struct FindTypePathsNotStartOrContainIdent<'a> { + result: Vec, + ident: &'a Ident, +} + +impl<'a, 'ast> Visit<'ast> for FindTypePathsNotStartOrContainIdent<'a> { + fn visit_type_path(&mut self, i: &'ast TypePath) { + if type_path_or_sub_starts_with_ident(i, &self.ident) { + visit::visit_type_path(self, i); + } else { + self.result.push(i.clone()); + } + } +} + +/// Collects all type paths that do not start or contain the given ident in the given type. +/// +/// Returns `T`, `N`, `A` for `Vec<(Recursive, A)>` with `Recursive` as ident. +fn find_type_paths_not_start_or_contain_ident(ty: &Type, ident: &Ident) -> Vec { + let mut visitor = FindTypePathsNotStartOrContainIdent { result: Vec::new(), ident }; + visitor.visit_type(ty); + visitor.result +} + +/// Add required trait bounds to all generic types. +pub fn add( + input_ident: &Ident, + generics: &mut Generics, + data: &syn::Data, + custom_trait_bound: Option>, + codec_bound: syn::Path, + codec_skip_bound: Option, + dumb_trait_bounds: bool, + crate_path: &syn::Path, +) -> Result<()> { + let skip_type_params = match custom_trait_bound { + Some(CustomTraitBound::SpecifiedBounds { bounds, .. }) => { + generics.make_where_clause().predicates.extend(bounds); + return Ok(()) + }, + Some(CustomTraitBound::SkipTypeParams { type_names, .. }) => + type_names.into_iter().collect::>(), + None => Vec::new(), + }; + + let ty_params = generics + .type_params() + .filter_map(|tp| { + skip_type_params.iter().all(|skip| skip != &tp.ident).then(|| tp.ident.clone()) + }) + .collect::>(); + if ty_params.is_empty() { + return Ok(()) + } + + let codec_types = + get_types_to_add_trait_bound(input_ident, data, &ty_params, dumb_trait_bounds)?; + + let compact_types = collect_types(&data, utils::is_compact)? + .into_iter() + // Only add a bound if the type uses a generic + .filter(|ty| type_contain_idents(ty, &ty_params)) + .collect::>(); + + let skip_types = if codec_skip_bound.is_some() { + let needs_default_bound = |f: &syn::Field| utils::should_skip(&f.attrs); + collect_types(&data, needs_default_bound)? + .into_iter() + // Only add a bound if the type uses a generic + .filter(|ty| type_contain_idents(ty, &ty_params)) + .collect::>() + } else { + Vec::new() + }; + + if !codec_types.is_empty() || !compact_types.is_empty() || !skip_types.is_empty() { + let where_clause = generics.make_where_clause(); + + codec_types + .into_iter() + .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #codec_bound))); + + let has_compact_bound: syn::Path = parse_quote!(#crate_path::HasCompact); + compact_types + .into_iter() + .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #has_compact_bound))); + + skip_types.into_iter().for_each(|ty| { + let codec_skip_bound = codec_skip_bound.as_ref(); + where_clause.predicates.push(parse_quote!(#ty : #codec_skip_bound)) + }); + } + + Ok(()) +} + +/// Returns all types that must be added to the where clause with the respective trait bound. +fn get_types_to_add_trait_bound( + input_ident: &Ident, + data: &syn::Data, + ty_params: &[Ident], + dumb_trait_bound: bool, +) -> Result> { + if dumb_trait_bound { + Ok(ty_params.iter().map(|t| parse_quote!( #t )).collect()) + } else { + let needs_codec_bound = |f: &syn::Field| { + !utils::is_compact(f) && + utils::get_encoded_as_type(f).is_none() && + !utils::should_skip(&f.attrs) + }; + let res = collect_types(&data, needs_codec_bound)? + .into_iter() + // Only add a bound if the type uses a generic + .filter(|ty| type_contain_idents(ty, &ty_params)) + // If a struct contains itself as field type, we can not add this type into the where + // clause. This is required to work a round the following compiler bug: https://github.com/rust-lang/rust/issues/47032 + .flat_map(|ty| { + find_type_paths_not_start_or_contain_ident(&ty, input_ident) + .into_iter() + .map(|ty| Type::Path(ty.clone())) + // Remove again types that do not contain any of our generic parameters + .filter(|ty| type_contain_idents(ty, &ty_params)) + // Add back the original type, as we don't want to loose it. + .chain(iter::once(ty)) + }) + // Remove all remaining types that start/contain the input ident to not have them in the + // where clause. + .filter(|ty| !type_or_sub_type_path_starts_with_ident(ty, input_ident)) + .collect(); + + Ok(res) + } +} + +fn collect_types(data: &syn::Data, type_filter: fn(&syn::Field) -> bool) -> Result> { + use syn::*; + + let types = match *data { + Data::Struct(ref data) => match &data.fields { + | Fields::Named(FieldsNamed { named: fields, .. }) | + Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => + fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(), + + Fields::Unit => Vec::new(), + }, + + Data::Enum(ref data) => data + .variants + .iter() + .filter(|variant| !utils::should_skip(&variant.attrs)) + .flat_map(|variant| match &variant.fields { + | Fields::Named(FieldsNamed { named: fields, .. }) | + Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => + fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(), + + Fields::Unit => Vec::new(), + }) + .collect(), + + Data::Union(ref data) => + return Err(Error::new(data.union_token.span(), "Union types are not supported.")), + }; + + Ok(types) +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..585ce9a --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,433 @@ +// Copyright 2018-2020 Parity Technologies +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Various internal utils. +//! +//! NOTE: attributes finder must be checked using check_attribute first, +//! otherwise the macro can panic. + +use std::str::FromStr; + +use proc_macro2::TokenStream; +use quote::quote; +use syn::{ + parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput, + Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant, +}; + +fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option +where + F: FnMut(M) -> Option + Clone, + I: Iterator, + M: Parse, +{ + itr.find_map(|attr| { + attr.path.is_ident("codec").then(|| pred(attr.parse_args().ok()?)).flatten() + }) +} + +/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute +/// is found, fall back to the discriminant or just the variant index. +pub fn variant_index(v: &Variant, i: usize) -> TokenStream { + // first look for an attribute + let index = find_meta_item(v.attrs.iter(), |meta| { + if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta { + if nv.path.is_ident("index") { + if let Lit::Int(ref v) = nv.lit { + let byte = v + .base10_parse::() + .expect("Internal error, index attribute must have been checked"); + return Some(byte) + } + } + } + + None + }); + + // then fallback to discriminant or just index + index.map(|i| quote! { #i }).unwrap_or_else(|| { + v.discriminant + .as_ref() + .map(|&(_, ref expr)| quote! { #expr }) + .unwrap_or_else(|| quote! { #i }) + }) +} + +/// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given +/// `Field`. +pub fn get_encoded_as_type(field: &Field) -> Option { + find_meta_item(field.attrs.iter(), |meta| { + if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta { + if nv.path.is_ident("encoded_as") { + if let Lit::Str(ref s) = nv.lit { + return Some( + TokenStream::from_str(&s.value()) + .expect("Internal error, encoded_as attribute must have been checked"), + ) + } + } + } + + None + }) +} + +/// Look for a `#[codec(compact)]` outer attribute on the given `Field`. +pub fn is_compact(field: &Field) -> bool { + find_meta_item(field.attrs.iter(), |meta| { + if let NestedMeta::Meta(Meta::Path(ref path)) = meta { + if path.is_ident("compact") { + return Some(()) + } + } + + None + }) + .is_some() +} + +/// Look for a `#[codec(skip)]` in the given attributes. +pub fn should_skip(attrs: &[Attribute]) -> bool { + find_meta_item(attrs.iter(), |meta| { + if let NestedMeta::Meta(Meta::Path(ref path)) = meta { + if path.is_ident("skip") { + return Some(path.span()) + } + } + + None + }) + .is_some() +} + +/// Look for a `#[codec(dumb_trait_bound)]`in the given attributes. +pub fn has_dumb_trait_bound(attrs: &[Attribute]) -> bool { + find_meta_item(attrs.iter(), |meta| { + if let NestedMeta::Meta(Meta::Path(ref path)) = meta { + if path.is_ident("dumb_trait_bound") { + return Some(()) + } + } + + None + }) + .is_some() +} + +/// Generate the crate access for the crate using 2018 syntax. +fn crate_access() -> syn::Result { + use proc_macro2::{Ident, Span}; + use proc_macro_crate::{crate_name, FoundCrate}; + const DEF_CRATE: &str = "parity-scale-codec"; + match crate_name(DEF_CRATE) { + Ok(FoundCrate::Itself) => { + let name = DEF_CRATE.to_string().replace("-", "_"); + Ok(syn::Ident::new(&name, Span::call_site())) + }, + Ok(FoundCrate::Name(name)) => Ok(Ident::new(&name, Span::call_site())), + Err(e) => Err(syn::Error::new(Span::call_site(), e)), + } +} + +/// This struct matches `crate = ...` where the ellipsis is a `Path`. +struct CratePath { + _crate_token: Token![crate], + _eq_token: Token![=], + path: Path, +} + +impl Parse for CratePath { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(CratePath { + _crate_token: input.parse()?, + _eq_token: input.parse()?, + path: input.parse()?, + }) + } +} + +impl From for Path { + fn from(CratePath { path, .. }: CratePath) -> Self { + path + } +} + +/// Match `#[codec(crate = ...)]` and return the `...` if it is a `Path`. +fn codec_crate_path_inner(attr: &Attribute) -> Option { + // match `#[codec ...]` + attr.path + .is_ident("codec") + .then(|| { + // match `#[codec(crate = ...)]` and return the `...` + attr.parse_args::().map(Into::into).ok() + }) + .flatten() +} + +/// Match `#[codec(crate = ...)]` and return the ellipsis as a `Path`. +/// +/// If not found, returns the default crate access pattern. +/// +/// If multiple items match the pattern, all but the first are ignored. +pub fn codec_crate_path(attrs: &[Attribute]) -> syn::Result { + match attrs.iter().find_map(codec_crate_path_inner) { + Some(path) => Ok(path), + None => crate_access().map(|ident| parse_quote!(::#ident)), + } +} + +/// Parse `name(T: Bound, N: Bound)` or `name(skip_type_params(T, N))` as a custom trait bound. +pub enum CustomTraitBound { + SpecifiedBounds { + _name: N, + _paren_token: token::Paren, + bounds: Punctuated, + }, + SkipTypeParams { + _name: N, + _paren_token_1: token::Paren, + _skip_type_params: skip_type_params, + _paren_token_2: token::Paren, + type_names: Punctuated, + }, +} + +impl Parse for CustomTraitBound { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut content; + let _name: N = input.parse()?; + let _paren_token = syn::parenthesized!(content in input); + if content.peek(skip_type_params) { + Ok(Self::SkipTypeParams { + _name, + _paren_token_1: _paren_token, + _skip_type_params: content.parse::()?, + _paren_token_2: syn::parenthesized!(content in content), + type_names: content.parse_terminated(syn::Ident::parse)?, + }) + } else { + Ok(Self::SpecifiedBounds { + _name, + _paren_token, + bounds: content.parse_terminated(syn::WherePredicate::parse)?, + }) + } + } +} + +syn::custom_keyword!(encode_bound); +syn::custom_keyword!(decode_bound); +syn::custom_keyword!(mel_bound); +syn::custom_keyword!(skip_type_params); + +/// Look for a `#[codec(decode_bound(T: Decode))]` in the given attributes. +/// +/// If found, it should be used as trait bounds when deriving the `Decode` trait. +pub fn custom_decode_trait_bound(attrs: &[Attribute]) -> Option> { + find_meta_item(attrs.iter(), Some) +} + +/// Look for a `#[codec(encode_bound(T: Encode))]` in the given attributes. +/// +/// If found, it should be used as trait bounds when deriving the `Encode` trait. +pub fn custom_encode_trait_bound(attrs: &[Attribute]) -> Option> { + find_meta_item(attrs.iter(), Some) +} + +/// Look for a `#[codec(mel_bound(T: MaxEncodedLen))]` in the given attributes. +/// +/// If found, it should be used as the trait bounds when deriving the `MaxEncodedLen` trait. +#[cfg(feature = "max-encoded-len")] +pub fn custom_mel_trait_bound(attrs: &[Attribute]) -> Option> { + find_meta_item(attrs.iter(), Some) +} + +/// Given a set of named fields, return an iterator of `Field` where all fields +/// marked `#[codec(skip)]` are filtered out. +pub fn filter_skip_named<'a>(fields: &'a syn::FieldsNamed) -> impl Iterator + 'a { + fields.named.iter().filter(|f| !should_skip(&f.attrs)) +} + +/// Given a set of unnamed fields, return an iterator of `(index, Field)` where all fields +/// marked `#[codec(skip)]` are filtered out. +pub fn filter_skip_unnamed<'a>( + fields: &'a syn::FieldsUnnamed, +) -> impl Iterator + 'a { + fields.unnamed.iter().enumerate().filter(|(_, f)| !should_skip(&f.attrs)) +} + +/// Ensure attributes are correctly applied. This *must* be called before using +/// any of the attribute finder methods or the macro may panic if it encounters +/// misapplied attributes. +/// +/// The top level can have the following attributes: +/// +/// * `#[codec(dumb_trait_bound)]` +/// * `#[codec(encode_bound(T: Encode))]` +/// * `#[codec(decode_bound(T: Decode))]` +/// * `#[codec(mel_bound(T: MaxEncodedLen))]` +/// * `#[codec(crate = path::to::crate)] +/// +/// Fields can have the following attributes: +/// +/// * `#[codec(skip)]` +/// * `#[codec(compact)]` +/// * `#[codec(encoded_as = "$EncodeAs")]` with $EncodedAs a valid TokenStream +/// +/// Variants can have the following attributes: +/// +/// * `#[codec(skip)]` +/// * `#[codec(index = $int)]` +pub fn check_attributes(input: &DeriveInput) -> syn::Result<()> { + for attr in &input.attrs { + check_top_attribute(attr)?; + } + + match input.data { + Data::Struct(ref data) => match &data.fields { + | Fields::Named(FieldsNamed { named: fields, .. }) | + Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) => + for field in fields { + for attr in &field.attrs { + check_field_attribute(attr)?; + } + }, + Fields::Unit => (), + }, + Data::Enum(ref data) => + for variant in data.variants.iter() { + for attr in &variant.attrs { + check_variant_attribute(attr)?; + } + for field in &variant.fields { + for attr in &field.attrs { + check_field_attribute(attr)?; + } + } + }, + Data::Union(_) => (), + } + Ok(()) +} + +// Check if the attribute is `#[allow(..)]`, `#[deny(..)]`, `#[forbid(..)]` or `#[warn(..)]`. +pub fn is_lint_attribute(attr: &Attribute) -> bool { + attr.path.is_ident("allow") || + attr.path.is_ident("deny") || + attr.path.is_ident("forbid") || + attr.path.is_ident("warn") +} + +// Ensure a field is decorated only with the following attributes: +// * `#[codec(skip)]` +// * `#[codec(compact)]` +// * `#[codec(encoded_as = "$EncodeAs")]` with $EncodedAs a valid TokenStream +fn check_field_attribute(attr: &Attribute) -> syn::Result<()> { + let field_error = "Invalid attribute on field, only `#[codec(skip)]`, `#[codec(compact)]` and \ + `#[codec(encoded_as = \"$EncodeAs\")]` are accepted."; + + if attr.path.is_ident("codec") { + match attr.parse_meta()? { + Meta::List(ref meta_list) if meta_list.nested.len() == 1 => { + match meta_list.nested.first().expect("Just checked that there is one item; qed") { + NestedMeta::Meta(Meta::Path(path)) + if path.get_ident().map_or(false, |i| i == "skip") => + Ok(()), + + NestedMeta::Meta(Meta::Path(path)) + if path.get_ident().map_or(false, |i| i == "compact") => + Ok(()), + + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(lit_str), + .. + })) if path.get_ident().map_or(false, |i| i == "encoded_as") => + TokenStream::from_str(&lit_str.value()) + .map(|_| ()) + .map_err(|_e| syn::Error::new(lit_str.span(), "Invalid token stream")), + + elt @ _ => Err(syn::Error::new(elt.span(), field_error)), + } + }, + meta @ _ => Err(syn::Error::new(meta.span(), field_error)), + } + } else { + Ok(()) + } +} + +// Ensure a field is decorated only with the following attributes: +// * `#[codec(skip)]` +// * `#[codec(index = $int)]` +fn check_variant_attribute(attr: &Attribute) -> syn::Result<()> { + let variant_error = "Invalid attribute on variant, only `#[codec(skip)]` and \ + `#[codec(index = $u8)]` are accepted."; + + if attr.path.is_ident("codec") { + match attr.parse_meta()? { + Meta::List(ref meta_list) if meta_list.nested.len() == 1 => { + match meta_list.nested.first().expect("Just checked that there is one item; qed") { + NestedMeta::Meta(Meta::Path(path)) + if path.get_ident().map_or(false, |i| i == "skip") => + Ok(()), + + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Int(lit_int), + .. + })) if path.get_ident().map_or(false, |i| i == "index") => lit_int + .base10_parse::() + .map(|_| ()) + .map_err(|_| syn::Error::new(lit_int.span(), "Index must be in 0..255")), + + elt @ _ => Err(syn::Error::new(elt.span(), variant_error)), + } + }, + meta @ _ => Err(syn::Error::new(meta.span(), variant_error)), + } + } else { + Ok(()) + } +} + +// Only `#[codec(dumb_trait_bound)]` is accepted as top attribute +fn check_top_attribute(attr: &Attribute) -> syn::Result<()> { + let top_error = "Invalid attribute: only `#[codec(dumb_trait_bound)]`, \ + `#[codec(crate = path::to::crate)]`, `#[codec(encode_bound(T: Encode))]`, \ + `#[codec(decode_bound(T: Decode))]`, or `#[codec(mel_bound(T: MaxEncodedLen))]` \ + are accepted as top attribute"; + if attr.path.is_ident("codec") && + attr.parse_args::>().is_err() && + attr.parse_args::>().is_err() && + attr.parse_args::>().is_err() && + codec_crate_path_inner(attr).is_none() + { + match attr.parse_meta()? { + Meta::List(ref meta_list) if meta_list.nested.len() == 1 => { + match meta_list.nested.first().expect("Just checked that there is one item; qed") { + NestedMeta::Meta(Meta::Path(path)) + if path.get_ident().map_or(false, |i| i == "dumb_trait_bound") => + Ok(()), + + elt @ _ => Err(syn::Error::new(elt.span(), top_error)), + } + }, + _ => Err(syn::Error::new(attr.span(), top_error)), + } + } else { + Ok(()) + } +} -- 2.34.1