From 8a7fe7380d0753048b4975dd4aed06b642bd3b32 Mon Sep 17 00:00:00 2001 From: DongHun Kwak Date: Thu, 20 Apr 2023 09:45:05 +0900 Subject: [PATCH 1/1] Import rstest_macros 0.17.0 --- .cargo_vcs_info.json | 6 + Cargo.toml | 81 ++ Cargo.toml.orig | 39 + README.md | 21 + build.rs | 30 + src/error.rs | 271 ++++++ src/lib.rs | 1078 ++++++++++++++++++++++++ src/parse/expressions.rs | 28 + src/parse/fixture.rs | 748 +++++++++++++++++ src/parse/future.rs | 260 ++++++ src/parse/macros.rs | 27 + src/parse/mod.rs | 826 +++++++++++++++++++ src/parse/rstest.rs | 935 +++++++++++++++++++++ src/parse/testcase.rs | 162 ++++ src/parse/vlist.rs | 105 +++ src/refident.rs | 86 ++ src/render/apply_argumets.rs | 249 ++++++ src/render/fixture.rs | 575 +++++++++++++ src/render/inject.rs | 205 +++++ src/render/mod.rs | 434 ++++++++++ src/render/test.rs | 1855 ++++++++++++++++++++++++++++++++++++++++++ src/render/wrapper.rs | 19 + src/resolver.rs | 174 ++++ src/test.rs | 328 ++++++++ src/utils.rs | 390 +++++++++ 25 files changed, 8932 insertions(+) create mode 100644 .cargo_vcs_info.json create mode 100644 Cargo.toml create mode 100644 Cargo.toml.orig create mode 100644 README.md create mode 100644 build.rs create mode 100644 src/error.rs create mode 100644 src/lib.rs create mode 100644 src/parse/expressions.rs create mode 100644 src/parse/fixture.rs create mode 100644 src/parse/future.rs create mode 100644 src/parse/macros.rs create mode 100644 src/parse/mod.rs create mode 100644 src/parse/rstest.rs create mode 100644 src/parse/testcase.rs create mode 100644 src/parse/vlist.rs create mode 100644 src/refident.rs create mode 100644 src/render/apply_argumets.rs create mode 100644 src/render/fixture.rs create mode 100644 src/render/inject.rs create mode 100644 src/render/mod.rs create mode 100644 src/render/test.rs create mode 100644 src/render/wrapper.rs create mode 100644 src/resolver.rs create mode 100644 src/test.rs create mode 100644 src/utils.rs diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json new file mode 100644 index 0000000..3142f52 --- /dev/null +++ b/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "7ac624c3a54de3d7a1506441863562371c5a2359" + }, + "path_in_vcs": "rstest_macros" +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..7cd6079 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,81 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2018" +name = "rstest_macros" +version = "0.17.0" +authors = ["Michele d'Amico "] +description = """ +Rust fixture based test framework. It use procedural macro +to implement fixtures and table based tests. +""" +homepage = "https://github.com/la10736/rstest" +readme = "README.md" +keywords = [ + "test", + "fixture", +] +categories = ["development-tools::testing"] +license = "MIT/Apache-2.0" +repository = "https://github.com/la10736/rstest" + +[lib] +proc-macro = true + +[dependencies.cfg-if] +version = "1.0.0" + +[dependencies.proc-macro2] +version = "1.0.39" + +[dependencies.quote] +version = "1.0.19" + +[dependencies.syn] +version = "1.0.98" +features = [ + "full", + "parsing", + "extra-traits", + "visit", + "visit-mut", +] + +[dependencies.unicode-ident] +version = "1.0.5" + +[dev-dependencies.actix-rt] +version = "2.7.0" + +[dev-dependencies.async-std] +version = "1.12.0" +features = ["attributes"] + +[dev-dependencies.pretty_assertions] +version = "1.2.1" + +[dev-dependencies.rstest] +version = "0.16.0" +default-features = false + +[dev-dependencies.rstest_reuse] +version = "0.5.0" + +[dev-dependencies.rstest_test] +version = "0.11.0" + +[build-dependencies.rustc_version] +version = "0.4.0" + +[features] +async-timeout = [] +default = ["async-timeout"] diff --git a/Cargo.toml.orig b/Cargo.toml.orig new file mode 100644 index 0000000..04b12ae --- /dev/null +++ b/Cargo.toml.orig @@ -0,0 +1,39 @@ +[package] +authors = ["Michele d'Amico "] +categories = ["development-tools::testing"] +description = """ +Rust fixture based test framework. It use procedural macro +to implement fixtures and table based tests. +""" +edition = "2018" +homepage = "https://github.com/la10736/rstest" +keywords = ["test", "fixture"] +license = "MIT/Apache-2.0" +name = "rstest_macros" +repository = "https://github.com/la10736/rstest" +version = "0.17.0" + +[lib] +proc-macro = true + +[features] +async-timeout = [] +default = ["async-timeout"] + +[dependencies] +cfg-if = "1.0.0" +proc-macro2 = "1.0.39" +quote = "1.0.19" +syn = {version = "1.0.98", features = ["full", "parsing", "extra-traits", "visit", "visit-mut"]} +unicode-ident = "1.0.5" + +[dev-dependencies] +actix-rt = "2.7.0" +async-std = {version = "1.12.0", features = ["attributes"]} +pretty_assertions = "1.2.1" +rstest = {version = "0.16.0", default-features = false} +rstest_reuse = {version = "0.5.0", path = "../rstest_reuse"} +rstest_test = {version = "0.11.0", path = "../rstest_test"} + +[build-dependencies] +rustc_version = "0.4.0" diff --git a/README.md b/README.md new file mode 100644 index 0000000..04edd46 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +[![Crate][crate-image]][crate-link] +[![Docs][docs-image]][docs-link] +[![Status][test-action-image]][test-action-link] +[![Apache 2.0 Licensed][license-apache-image]][license-apache-link] +[![MIT Licensed][license-mit-image]][license-mit-link] + +# `rstest`'s Macros Crate + +See [`rstest`][crate-link]. + +[crate-image]: https://img.shields.io/crates/v/rstest.svg +[crate-link]: https://crates.io/crates/rstest +[docs-image]: https://docs.rs/rstest/badge.svg +[docs-link]: https://docs.rs/rstest/ +[test-action-image]: https://github.com/la10736/rstest/workflows/Test/badge.svg +[test-action-link]: https://github.com/la10736/rstest/actions?query=workflow:Test +[license-apache-image]: https://img.shields.io/badge/license-Apache2.0-blue.svg +[license-mit-image]: https://img.shields.io/badge/license-MIT-blue.svg +[license-apache-link]: http://www.apache.org/licenses/LICENSE-2.0 +[license-MIT-link]: http://opensource.org/licenses/MIT +[reuse-crate-link]: https://crates.io/crates/rstest_reuse diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..580edcb --- /dev/null +++ b/build.rs @@ -0,0 +1,30 @@ +use rustc_version::{version, version_meta, Channel}; + +fn allow_features() -> Option> { + std::env::var("CARGO_ENCODED_RUSTFLAGS").ok().map(|args| { + args.split('\u{001f}') + .filter(|arg| arg.starts_with("-Zallow-features=")) + .map(|arg| arg.split('=').nth(1).unwrap()) + .flat_map(|features| features.split(',')) + .map(|f| f.to_owned()) + .collect() + }) +} + +fn can_enable_proc_macro_diagnostic() -> bool { + allow_features() + .map(|f| f.iter().any(|f| dbg!(f) == "proc_macro_diagnostic")) + .unwrap_or(true) +} + +fn main() { + let ver = version().unwrap(); + assert!(ver.major >= 1); + + match version_meta().unwrap().channel { + Channel::Nightly | Channel::Dev if can_enable_proc_macro_diagnostic() => { + println!("cargo:rustc-cfg=use_proc_macro_diagnostic"); + } + _ => {} + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..a38eac3 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,271 @@ +/// Module for error rendering stuff +use std::collections::HashMap; + +use proc_macro2::TokenStream; +use syn::{spanned::Spanned, visit::Visit}; +use syn::{visit, ItemFn}; + +use crate::parse::{ + fixture::FixtureInfo, + rstest::{RsTestData, RsTestInfo}, +}; +use crate::refident::MaybeIdent; + +use super::utils::fn_args_has_ident; + +pub(crate) fn rstest(test: &ItemFn, info: &RsTestInfo) -> TokenStream { + missed_arguments(test, info.data.items.iter()) + .chain(duplicate_arguments(info.data.items.iter())) + .chain(invalid_cases(&info.data)) + .chain(case_args_without_cases(&info.data)) + .map(|e| e.to_compile_error()) + .collect() +} + +pub(crate) fn fixture(test: &ItemFn, info: &FixtureInfo) -> TokenStream { + missed_arguments(test, info.data.items.iter()) + .chain(duplicate_arguments(info.data.items.iter())) + .chain(async_once(test, info)) + .chain(generics_once(test, info)) + .map(|e| e.to_compile_error()) + .collect() +} + +fn async_once<'a>(test: &'a ItemFn, info: &FixtureInfo) -> Errors<'a> { + match (test.sig.asyncness, info.attributes.get_once()) { + (Some(_asyncness), Some(once)) => Box::new(std::iter::once(syn::Error::new( + once.span(), + "Cannot apply #[once] to async fixture.", + ))), + _ => Box::new(std::iter::empty()), + } +} + +#[derive(Default)] +struct SearchImpl(bool); + +impl<'ast> Visit<'ast> for SearchImpl { + fn visit_type(&mut self, i: &'ast syn::Type) { + if self.0 { + return; + } + if let syn::Type::ImplTrait(_) = i { + self.0 = true + } + visit::visit_type(self, i); + } +} + +impl SearchImpl { + fn function_has_some_impl(f: &ItemFn) -> bool { + let mut s = SearchImpl::default(); + visit::visit_item_fn(&mut s, f); + s.0 + } +} + +fn has_some_generics(test: &ItemFn) -> bool { + !test.sig.generics.params.is_empty() || SearchImpl::function_has_some_impl(test) +} + +fn generics_once<'a>(test: &'a ItemFn, info: &FixtureInfo) -> Errors<'a> { + match (has_some_generics(test), info.attributes.get_once()) { + (true, Some(once)) => Box::new(std::iter::once(syn::Error::new( + once.span(), + "Cannot apply #[once] on generic fixture.", + ))), + _ => Box::new(std::iter::empty()), + } +} + +#[derive(Debug, Default)] +pub(crate) struct ErrorsVec(Vec); + +pub(crate) fn _merge_errors( + r1: Result, + r2: Result, +) -> Result<(R1, R2), ErrorsVec> { + match (r1, r2) { + (Ok(r1), Ok(r2)) => Ok((r1, r2)), + (Ok(_), Err(e)) | (Err(e), Ok(_)) => Err(e), + (Err(mut e1), Err(mut e2)) => { + e1.append(&mut e2); + Err(e1) + } + } +} + +macro_rules! merge_errors { + ($e:expr) => { + $e + }; + ($e:expr, $($es:expr), +) => { + crate::error::_merge_errors($e, merge_errors!($($es),*)) + }; +} + +macro_rules! composed_tuple { + ($i:ident) => { + $i + }; + ($i:ident, $($is:ident), +) => { + ($i, composed_tuple!($($is),*)) + }; +} + +impl std::ops::Deref for ErrorsVec { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for ErrorsVec { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for ErrorsVec { + fn from(errors: syn::Error) -> Self { + vec![errors].into() + } +} + +impl From> for ErrorsVec { + fn from(errors: Vec) -> Self { + Self(errors) + } +} + +impl From for Vec { + fn from(v: ErrorsVec) -> Self { + v.0 + } +} + +impl quote::ToTokens for ErrorsVec { + fn to_tokens(&self, tokens: &mut TokenStream) { + tokens.extend(self.0.iter().map(|e| e.to_compile_error())) + } +} + +impl From for proc_macro::TokenStream { + fn from(v: ErrorsVec) -> Self { + use quote::ToTokens; + v.into_token_stream().into() + } +} + +type Errors<'a> = Box + 'a>; + +fn missed_arguments<'a, I: MaybeIdent + Spanned + 'a>( + test: &'a ItemFn, + args: impl Iterator + 'a, +) -> Errors<'a> { + Box::new( + args.filter_map(|it| it.maybe_ident().map(|ident| (it, ident))) + .filter(move |(_, ident)| !fn_args_has_ident(test, ident)) + .map(|(missed, ident)| { + syn::Error::new( + missed.span(), + format!("Missed argument: '{ident}' should be a test function argument."), + ) + }), + ) +} + +fn duplicate_arguments<'a, I: MaybeIdent + Spanned + 'a>( + args: impl Iterator + 'a, +) -> Errors<'a> { + let mut used = HashMap::new(); + Box::new( + args.filter_map(|it| it.maybe_ident().map(|ident| (it, ident))) + .filter_map(move |(it, ident)| { + let name = ident.to_string(); + let is_duplicate = used.contains_key(&name); + used.insert(name, it); + match is_duplicate { + true => Some((it, ident)), + false => None, + } + }) + .map(|(duplicate, ident)| { + syn::Error::new( + duplicate.span(), + format!("Duplicate argument: '{ident}' is already defined."), + ) + }), + ) +} + +fn invalid_cases(params: &RsTestData) -> Errors { + let n_args = params.case_args().count(); + Box::new( + params + .cases() + .filter(move |case| case.args.len() != n_args) + .map(|case| { + syn::Error::new_spanned( + case, + "Wrong case signature: should match the given parameters list.", + ) + }), + ) +} + +fn case_args_without_cases(params: &RsTestData) -> Errors { + if !params.has_cases() { + return Box::new( + params + .case_args() + .map(|a| syn::Error::new(a.span(), "No cases for this argument.")), + ); + } + Box::new(std::iter::empty()) +} + +#[cfg(test)] +mod test { + use crate::test::{assert_eq, *}; + use rstest_test::assert_in; + + use super::*; + + #[rstest] + #[case::generics("fn f(){}")] + #[case::const_generics("fn f(){}")] + #[case::lifetimes("fn f<'a>(){}")] + #[case::use_impl_in_answer("fn f() -> impl Iterator{}")] + #[case::use_impl_in_argumets("fn f(it: impl Iterator){}")] + #[should_panic] + #[case::sanity_check_with_no_generics("fn f() {}")] + fn generics_once_should_return_error(#[case] f: &str) { + let f: ItemFn = f.ast(); + let info = FixtureInfo::default().with_once(); + + let errors = generics_once(&f, &info); + + let out = errors + .map(|e| format!("{:?}", e)) + .collect::>() + .join("-----------------------\n"); + + assert_in!(out, "Cannot apply #[once] on generic fixture."); + } + + #[rstest] + #[case::generics("fn f(){}")] + #[case::const_generics("fn f(){}")] + #[case::lifetimes("fn f<'a>(){}")] + #[case::use_impl_in_answer("fn f() -> impl Iterator{}")] + #[case::use_impl_in_argumets("fn f(it: impl Iterator){}")] + fn generics_once_should_not_return_if_no_once(#[case] f: &str) { + let f: ItemFn = f.ast(); + let info = FixtureInfo::default(); + + let errors = generics_once(&f, &info); + + assert_eq!(0, errors.count()); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..5f6c5a4 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,1078 @@ +#![cfg_attr(use_proc_macro_diagnostic, feature(proc_macro_diagnostic))] +extern crate proc_macro; + +// Test utility module +#[cfg(test)] +pub(crate) mod test; +#[cfg(test)] +use rstest_reuse; + +#[macro_use] +mod error; +mod parse; +mod refident; +mod render; +mod resolver; +mod utils; + +use syn::{parse_macro_input, ItemFn}; + +use crate::parse::{fixture::FixtureInfo, rstest::RsTestInfo}; +use parse::ExtendWithFunctionAttrs; +use quote::ToTokens; + +/// Define a fixture that you can use in all `rstest`'s test arguments. You should just mark your +/// function as `#[fixture]` and then use it as a test's argument. Fixture functions can also +/// use other fixtures. +/// +/// Let's see a trivial example: +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// fn twenty_one() -> i32 { 21 } +/// +/// #[fixture] +/// fn two() -> i32 { 2 } +/// +/// #[fixture] +/// fn injected(twenty_one: i32, two: i32) -> i32 { twenty_one * two } +/// +/// #[rstest] +/// fn the_test(injected: i32) { +/// assert_eq!(42, injected) +/// } +/// ``` +/// +/// If the fixture function is an [`async` function](#async) your fixture become an `async` +/// fixture. +/// +/// # Default values +/// +/// If you need to define argument default value you can use `#[default(expression)]` +/// argument's attribute: +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// fn injected( +/// #[default(21)] +/// twenty_one: i32, +/// #[default(1 + 1)] +/// two: i32 +/// ) -> i32 { twenty_one * two } +/// +/// #[rstest] +/// fn the_test(injected: i32) { +/// assert_eq!(42, injected) +/// } +/// ``` +/// The `expression` could be any valid rust expression, even an `async` block if you need. +/// Moreover, if the type implements `FromStr` trait you can use a literal string to build it. +/// +/// ``` +/// # use rstest::*; +/// # use std::net::SocketAddr; +/// # struct DbConnection {} +/// #[fixture] +/// fn db_connection( +/// #[default("127.0.0.1:9000")] +/// addr: SocketAddr +/// ) -> DbConnection { +/// // create connection +/// # DbConnection{} +/// } +/// ``` +/// +/// # Async +/// +/// If you need you can write `async` fixtures to use in your `async` tests. Simply use `async` +/// keyword for your function and the fixture become an `async` fixture. +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// async fn async_fixture() -> i32 { 42 } +/// +/// +/// #[rstest] +/// async fn the_test(#[future] async_fixture: i32) { +/// assert_eq!(42, async_fixture.await) +/// } +/// ``` +/// The `#[future]` argument attribute helps to remove the `impl Future` boilerplate. +/// In this case the macro expands it in: +/// +/// ``` +/// # use rstest::*; +/// # use std::future::Future; +/// # #[fixture] +/// # async fn async_fixture() -> i32 { 42 } +/// #[rstest] +/// async fn the_test(async_fixture: impl std::future::Future) { +/// assert_eq!(42, async_fixture.await) +/// } +/// ``` +/// If you need, you can use `#[future]` attribute also with an implicit lifetime reference +/// because the macro will replace the implicit lifetime with an explicit one. +/// +/// # Rename +/// +/// Sometimes you want to have long and descriptive name for your fixture but you prefer to use a much +/// shorter name for argument that represent it in your fixture or test. You can rename the fixture +/// using `#[from(short_name)]` attribute like following example: +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// fn long_and_boring_descriptive_name() -> i32 { 42 } +/// +/// #[rstest] +/// fn the_test(#[from(long_and_boring_descriptive_name)] short: i32) { +/// assert_eq!(42, short) +/// } +/// ``` +/// +/// # `#[once]` Fixture +/// +/// Expecially in integration tests there are cases where you need a fixture that is called just once +/// for every tests. `rstest` provides `#[once]` attribute for these cases. +/// +/// If you mark your fixture with this attribute, then `rstest` will compute a static reference to your +/// fixture result and return this reference to all your tests that need this fixture. +/// +/// In follow example all tests share the same reference to the `42` static value. +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// #[once] +/// fn once_fixture() -> i32 { 42 } +/// +/// // Take care!!! You need to use a reference to the fixture value +/// +/// #[rstest] +/// #[case(1)] +/// #[case(2)] +/// fn cases_tests(once_fixture: &i32, #[case] v: i32) { +/// // Take care!!! You need to use a reference to the fixture value +/// assert_eq!(&42, once_fixture) +/// } +/// +/// #[rstest] +/// fn single(once_fixture: &i32) { +/// assert_eq!(&42, once_fixture) +/// } +/// ``` +/// +/// There are some limitations when you use `#[once]` fixture. `rstest` forbid to use once fixture +/// for: +/// +/// - `async` function +/// - Generic function (both with generic types or use `impl` trait) +/// +/// Take care that the `#[once]` fixture value will **never be dropped**. +/// +/// # Partial Injection +/// +/// You can also partialy inject fixture dependency using `#[with(v1, v2, ..)]` attribute: +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// fn base() -> i32 { 1 } +/// +/// #[fixture] +/// fn first(base: i32) -> i32 { 1 * base } +/// +/// #[fixture] +/// fn second(base: i32) -> i32 { 2 * base } +/// +/// #[fixture] +/// fn injected(first: i32, #[with(3)] second: i32) -> i32 { first * second } +/// +/// #[rstest] +/// fn the_test(injected: i32) { +/// assert_eq!(-6, injected) +/// } +/// ``` +/// Note that injected value can be an arbitrary rust expression. `#[with(v1, ..., vn)]` +/// attribute will inject `v1, ..., vn` expression as fixture arguments: all remaining arguments +/// will be resolved as fixtures. +/// +/// Sometimes the return type cannot be infered so you must define it: For the few times you may +/// need to do it, you can use the `#[default(type)]`, `#[partial_n(type)]` function attribute +/// to define it: +/// +/// ``` +/// use rstest::*; +/// # use std::fmt::Debug; +/// +/// #[fixture] +/// pub fn i() -> u32 { +/// 42 +/// } +/// +/// #[fixture] +/// pub fn j() -> i32 { +/// -42 +/// } +/// +/// #[fixture] +/// #[default(impl Iterator)] +/// #[partial_1(impl Iterator)] +/// pub fn fx(i: I, j: J) -> impl Iterator { +/// std::iter::once((i, j)) +/// } +/// +/// #[rstest] +/// fn resolve_by_default(mut fx: impl Iterator) { +/// assert_eq!((42, -42), fx.next().unwrap()) +/// } +/// +/// #[rstest] +/// fn resolve_partial(#[with(42.0)] mut fx: impl Iterator) { +/// assert_eq!((42.0, -42), fx.next().unwrap()) +/// } +/// ``` +/// `partial_i` is the fixture used when you inject the first `i` arguments in test call. +/// +/// # Old _compact_ syntax +/// +/// There is also a compact form for all previous features. This will mantained for a long time +/// but for `fixture` I strongly recomand to migrate your code because you'll pay a little +/// verbosity but get back a more readable code. +/// +/// Follow the previous examples in old _compact_ syntax. +/// +/// ## Default +/// ``` +/// # use rstest::*; +/// #[fixture(twenty_one=21, two=2)] +/// fn injected(twenty_one: i32, two: i32) -> i32 { twenty_one * two } +/// ``` +/// +/// ## Rename +/// ``` +/// # use rstest::*; +/// #[fixture] +/// fn long_and_boring_descriptive_name() -> i32 { 42 } +/// +/// #[rstest(long_and_boring_descriptive_name as short)] +/// fn the_test(short: i32) { +/// assert_eq!(42, short) +/// } +/// ``` +/// +/// ## Partial Injection +/// ``` +/// # use rstest::*; +/// # #[fixture] +/// # fn base() -> i32 { 1 } +/// # +/// # #[fixture] +/// # fn first(base: i32) -> i32 { 1 * base } +/// # +/// # #[fixture] +/// # fn second(base: i32) -> i32 { 2 * base } +/// # +/// #[fixture(second(-3))] +/// fn injected(first: i32, second: i32) -> i32 { first * second } +/// ``` +/// ## Partial Type Injection +/// ``` +/// # use rstest::*; +/// # use std::fmt::Debug; +/// # +/// # #[fixture] +/// # pub fn i() -> u32 { +/// # 42 +/// # } +/// # +/// # #[fixture] +/// # pub fn j() -> i32 { +/// # -42 +/// # } +/// # +/// #[fixture(::default>::partial_1>)] +/// pub fn fx(i: I, j: J) -> impl Iterator { +/// std::iter::once((i, j)) +/// } +/// ``` + +#[proc_macro_attribute] +pub fn fixture( + args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let mut info: FixtureInfo = parse_macro_input!(args as FixtureInfo); + let mut fixture = parse_macro_input!(input as ItemFn); + + let extend_result = info.extend_with_function_attrs(&mut fixture); + + let mut errors = error::fixture(&fixture, &info); + + if let Err(attrs_errors) = extend_result { + attrs_errors.to_tokens(&mut errors); + } + + if errors.is_empty() { + render::fixture(fixture, info) + } else { + errors + } + .into() +} + +/// The attribute that you should use for your tests. Your +/// annotated function's arguments can be +/// [injected](attr.rstest.html#injecting-fixtures) with +/// [`[fixture]`](macro@fixture)s, provided by +/// [parametrized cases](attr.rstest.html#test-parametrized-cases) +/// or by [value lists](attr.rstest.html#values-lists). +/// +/// `rstest` attribute can be applied to _any_ function and you can customize its +/// parameters by using function and arguments attributes. +/// +/// Your test function can use generics, `impl` or `dyn` and like any kind of rust tests: +/// +/// - return results +/// - marked by `#[should_panic]` attribute +/// +/// If the test function is an [`async` function](#async) `rstest` will run all tests as `async` +/// tests. You can use it just with `async-std` and you should include `attributes` in +/// `async-std`'s features. +/// +/// In your test function you can: +/// +/// - [injecting fixtures](#injecting-fixtures) +/// - Generate [parametrized test cases](#test-parametrized-cases) +/// - Generate tests for each combination of [value lists](#values-lists) +/// +/// ## Injecting Fixtures +/// +/// The simplest case is write a test that can be injected with +/// [`[fixture]`](macro@fixture)s. You can just declare all used fixtures by passing +/// them as a function's arguments. This can help your test to be neat +/// and make your dependecy clear. +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// fn injected() -> i32 { 42 } +/// +/// #[rstest] +/// fn the_test(injected: i32) { +/// assert_eq!(42, injected) +/// } +/// ``` +/// +/// [`[rstest]`](macro@rstest) procedural macro will desugar it to something that isn't +/// so far from +/// +/// ``` +/// #[test] +/// fn the_test() { +/// let injected=injected(); +/// assert_eq!(42, injected) +/// } +/// ``` +/// +/// If you want to use long and descriptive names for your fixture but prefer to use +/// shorter names inside your tests you use rename feature described in +/// [fixture rename](attr.fixture.html#rename): +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// fn long_and_boring_descriptive_name() -> i32 { 42 } +/// +/// #[rstest] +/// fn the_test(#[from(long_and_boring_descriptive_name)] short: i32) { +/// assert_eq!(42, short) +/// } +/// ``` +/// +/// Sometimes is useful to have some parametes in your fixtures but your test would +/// override the fixture's default values in some cases. Like in +/// [fixture partial injection](attr.fixture.html#partial-injection) you use `#[with]` +/// attribute to indicate some fixture's arguments also in `rstest`. +/// +/// ``` +/// # struct User(String, u8); +/// # impl User { fn name(&self) -> &str {&self.0} } +/// use rstest::*; +/// +/// #[fixture] +/// fn user( +/// #[default("Alice")] name: impl AsRef, +/// #[default(22)] age: u8 +/// ) -> User { User(name.as_ref().to_owned(), age) } +/// +/// #[rstest] +/// fn check_user(#[with("Bob")] user: User) { +/// assert_eq("Bob", user.name()) +/// } +/// ``` +/// +/// ## Test Parametrized Cases +/// +/// If you would execute your test for a set of input data cases +/// you can define the arguments to use and the cases list. Let see +/// the classical Fibonacci example. In this case we would give the +/// `input` value and the `expected` result for a set of cases to test. +/// +/// ``` +/// use rstest::rstest; +/// +/// #[rstest] +/// #[case(0, 0)] +/// #[case(1, 1)] +/// #[case(2, 1)] +/// #[case(3, 2)] +/// #[case(4, 3)] +/// fn fibonacci_test(#[case] input: u32,#[case] expected: u32) { +/// assert_eq!(expected, fibonacci(input)) +/// } +/// +/// fn fibonacci(input: u32) -> u32 { +/// match input { +/// 0 => 0, +/// 1 => 1, +/// n => fibonacci(n - 2) + fibonacci(n - 1) +/// } +/// } +/// ``` +/// +/// `rstest` will produce 5 indipendent tests and not just one that +/// check every case. Every test can fail indipendently and `cargo test` +/// will give follow output: +/// +/// ```text +/// running 5 tests +/// test fibonacci_test::case_1 ... ok +/// test fibonacci_test::case_2 ... ok +/// test fibonacci_test::case_3 ... ok +/// test fibonacci_test::case_4 ... ok +/// test fibonacci_test::case_5 ... ok +/// +/// test result: ok. 5 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out +/// ``` +/// +/// The cases input values can be arbitrary Rust expresions that return the +/// argument type. +/// +/// ``` +/// use rstest::rstest; +/// +/// fn sum(a: usize, b: usize) -> usize { a + b } +/// +/// #[rstest] +/// #[case("foo", 3)] +/// #[case(String::from("foo"), 2 + 1)] +/// #[case(format!("foo"), sum(2, 1))] +/// fn test_len(#[case] s: impl AsRef,#[case] len: usize) { +/// assert_eq!(s.as_ref().len(), len); +/// } +/// ``` +/// +/// ### Magic Conversion +/// +/// You can use the magic conversion feature every time you would define a variable +/// where its type define `FromStr` trait: test will parse the string to build the value. +/// +/// ``` +/// # use rstest::rstest; +/// # use std::path::PathBuf; +/// # fn count_words(path: PathBuf) -> usize {0} +/// #[rstest] +/// #[case("resources/empty", 0)] +/// #[case("resources/divine_commedy", 101.698)] +/// fn test_count_words(#[case] path: PathBuf, #[case] expected: usize) { +/// assert_eq!(expected, count_words(path)) +/// } +/// ``` +/// +/// ### Optional case description +/// +/// Optionally you can give a _description_ to every case simple by follow `case` +/// with `::my_case_description` where `my_case_description` should be a a valid +/// Rust ident. +/// +/// ``` +/// # use rstest::*; +/// #[rstest] +/// #[case::zero_base_case(0, 0)] +/// #[case::one_base_case(1, 1)] +/// #[case(2, 1)] +/// #[case(3, 2)] +/// fn fibonacci_test(#[case] input: u32,#[case] expected: u32) { +/// assert_eq!(expected, fibonacci(input)) +/// } +/// +/// # fn fibonacci(input: u32) -> u32 { +/// # match input { +/// # 0 => 0, +/// # 1 => 1, +/// # n => fibonacci(n - 2) + fibonacci(n - 1) +/// # } +/// # } +/// ``` +/// +/// Outuput will be +/// ```text +/// running 4 tests +/// test fibonacci_test::case_1_zero_base_case ... ok +/// test fibonacci_test::case_2_one_base_case ... ok +/// test fibonacci_test::case_3 ... ok +/// test fibonacci_test::case_4 ... ok +/// +/// test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out +/// ``` +/// +/// ### Use specific `case` attributes +/// +/// Every function's attributes that preceding a `#[case]` attribute will +/// be used in this test case and all function's attributes that follow the +/// last `#[case]` attribute will mark all test cases. +/// +/// This feature can be use to mark just some cases as `should_panic` +/// and choose to have a fine grain on expected panic messages. +/// +/// In follow example we run 3 tests where the first pass without any +/// panic, in the second we catch a panic but we don't care about the message +/// and in the third one we also check the panic message. +/// +/// ``` +/// use rstest::rstest; +/// +/// #[rstest] +/// #[case::no_panic(0)] +/// #[should_panic] +/// #[case::panic(1)] +/// #[should_panic(expected="expected")] +/// #[case::panic_with_message(2)] +/// fn attribute_per_case(#[case] val: i32) { +/// match val { +/// 0 => assert!(true), +/// 1 => panic!("No catch"), +/// 2 => panic!("expected"), +/// _ => unreachable!(), +/// } +/// } +/// ``` +/// +/// Output: +/// +/// ```text +/// running 3 tests +/// test attribute_per_case::case_1_no_panic ... ok +/// test attribute_per_case::case_3_panic_with_message ... ok +/// test attribute_per_case::case_2_panic ... ok +/// +/// test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out +/// ``` +/// +/// To mark all your tests as `#[should_panic]` use: +/// +/// ``` +/// # use rstest::rstest; +/// #[rstest] +/// #[case(1)] +/// #[case(2)] +/// #[case(3)] +/// #[should_panic] +/// fn fail(#[case] v: u32) { assert_eq!(0, v) } +/// ``` +/// +/// ## Values Lists +/// +/// Another useful way to write a test and execute it for some values +/// is to use the values list syntax. This syntax can be usefull both +/// for a plain list and for testing all combination of input arguments. +/// +/// ``` +/// # use rstest::*; +/// # fn is_valid(input: &str) -> bool { true } +/// +/// #[rstest] +/// fn should_be_valid( +/// #[values("Jhon", "alice", "My_Name", "Zigy_2001")] +/// input: &str +/// ) { +/// assert!(is_valid(input)) +/// } +/// ``` +/// +/// or +/// +/// ``` +/// # use rstest::*; +/// # fn valid_user(name: &str, age: u8) -> bool { true } +/// +/// #[rstest] +/// fn should_accept_all_corner_cases( +/// #[values("J", "A", "A________________________________________21")] +/// name: &str, +/// #[values(14, 100)] +/// age: u8 +/// ) { +/// assert!(valid_user(name, age)) +/// } +/// ``` +/// where `cargo test` output is +/// +/// ```text +/// test should_accept_all_corner_cases::name_1___J__::age_2_100 ... ok +/// test should_accept_all_corner_cases::name_2___A__::age_1_14 ... ok +/// test should_accept_all_corner_cases::name_2___A__::age_2_100 ... ok +/// test should_accept_all_corner_cases::name_3___A________________________________________21__::age_2_100 ... ok +/// test should_accept_all_corner_cases::name_3___A________________________________________21__::age_1_14 ... ok +/// test should_accept_all_corner_cases::name_1___J__::age_1_14 ... ok +/// +/// test result: ok. 6 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s +/// ``` +/// Note that the test names contains the given expression sanitized into +/// a valid Rust identifier name. This should help to identify wich case fails. +/// +/// +/// Also value list implements the magic conversion feature: every time the value type +/// implements `FromStr` trait you can use a literal string to define it. +/// +/// ## Use Parametrize definition in more tests +/// +/// If you need to use a test list for more than one test you can use +/// [`rstest_reuse`](https://crates.io/crates/rstest_reuse) crate. +/// With this helper crate you can define a template and use it everywhere. +/// +/// ``` +/// # use rstest::rstest; +/// # use std::net::SocketAddr; +/// #[rstest] +/// fn given_port(#[values("1.2.3.4:8000", "4.3.2.1:8000", "127.0.0.1:8000")] addr: SocketAddr) { +/// assert_eq(8000, addr.port()) +/// } +/// ``` +/// +/// ```rust,ignore +/// use rstest::rstest; +/// use rstest_reuse::{self, *}; +/// +/// #[template] +/// #[rstest] +/// #[case(2, 2)] +/// #[case(4/2, 2)] +/// fn two_simple_cases(#[case] a: u32, #[case] b: u32) {} +/// +/// #[apply(two_simple_cases)] +/// fn it_works(#[case] a: u32,#[case] b: u32) { +/// assert!(a == b); +/// } +/// ``` +/// +/// See [`rstest_reuse`](https://crates.io/crates/rstest_reuse) for more dettails. +/// +/// ## Async +/// +/// `rstest` provides out of the box `async` support. Just mark your +/// test function as `async` and it'll use `#[async-std::test]` to +/// annotate it. This feature can be really useful to build async +/// parametric tests using a tidy syntax: +/// +/// ``` +/// use rstest::*; +/// # async fn async_sum(a: u32, b: u32) -> u32 { a + b } +/// +/// #[rstest] +/// #[case(5, 2, 3)] +/// #[should_panic] +/// #[case(42, 40, 1)] +/// async fn my_async_test(#[case] expected: u32, #[case] a: u32, #[case] b: u32) { +/// assert_eq!(expected, async_sum(a, b).await); +/// } +/// ``` +/// +/// Currently only `async-std` is supported out of the box. But if you need to use +/// another runtime that provide it's own test attribute (i.e. `tokio::test` or +/// `actix_rt::test`) you can use it in your `async` test like described in +/// [Inject Test Attribute](attr.rstest.html#inject-test-attribute). +/// +/// To use this feature, you need to enable `attributes` in the `async-std` +/// features list in your `Cargo.toml`: +/// +/// ```toml +/// async-std = { version = "1.5", features = ["attributes"] } +/// ``` +/// +/// If your test input is an async value (fixture or test parameter) you can use `#[future]` +/// attribute to remove `impl Future` boilerplate and just use `T`: +/// +/// ``` +/// use rstest::*; +/// #[fixture] +/// async fn base() -> u32 { 42 } +/// +/// #[rstest] +/// #[case(21, async { 2 })] +/// #[case(6, async { 7 })] +/// async fn my_async_test(#[future] base: u32, #[case] expected: u32, #[future] #[case] div: u32) { +/// assert_eq!(expected, base.await / div.await); +/// } +/// ``` +/// +/// As you noted you should `.await` all _future_ values and this some times can be really boring. +/// In this case you can use `#[timeout(awt)]` to _awaiting_ an input or annotating your function +/// with `#[awt]` attributes to globally `.await` all your _future_ inputs. Previous code can be +/// simplified like follow: +/// +/// ``` +/// use rstest::*; +/// # #[fixture] +/// # async fn base() -> u32 { 42 } +/// +/// #[rstest] +/// #[case(21, async { 2 })] +/// #[case(6, async { 7 })] +/// #[awt] +/// async fn global(#[future] base: u32, #[case] expected: u32, #[future] #[case] div: u32) { +/// assert_eq!(expected, base / div); +/// } +/// +/// #[rstest] +/// #[case(21, async { 2 })] +/// #[case(6, async { 7 })] +/// async fn single(#[future] base: u32, #[case] expected: u32, #[future(awt)] #[case] div: u32) { +/// assert_eq!(expected, base.await / div); +/// } +/// ``` +/// +/// ### Test `#[timeout()]` +/// +/// You can define an execution timeout for your tests with `#[timeout()]` attribute. Timeouts +/// works both for sync and async tests and is runtime agnostic. `#[timeout()]` take an +/// expression that should return a `std::time::Duration`. Follow a simple async example: +/// +/// ```rust +/// use rstest::*; +/// use std::time::Duration; +/// +/// async fn delayed_sum(a: u32, b: u32,delay: Duration) -> u32 { +/// async_std::task::sleep(delay).await; +/// a + b +/// } +/// +/// #[rstest] +/// #[timeout(Duration::from_millis(80))] +/// async fn single_pass() { +/// assert_eq!(4, delayed_sum(2, 2, ms(10)).await); +/// } +/// ``` +/// In this case test pass because the delay is just 10 milliseconds and timeout is +/// 80 milliseconds. +/// +/// You can use `timeout` attribute like any other attibute in your tests and you can +/// override a group timeout with a test specific one. In the follow example we have +/// 3 tests where first and third use 100 millis but the second one use 10 millis. +/// Another valuable point in this example is to use an expression to compute the +/// duration. +/// +/// ```rust +/// # use rstest::*; +/// # use std::time::Duration; +/// # +/// # async fn delayed_sum(a: u32, b: u32,delay: Duration) -> u32 { +/// # async_std::task::sleep(delay).await; +/// # a + b +/// # } +/// fn ms(ms: u32) -> Duration { +/// Duration::from_millis(ms.into()) +/// } +/// +/// #[rstest] +/// #[case::pass(ms(1), 4)] +/// #[timeout(ms(10))] +/// #[case::fail_timeout(ms(60), 4)] +/// #[case::fail_value(ms(1), 5)] +/// #[timeout(ms(100))] +/// async fn group_one_timeout_override(#[case] delay: Duration, #[case] expected: u32) { +/// assert_eq!(expected, delayed_sum(2, 2, delay).await); +/// } +/// ``` +/// +/// If you want to use `timeout` for `async` test you need to use `async-timeout` +/// feature (enabled by default). +/// +/// ## Inject Test Attribute +/// +/// If you would like to use another `test` attribute for your test you can simply +/// indicate it in your test function's attributes. For instance if you want +/// to test some async function with use `actix_rt::test` attribute you can just write: +/// +/// ``` +/// use rstest::*; +/// use actix_rt; +/// use std::future::Future; +/// +/// #[rstest] +/// #[case(2, async { 4 })] +/// #[case(21, async { 42 })] +/// #[actix_rt::test] +/// async fn my_async_test(#[case] a: u32, #[case] #[future] result: u32) { +/// assert_eq!(2 * a, result.await); +/// } +/// ``` +/// Just the attributes that ends with `test` (last path segment) can be injected: +/// in this case the `#[actix_rt::test]` attribute will replace the standard `#[test]` +/// attribute. +/// +/// ## Putting all Together +/// +/// All these features can be used together with a mixture of fixture variables, +/// fixed cases and bunch of values. For instance, you might need two +/// test cases which test for panics, one for a logged in user and one for a guest user. +/// +/// ```rust +/// # enum User { Guest, Logged, } +/// # impl User { fn logged(_n: &str, _d: &str, _w: &str, _s: &str) -> Self { Self::Logged } } +/// # struct Item {} +/// # trait Repository { fn find_items(&self, user: &User, query: &str) -> Result, String> { Err("Invalid query error".to_owned()) } } +/// # #[derive(Default)] struct InMemoryRepository {} +/// # impl Repository for InMemoryRepository {} +/// +/// use rstest::*; +/// +/// #[fixture] +/// fn repository() -> InMemoryRepository { +/// let mut r = InMemoryRepository::default(); +/// // fill repository with some data +/// r +/// } +/// +/// #[fixture] +/// fn alice() -> User { +/// User::logged("Alice", "2001-10-04", "London", "UK") +/// } +/// +/// #[rstest] +/// #[case::authed_user(alice())] // We can use `fixture` also as standard function +/// #[case::guest(User::Guest)] // We can give a name to every case : `guest` in this case +/// #[should_panic(expected = "Invalid query error")] // We whould test a panic +/// fn should_be_invalid_query_error( +/// repository: impl Repository, +/// #[case] user: User, +/// #[values(" ", "^%$some#@invalid!chars", ".n.o.d.o.t.s.")] query: &str, +/// query: &str +/// ) { +/// repository.find_items(&user, query).unwrap(); +/// } +/// ``` +/// +/// ## Trace Input Arguments +/// +/// Sometimes can be very helpful to print all test's input arguments. To +/// do it you can use the `#[trace]` function attribute that you can apply +/// to all cases or just to some of them. +/// +/// ``` +/// use rstest::*; +/// +/// #[fixture] +/// fn injected() -> i32 { 42 } +/// +/// #[rstest] +/// #[trace] +/// fn the_test(injected: i32) { +/// assert_eq!(42, injected) +/// } +/// ``` +/// +/// Will print an output like +/// +/// ```bash +/// Testing started at 14.12 ... +/// ------------ TEST ARGUMENTS ------------ +/// injected = 42 +/// -------------- TEST START -------------- +/// +/// +/// Expected :42 +/// Actual :43 +/// ``` +/// But +/// ``` +/// # use rstest::*; +/// #[rstest] +/// #[case(1)] +/// #[trace] +/// #[case(2)] +/// fn the_test(#[case] v: i32) { +/// assert_eq!(0, v) +/// } +/// ``` +/// will trace just `case_2` input arguments. +/// +/// If you want to trace input arguments but skip some of them that don't +/// implement the `Debug` trait, you can also use the +/// `#[notrace]` argument attribute to skip them: +/// +/// ``` +/// # use rstest::*; +/// # struct Xyz; +/// # struct NoSense; +/// #[rstest] +/// #[trace] +/// fn the_test(injected: i32, #[notrace] xyz: Xyz, #[notrace] have_no_sense: NoSense) { +/// assert_eq!(42, injected) +/// } +/// ``` +/// # Old _compact_ syntax +/// +/// `rstest` support also a syntax where all options and configuration can be write as +/// `rstest` attribute arguments. This syntax is a little less verbose but make +/// composition harder: for istance try to add some cases to a `rstest_reuse` template +/// is really hard. +/// +/// So we'll continue to maintain the old syntax for a long time but we strongly encourage +/// to switch your test in the new form. +/// +/// Anyway, here we recall this syntax and rewrite the previous example in the _compact_ form. +/// +/// ```text +/// rstest( +/// arg_1, +/// ..., +/// arg_n[,] +/// [::attribute_1[:: ... [::attribute_k]]] +/// ) +/// ``` +/// Where: +/// +/// - `arg_i` could be one of the follow +/// - `ident` that match to one of function arguments for parametrized cases +/// - `case[::description](v1, ..., vl)` a test case +/// - `fixture(v1, ..., vl) [as argument_name]` where fixture is the injected +/// fixture and argument_name (default use fixture) is one of function arguments +/// that and `v1, ..., vl` is a partial list of fixture's arguments +/// - `ident => [v1, ..., vl]` where `ident` is one of function arguments and +/// `v1, ..., vl` is a list of values for ident +/// - `attribute_j` a test attribute like `trace` or `notrace` +/// +/// ## Fixture Arguments +/// +/// ``` +/// # struct User(String, u8); +/// # impl User { fn name(&self) -> &str {&self.0} } +/// # use rstest::*; +/// # +/// # #[fixture] +/// # fn user( +/// # #[default("Alice")] name: impl AsRef, +/// # #[default(22)] age: u8 +/// # ) -> User { User(name.as_ref().to_owned(), age) } +/// # +/// #[rstest(user("Bob"))] +/// fn check_user(user: User) { +/// assert_eq("Bob", user.name()) +/// } +/// ``` +/// +/// ## Fixture Rename +/// ``` +/// # use rstest::*; +/// #[fixture] +/// fn long_and_boring_descriptive_name() -> i32 { 42 } +/// +/// #[rstest(long_and_boring_descriptive_name as short)] +/// fn the_test(short: i32) { +/// assert_eq!(42, short) +/// } +/// ``` +/// +/// ## Parametrized +/// +/// ``` +/// # use rstest::*; +/// #[rstest(input, expected, +/// case::zero_base_case(0, 0), +/// case::one_base_case(1, 1), +/// case(2, 1), +/// case(3, 2), +/// #[should_panic] +/// case(4, 42) +/// )] +/// fn fibonacci_test(input: u32, expected: u32) { +/// assert_eq!(expected, fibonacci(input)) +/// } +/// +/// # fn fibonacci(input: u32) -> u32 { +/// # match input { +/// # 0 => 0, +/// # 1 => 1, +/// # n => fibonacci(n - 2) + fibonacci(n - 1) +/// # } +/// # } +/// ``` +/// +/// ## Values Lists +/// +/// ``` +/// # use rstest::*; +/// # fn is_valid(input: &str) -> bool { true } +/// +/// #[rstest( +/// input => ["Jhon", "alice", "My_Name", "Zigy_2001"] +/// )] +/// fn should_be_valid(input: &str) { +/// assert!(is_valid(input)) +/// } +/// ``` +/// +/// ## `trace` and `notrace` +/// +/// ``` +/// # use rstest::*; +/// # struct Xyz; +/// # struct NoSense; +/// #[rstest(::trace::notrace(xzy, have_no_sense))] +/// fn the_test(injected: i32, xyz: Xyz, have_no_sense: NoSense) { +/// assert_eq!(42, injected) +/// } +/// ``` +/// +#[proc_macro_attribute] +pub fn rstest( + args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let mut test = parse_macro_input!(input as ItemFn); + let mut info = parse_macro_input!(args as RsTestInfo); + + let extend_result = info.extend_with_function_attrs(&mut test); + + let mut errors = error::rstest(&test, &info); + + if let Err(attrs_errors) = extend_result { + attrs_errors.to_tokens(&mut errors); + } + + if errors.is_empty() { + if info.data.has_list_values() { + render::matrix(test, info) + } else if info.data.has_cases() { + render::parametrize(test, info) + } else { + render::single(test, info) + } + } else { + errors + } + .into() +} diff --git a/src/parse/expressions.rs b/src/parse/expressions.rs new file mode 100644 index 0000000..152662b --- /dev/null +++ b/src/parse/expressions.rs @@ -0,0 +1,28 @@ +use syn::{ + parse::{Parse, ParseStream, Result}, + Expr, Token, +}; + +pub(crate) struct Expressions(Vec); + +impl Expressions { + pub(crate) fn take(self) -> Vec { + self.0 + } +} + +impl Parse for Expressions { + fn parse(input: ParseStream) -> Result { + let values = input + .parse_terminated::<_, Token![,]>(Parse::parse)? + .into_iter() + .collect(); + Ok(Self(values)) + } +} + +impl From for Vec { + fn from(expressions: Expressions) -> Self { + expressions.0 + } +} \ No newline at end of file diff --git a/src/parse/fixture.rs b/src/parse/fixture.rs new file mode 100644 index 0000000..07e3a8a --- /dev/null +++ b/src/parse/fixture.rs @@ -0,0 +1,748 @@ +/// `fixture`'s related data and parsing +use syn::{ + parse::{Parse, ParseStream}, + parse_quote, + visit_mut::VisitMut, + Expr, FnArg, Ident, ItemFn, Token, +}; + +use super::{ + arguments::ArgumentsInfo, extract_argument_attrs, extract_default_return_type, + extract_defaults, extract_fixtures, extract_partials_return_type, future::extract_futures, + parse_vector_trailing_till_double_comma, Attributes, ExtendWithFunctionAttrs, Fixture, +}; +use crate::{ + error::ErrorsVec, + parse::extract_once, + refident::{MaybeIdent, RefIdent}, + utils::attr_is, +}; +use crate::{parse::Attribute, utils::attr_in}; +use proc_macro2::TokenStream; +use quote::{format_ident, ToTokens}; + +#[derive(PartialEq, Debug, Default)] +pub(crate) struct FixtureInfo { + pub(crate) data: FixtureData, + pub(crate) attributes: FixtureModifiers, + pub(crate) arguments: ArgumentsInfo, +} + +impl Parse for FixtureModifiers { + fn parse(input: ParseStream) -> syn::Result { + Ok(input.parse::()?.into()) + } +} + +impl Parse for FixtureInfo { + fn parse(input: ParseStream) -> syn::Result { + Ok(if input.is_empty() { + Default::default() + } else { + Self { + data: input.parse()?, + attributes: input + .parse::() + .or_else(|_| Ok(Default::default())) + .and_then(|_| input.parse())?, + arguments: Default::default(), + } + }) + } +} + +impl ExtendWithFunctionAttrs for FixtureInfo { + fn extend_with_function_attrs( + &mut self, + item_fn: &mut ItemFn, + ) -> std::result::Result<(), ErrorsVec> { + let composed_tuple!( + fixtures, + defaults, + default_return_type, + partials_return_type, + once, + futures + ) = merge_errors!( + extract_fixtures(item_fn), + extract_defaults(item_fn), + extract_default_return_type(item_fn), + extract_partials_return_type(item_fn), + extract_once(item_fn), + extract_futures(item_fn) + )?; + self.data.items.extend( + fixtures + .into_iter() + .map(|f| f.into()) + .chain(defaults.into_iter().map(|d| d.into())), + ); + if let Some(return_type) = default_return_type { + self.attributes.set_default_return_type(return_type); + } + for (id, return_type) in partials_return_type { + self.attributes.set_partial_return_type(id, return_type); + } + if let Some(ident) = once { + self.attributes.set_once(ident) + }; + let (futures, global_awt) = futures; + self.arguments.set_global_await(global_awt); + self.arguments.set_futures(futures.into_iter()); + Ok(()) + } +} + +fn parse_attribute_args_just_once<'a, T: Parse>( + attributes: impl Iterator, + name: &str, +) -> (Option, Vec) { + let mut errors = Vec::new(); + let val = attributes + .filter(|&a| attr_is(a, name)) + .map(|a| (a, a.parse_args::())) + .fold(None, |first, (a, res)| match (first, res) { + (None, Ok(parsed)) => Some(parsed), + (first, Err(err)) => { + errors.push(err); + first + } + (first, _) => { + errors.push(syn::Error::new_spanned( + a, + format!( + "You cannot use '{name}' attribute more than once for the same argument" + ), + )); + first + } + }); + (val, errors) +} + +/// Simple struct used to visit function attributes and extract Fixtures and +/// eventualy parsing errors +#[derive(Default)] +pub(crate) struct FixturesFunctionExtractor(pub(crate) Vec, pub(crate) Vec); + +impl VisitMut for FixturesFunctionExtractor { + fn visit_fn_arg_mut(&mut self, node: &mut FnArg) { + if let FnArg::Typed(ref mut arg) = node { + let name = match arg.pat.as_ref() { + syn::Pat::Ident(ident) => ident.ident.clone(), + _ => return, + }; + let (extracted, remain): (Vec<_>, Vec<_>) = std::mem::take(&mut arg.attrs) + .into_iter() + .partition(|attr| attr_in(attr, &["with", "from"])); + arg.attrs = remain; + + let (pos, errors) = parse_attribute_args_just_once(extracted.iter(), "with"); + self.1.extend(errors.into_iter()); + let (resolve, errors) = parse_attribute_args_just_once(extracted.iter(), "from"); + self.1.extend(errors.into_iter()); + if pos.is_some() || resolve.is_some() { + self.0 + .push(Fixture::new(name, resolve, pos.unwrap_or_default())) + } + } + } +} + +/// Simple struct used to visit function attributes and extract fixture default values info and +/// eventualy parsing errors +#[derive(Default)] +pub(crate) struct DefaultsFunctionExtractor( + pub(crate) Vec, + pub(crate) Vec, +); + +impl VisitMut for DefaultsFunctionExtractor { + fn visit_fn_arg_mut(&mut self, node: &mut FnArg) { + for r in extract_argument_attrs( + node, + |a| attr_is(a, "default"), + |a, name| { + a.parse_args::() + .map(|e| ArgumentValue::new(name.clone(), e)) + }, + ) { + match r { + Ok(value) => self.0.push(value), + Err(err) => self.1.push(err), + } + } + } +} + +#[derive(PartialEq, Debug, Default)] +pub(crate) struct FixtureData { + pub items: Vec, +} + +impl FixtureData { + pub(crate) fn fixtures(&self) -> impl Iterator { + self.items.iter().filter_map(|f| match f { + FixtureItem::Fixture(ref fixture) => Some(fixture), + _ => None, + }) + } + + pub(crate) fn values(&self) -> impl Iterator { + self.items.iter().filter_map(|f| match f { + FixtureItem::ArgumentValue(ref value) => Some(value.as_ref()), + _ => None, + }) + } +} + +impl Parse for FixtureData { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(Token![::]) { + Ok(Default::default()) + } else { + Ok(Self { + items: parse_vector_trailing_till_double_comma::<_, Token![,]>(input)?, + }) + } + } +} + +#[derive(PartialEq, Debug)] +pub(crate) struct ArgumentValue { + pub name: Ident, + pub expr: Expr, +} + +impl ArgumentValue { + pub(crate) fn new(name: Ident, expr: Expr) -> Self { + Self { name, expr } + } +} + +#[derive(PartialEq, Debug)] +pub(crate) enum FixtureItem { + Fixture(Fixture), + ArgumentValue(Box), +} + +impl From for FixtureItem { + fn from(f: Fixture) -> Self { + FixtureItem::Fixture(f) + } +} + +impl Parse for FixtureItem { + fn parse(input: ParseStream) -> syn::Result { + if input.peek2(Token![=]) { + input.parse::().map(|v| v.into()) + } else { + input.parse::().map(|v| v.into()) + } + } +} + +impl RefIdent for FixtureItem { + fn ident(&self) -> &Ident { + match self { + FixtureItem::Fixture(Fixture { ref name, .. }) => name, + FixtureItem::ArgumentValue(ref av) => &av.name, + } + } +} + +impl ToTokens for FixtureItem { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.ident().to_tokens(tokens) + } +} + +impl From for FixtureItem { + fn from(av: ArgumentValue) -> Self { + FixtureItem::ArgumentValue(Box::new(av)) + } +} + +impl Parse for ArgumentValue { + fn parse(input: ParseStream) -> syn::Result { + let name = input.parse()?; + let _eq: Token![=] = input.parse()?; + let expr = input.parse()?; + Ok(ArgumentValue::new(name, expr)) + } +} + +wrap_attributes!(FixtureModifiers); + +impl FixtureModifiers { + pub(crate) const DEFAULT_RET_ATTR: &'static str = "default"; + pub(crate) const PARTIAL_RET_ATTR: &'static str = "partial_"; + + pub(crate) fn extract_default_type(&self) -> Option { + self.extract_type(Self::DEFAULT_RET_ATTR) + } + + pub(crate) fn extract_partial_type(&self, pos: usize) -> Option { + self.extract_type(&format!("{}{}", Self::PARTIAL_RET_ATTR, pos)) + } + + pub(crate) fn set_default_return_type(&mut self, return_type: syn::Type) { + self.inner.attributes.push(Attribute::Type( + format_ident!("{}", Self::DEFAULT_RET_ATTR), + Box::new(return_type), + )) + } + + pub(crate) fn set_partial_return_type(&mut self, id: usize, return_type: syn::Type) { + self.inner.attributes.push(Attribute::Type( + format_ident!("{}{}", Self::PARTIAL_RET_ATTR, id), + Box::new(return_type), + )) + } + + pub(crate) fn set_once(&mut self, once: syn::Ident) { + self.inner.attributes.push(Attribute::Attr(once)) + } + + pub(crate) fn get_once(&self) -> Option<&Ident> { + self.iter() + .find(|&a| a == &Attribute::Attr(format_ident!("once"))) + .and_then(|a| a.maybe_ident()) + } + + pub(crate) fn is_once(&self) -> bool { + self.get_once().is_some() + } + + fn extract_type(&self, attr_name: &str) -> Option { + self.iter() + .filter_map(|m| match m { + Attribute::Type(name, t) if name == attr_name => Some(parse_quote! { -> #t}), + _ => None, + }) + .next() + } +} + +#[cfg(test)] +mod should { + use super::*; + use crate::test::{assert_eq, *}; + + mod parse { + use super::{assert_eq, *}; + + fn parse_fixture>(fixture_data: S) -> FixtureInfo { + parse_meta(fixture_data) + } + + #[test] + fn happy_path() { + let data = parse_fixture( + r#"my_fixture(42, "other"), other(vec![42]), value=42, other_value=vec![1.0] + :: trace :: no_trace(some)"#, + ); + + let expected = FixtureInfo { + data: vec![ + fixture("my_fixture", &["42", r#""other""#]).into(), + fixture("other", &["vec![42]"]).into(), + arg_value("value", "42").into(), + arg_value("other_value", "vec![1.0]").into(), + ] + .into(), + attributes: Attributes { + attributes: vec![ + Attribute::attr("trace"), + Attribute::tagged("no_trace", vec!["some"]), + ], + } + .into(), + arguments: Default::default(), + }; + + assert_eq!(expected, data); + } + + #[test] + fn some_literals() { + let args_expressions = literal_expressions_str(); + let fixture = parse_fixture(&format!("my_fixture({})", args_expressions.join(", "))); + let args = fixture.data.fixtures().next().unwrap().positional.clone(); + + assert_eq!(to_args!(args_expressions), args.0); + } + + #[test] + fn empty_fixtures() { + let data = parse_fixture(r#"::trace::no_trace(some)"#); + + let expected = FixtureInfo { + attributes: Attributes { + attributes: vec![ + Attribute::attr("trace"), + Attribute::tagged("no_trace", vec!["some"]), + ], + } + .into(), + ..Default::default() + }; + + assert_eq!(expected, data); + } + + #[test] + fn empty_attributes() { + let data = parse_fixture(r#"my_fixture(42, "other")"#); + + let expected = FixtureInfo { + data: vec![fixture("my_fixture", &["42", r#""other""#]).into()].into(), + ..Default::default() + }; + + assert_eq!(expected, data); + } + + #[rstest] + #[case("first(42),", 1)] + #[case("first(42), second=42,", 2)] + #[case(r#"fixture(42, "other"), :: trace"#, 1)] + #[case(r#"second=42, fixture(42, "other"), :: trace"#, 2)] + fn should_accept_trailing_comma(#[case] input: &str, #[case] expected: usize) { + let info: FixtureInfo = input.ast(); + + assert_eq!( + expected, + info.data.fixtures().count() + info.data.values().count() + ); + } + } +} + +#[cfg(test)] +mod extend { + use super::*; + use crate::test::{assert_eq, *}; + use syn::ItemFn; + + mod should { + use super::{assert_eq, *}; + + #[test] + fn use_with_attributes() { + let to_parse = r#" + fn my_fix(#[with(2)] f1: &str, #[with(vec![1,2], "s")] f2: u32) {} + "#; + + let mut item_fn: ItemFn = to_parse.ast(); + let mut info = FixtureInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + let expected = FixtureInfo { + data: vec![ + fixture("f1", &["2"]).into(), + fixture("f2", &["vec![1,2]", r#""s""#]).into(), + ] + .into(), + ..Default::default() + }; + + assert!(!format!("{:?}", item_fn).contains("with")); + assert_eq!(expected, info); + } + + #[test] + fn rename_with_attributes() { + let mut item_fn = r#" + fn test_fn( + #[from(long_fixture_name)] + #[with(42, "other")] short: u32, + #[from(simple)] + s: &str, + no_change: i32) { + } + "# + .ast(); + + let expected = FixtureInfo { + data: vec![ + fixture("short", &["42", r#""other""#]) + .with_resolve("long_fixture_name") + .into(), + fixture("s", &[]).with_resolve("simple").into(), + ] + .into(), + ..Default::default() + }; + + let mut data = FixtureInfo::default(); + data.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!(expected, data); + } + + #[test] + fn use_default_values_attributes() { + let to_parse = r#" + fn my_fix(#[default(2)] f1: &str, #[default((vec![1,2], "s"))] f2: (Vec, &str)) {} + "#; + + let mut item_fn: ItemFn = to_parse.ast(); + let mut info = FixtureInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + let expected = FixtureInfo { + data: vec![ + arg_value("f1", "2").into(), + arg_value("f2", r#"(vec![1,2], "s")"#).into(), + ] + .into(), + ..Default::default() + }; + + assert!(!format!("{:?}", item_fn).contains("default")); + assert_eq!(expected, info); + } + + #[test] + fn find_default_return_type() { + let mut item_fn: ItemFn = r#" + #[simple] + #[first(comp)] + #[second::default] + #[default(impl Iterator)] + #[last::more] + fn my_fix(f1: I, f2: J) -> impl Iterator {} + "# + .ast(); + + let mut info = FixtureInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!( + info.attributes.extract_default_type(), + Some(parse_quote! { -> impl Iterator }) + ); + assert_eq!( + attrs("#[simple]#[first(comp)]#[second::default]#[last::more]"), + item_fn.attrs + ); + } + + #[test] + fn find_partials_return_type() { + let mut item_fn: ItemFn = r#" + #[simple] + #[first(comp)] + #[second::default] + #[partial_1(impl Iterator)] + #[partial_2(impl Iterator)] + #[last::more] + fn my_fix(f1: I, f2: J, f3: K) -> impl Iterator {} + "# + .ast(); + + let mut info = FixtureInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!( + info.attributes.extract_partial_type(1), + Some(parse_quote! { -> impl Iterator }) + ); + assert_eq!( + info.attributes.extract_partial_type(2), + Some(parse_quote! { -> impl Iterator }) + ); + assert_eq!( + attrs("#[simple]#[first(comp)]#[second::default]#[last::more]"), + item_fn.attrs + ); + } + + #[test] + fn find_once_attribute() { + let mut item_fn: ItemFn = r#" + #[simple] + #[first(comp)] + #[second::default] + #[once] + #[last::more] + fn my_fix(f1: I, f2: J, f3: K) -> impl Iterator {} + "# + .ast(); + + let mut info = FixtureInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert!(info.attributes.is_once()); + } + + #[test] + fn no_once_attribute() { + let mut item_fn: ItemFn = r#" + fn my_fix(f1: I, f2: J, f3: K) -> impl Iterator {} + "# + .ast(); + + let mut info = FixtureInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert!(!info.attributes.is_once()); + } + + #[rstest] + fn extract_future() { + let mut item_fn = "fn f(#[future] a: u32, b: u32) {}".ast(); + let expected = "fn f(a: u32, b: u32) {}".ast(); + + let mut info = FixtureInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!(item_fn, expected); + assert!(info.arguments.is_future(&ident("a"))); + assert!(!info.arguments.is_future(&ident("b"))); + } + + mod raise_error { + use super::{assert_eq, *}; + use rstest_test::assert_in; + + #[test] + fn for_invalid_expressions() { + let mut item_fn: ItemFn = r#" + fn my_fix(#[with(valid)] f1: &str, #[with(with(,.,))] f2: u32, #[with(with(use))] f3: u32) {} + "# + .ast(); + + let errors = FixtureInfo::default() + .extend_with_function_attrs(&mut item_fn) + .unwrap_err(); + + assert_eq!(2, errors.len()); + } + + #[test] + fn for_invalid_default_type() { + let mut item_fn: ItemFn = r#" + #[default(notype)] + fn my_fix() -> I {} + "# + .ast(); + + let errors = FixtureInfo::default() + .extend_with_function_attrs(&mut item_fn) + .unwrap_err(); + + assert_eq!(1, errors.len()); + } + + #[test] + fn with_used_more_than_once() { + let mut item_fn: ItemFn = r#" + fn my_fix(#[with(1)] #[with(2)] fixture1: &str, #[with(1)] #[with(2)] #[with(3)] fixture2: &str) {} + "# + .ast(); + + let errors = FixtureInfo::default() + .extend_with_function_attrs(&mut item_fn) + .err() + .unwrap_or_default(); + + assert_eq!(3, errors.len()); + } + + #[test] + fn from_used_more_than_once() { + let mut item_fn: ItemFn = r#" + fn my_fix(#[from(a)] #[from(b)] fixture1: &str, #[from(c)] #[from(d)] #[from(e)] fixture2: &str) {} + "# + .ast(); + + let errors = FixtureInfo::default() + .extend_with_function_attrs(&mut item_fn) + .err() + .unwrap_or_default(); + + assert_eq!(3, errors.len()); + } + + #[test] + fn if_once_is_defined_more_than_once() { + let mut item_fn: ItemFn = r#" + #[once] + #[once] + fn my_fix() -> I {} + "# + .ast(); + + let mut info = FixtureInfo::default(); + + let error = info.extend_with_function_attrs(&mut item_fn).unwrap_err(); + + assert_in!( + format!("{:?}", error).to_lowercase(), + "cannot use #[once] more than once" + ); + } + + #[test] + fn if_default_is_defined_more_than_once() { + let mut item_fn: ItemFn = r#" + #[default(u32)] + #[default(u32)] + fn my_fix() -> I {} + "# + .ast(); + + let mut info = FixtureInfo::default(); + + let error = info.extend_with_function_attrs(&mut item_fn).unwrap_err(); + + assert_in!( + format!("{:?}", error).to_lowercase(), + "cannot use default more than once" + ); + } + + #[test] + fn for_invalid_partial_type() { + let mut item_fn: ItemFn = r#" + #[partial_1(notype)] + fn my_fix(x: I, y: u32) -> I {} + "# + .ast(); + + let errors = FixtureInfo::default() + .extend_with_function_attrs(&mut item_fn) + .unwrap_err(); + + assert_eq!(1, errors.len()); + } + + #[test] + fn if_partial_is_not_correct() { + let mut item_fn: ItemFn = r#" + #[partial_not_a_number(u32)] + fn my_fix(f1: I, f2: &str) -> I {} + "# + .ast(); + + let mut info = FixtureInfo::default(); + + let error = info.extend_with_function_attrs(&mut item_fn).unwrap_err(); + + assert_in!( + format!("{:?}", error).to_lowercase(), + "invalid partial syntax" + ); + } + } + } +} diff --git a/src/parse/future.rs b/src/parse/future.rs new file mode 100644 index 0000000..45eec12 --- /dev/null +++ b/src/parse/future.rs @@ -0,0 +1,260 @@ +use quote::{format_ident, ToTokens}; +use syn::{visit_mut::VisitMut, FnArg, Ident, ItemFn, PatType, Type}; + +use crate::{error::ErrorsVec, refident::MaybeType, utils::attr_is}; + +use super::{arguments::FutureArg, extract_argument_attrs}; + +pub(crate) fn extract_futures( + item_fn: &mut ItemFn, +) -> Result<(Vec<(Ident, FutureArg)>, bool), ErrorsVec> { + let mut extractor = FutureFunctionExtractor::default(); + extractor.visit_item_fn_mut(item_fn); + extractor.take() +} + +pub(crate) trait MaybeFutureImplType { + fn as_future_impl_type(&self) -> Option<&Type>; + + fn as_mut_future_impl_type(&mut self) -> Option<&mut Type>; +} + +impl MaybeFutureImplType for FnArg { + fn as_future_impl_type(&self) -> Option<&Type> { + match self { + FnArg::Typed(PatType { ty, .. }) if can_impl_future(ty.as_ref()) => Some(ty.as_ref()), + _ => None, + } + } + + fn as_mut_future_impl_type(&mut self) -> Option<&mut Type> { + match self { + FnArg::Typed(PatType { ty, .. }) if can_impl_future(ty.as_ref()) => Some(ty.as_mut()), + _ => None, + } + } +} + +fn can_impl_future(ty: &Type) -> bool { + use Type::*; + !matches!( + ty, + Group(_) + | ImplTrait(_) + | Infer(_) + | Macro(_) + | Never(_) + | Slice(_) + | TraitObject(_) + | Verbatim(_) + ) +} + +/// Simple struct used to visit function attributes and extract future args to +/// implement the boilerplate. +#[derive(Default)] +struct FutureFunctionExtractor { + futures: Vec<(Ident, FutureArg)>, + awt: bool, + errors: Vec, +} + +impl FutureFunctionExtractor { + pub(crate) fn take(self) -> Result<(Vec<(Ident, FutureArg)>, bool), ErrorsVec> { + if self.errors.is_empty() { + Ok((self.futures, self.awt)) + } else { + Err(self.errors.into()) + } + } +} + +impl VisitMut for FutureFunctionExtractor { + fn visit_item_fn_mut(&mut self, node: &mut ItemFn) { + let attrs = std::mem::take(&mut node.attrs); + let (awts, remain): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|a| attr_is(a, "awt")); + self.awt = match awts.len().cmp(&1) { + std::cmp::Ordering::Equal => true, + std::cmp::Ordering::Greater => { + self.errors.extend(awts.into_iter().skip(1).map(|a| { + syn::Error::new_spanned( + a.into_token_stream(), + "Cannot use #[awt] more than once.".to_owned(), + ) + })); + false + } + std::cmp::Ordering::Less => false, + }; + node.attrs = remain; + syn::visit_mut::visit_item_fn_mut(self, node); + } + + fn visit_fn_arg_mut(&mut self, node: &mut FnArg) { + if matches!(node, FnArg::Receiver(_)) { + return; + } + match extract_argument_attrs( + node, + |a| attr_is(a, "future"), + |arg, name| { + let kind = if arg.tokens.is_empty() { + FutureArg::Define + } else { + match arg.parse_args::>()? { + Some(awt) if awt == format_ident!("awt") => FutureArg::Await, + None => FutureArg::Define, + Some(invalid) => { + return Err(syn::Error::new_spanned( + arg.parse_args::>()?.into_token_stream(), + format!("Invalid '{invalid}' #[future(...)] arg."), + )); + } + } + }; + Ok((arg, name.clone(), kind)) + }, + ) + .collect::, _>>() + { + Ok(futures) => match futures.len().cmp(&1) { + std::cmp::Ordering::Equal => match node.as_future_impl_type() { + Some(_) => self.futures.push((futures[0].1.clone(), futures[0].2)), + None => self.errors.push(syn::Error::new_spanned( + node.maybe_type().unwrap().into_token_stream(), + "This type cannot used to generate impl Future.".to_owned(), + )), + }, + std::cmp::Ordering::Greater => { + self.errors + .extend(futures.iter().skip(1).map(|(attr, _ident, _type)| { + syn::Error::new_spanned( + attr.into_token_stream(), + "Cannot use #[future] more than once.".to_owned(), + ) + })); + } + std::cmp::Ordering::Less => {} + }, + Err(e) => { + self.errors.push(e); + } + }; + } +} + +#[cfg(test)] +mod should { + use super::*; + use crate::test::{assert_eq, *}; + use rstest_test::assert_in; + + #[rstest] + #[case("fn simple(a: u32) {}")] + #[case("fn more(a: u32, b: &str) {}")] + #[case("fn gen>(a: u32, b: S) {}")] + #[case("fn attr(#[case] a: u32, #[values(1,2)] b: i32) {}")] + fn not_change_anything_if_no_future_attribute_found(#[case] item_fn: &str) { + let mut item_fn: ItemFn = item_fn.ast(); + let orig = item_fn.clone(); + + let (futures, awt) = extract_futures(&mut item_fn).unwrap(); + + assert_eq!(orig, item_fn); + assert!(futures.is_empty()); + assert!(!awt); + } + + #[rstest] + #[case::simple("fn f(#[future] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Define)], false)] + #[case::global_awt("#[awt] fn f(a: u32) {}", "fn f(a: u32) {}", &[], true)] + #[case::simple_awaited("fn f(#[future(awt)] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Await)], false)] + #[case::simple_awaited_and_global("#[awt] fn f(#[future(awt)] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Await)], true)] + #[case::more_than_one( + "fn f(#[future] a: u32, #[future(awt)] b: String, #[future()] c: std::collection::HashMap) {}", + r#"fn f(a: u32, + b: String, + c: std::collection::HashMap) {}"#, + &[("a", FutureArg::Define), ("b", FutureArg::Await), ("c", FutureArg::Define)], + false, + )] + #[case::just_one( + "fn f(a: u32, #[future] b: String) {}", + r#"fn f(a: u32, b: String) {}"#, + &[("b", FutureArg::Define)], + false, + )] + #[case::just_one_awaited( + "fn f(a: u32, #[future(awt)] b: String) {}", + r#"fn f(a: u32, b: String) {}"#, + &[("b", FutureArg::Await)], + false, + )] + fn extract( + #[case] item_fn: &str, + #[case] expected: &str, + #[case] expected_futures: &[(&str, FutureArg)], + #[case] expected_awt: bool, + ) { + let mut item_fn: ItemFn = item_fn.ast(); + let expected: ItemFn = expected.ast(); + + let (futures, awt) = extract_futures(&mut item_fn).unwrap(); + + assert_eq!(expected, item_fn); + assert_eq!( + futures, + expected_futures + .into_iter() + .map(|(id, a)| (ident(id), *a)) + .collect::>() + ); + assert_eq!(expected_awt, awt); + } + + #[rstest] + #[case::base(r#"#[awt] fn f(a: u32) {}"#, r#"fn f(a: u32) {}"#)] + #[case::two( + r#" + #[awt] + #[awt] + fn f(a: u32) {} + "#, + r#"fn f(a: u32) {}"# + )] + #[case::inner( + r#" + #[one] + #[awt] + #[two] + fn f(a: u32) {} + "#, + r#" + #[one] + #[two] + fn f(a: u32) {} + "# + )] + fn remove_all_awt_attributes(#[case] item_fn: &str, #[case] expected: &str) { + let mut item_fn: ItemFn = item_fn.ast(); + let expected: ItemFn = expected.ast(); + + let _ = extract_futures(&mut item_fn); + + assert_eq!(item_fn, expected); + } + + #[rstest] + #[case::no_more_than_one("fn f(#[future] #[future] a: u32) {}", "more than once")] + #[case::no_impl("fn f(#[future] a: impl AsRef) {}", "generate impl Future")] + #[case::no_slice("fn f(#[future] a: [i32]) {}", "generate impl Future")] + #[case::invalid_arg("fn f(#[future(other)] a: [i32]) {}", "Invalid 'other'")] + #[case::no_more_than_one_awt("#[awt] #[awt] fn f(a: u32) {}", "more than once")] + fn raise_error(#[case] item_fn: &str, #[case] message: &str) { + let mut item_fn: ItemFn = item_fn.ast(); + + let err = extract_futures(&mut item_fn).unwrap_err(); + + assert_in!(format!("{:?}", err), message); + } +} diff --git a/src/parse/macros.rs b/src/parse/macros.rs new file mode 100644 index 0000000..beb47cf --- /dev/null +++ b/src/parse/macros.rs @@ -0,0 +1,27 @@ +macro_rules! wrap_attributes { + ($ident:ident) => { + #[derive(Default, Debug, PartialEq, Clone)] + pub(crate) struct $ident { + inner: Attributes, + } + + impl From for $ident { + fn from(inner: Attributes) -> Self { + $ident { inner } + } + } + + impl $ident { + fn iter(&self) -> impl Iterator { + self.inner.attributes.iter() + } + } + + impl $ident { + #[allow(dead_code)] + pub(crate) fn append(&mut self, attr: Attribute) { + self.inner.attributes.push(attr) + } + } + }; +} diff --git a/src/parse/mod.rs b/src/parse/mod.rs new file mode 100644 index 0000000..c772163 --- /dev/null +++ b/src/parse/mod.rs @@ -0,0 +1,826 @@ +use proc_macro2::TokenStream; +use syn::{ + parse::{Parse, ParseStream}, + parse_quote, + punctuated::Punctuated, + token::{self, Async, Paren}, + visit_mut::VisitMut, + FnArg, Ident, ItemFn, Token, +}; + +use crate::{ + error::ErrorsVec, + refident::{MaybeIdent, RefIdent}, + utils::{attr_is, attr_starts_with}, +}; +use fixture::{ + ArgumentValue, DefaultsFunctionExtractor, FixtureModifiers, FixturesFunctionExtractor, +}; +use quote::ToTokens; +use testcase::TestCase; + +use self::{expressions::Expressions, vlist::ValueList}; + +// To use the macros this should be the first one module +#[macro_use] +pub(crate) mod macros; + +pub(crate) mod expressions; +pub(crate) mod fixture; +pub(crate) mod future; +pub(crate) mod rstest; +pub(crate) mod testcase; +pub(crate) mod vlist; + +pub(crate) trait ExtendWithFunctionAttrs { + fn extend_with_function_attrs( + &mut self, + item_fn: &mut ItemFn, + ) -> std::result::Result<(), ErrorsVec>; +} + +#[derive(Default, Debug, PartialEq, Clone)] +pub(crate) struct Attributes { + pub(crate) attributes: Vec, +} + +impl Parse for Attributes { + fn parse(input: ParseStream) -> syn::Result { + let vars = Punctuated::::parse_terminated(input)?; + Ok(Attributes { + attributes: vars.into_iter().collect(), + }) + } +} + +#[derive(Debug, PartialEq, Clone)] +pub(crate) enum Attribute { + Attr(Ident), + Tagged(Ident, Vec), + Type(Ident, Box), +} + +impl Parse for Attribute { + fn parse(input: ParseStream) -> syn::Result { + if input.peek2(Token![<]) { + let tag = input.parse()?; + let _open = input.parse::()?; + let inner = input.parse()?; + let _close = input.parse::]>()?; + Ok(Attribute::Type(tag, inner)) + } else if input.peek2(Token![::]) { + let inner = input.parse()?; + Ok(Attribute::Attr(inner)) + } else if input.peek2(token::Paren) { + let tag = input.parse()?; + let content; + let _ = syn::parenthesized!(content in input); + let args = Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(); + + Ok(Attribute::Tagged(tag, args)) + } else { + Ok(Attribute::Attr(input.parse()?)) + } + } +} + +fn parse_vector_trailing_till_double_comma(input: ParseStream) -> syn::Result> +where + T: Parse, + P: syn::token::Token + Parse, +{ + Ok( + Punctuated::, P>::parse_separated_nonempty_with(input, |input_tokens| { + if input_tokens.is_empty() || input_tokens.peek(Token![::]) { + Ok(None) + } else { + T::parse(input_tokens).map(Some) + } + })? + .into_iter() + .flatten() + .collect(), + ) +} + +#[allow(dead_code)] +pub(crate) fn drain_stream(input: ParseStream) { + // JUST TO SKIP ALL + let _ = input.step(|cursor| { + let mut rest = *cursor; + while let Some((_, next)) = rest.token_tree() { + rest = next + } + Ok(((), rest)) + }); +} + +#[derive(PartialEq, Debug, Clone, Default)] +pub(crate) struct Positional(pub(crate) Vec); + +impl Parse for Positional { + fn parse(input: ParseStream) -> syn::Result { + Ok(Self( + Punctuated::::parse_terminated(input)? + .into_iter() + .collect(), + )) + } +} + +#[derive(PartialEq, Debug, Clone)] +pub(crate) struct Fixture { + pub(crate) name: Ident, + pub(crate) resolve: Option, + pub(crate) positional: Positional, +} + +impl Fixture { + pub(crate) fn new(name: Ident, resolve: Option, positional: Positional) -> Self { + Self { + name, + resolve, + positional, + } + } +} + +impl Parse for Fixture { + fn parse(input: ParseStream) -> syn::Result { + let resolve = input.parse()?; + if input.peek(Paren) || input.peek(Token![as]) { + let positional = if input.peek(Paren) { + let content; + let _ = syn::parenthesized!(content in input); + content.parse()? + } else { + Default::default() + }; + + if input.peek(Token![as]) { + let _: Token![as] = input.parse()?; + Ok(Self::new(input.parse()?, Some(resolve), positional)) + } else { + Ok(Self::new(resolve, None, positional)) + } + } else { + Err(syn::Error::new( + input.span(), + "fixture need arguments or 'as new_name' format", + )) + } + } +} + +impl RefIdent for Fixture { + fn ident(&self) -> &Ident { + &self.name + } +} + +impl ToTokens for Fixture { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.name.to_tokens(tokens) + } +} + +pub(crate) fn extract_fixtures(item_fn: &mut ItemFn) -> Result, ErrorsVec> { + let mut fixtures_extractor = FixturesFunctionExtractor::default(); + fixtures_extractor.visit_item_fn_mut(item_fn); + + if fixtures_extractor.1.is_empty() { + Ok(fixtures_extractor.0) + } else { + Err(fixtures_extractor.1.into()) + } +} + +pub(crate) fn extract_defaults(item_fn: &mut ItemFn) -> Result, ErrorsVec> { + let mut defaults_extractor = DefaultsFunctionExtractor::default(); + defaults_extractor.visit_item_fn_mut(item_fn); + + if defaults_extractor.1.is_empty() { + Ok(defaults_extractor.0) + } else { + Err(defaults_extractor.1.into()) + } +} + +pub(crate) fn extract_default_return_type( + item_fn: &mut ItemFn, +) -> Result, ErrorsVec> { + let mut default_type_extractor = DefaultTypeFunctionExtractor::default(); + default_type_extractor.visit_item_fn_mut(item_fn); + default_type_extractor.take() +} + +pub(crate) fn extract_partials_return_type( + item_fn: &mut ItemFn, +) -> Result, ErrorsVec> { + let mut partials_type_extractor = PartialsTypeFunctionExtractor::default(); + partials_type_extractor.visit_item_fn_mut(item_fn); + partials_type_extractor.take() +} + +pub(crate) fn extract_once(item_fn: &mut ItemFn) -> Result, ErrorsVec> { + let mut extractor = IsOnceAttributeFunctionExtractor::default(); + extractor.visit_item_fn_mut(item_fn); + extractor.take() +} + +pub(crate) fn extract_argument_attrs<'a, B: 'a + std::fmt::Debug>( + node: &mut FnArg, + is_valid_attr: fn(&syn::Attribute) -> bool, + build: fn(syn::Attribute, &Ident) -> syn::Result, +) -> Box> + 'a> { + let name = node.maybe_ident().cloned(); + if name.is_none() { + return Box::new(std::iter::empty()); + } + + let name = name.unwrap(); + if let FnArg::Typed(ref mut arg) = node { + // Extract interesting attributes + let attrs = std::mem::take(&mut arg.attrs); + let (extracted, remain): (Vec<_>, Vec<_>) = attrs.into_iter().partition(is_valid_attr); + + arg.attrs = remain; + + // Parse attrs + Box::new(extracted.into_iter().map(move |attr| build(attr, &name))) + } else { + Box::new(std::iter::empty()) + } +} + +/// Simple struct used to visit function attributes and extract default return +/// type +struct DefaultTypeFunctionExtractor(Result, ErrorsVec>); + +impl DefaultTypeFunctionExtractor { + fn take(self) -> Result, ErrorsVec> { + self.0 + } +} + +impl Default for DefaultTypeFunctionExtractor { + fn default() -> Self { + Self(Ok(None)) + } +} + +impl VisitMut for DefaultTypeFunctionExtractor { + fn visit_item_fn_mut(&mut self, node: &mut ItemFn) { + let attrs = std::mem::take(&mut node.attrs); + let (defaults, remain): (Vec<_>, Vec<_>) = attrs + .into_iter() + .partition(|attr| attr_is(attr, FixtureModifiers::DEFAULT_RET_ATTR)); + + node.attrs = remain; + let mut defaults = defaults.into_iter(); + let mut data = None; + let mut errors = ErrorsVec::default(); + match defaults.next().map(|def| def.parse_args::()) { + Some(Ok(t)) => data = Some(t), + Some(Err(e)) => errors.push(e), + None => {} + }; + errors.extend( + defaults.map(|a| syn::Error::new_spanned(a, "You cannot use default more than once")), + ); + self.0 = if errors.len() > 0 { + Err(errors) + } else { + Ok(data) + }; + + syn::visit_mut::visit_item_fn_mut(self, node); + } +} + +/// Simple struct used to visit function attributes and extract default return +/// type +struct PartialsTypeFunctionExtractor(Result, ErrorsVec>); + +impl PartialsTypeFunctionExtractor { + fn take(self) -> Result, ErrorsVec> { + self.0 + } +} + +impl Default for PartialsTypeFunctionExtractor { + fn default() -> Self { + Self(Ok(Vec::default())) + } +} + +impl VisitMut for PartialsTypeFunctionExtractor { + fn visit_item_fn_mut(&mut self, node: &mut ItemFn) { + let attrs = std::mem::take(&mut node.attrs); + let (partials, remain): (Vec<_>, Vec<_>) = + attrs + .into_iter() + .partition(|attr| match attr.path.get_ident() { + Some(name) => name + .to_string() + .starts_with(FixtureModifiers::PARTIAL_RET_ATTR), + None => false, + }); + + node.attrs = remain; + let mut errors = ErrorsVec::default(); + let mut data: Vec<(usize, syn::Type)> = Vec::default(); + for attr in partials { + match attr.parse_args::() { + Ok(t) => { + match attr.path.get_ident().unwrap().to_string() + [FixtureModifiers::PARTIAL_RET_ATTR.len()..] + .parse() + { + Ok(id) => data.push((id, t)), + Err(_) => errors.push(syn::Error::new_spanned( + attr, + "Invalid partial syntax: should be partial_", + )), + } + } + Err(e) => errors.push(e), + } + } + self.0 = if errors.len() > 0 { + Err(errors) + } else { + Ok(data) + }; + + syn::visit_mut::visit_item_fn_mut(self, node); + } +} + +/// Simple struct used to visit function attributes and extract once +/// type +struct IsOnceAttributeFunctionExtractor(Result, ErrorsVec>); + +impl IsOnceAttributeFunctionExtractor { + fn take(self) -> Result, ErrorsVec> { + self.0 + } +} + +impl Default for IsOnceAttributeFunctionExtractor { + fn default() -> Self { + Self(Ok(None)) + } +} + +impl VisitMut for IsOnceAttributeFunctionExtractor { + fn visit_item_fn_mut(&mut self, node: &mut ItemFn) { + let attrs = std::mem::take(&mut node.attrs); + let (onces, remain): (Vec<_>, Vec<_>) = + attrs.into_iter().partition(|attr| attr_is(attr, "once")); + + node.attrs = remain; + self.0 = match onces.len() { + 1 => Ok(onces[0].path.get_ident().cloned()), + 0 => Ok(None), + _ => Err(onces + .into_iter() + .skip(1) + .map(|attr| syn::Error::new_spanned(attr, "You cannot use #[once] more than once")) + .collect::>() + .into()), + }; + syn::visit_mut::visit_item_fn_mut(self, node); + } +} + +/// Simple struct used to visit function attributes and extract case arguments and +/// eventualy parsing errors +#[derive(Default)] +struct CaseArgsFunctionExtractor(Vec, Vec); + +impl VisitMut for CaseArgsFunctionExtractor { + fn visit_fn_arg_mut(&mut self, node: &mut FnArg) { + for r in extract_argument_attrs(node, |a| attr_is(a, "case"), |_a, name| Ok(name.clone())) { + match r { + Ok(value) => self.0.push(value), + Err(err) => self.1.push(err), + } + } + + syn::visit_mut::visit_fn_arg_mut(self, node); + } +} + +pub(crate) fn extract_case_args(item_fn: &mut ItemFn) -> Result, ErrorsVec> { + let mut case_args_extractor = CaseArgsFunctionExtractor::default(); + case_args_extractor.visit_item_fn_mut(item_fn); + + if case_args_extractor.1.is_empty() { + Ok(case_args_extractor.0) + } else { + Err(case_args_extractor.1.into()) + } +} + +/// Simple struct used to visit function attributes and extract cases and +/// eventualy parsing errors +#[derive(Default)] +struct CasesFunctionExtractor(Vec, Vec); + +impl VisitMut for CasesFunctionExtractor { + fn visit_item_fn_mut(&mut self, node: &mut ItemFn) { + let attrs = std::mem::take(&mut node.attrs); + let mut attrs_buffer = Default::default(); + let case: syn::PathSegment = parse_quote! { case }; + for attr in attrs.into_iter() { + if attr_starts_with(&attr, &case) { + match attr.parse_args::() { + Ok(expressions) => { + let description = attr.path.segments.into_iter().nth(1).map(|p| p.ident); + self.0.push(TestCase { + args: expressions.into(), + attrs: std::mem::take(&mut attrs_buffer), + description, + }); + } + Err(err) => self.1.push(err), + }; + } else { + attrs_buffer.push(attr) + } + } + node.attrs = std::mem::take(&mut attrs_buffer); + syn::visit_mut::visit_item_fn_mut(self, node); + } +} + +pub(crate) fn extract_cases(item_fn: &mut ItemFn) -> Result, ErrorsVec> { + let mut cases_extractor = CasesFunctionExtractor::default(); + cases_extractor.visit_item_fn_mut(item_fn); + + if cases_extractor.1.is_empty() { + Ok(cases_extractor.0) + } else { + Err(cases_extractor.1.into()) + } +} + +/// Simple struct used to visit function attributes and extract value list and +/// eventualy parsing errors +#[derive(Default)] +struct ValueListFunctionExtractor(Vec, Vec); + +impl VisitMut for ValueListFunctionExtractor { + fn visit_fn_arg_mut(&mut self, node: &mut FnArg) { + for r in extract_argument_attrs( + node, + |a| attr_is(a, "values"), + |a, name| { + a.parse_args::().map(|v| ValueList { + arg: name.clone(), + values: v.take(), + }) + }, + ) { + match r { + Ok(vlist) => self.0.push(vlist), + Err(err) => self.1.push(err), + } + } + + syn::visit_mut::visit_fn_arg_mut(self, node); + } +} + +pub(crate) fn extract_value_list(item_fn: &mut ItemFn) -> Result, ErrorsVec> { + let mut vlist_extractor = ValueListFunctionExtractor::default(); + vlist_extractor.visit_item_fn_mut(item_fn); + + if vlist_extractor.1.is_empty() { + Ok(vlist_extractor.0) + } else { + Err(vlist_extractor.1.into()) + } +} + +/// Simple struct used to visit function args attributes to extract the +/// excluded ones and eventualy parsing errors +struct ExcludedTraceAttributesFunctionExtractor(Result, ErrorsVec>); +impl From, ErrorsVec>> for ExcludedTraceAttributesFunctionExtractor { + fn from(inner: Result, ErrorsVec>) -> Self { + Self(inner) + } +} + +impl ExcludedTraceAttributesFunctionExtractor { + pub(crate) fn take(self) -> Result, ErrorsVec> { + self.0 + } + + fn update_error(&mut self, mut errors: ErrorsVec) { + match &mut self.0 { + Ok(_) => self.0 = Err(errors), + Err(err) => err.append(&mut errors), + } + } + + fn update_excluded(&mut self, value: Ident) { + if let Some(inner) = self.0.iter_mut().next() { + inner.push(value); + } + } +} + +impl Default for ExcludedTraceAttributesFunctionExtractor { + fn default() -> Self { + Self(Ok(Default::default())) + } +} + +impl VisitMut for ExcludedTraceAttributesFunctionExtractor { + fn visit_fn_arg_mut(&mut self, node: &mut FnArg) { + for r in + extract_argument_attrs(node, |a| attr_is(a, "notrace"), |_a, name| Ok(name.clone())) + { + match r { + Ok(value) => self.update_excluded(value), + Err(err) => self.update_error(err.into()), + } + } + + syn::visit_mut::visit_fn_arg_mut(self, node); + } +} + +pub(crate) fn extract_excluded_trace(item_fn: &mut ItemFn) -> Result, ErrorsVec> { + let mut excluded_trace_extractor = ExcludedTraceAttributesFunctionExtractor::default(); + excluded_trace_extractor.visit_item_fn_mut(item_fn); + excluded_trace_extractor.take() +} + +/// Simple struct used to visit function args attributes to check timeout syntax +struct CheckTimeoutAttributesFunction(Result<(), ErrorsVec>); +impl From for CheckTimeoutAttributesFunction { + fn from(errors: ErrorsVec) -> Self { + Self(Err(errors)) + } +} + +impl CheckTimeoutAttributesFunction { + pub(crate) fn take(self) -> Result<(), ErrorsVec> { + self.0 + } + + fn check_if_can_implement_timeous( + &self, + timeouts: &[&syn::Attribute], + asyncness: Option<&Async>, + ) -> Option { + if cfg!(feature = "async-timeout") || timeouts.is_empty() { + None + } else { + asyncness.map(|a| { + syn::Error::new( + a.span, + "Enable async-timeout feature to use timeout in async tests", + ) + }) + } + } +} + +impl Default for CheckTimeoutAttributesFunction { + fn default() -> Self { + Self(Ok(())) + } +} + +impl VisitMut for CheckTimeoutAttributesFunction { + fn visit_item_fn_mut(&mut self, node: &mut ItemFn) { + let timeouts = node + .attrs + .iter() + .filter(|&a| attr_is(a, "timeout")) + .collect::>(); + let mut errors = timeouts + .iter() + .map(|&attr| attr.parse_args::()) + .filter_map(Result::err) + .collect::>(); + + if let Some(e) = + self.check_if_can_implement_timeous(timeouts.as_slice(), node.sig.asyncness.as_ref()) + { + errors.push(e); + } + if !errors.is_empty() { + *self = Self(Err(errors.into())); + } + } +} + +pub(crate) fn check_timeout_attrs(item_fn: &mut ItemFn) -> Result<(), ErrorsVec> { + let mut checker = CheckTimeoutAttributesFunction::default(); + checker.visit_item_fn_mut(item_fn); + checker.take() +} + +pub(crate) mod arguments { + use std::collections::HashMap; + + use syn::Ident; + + #[derive(PartialEq, Debug, Clone, Copy)] + #[allow(dead_code)] + pub(crate) enum FutureArg { + None, + Define, + Await, + } + + impl Default for FutureArg { + fn default() -> Self { + FutureArg::None + } + } + + #[derive(PartialEq, Default, Debug)] + pub(crate) struct ArgumentInfo { + future: FutureArg, + } + + impl ArgumentInfo { + fn future(future: FutureArg) -> Self { + Self { future } + } + + fn is_future(&self) -> bool { + use FutureArg::*; + + matches!(self.future, Define | Await) + } + + fn is_future_await(&self) -> bool { + use FutureArg::*; + + matches!(self.future, Await) + } + } + + #[derive(PartialEq, Default, Debug)] + pub(crate) struct ArgumentsInfo { + args: HashMap, + is_global_await: bool, + } + + impl ArgumentsInfo { + pub(crate) fn set_future(&mut self, ident: Ident, kind: FutureArg) { + self.args + .entry(ident) + .and_modify(|v| v.future = kind) + .or_insert_with(|| ArgumentInfo::future(kind)); + } + + pub(crate) fn set_futures(&mut self, futures: impl Iterator) { + futures.for_each(|(ident, k)| self.set_future(ident, k)); + } + + pub(crate) fn set_global_await(&mut self, is_global_await: bool) { + self.is_global_await = is_global_await; + } + + #[allow(dead_code)] + pub(crate) fn add_future(&mut self, ident: Ident) { + self.set_future(ident, FutureArg::Define); + } + + pub(crate) fn is_future(&self, id: &Ident) -> bool { + self.args + .get(id) + .map(|arg| arg.is_future()) + .unwrap_or_default() + } + + pub(crate) fn is_future_await(&self, ident: &Ident) -> bool { + match self.args.get(ident) { + Some(arg) => arg.is_future_await() || (arg.is_future() && self.is_global_await()), + None => false, + } + } + + pub(crate) fn is_global_await(&self) -> bool { + self.is_global_await + } + } + + #[cfg(test)] + mod should_implement_is_future_await_logic { + use super::*; + use crate::test::*; + + #[fixture] + fn info() -> ArgumentsInfo { + let mut a = ArgumentsInfo::default(); + a.set_future(ident("simple"), FutureArg::Define); + a.set_future(ident("other_simple"), FutureArg::Define); + a.set_future(ident("awaited"), FutureArg::Await); + a.set_future(ident("other_awaited"), FutureArg::Await); + a.set_future(ident("none"), FutureArg::None); + a + } + + #[rstest] + fn no_matching_ident(info: ArgumentsInfo) { + assert!(!info.is_future_await(&ident("some"))); + assert!(!info.is_future_await(&ident("simple"))); + assert!(!info.is_future_await(&ident("none"))); + } + + #[rstest] + fn matching_ident(info: ArgumentsInfo) { + assert!(info.is_future_await(&ident("awaited"))); + assert!(info.is_future_await(&ident("other_awaited"))); + } + + #[rstest] + fn global_matching_future_ident(mut info: ArgumentsInfo) { + info.set_global_await(true); + assert!(info.is_future_await(&ident("simple"))); + assert!(info.is_future_await(&ident("other_simple"))); + assert!(info.is_future_await(&ident("awaited"))); + + assert!(!info.is_future_await(&ident("some"))); + assert!(!info.is_future_await(&ident("none"))); + } + } +} + +#[cfg(test)] +mod should { + use super::*; + use crate::test::*; + + mod parse_attributes { + use super::assert_eq; + use super::*; + + fn parse_attributes>(attributes: S) -> Attributes { + parse_meta(attributes) + } + + #[test] + fn one_simple_ident() { + let attributes = parse_attributes("my_ident"); + + let expected = Attributes { + attributes: vec![Attribute::attr("my_ident")], + }; + + assert_eq!(expected, attributes); + } + + #[test] + fn one_simple_group() { + let attributes = parse_attributes("group_tag(first, second)"); + + let expected = Attributes { + attributes: vec![Attribute::tagged("group_tag", vec!["first", "second"])], + }; + + assert_eq!(expected, attributes); + } + + #[test] + fn one_simple_type() { + let attributes = parse_attributes("type_tag<(u32, T, (String, i32))>"); + + let expected = Attributes { + attributes: vec![Attribute::typed("type_tag", "(u32, T, (String, i32))")], + }; + + assert_eq!(expected, attributes); + } + + #[test] + fn integrated() { + let attributes = parse_attributes( + r#" + simple :: tagged(first, second) :: type_tag<(u32, T, (std::string::String, i32))> :: more_tagged(a,b)"#, + ); + + let expected = Attributes { + attributes: vec![ + Attribute::attr("simple"), + Attribute::tagged("tagged", vec!["first", "second"]), + Attribute::typed("type_tag", "(u32, T, (std::string::String, i32))"), + Attribute::tagged("more_tagged", vec!["a", "b"]), + ], + }; + + assert_eq!(expected, attributes); + } + } +} diff --git a/src/parse/rstest.rs b/src/parse/rstest.rs new file mode 100644 index 0000000..ef398b0 --- /dev/null +++ b/src/parse/rstest.rs @@ -0,0 +1,935 @@ +use syn::{ + parse::{Parse, ParseStream}, + Ident, ItemFn, Token, +}; + +use super::testcase::TestCase; +use super::{ + arguments::ArgumentsInfo, check_timeout_attrs, extract_case_args, extract_cases, + extract_excluded_trace, extract_fixtures, extract_value_list, future::extract_futures, + parse_vector_trailing_till_double_comma, Attribute, Attributes, ExtendWithFunctionAttrs, + Fixture, +}; +use crate::parse::vlist::ValueList; +use crate::{ + error::ErrorsVec, + refident::{MaybeIdent, RefIdent}, +}; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, ToTokens}; + +#[derive(PartialEq, Debug, Default)] +pub(crate) struct RsTestInfo { + pub(crate) data: RsTestData, + pub(crate) attributes: RsTestAttributes, + pub(crate) arguments: ArgumentsInfo, +} + +impl Parse for RsTestInfo { + fn parse(input: ParseStream) -> syn::Result { + Ok(if input.is_empty() { + Default::default() + } else { + Self { + data: input.parse()?, + attributes: input + .parse::() + .or_else(|_| Ok(Default::default())) + .and_then(|_| input.parse())?, + arguments: Default::default(), + } + }) + } +} + +impl ExtendWithFunctionAttrs for RsTestInfo { + fn extend_with_function_attrs(&mut self, item_fn: &mut ItemFn) -> Result<(), ErrorsVec> { + let composed_tuple!(_inner, excluded, _timeout, futures) = merge_errors!( + self.data.extend_with_function_attrs(item_fn), + extract_excluded_trace(item_fn), + check_timeout_attrs(item_fn), + extract_futures(item_fn) + )?; + let (futures, global_awt) = futures; + self.attributes.add_notraces(excluded); + self.arguments.set_global_await(global_awt); + self.arguments.set_futures(futures.into_iter()); + Ok(()) + } +} + +#[derive(PartialEq, Debug, Default)] +pub(crate) struct RsTestData { + pub(crate) items: Vec, +} + +impl RsTestData { + pub(crate) fn case_args(&self) -> impl Iterator { + self.items.iter().filter_map(|it| match it { + RsTestItem::CaseArgName(ref arg) => Some(arg), + _ => None, + }) + } + + #[allow(dead_code)] + pub(crate) fn has_case_args(&self) -> bool { + self.case_args().next().is_some() + } + + pub(crate) fn cases(&self) -> impl Iterator { + self.items.iter().filter_map(|it| match it { + RsTestItem::TestCase(ref case) => Some(case), + _ => None, + }) + } + + pub(crate) fn has_cases(&self) -> bool { + self.cases().next().is_some() + } + + pub(crate) fn fixtures(&self) -> impl Iterator { + self.items.iter().filter_map(|it| match it { + RsTestItem::Fixture(ref fixture) => Some(fixture), + _ => None, + }) + } + + #[allow(dead_code)] + pub(crate) fn has_fixtures(&self) -> bool { + self.fixtures().next().is_some() + } + + pub(crate) fn list_values(&self) -> impl Iterator { + self.items.iter().filter_map(|mv| match mv { + RsTestItem::ValueList(ref value_list) => Some(value_list), + _ => None, + }) + } + + pub(crate) fn has_list_values(&self) -> bool { + self.list_values().next().is_some() + } +} + +impl Parse for RsTestData { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(Token![::]) { + Ok(Default::default()) + } else { + Ok(Self { + items: parse_vector_trailing_till_double_comma::<_, Token![,]>(input)?, + }) + } + } +} + +impl ExtendWithFunctionAttrs for RsTestData { + fn extend_with_function_attrs(&mut self, item_fn: &mut ItemFn) -> Result<(), ErrorsVec> { + let composed_tuple!(fixtures, case_args, cases, value_list) = merge_errors!( + extract_fixtures(item_fn), + extract_case_args(item_fn), + extract_cases(item_fn), + extract_value_list(item_fn) + )?; + + self.items.extend(fixtures.into_iter().map(|f| f.into())); + self.items.extend(case_args.into_iter().map(|f| f.into())); + self.items.extend(cases.into_iter().map(|f| f.into())); + self.items.extend(value_list.into_iter().map(|f| f.into())); + Ok(()) + } +} + +#[derive(PartialEq, Debug)] +pub(crate) enum RsTestItem { + Fixture(Fixture), + CaseArgName(Ident), + TestCase(TestCase), + ValueList(ValueList), +} + +impl From for RsTestItem { + fn from(f: Fixture) -> Self { + RsTestItem::Fixture(f) + } +} + +impl From for RsTestItem { + fn from(ident: Ident) -> Self { + RsTestItem::CaseArgName(ident) + } +} + +impl From for RsTestItem { + fn from(case: TestCase) -> Self { + RsTestItem::TestCase(case) + } +} + +impl From for RsTestItem { + fn from(value_list: ValueList) -> Self { + RsTestItem::ValueList(value_list) + } +} + +impl Parse for RsTestItem { + fn parse(input: ParseStream) -> syn::Result { + if input.fork().parse::().is_ok() { + input.parse::().map(RsTestItem::TestCase) + } else if input.peek2(Token![=>]) { + input.parse::().map(RsTestItem::ValueList) + } else if input.fork().parse::().is_ok() { + input.parse::().map(RsTestItem::Fixture) + } else if input.fork().parse::().is_ok() { + input.parse::().map(RsTestItem::CaseArgName) + } else { + Err(syn::Error::new(Span::call_site(), "Cannot parse it")) + } + } +} + +impl MaybeIdent for RsTestItem { + fn maybe_ident(&self) -> Option<&Ident> { + use RsTestItem::*; + match self { + Fixture(ref fixture) => Some(fixture.ident()), + CaseArgName(ref case_arg) => Some(case_arg), + ValueList(ref value_list) => Some(value_list.ident()), + TestCase(_) => None, + } + } +} + +impl ToTokens for RsTestItem { + fn to_tokens(&self, tokens: &mut TokenStream) { + use RsTestItem::*; + match self { + Fixture(ref fixture) => fixture.to_tokens(tokens), + CaseArgName(ref case_arg) => case_arg.to_tokens(tokens), + TestCase(ref case) => case.to_tokens(tokens), + ValueList(ref list) => list.to_tokens(tokens), + } + } +} + +wrap_attributes!(RsTestAttributes); + +impl RsTestAttributes { + const TRACE_VARIABLE_ATTR: &'static str = "trace"; + const NOTRACE_VARIABLE_ATTR: &'static str = "notrace"; + + pub(crate) fn trace_me(&self, ident: &Ident) -> bool { + if self.should_trace() { + !self.iter().any(|m| Self::is_notrace(ident, m)) + } else { + false + } + } + + fn is_notrace(ident: &Ident, m: &Attribute) -> bool { + match m { + Attribute::Tagged(i, args) if i == Self::NOTRACE_VARIABLE_ATTR => { + args.iter().any(|a| a == ident) + } + _ => false, + } + } + + pub(crate) fn should_trace(&self) -> bool { + self.iter().any(Self::is_trace) + } + + pub(crate) fn add_trace(&mut self, trace: Ident) { + self.inner.attributes.push(Attribute::Attr(trace)); + } + + pub(crate) fn add_notraces(&mut self, notraces: Vec) { + if notraces.is_empty() { + return; + } + self.inner.attributes.push(Attribute::Tagged( + format_ident!("{}", Self::NOTRACE_VARIABLE_ATTR), + notraces, + )); + } + + fn is_trace(m: &Attribute) -> bool { + matches!(m, Attribute::Attr(i) if i == Self::TRACE_VARIABLE_ATTR) + } +} + +impl Parse for RsTestAttributes { + fn parse(input: ParseStream) -> syn::Result { + Ok(input.parse::()?.into()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::test::{assert_eq, *}; + + mod parse_rstest_data { + use super::assert_eq; + use super::*; + + fn parse_rstest_data>(fixtures: S) -> RsTestData { + parse_meta(fixtures) + } + + #[test] + fn one_arg() { + let fixtures = parse_rstest_data("my_fixture(42)"); + + let expected = RsTestData { + items: vec![fixture("my_fixture", &["42"]).into()], + }; + + assert_eq!(expected, fixtures); + } + } + + #[test] + fn should_check_all_timeout_to_catch_the_right_errors() { + let mut item_fn = r#" + #[timeout()] + #[timeout(42)] + #[timeout] + #[timeout(Duration::from_millis(20))] + fn test_fn(#[case] arg: u32) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + let errors = info.extend_with_function_attrs(&mut item_fn).unwrap_err(); + + assert_eq!(2, errors.len()); + } + + #[cfg(feature = "async-timeout")] + #[test] + fn should_parse_async_timeout() { + let mut item_fn = r#" + #[timeout(Duration::from_millis(20))] + async fn test_fn(#[case] arg: u32) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + } + + #[cfg(not(feature = "async-timeout"))] + #[test] + fn should_return_error_for_async_timeout() { + let mut item_fn = r#" + #[timeout(Duration::from_millis(20))] + async fn test_fn(#[case] arg: u32) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + let errors = info.extend_with_function_attrs(&mut item_fn).unwrap_err(); + + assert_eq!(1, errors.len()); + assert!(format!("{:?}", errors).contains("async-timeout feature")) + } + + fn parse_rstest>(rstest_data: S) -> RsTestInfo { + parse_meta(rstest_data) + } + + mod no_cases { + use super::{assert_eq, *}; + use crate::parse::{Attribute, Attributes}; + + #[test] + fn happy_path() { + let data = parse_rstest( + r#"my_fixture(42, "other"), other(vec![42]) + :: trace :: no_trace(some)"#, + ); + + let expected = RsTestInfo { + data: vec![ + fixture("my_fixture", &["42", r#""other""#]).into(), + fixture("other", &["vec![42]"]).into(), + ] + .into(), + attributes: Attributes { + attributes: vec![ + Attribute::attr("trace"), + Attribute::tagged("no_trace", vec!["some"]), + ], + } + .into(), + ..Default::default() + }; + + assert_eq!(expected, data); + } + + mod fixture_extraction { + use super::{assert_eq, *}; + + #[test] + fn rename() { + let data = parse_rstest( + r#"long_fixture_name(42, "other") as short, simple as s, no_change()"#, + ); + + let expected = RsTestInfo { + data: vec![ + fixture("short", &["42", r#""other""#]) + .with_resolve("long_fixture_name") + .into(), + fixture("s", &[]).with_resolve("simple").into(), + fixture("no_change", &[]).into(), + ] + .into(), + ..Default::default() + }; + + assert_eq!(expected, data); + } + + #[test] + fn rename_with_attributes() { + let mut item_fn = r#" + fn test_fn( + #[from(long_fixture_name)] + #[with(42, "other")] short: u32, + #[from(simple)] + s: &str, + no_change: i32) { + } + "# + .ast(); + + let expected = RsTestInfo { + data: vec![ + fixture("short", &["42", r#""other""#]) + .with_resolve("long_fixture_name") + .into(), + fixture("s", &[]).with_resolve("simple").into(), + ] + .into(), + ..Default::default() + }; + + let mut data = RsTestInfo::default(); + + data.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!(expected, data); + } + + #[test] + fn defined_via_with_attributes() { + let mut item_fn = r#" + fn test_fn(#[with(42, "other")] my_fixture: u32, #[with(vec![42])] other: &str) { + } + "# + .ast(); + + let expected = RsTestInfo { + data: vec![ + fixture("my_fixture", &["42", r#""other""#]).into(), + fixture("other", &["vec![42]"]).into(), + ] + .into(), + ..Default::default() + }; + + let mut data = RsTestInfo::default(); + + data.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!(expected, data); + } + } + + #[test] + fn empty_fixtures() { + let data = parse_rstest(r#"::trace::no_trace(some)"#); + + let expected = RsTestInfo { + attributes: Attributes { + attributes: vec![ + Attribute::attr("trace"), + Attribute::tagged("no_trace", vec!["some"]), + ], + } + .into(), + ..Default::default() + }; + + assert_eq!(expected, data); + } + + #[test] + fn empty_attributes() { + let data = parse_rstest(r#"my_fixture(42, "other")"#); + + let expected = RsTestInfo { + data: vec![fixture("my_fixture", &["42", r#""other""#]).into()].into(), + ..Default::default() + }; + + assert_eq!(expected, data); + } + + #[test] + fn extract_notrace_args_atttribute() { + let mut item_fn = r#" + fn test_fn(#[notrace] a: u32, #[something_else] b: &str, #[notrace] c: i32) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + info.attributes.add_trace(ident("trace")); + + assert!(!info.attributes.trace_me(&ident("a"))); + assert!(info.attributes.trace_me(&ident("b"))); + assert!(!info.attributes.trace_me(&ident("c"))); + let b_args = item_fn + .sig + .inputs + .into_iter() + .nth(1) + .and_then(|id| match id { + syn::FnArg::Typed(arg) => Some(arg.attrs), + _ => None, + }) + .unwrap(); + assert_eq!(attrs("#[something_else]"), b_args); + } + + #[rstest] + fn extract_future() { + let mut item_fn = "fn f(#[future] a: u32, b: u32) {}".ast(); + let expected = "fn f(a: u32, b: u32) {}".ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!(item_fn, expected); + assert!(info.arguments.is_future(&ident("a"))); + assert!(!info.arguments.is_future(&ident("b"))); + } + } + + mod parametrize_cases { + use super::{assert_eq, *}; + use std::iter::FromIterator; + + #[test] + fn one_simple_case_one_arg() { + let data = parse_rstest(r#"arg, case(42)"#).data; + + let args = data.case_args().collect::>(); + let cases = data.cases().collect::>(); + + assert_eq!(1, args.len()); + assert_eq!(1, cases.len()); + assert_eq!("arg", &args[0].to_string()); + assert_eq!(to_args!(["42"]), cases[0].args()) + } + + #[test] + fn happy_path() { + let info = parse_rstest( + r#" + my_fixture(42,"foo"), + arg1, arg2, arg3, + case(1,2,3), + case(11,12,13), + case(21,22,23) + "#, + ); + + let data = info.data; + let fixtures = data.fixtures().cloned().collect::>(); + + assert_eq!(vec![fixture("my_fixture", &["42", r#""foo""#])], fixtures); + assert_eq!( + to_strs!(vec!["arg1", "arg2", "arg3"]), + data.case_args() + .map(ToString::to_string) + .collect::>() + ); + + let cases = data.cases().collect::>(); + + assert_eq!(3, cases.len()); + assert_eq!(to_args!(["1", "2", "3"]), cases[0].args()); + assert_eq!(to_args!(["11", "12", "13"]), cases[1].args()); + assert_eq!(to_args!(["21", "22", "23"]), cases[2].args()); + } + + mod defined_via_with_attributes { + use super::{assert_eq, *}; + + #[test] + fn one_case() { + let mut item_fn = r#" + #[case::first(42, "first")] + fn test_fn(#[case] arg1: u32, #[case] arg2: &str) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + let case_args = info.data.case_args().cloned().collect::>(); + let cases = info.data.cases().cloned().collect::>(); + + assert_eq!(to_idents!(["arg1", "arg2"]), case_args); + assert_eq!( + vec![ + TestCase::from_iter(["42", r#""first""#].iter()).with_description("first"), + ], + cases + ); + } + + #[test] + fn parse_tuple_value() { + let mut item_fn = r#" + #[case(42, (24, "first"))] + fn test_fn(#[case] arg1: u32, #[case] tupled: (u32, &str)) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + let cases = info.data.cases().cloned().collect::>(); + + assert_eq!( + vec![TestCase::from_iter(["42", r#"(24, "first")"#].iter()),], + cases + ); + } + + #[test] + fn more_cases() { + let mut item_fn = r#" + #[case::first(42, "first")] + #[case(24, "second")] + #[case::third(0, "third")] + fn test_fn(#[case] arg1: u32, #[case] arg2: &str) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + let case_args = info.data.case_args().cloned().collect::>(); + let cases = info.data.cases().cloned().collect::>(); + + assert_eq!(to_idents!(["arg1", "arg2"]), case_args); + assert_eq!( + vec![ + TestCase::from_iter(["42", r#""first""#].iter()).with_description("first"), + TestCase::from_iter(["24", r#""second""#].iter()), + TestCase::from_iter(["0", r#""third""#].iter()).with_description("third"), + ], + cases + ); + } + + #[test] + fn should_collect_attributes() { + let mut item_fn = r#" + #[first] + #[first2(42)] + #[case(42)] + #[second] + #[case(24)] + #[global] + fn test_fn(#[case] arg: u32) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + let cases = info.data.cases().cloned().collect::>(); + + assert_eq!( + vec![ + TestCase::from_iter(["42"].iter()).with_attrs(attrs( + " + #[first] + #[first2(42)] + " + )), + TestCase::from_iter(["24"].iter()).with_attrs(attrs( + " + #[second] + " + )), + ], + cases + ); + } + + #[test] + fn should_consume_all_used_attributes() { + let mut item_fn = r#" + #[first] + #[first2(42)] + #[case(42)] + #[second] + #[case(24)] + #[global] + fn test_fn(#[case] arg: u32) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!( + item_fn.attrs, + attrs( + " + #[global] + " + ) + ); + assert!(!format!("{:?}", item_fn).contains("case")); + } + + #[test] + fn should_report_all_errors() { + let mut item_fn = r#" + #[case(#case_error#)] + fn test_fn(#[case] arg: u32, #[with(#fixture_error#)] err_fixture: u32) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + let errors = info.extend_with_function_attrs(&mut item_fn).unwrap_err(); + + assert_eq!(2, errors.len()); + } + } + + #[test] + fn should_accept_comma_at_the_end_of_cases() { + let data = parse_rstest( + r#" + arg, + case(42), + "#, + ) + .data; + + let args = data.case_args().collect::>(); + let cases = data.cases().collect::>(); + + assert_eq!(1, args.len()); + assert_eq!(1, cases.len()); + assert_eq!("arg", &args[0].to_string()); + assert_eq!(to_args!(["42"]), cases[0].args()) + } + + #[test] + #[should_panic] + fn should_not_accept_invalid_separator_from_args_and_cases() { + parse_rstest( + r#" + ret + case::should_success(Ok(())), + case::should_fail(Err("Return Error")) + "#, + ); + } + + #[test] + fn case_could_be_arg_name() { + let data = parse_rstest( + r#" + case, + case(42) + "#, + ) + .data; + + assert_eq!("case", &data.case_args().next().unwrap().to_string()); + + let cases = data.cases().collect::>(); + + assert_eq!(1, cases.len()); + assert_eq!(to_args!(["42"]), cases[0].args()); + } + } + + mod matrix_cases { + use crate::parse::Attribute; + + use super::{assert_eq, *}; + + #[test] + fn happy_path() { + let info = parse_rstest( + r#" + expected => [12, 34 * 2], + input => [format!("aa_{}", 2), "other"], + "#, + ); + + let value_ranges = info.data.list_values().collect::>(); + assert_eq!(2, value_ranges.len()); + assert_eq!(to_args!(["12", "34 * 2"]), value_ranges[0].args()); + assert_eq!( + to_args!([r#"format!("aa_{}", 2)"#, r#""other""#]), + value_ranges[1].args() + ); + assert_eq!(info.attributes, Default::default()); + } + + #[test] + fn should_parse_attributes_too() { + let info = parse_rstest( + r#" + a => [12, 24, 42] + ::trace + "#, + ); + + assert_eq!( + info.attributes, + Attributes { + attributes: vec![Attribute::attr("trace")] + } + .into() + ); + } + + #[test] + fn should_parse_injected_fixtures_too() { + let info = parse_rstest( + r#" + a => [12, 24, 42], + fixture_1(42, "foo"), + fixture_2("bar") + "#, + ); + + let fixtures = info.data.fixtures().cloned().collect::>(); + + assert_eq!( + vec![ + fixture("fixture_1", &["42", r#""foo""#]), + fixture("fixture_2", &[r#""bar""#]) + ], + fixtures + ); + } + + #[test] + #[should_panic(expected = "should not be empty")] + fn should_not_compile_if_empty_expression_slice() { + parse_rstest( + r#" + invalid => [] + "#, + ); + } + + mod defined_via_with_attributes { + use super::{assert_eq, *}; + + #[test] + fn one_arg() { + let mut item_fn = r#" + fn test_fn(#[values(1, 2, 1+2)] arg1: u32, #[values(format!("a"), "b b".to_owned(), String::new())] arg2: String) { + } + "# + .ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + let list_values = info.data.list_values().cloned().collect::>(); + + assert_eq!(2, list_values.len()); + assert_eq!(to_args!(["1", "2", "1+2"]), list_values[0].args()); + assert_eq!( + to_args!([r#"format!("a")"#, r#""b b".to_owned()"#, "String::new()"]), + list_values[1].args() + ); + } + } + } + + mod integrated { + use super::{assert_eq, *}; + + #[test] + fn should_parse_fixture_cases_and_matrix_in_any_order() { + let data = parse_rstest( + r#" + u, + m => [1, 2], + case(42, A{}, D{}), + a, + case(43, A{}, D{}), + the_fixture(42), + mm => ["f", "oo", "BAR"], + d + "#, + ) + .data; + + let fixtures = data.fixtures().cloned().collect::>(); + assert_eq!(vec![fixture("the_fixture", &["42"])], fixtures); + + assert_eq!( + to_strs!(vec!["u", "a", "d"]), + data.case_args() + .map(ToString::to_string) + .collect::>() + ); + + let cases = data.cases().collect::>(); + assert_eq!(2, cases.len()); + assert_eq!(to_args!(["42", "A{}", "D{}"]), cases[0].args()); + assert_eq!(to_args!(["43", "A{}", "D{}"]), cases[1].args()); + + let value_ranges = data.list_values().collect::>(); + assert_eq!(2, value_ranges.len()); + assert_eq!(to_args!(["1", "2"]), value_ranges[0].args()); + assert_eq!( + to_args!([r#""f""#, r#""oo""#, r#""BAR""#]), + value_ranges[1].args() + ); + } + } +} diff --git a/src/parse/testcase.rs b/src/parse/testcase.rs new file mode 100644 index 0000000..20efd79 --- /dev/null +++ b/src/parse/testcase.rs @@ -0,0 +1,162 @@ +use syn::{ + parse::{Error, Parse, ParseStream, Result}, + punctuated::Punctuated, + Attribute, Expr, Ident, Token, +}; + +use proc_macro2::TokenStream; +use quote::ToTokens; + +#[derive(PartialEq, Debug, Clone)] +/// A test case instance data. Contains a list of arguments. It is parsed by parametrize +/// attributes. +pub(crate) struct TestCase { + pub(crate) args: Vec, + pub(crate) attrs: Vec, + pub(crate) description: Option, +} + +impl Parse for TestCase { + fn parse(input: ParseStream) -> Result { + let attrs = Attribute::parse_outer(input)?; + let case: Ident = input.parse()?; + if case == "case" { + let mut description = None; + if input.peek(Token![::]) { + let _ = input.parse::(); + description = Some(input.parse()?); + } + let content; + let _ = syn::parenthesized!(content in input); + let args = Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(); + Ok(TestCase { + args, + attrs, + description, + }) + } else { + Err(Error::new(case.span(), "expected a test case")) + } + } +} + +impl ToTokens for TestCase { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.args.iter().for_each(|c| c.to_tokens(tokens)) + } +} + +#[cfg(test)] +mod should { + use super::*; + use crate::test::{assert_eq, *}; + + fn parse_test_case>(test_case: S) -> TestCase { + parse_meta(test_case) + } + + #[test] + fn two_literal_args() { + let test_case = parse_test_case(r#"case(42, "value")"#); + let args = test_case.args(); + + let expected = to_args!(["42", r#""value""#]); + + assert_eq!(expected, args); + } + + #[test] + fn some_literals() { + let args_expressions = literal_expressions_str(); + let test_case = parse_test_case(&format!("case({})", args_expressions.join(", "))); + let args = test_case.args(); + + assert_eq!(to_args!(args_expressions), args); + } + + #[test] + fn accept_arbitrary_rust_code() { + let test_case = parse_test_case(r#"case(vec![1,2,3])"#); + let args = test_case.args(); + + assert_eq!(to_args!(["vec![1, 2, 3]"]), args); + } + + #[test] + #[should_panic] + fn raise_error_on_invalid_rust_code() { + parse_test_case(r#"case(some:<>(1,2,3))"#); + } + + #[test] + fn get_description_if_any() { + let test_case = parse_test_case(r#"case::this_test_description(42)"#); + let args = test_case.args(); + + assert_eq!( + "this_test_description", + &test_case.description.unwrap().to_string() + ); + assert_eq!(to_args!(["42"]), args); + } + + #[test] + fn get_description_also_with_more_args() { + let test_case = parse_test_case(r#"case :: this_test_description (42, 24)"#); + let args = test_case.args(); + + assert_eq!( + "this_test_description", + &test_case.description.unwrap().to_string() + ); + assert_eq!(to_args!(["42", "24"]), args); + } + + #[test] + fn parse_arbitrary_rust_code_as_expression() { + let test_case = parse_test_case( + r##" + case(42, -42, + pippo("pluto"), + Vec::new(), + String::from(r#"prrr"#), + { + let mut sum=0; + for i in 1..3 { + sum += i; + } + sum + }, + vec![1,2,3] + )"##, + ); + + let args = test_case.args(); + + assert_eq!( + to_args!([ + "42", + "-42", + r#"pippo("pluto")"#, + "Vec::new()", + r##"String::from(r#"prrr"#)"##, + r#"{let mut sum=0;for i in 1..3 {sum += i;}sum}"#, + "vec![1,2,3]" + ]), + args + ); + } + + #[test] + fn save_attributes() { + let test_case = parse_test_case(r#"#[should_panic]#[other_attr(x)]case(42)"#); + + let content = format!("{:?}", test_case.attrs); + + assert_eq!(2, test_case.attrs.len()); + assert!(content.contains("should_panic")); + assert!(content.contains("other_attr")); + } +} diff --git a/src/parse/vlist.rs b/src/parse/vlist.rs new file mode 100644 index 0000000..43a5aab --- /dev/null +++ b/src/parse/vlist.rs @@ -0,0 +1,105 @@ +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::{ + parse::{Parse, ParseStream, Result}, + Expr, Ident, Token, +}; + +use crate::refident::RefIdent; + +use super::expressions::Expressions; + +#[derive(Debug, PartialEq, Clone)] +pub(crate) struct ValueList { + pub(crate) arg: Ident, + pub(crate) values: Vec, +} + +impl Parse for ValueList { + fn parse(input: ParseStream) -> Result { + let arg = input.parse()?; + let _to: Token![=>] = input.parse()?; + let content; + let paren = syn::bracketed!(content in input); + let values: Expressions = content.parse()?; + + let ret = Self { + arg, + values: values.take(), + }; + if ret.values.is_empty() { + Err(syn::Error::new( + paren.span, + "Values list should not be empty", + )) + } else { + Ok(ret) + } + } +} + +impl RefIdent for ValueList { + fn ident(&self) -> &Ident { + &self.arg + } +} + +impl ToTokens for ValueList { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.arg.to_tokens(tokens) + } +} + +#[cfg(test)] +mod should { + use crate::test::{assert_eq, *}; + + use super::*; + + mod parse_values_list { + use super::assert_eq; + use super::*; + + fn parse_values_list>(values_list: S) -> ValueList { + parse_meta(values_list) + } + + #[test] + fn some_literals() { + let literals = literal_expressions_str(); + let name = "argument"; + + let values_list = parse_values_list(format!( + r#"{} => [{}]"#, + name, + literals + .iter() + .map(ToString::to_string) + .collect::>() + .join(", ") + )); + + assert_eq!(name, &values_list.arg.to_string()); + assert_eq!(values_list.args(), to_args!(literals)); + } + + #[test] + fn raw_code() { + let values_list = parse_values_list(r#"no_mater => [vec![1,2,3]]"#); + + assert_eq!(values_list.args(), to_args!(["vec![1, 2, 3]"])); + } + + #[test] + #[should_panic] + fn raw_code_with_parsing_error() { + parse_values_list(r#"other => [some:<>(1,2,3)]"#); + } + + #[test] + #[should_panic(expected = r#"expected square brackets"#)] + fn forget_brackets() { + parse_values_list(r#"other => 42"#); + } + } +} diff --git a/src/refident.rs b/src/refident.rs new file mode 100644 index 0000000..14fbcd9 --- /dev/null +++ b/src/refident.rs @@ -0,0 +1,86 @@ +/// Provide `RefIdent` and `MaybeIdent` traits that give a shortcut to extract identity reference +/// (`syn::Ident` struct). +use proc_macro2::Ident; +use syn::{FnArg, Pat, PatType, Type}; + +pub trait RefIdent { + /// Return the reference to ident if any + fn ident(&self) -> &Ident; +} + +pub trait MaybeIdent { + /// Return the reference to ident if any + fn maybe_ident(&self) -> Option<&Ident>; +} + +impl MaybeIdent for I { + fn maybe_ident(&self) -> Option<&Ident> { + Some(self.ident()) + } +} + +impl RefIdent for Ident { + fn ident(&self) -> &Ident { + self + } +} + +impl<'a> RefIdent for &'a Ident { + fn ident(&self) -> &Ident { + self + } +} + +impl MaybeIdent for FnArg { + fn maybe_ident(&self) -> Option<&Ident> { + match self { + FnArg::Typed(PatType { pat, .. }) => match pat.as_ref() { + Pat::Ident(ident) => Some(&ident.ident), + _ => None, + }, + _ => None, + } + } +} + +impl MaybeIdent for Type { + fn maybe_ident(&self) -> Option<&Ident> { + match self { + Type::Path(tp) if tp.qself.is_none() => tp.path.get_ident(), + _ => None, + } + } +} + +pub trait MaybeType { + /// Return the reference to type if any + fn maybe_type(&self) -> Option<&Type>; +} + +impl MaybeType for FnArg { + fn maybe_type(&self) -> Option<&Type> { + match self { + FnArg::Typed(PatType { ty, .. }) => Some(ty.as_ref()), + _ => None, + } + } +} + +impl MaybeIdent for syn::GenericParam { + fn maybe_ident(&self) -> Option<&Ident> { + match self { + syn::GenericParam::Type(syn::TypeParam { ident, .. }) + | syn::GenericParam::Const(syn::ConstParam { ident, .. }) => Some(ident), + syn::GenericParam::Lifetime(syn::LifetimeDef { lifetime, .. }) => Some(&lifetime.ident), + } + } +} + +impl MaybeIdent for crate::parse::Attribute { + fn maybe_ident(&self) -> Option<&Ident> { + use crate::parse::Attribute::*; + match self { + Attr(ident) | Tagged(ident, _) | Type(ident, _) => Some(ident), + } + } +} diff --git a/src/render/apply_argumets.rs b/src/render/apply_argumets.rs new file mode 100644 index 0000000..b09c811 --- /dev/null +++ b/src/render/apply_argumets.rs @@ -0,0 +1,249 @@ +use quote::{format_ident, ToTokens}; +use syn::{parse_quote, FnArg, Generics, Ident, ItemFn, Lifetime, Signature, Type, TypeReference}; + +use crate::{ + parse::{arguments::ArgumentsInfo, future::MaybeFutureImplType}, + refident::MaybeIdent, +}; + +pub(crate) trait ApplyArgumets { + fn apply_argumets(&mut self, arguments: &ArgumentsInfo) -> R; +} + +impl ApplyArgumets> for FnArg { + fn apply_argumets(&mut self, arguments: &ArgumentsInfo) -> Option { + if self + .maybe_ident() + .map(|id| arguments.is_future(id)) + .unwrap_or_default() + { + self.impl_future_arg() + } else { + None + } + } +} + +fn move_generic_list(data: &mut Generics, other: Generics) { + data.lt_token = data.lt_token.or(other.lt_token); + data.params = other.params; + data.gt_token = data.gt_token.or(other.gt_token); +} + +fn extend_generics_with_lifetimes<'a, 'b>( + generics: impl Iterator, + lifetimes: impl Iterator, +) -> Generics { + let all = lifetimes + .map(|lt| lt as &dyn ToTokens) + .chain(generics.map(|gp| gp as &dyn ToTokens)); + parse_quote! { + <#(#all),*> + } +} + +impl ApplyArgumets for Signature { + fn apply_argumets(&mut self, arguments: &ArgumentsInfo) { + let new_lifetimes = self + .inputs + .iter_mut() + .filter_map(|arg| arg.apply_argumets(arguments)) + .collect::>(); + if !new_lifetimes.is_empty() || !self.generics.params.is_empty() { + let new_generics = + extend_generics_with_lifetimes(self.generics.params.iter(), new_lifetimes.iter()); + move_generic_list(&mut self.generics, new_generics); + } + } +} + +impl ApplyArgumets for ItemFn { + fn apply_argumets(&mut self, arguments: &ArgumentsInfo) { + let awaited_args = self + .sig + .inputs + .iter() + .filter_map(|a| a.maybe_ident()) + .filter(|&a| arguments.is_future_await(a)) + .cloned(); + let orig_block_impl = self.block.clone(); + self.block = parse_quote! { + { + #(let #awaited_args = #awaited_args.await;)* + #orig_block_impl + } + }; + self.sig.apply_argumets(arguments); + } +} + +pub(crate) trait ImplFutureArg { + fn impl_future_arg(&mut self) -> Option; +} + +impl ImplFutureArg for FnArg { + fn impl_future_arg(&mut self) -> Option { + let lifetime_id = self.maybe_ident().map(|id| format_ident!("_{}", id)); + match self.as_mut_future_impl_type() { + Some(ty) => { + let lifetime = lifetime_id.and_then(|id| update_type_with_lifetime(ty, id)); + *ty = parse_quote! { + impl std::future::Future + }; + lifetime + } + None => None, + } + } +} + +fn update_type_with_lifetime(ty: &mut Type, ident: Ident) -> Option { + if let Type::Reference(ty_ref @ TypeReference { lifetime: None, .. }) = ty { + let lifetime = Some(syn::Lifetime { + apostrophe: ident.span(), + ident, + }); + ty_ref.lifetime = lifetime.clone(); + lifetime + } else { + None + } +} + +#[cfg(test)] +mod should { + use super::*; + use crate::test::{assert_eq, *}; + use syn::ItemFn; + + #[rstest] + #[case("fn simple(a: u32) {}")] + #[case("fn more(a: u32, b: &str) {}")] + #[case("fn gen>(a: u32, b: S) {}")] + #[case("fn attr(#[case] a: u32, #[values(1,2)] b: i32) {}")] + fn no_change(#[case] item_fn: &str) { + let mut item_fn: ItemFn = item_fn.ast(); + let orig = item_fn.clone(); + + item_fn.sig.apply_argumets(&ArgumentsInfo::default()); + + assert_eq!(orig, item_fn) + } + + #[rstest] + #[case::simple( + "fn f(a: u32) {}", + &["a"], + "fn f(a: impl std::future::Future) {}" + )] + #[case::more_than_one( + "fn f(a: u32, b: String, c: std::collection::HashMap) {}", + &["a", "b", "c"], + r#"fn f(a: impl std::future::Future, + b: impl std::future::Future, + c: impl std::future::Future>) {}"#, + )] + #[case::just_one( + "fn f(a: u32, b: String) {}", + &["b"], + r#"fn f(a: u32, + b: impl std::future::Future) {}"# + )] + #[case::generics( + "fn f>(a: S) {}", + &["a"], + "fn f>(a: impl std::future::Future) {}" + )] + fn replace_future_basic_type( + #[case] item_fn: &str, + #[case] futures: &[&str], + #[case] expected: &str, + ) { + let mut item_fn: ItemFn = item_fn.ast(); + let expected: ItemFn = expected.ast(); + + let mut arguments = ArgumentsInfo::default(); + futures + .into_iter() + .for_each(|&f| arguments.add_future(ident(f))); + + item_fn.sig.apply_argumets(&arguments); + + assert_eq!(expected, item_fn) + } + + #[rstest] + #[case::base( + "fn f(ident_name: &u32) {}", + &["ident_name"], + "fn f<'_ident_name>(ident_name: impl std::future::Future) {}" + )] + #[case::lifetime_already_exists( + "fn f<'b>(a: &'b u32) {}", + &["a"], + "fn f<'b>(a: impl std::future::Future) {}" + )] + #[case::some_other_generics( + "fn f<'b, IT: Iterator>(a: &u32, it: IT) {}", + &["a"], + "fn f<'_a, 'b, IT: Iterator>(a: impl std::future::Future, it: IT) {}" + )] + fn replace_reference_type( + #[case] item_fn: &str, + #[case] futures: &[&str], + #[case] expected: &str, + ) { + let mut item_fn: ItemFn = item_fn.ast(); + let expected: ItemFn = expected.ast(); + + let mut arguments = ArgumentsInfo::default(); + futures + .into_iter() + .for_each(|&f| arguments.add_future(ident(f))); + + item_fn.sig.apply_argumets(&arguments); + + assert_eq!(expected, item_fn) + } + + mod await_future_args { + use rstest_test::{assert_in, assert_not_in}; + + use crate::parse::arguments::FutureArg; + + use super::*; + + #[test] + fn with_global_await() { + let mut item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut arguments: ArgumentsInfo = Default::default(); + arguments.set_global_await(true); + arguments.add_future(ident("a")); + arguments.add_future(ident("b")); + + item_fn.apply_argumets(&arguments); + + let code = item_fn.block.display_code(); + + assert_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + #[test] + fn with_selective_await() { + let mut item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut arguments: ArgumentsInfo = Default::default(); + arguments.set_future(ident("a"), FutureArg::Define); + arguments.set_future(ident("b"), FutureArg::Await); + + item_fn.apply_argumets(&arguments); + + let code = item_fn.block.display_code(); + + assert_not_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + } +} diff --git a/src/render/fixture.rs b/src/render/fixture.rs new file mode 100644 index 0000000..cdcaeec --- /dev/null +++ b/src/render/fixture.rs @@ -0,0 +1,575 @@ +use proc_macro2::{Span, TokenStream}; +use syn::{parse_quote, Ident, ItemFn, ReturnType}; + +use quote::quote; + +use super::apply_argumets::ApplyArgumets; +use super::{inject, render_exec_call}; +use crate::resolver::{self, Resolver}; +use crate::utils::{fn_args, fn_args_idents}; +use crate::{parse::fixture::FixtureInfo, utils::generics_clean_up}; + +fn wrap_return_type_as_static_ref(rt: ReturnType) -> ReturnType { + match rt { + syn::ReturnType::Type(_, t) => parse_quote! { + -> &'static #t + }, + o => o, + } +} + +fn wrap_call_impl_with_call_once_impl(call_impl: TokenStream, rt: &ReturnType) -> TokenStream { + match rt { + syn::ReturnType::Type(_, t) => parse_quote! { + static mut S: Option<#t> = None; + static CELL: std::sync::Once = std::sync::Once::new(); + CELL.call_once(|| unsafe { S = Some(#call_impl) }); + unsafe { S.as_ref().unwrap() } + }, + _ => parse_quote! { + static CELL: std::sync::Once = std::sync::Once::new(); + CELL.call_once(|| #call_impl ); + }, + } +} + +pub(crate) fn render(mut fixture: ItemFn, info: FixtureInfo) -> TokenStream { + fixture.apply_argumets(&info.arguments); + let name = &fixture.sig.ident; + let asyncness = &fixture.sig.asyncness.clone(); + let vargs = fn_args_idents(&fixture).cloned().collect::>(); + let args = &vargs; + let orig_args = &fixture.sig.inputs; + let orig_attrs = &fixture.attrs; + let generics = &fixture.sig.generics; + let mut default_output = info + .attributes + .extract_default_type() + .unwrap_or_else(|| fixture.sig.output.clone()); + let default_generics = + generics_clean_up(&fixture.sig.generics, std::iter::empty(), &default_output); + let default_where_clause = &default_generics.where_clause; + let where_clause = &fixture.sig.generics.where_clause; + let mut output = fixture.sig.output.clone(); + let visibility = &fixture.vis; + let resolver = ( + resolver::fixtures::get(info.data.fixtures()), + resolver::values::get(info.data.values()), + ); + let generics_idents = generics + .type_params() + .map(|tp| &tp.ident) + .cloned() + .collect::>(); + let inject = inject::resolve_aruments(fixture.sig.inputs.iter(), &resolver, &generics_idents); + + let partials = + (1..=orig_args.len()).map(|n| render_partial_impl(&fixture, n, &resolver, &info)); + + let call_get = render_exec_call(parse_quote! { Self::get }, args, asyncness.is_some()); + let mut call_impl = render_exec_call(parse_quote! { #name }, args, asyncness.is_some()); + + if info.attributes.is_once() { + call_impl = wrap_call_impl_with_call_once_impl(call_impl, &output); + output = wrap_return_type_as_static_ref(output); + default_output = wrap_return_type_as_static_ref(default_output); + } + + quote! { + #[allow(non_camel_case_types)] + #visibility struct #name {} + + impl #name { + #(#orig_attrs)* + #[allow(unused_mut)] + pub #asyncness fn get #generics (#orig_args) #output #where_clause { + #call_impl + } + + pub #asyncness fn default #default_generics () #default_output #default_where_clause { + #inject + #call_get + } + + #(#partials)* + } + + #[allow(dead_code)] + #fixture + } +} + +fn render_partial_impl( + fixture: &ItemFn, + n: usize, + resolver: &impl Resolver, + info: &FixtureInfo, +) -> TokenStream { + let mut output = info + .attributes + .extract_partial_type(n) + .unwrap_or_else(|| fixture.sig.output.clone()); + + if info.attributes.is_once() { + output = wrap_return_type_as_static_ref(output); + } + + let generics = generics_clean_up(&fixture.sig.generics, fn_args(fixture).take(n), &output); + let where_clause = &generics.where_clause; + let asyncness = &fixture.sig.asyncness; + + let genercs_idents = generics + .type_params() + .map(|tp| &tp.ident) + .cloned() + .collect::>(); + let inject = + inject::resolve_aruments(fixture.sig.inputs.iter().skip(n), resolver, &genercs_idents); + + let sign_args = fn_args(fixture).take(n); + let fixture_args = fn_args_idents(fixture).cloned().collect::>(); + let name = Ident::new(&format!("partial_{n}"), Span::call_site()); + + let call_get = render_exec_call( + parse_quote! { Self::get }, + &fixture_args, + asyncness.is_some(), + ); + + quote! { + #[allow(unused_mut)] + pub #asyncness fn #name #generics (#(#sign_args),*) #output #where_clause { + #inject + #call_get + } + } +} + +#[cfg(test)] +mod should { + use rstest_test::{assert_in, assert_not_in}; + use syn::{ + parse::{Parse, ParseStream}, + parse2, parse_str, ItemFn, ItemImpl, ItemStruct, Result, + }; + + use crate::parse::{ + arguments::{ArgumentsInfo, FutureArg}, + Attribute, Attributes, + }; + + use super::*; + use crate::test::{assert_eq, *}; + use rstest_reuse::*; + + #[derive(Clone)] + struct FixtureOutput { + orig: ItemFn, + fixture: ItemStruct, + core_impl: ItemImpl, + } + + impl Parse for FixtureOutput { + fn parse(input: ParseStream) -> Result { + Ok(FixtureOutput { + fixture: input.parse()?, + core_impl: input.parse()?, + orig: input.parse()?, + }) + } + } + + fn parse_fixture>(code: S) -> (ItemFn, FixtureOutput) { + let item_fn = parse_str::(code.as_ref()).unwrap(); + + let tokens = render(item_fn.clone(), Default::default()); + (item_fn, parse2(tokens).unwrap()) + } + + fn test_maintains_function_visibility(code: &str) { + let (item_fn, out) = parse_fixture(code); + + assert_eq!(item_fn.vis, out.fixture.vis); + assert_eq!(item_fn.vis, out.orig.vis); + } + + fn select_method>(impl_code: ItemImpl, name: S) -> Option { + impl_code + .items + .into_iter() + .filter_map(|ii| match ii { + syn::ImplItem::Method(f) => Some(f), + _ => None, + }) + .find(|f| f.sig.ident == name.as_ref()) + } + + #[test] + fn maintains_pub_visibility() { + test_maintains_function_visibility(r#"pub fn test() { }"#); + } + + #[test] + fn maintains_no_pub_visibility() { + test_maintains_function_visibility(r#"fn test() { }"#); + } + + #[test] + fn implement_a_get_method_with_input_fixture_signature() { + let (item_fn, out) = parse_fixture( + r#" + pub fn test, B>(mut s: String, v: &u32, a: &mut [i32], r: R) -> (u32, B, String, &str) + where B: Borrow + { } + "#, + ); + + let mut signature = select_method(out.core_impl, "get").unwrap().sig; + + signature.ident = item_fn.sig.ident.clone(); + + assert_eq!(item_fn.sig, signature); + } + + #[test] + fn return_a_static_reference_if_once_attribute() { + let item_fn = parse_str::(r#" + pub fn test, B>(mut s: String, v: &u32, a: &mut [i32], r: R) -> (u32, B, String, &str) + where B: Borrow + { } + "#).unwrap(); + let info = FixtureInfo::default().with_once(); + + let out: FixtureOutput = parse2(render(item_fn.clone(), info)).unwrap(); + + let signature = select_method(out.core_impl, "get").unwrap().sig; + + assert_eq!(signature.output, "-> &'static (u32, B, String, &str)".ast()) + } + + #[template] + #[rstest( + method => ["default", "get", "partial_1", "partial_2", "partial_3"]) + ] + #[case::async_fn(true)] + #[case::not_async_fn(false)] + fn async_fixture_cases(#[case] is_async: bool, method: &str) {} + + #[apply(async_fixture_cases)] + fn fixture_method_should_be_async_if_fixture_function_is_async( + #[case] is_async: bool, + method: &str, + ) { + let prefix = if is_async { "async" } else { "" }; + let (_, out) = parse_fixture(&format!( + r#" + pub {} fn test(mut s: String, v: &u32, a: &mut [i32]) -> u32 + where B: Borrow + {{ }} + "#, + prefix + )); + + let signature = select_method(out.core_impl, method).unwrap().sig; + + assert_eq!(is_async, signature.asyncness.is_some()); + } + + #[apply(async_fixture_cases)] + fn fixture_method_should_use_await_if_fixture_function_is_async( + #[case] is_async: bool, + method: &str, + ) { + let prefix = if is_async { "async" } else { "" }; + let (_, out) = parse_fixture(&format!( + r#" + pub {} fn test(mut s: String, v: &u32, a: &mut [i32]) -> u32 + {{ }} + "#, + prefix + )); + + let body = select_method(out.core_impl, method).unwrap().block; + let last_statment = body.stmts.last().unwrap(); + let is_await = match last_statment { + syn::Stmt::Expr(syn::Expr::Await(_)) => true, + _ => false, + }; + + assert_eq!(is_async, is_await); + } + + #[test] + fn implement_a_default_method_with_input_cleaned_fixture_signature_and_no_args() { + let (item_fn, out) = parse_fixture( + r#" + pub fn test, B, F, H: Iterator>(mut s: String, v: &u32, a: &mut [i32], r: R) -> (H, B, String, &str) + where F: ToString, + B: Borrow + + { } + "#, + ); + + let default_decl = select_method(out.core_impl, "default").unwrap().sig; + + let expected = parse_str::( + r#" + pub fn default>() -> (H, B, String, &str) + where B: Borrow + { } + "#, + ) + .unwrap(); + + assert_eq!(expected.sig.generics, default_decl.generics); + assert_eq!(item_fn.sig.output, default_decl.output); + assert!(default_decl.inputs.is_empty()); + } + + #[test] + fn use_default_return_type_if_any() { + let item_fn = parse_str::( + r#" + pub fn test, B, F, H: Iterator>() -> (H, B) + where F: ToString, + B: Borrow + { } + "#, + ) + .unwrap(); + + let tokens = render( + item_fn.clone(), + FixtureInfo { + attributes: Attributes { + attributes: vec![Attribute::Type( + parse_str("default").unwrap(), + parse_str("(impl Iterator, B)").unwrap(), + )], + } + .into(), + ..Default::default() + }, + ); + let out: FixtureOutput = parse2(tokens).unwrap(); + + let expected = parse_str::( + r#" + pub fn default() -> (impl Iterator, B) + where B: Borrow + { } + "#, + ) + .unwrap(); + + let default_decl = select_method(out.core_impl, "default").unwrap().sig; + + assert_eq!(expected.sig, default_decl); + } + + #[test] + fn implement_partial_methods() { + let (item_fn, out) = parse_fixture( + r#" + pub fn test(mut s: String, v: &u32, a: &mut [i32]) -> usize + { } + "#, + ); + + let partials = (1..=3) + .map(|n| { + select_method(out.core_impl.clone(), format!("partial_{}", n)) + .unwrap() + .sig + }) + .collect::>(); + + // All 3 methods found + + assert!(select_method(out.core_impl, "partial_4").is_none()); + + let expected_1 = parse_str::( + r#" + pub fn partial_1(mut s: String) -> usize + { } + "#, + ) + .unwrap(); + + assert_eq!(expected_1.sig, partials[0]); + for p in partials { + assert_eq!(item_fn.sig.output, p.output); + } + } + + #[rstest] + #[case::base("fn test, U: AsRef, F: ToString>(mut s: S, v: U) -> F {}", + vec![ + "fn default() -> F {}", + "fn partial_1, F: ToString>(mut s: S) -> F {}", + "fn partial_2, U: AsRef, F: ToString>(mut s: S, v: U) -> F {}", + ] + )] + #[case::associated_type("fn test(mut i: T) where T::Item: Copy {}", + vec![ + "fn default() {}", + "fn partial_1(mut i: T) where T::Item: Copy {}", + ] + )] + #[case::not_remove_const_generics("fn test(v: [u32; N]) -> [i32; N] {}", + vec![ + "fn default() -> [i32; N] {}", + "fn partial_1(v: [u32; N]) -> [i32; N] {}", + ] + )] + #[case::remove_const_generics("fn test(a: i32, v: [u32; N]) {}", + vec![ + "fn default() {}", + "fn partial_1(a:i32) {}", + "fn partial_2(a:i32, v: [u32; N]) {}", + ] + )] + + fn clean_generics(#[case] code: &str, #[case] expected: Vec<&str>) { + let (item_fn, out) = parse_fixture(code); + let n_args = item_fn.sig.inputs.iter().count(); + + let mut signatures = vec![select_method(out.core_impl.clone(), "default").unwrap().sig]; + signatures.extend((1..=n_args).map(|n| { + select_method(out.core_impl.clone(), format!("partial_{}", n)) + .unwrap() + .sig + })); + + let expected = expected + .into_iter() + .map(parse_str::) + .map(|f| f.unwrap().sig) + .collect::>(); + + assert_eq!(expected, signatures); + } + + #[test] + fn use_partial_return_type_if_any() { + let item_fn = parse_str::( + r#" + pub fn test, B, F, H: Iterator>(h: H, b: B) -> (H, B) + where F: ToString, + B: Borrow + { } + "#, + ) + .unwrap(); + + let tokens = render( + item_fn.clone(), + FixtureInfo { + attributes: Attributes { + attributes: vec![Attribute::Type( + parse_str("partial_1").unwrap(), + parse_str("(H, impl Iterator)").unwrap(), + )], + } + .into(), + ..Default::default() + }, + ); + let out: FixtureOutput = parse2(tokens).unwrap(); + + let expected = parse_str::( + r#" + pub fn partial_1>(h: H) -> (H, impl Iterator) + { } + "#, + ) + .unwrap(); + + let partial = select_method(out.core_impl, "partial_1").unwrap(); + + assert_eq!(expected.sig, partial.sig); + } + + #[test] + fn add_future_boilerplate_if_requested() { + let item_fn: ItemFn = + r#"async fn test(async_ref_u32: &u32, async_u32: u32,simple: u32) { }"#.ast(); + + let mut arguments = ArgumentsInfo::default(); + arguments.add_future(ident("async_ref_u32")); + arguments.add_future(ident("async_u32")); + + let tokens = render( + item_fn.clone(), + FixtureInfo { + arguments, + ..Default::default() + }, + ); + let out: FixtureOutput = parse2(tokens).unwrap(); + + let expected = parse_str::( + r#" + async fn get<'_async_ref_u32>( + async_ref_u32: impl std::future::Future, + async_u32: impl std::future::Future, + simple: u32 + ) + { } + "#, + ) + .unwrap(); + + let rendered = select_method(out.core_impl, "get").unwrap(); + + assert_eq!(expected.sig, rendered.sig); + } + + #[test] + fn use_global_await() { + let item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut arguments: ArgumentsInfo = Default::default(); + arguments.set_global_await(true); + arguments.add_future(ident("a")); + arguments.add_future(ident("b")); + + let tokens = render( + item_fn.clone(), + FixtureInfo { + arguments, + ..Default::default() + }, + ); + let out: FixtureOutput = parse2(tokens).unwrap(); + + let code = out.orig.display_code(); + + assert_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + #[test] + fn use_selective_await() { + let item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut arguments: ArgumentsInfo = Default::default(); + arguments.set_future(ident("a"), FutureArg::Define); + arguments.set_future(ident("b"), FutureArg::Await); + + let tokens = render( + item_fn.clone(), + FixtureInfo { + arguments, + ..Default::default() + }, + ); + let out: FixtureOutput = parse2(tokens).unwrap(); + + let code = out.orig.display_code(); + + assert_not_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } +} diff --git a/src/render/inject.rs b/src/render/inject.rs new file mode 100644 index 0000000..da72a0c --- /dev/null +++ b/src/render/inject.rs @@ -0,0 +1,205 @@ +use std::borrow::Cow; + +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse_quote, Expr, FnArg, Ident, Stmt, Type}; + +use crate::{ + refident::{MaybeIdent, MaybeType}, + resolver::Resolver, + utils::{fn_arg_mutability, IsLiteralExpression}, +}; + +pub(crate) fn resolve_aruments<'a>( + args: impl Iterator, + resolver: &impl Resolver, + generic_types: &[Ident], +) -> TokenStream { + let define_vars = args.map(|arg| ArgumentResolver::new(resolver, generic_types).resolve(arg)); + quote! { + #(#define_vars)* + } +} + +struct ArgumentResolver<'resolver, 'idents, 'f, R> +where + R: Resolver + 'resolver, +{ + resolver: &'resolver R, + generic_types_names: &'idents [Ident], + magic_conversion: &'f dyn Fn(Cow, &Type) -> Expr, +} + +impl<'resolver, 'idents, 'f, R> ArgumentResolver<'resolver, 'idents, 'f, R> +where + R: Resolver + 'resolver, +{ + fn new(resolver: &'resolver R, generic_types_names: &'idents [Ident]) -> Self { + Self { + resolver, + generic_types_names, + magic_conversion: &handling_magic_conversion_code, + } + } + + fn resolve(&self, arg: &FnArg) -> Option { + let ident = arg.maybe_ident()?; + let mutability = fn_arg_mutability(arg); + let unused_mut: Option = mutability + .as_ref() + .map(|_| parse_quote! {#[allow(unused_mut)]}); + let arg_type = arg.maybe_type()?; + let fixture_name = self.fixture_name(ident); + + let mut fixture = self + .resolver + .resolve(ident) + .or_else(|| self.resolver.resolve(&fixture_name)) + .unwrap_or_else(|| default_fixture_resolve(&fixture_name)); + + if fixture.is_literal() && self.type_can_be_get_from_literal_str(arg_type) { + fixture = Cow::Owned((self.magic_conversion)(fixture, arg_type)); + } + Some(parse_quote! { + #unused_mut + let #mutability #ident = #fixture; + }) + } + + fn fixture_name<'a>(&self, ident: &'a Ident) -> Cow<'a, Ident> { + let id_str = ident.to_string(); + if id_str.starts_with('_') && !id_str.starts_with("__") { + Cow::Owned(Ident::new(&id_str[1..], ident.span())) + } else { + Cow::Borrowed(ident) + } + } + + fn type_can_be_get_from_literal_str(&self, t: &Type) -> bool { + // Check valid type to apply magic conversion + match t { + Type::ImplTrait(_) + | Type::TraitObject(_) + | Type::Infer(_) + | Type::Group(_) + | Type::Macro(_) + | Type::Never(_) + | Type::Paren(_) + | Type::Verbatim(_) + | Type::Slice(_) => return false, + _ => {} + } + match t.maybe_ident() { + Some(id) => !self.generic_types_names.contains(id), + None => true, + } + } +} + +fn default_fixture_resolve(ident: &Ident) -> Cow { + Cow::Owned(parse_quote! { #ident::default() }) +} + +fn handling_magic_conversion_code(fixture: Cow, arg_type: &Type) -> Expr { + parse_quote! { + { + use rstest::magic_conversion::*; + (&&&Magic::<#arg_type>(std::marker::PhantomData)).magic_conversion(#fixture) + } + } +} + +#[cfg(test)] +mod should { + use super::*; + use crate::{ + test::{assert_eq, *}, + utils::fn_args, + }; + + #[rstest] + #[case::as_is("fix: String", "let fix = fix::default();")] + #[case::without_underscore("_fix: String", "let _fix = fix::default();")] + #[case::do_not_remove_inner_underscores("f_i_x: String", "let f_i_x = f_i_x::default();")] + #[case::do_not_remove_double_underscore("__fix: String", "let __fix = __fix::default();")] + #[case::preserve_mut_but_annotate_as_allow_unused_mut( + "mut fix: String", + "#[allow(unused_mut)] let mut fix = fix::default();" + )] + fn call_fixture(#[case] arg_str: &str, #[case] expected: &str) { + let arg = arg_str.ast(); + + let injected = ArgumentResolver::new(&EmptyResolver {}, &[]) + .resolve(&arg) + .unwrap(); + + assert_eq!(injected, expected.ast()); + } + + #[rstest] + #[case::as_is("fix: String", ("fix", expr("bar()")), "let fix = bar();")] + #[case::with_allow_unused_mut("mut fix: String", ("fix", expr("bar()")), "#[allow(unused_mut)] let mut fix = bar();")] + #[case::without_undescore("_fix: String", ("fix", expr("bar()")), "let _fix = bar();")] + #[case::without_remove_underscore_if_value("_orig: S", ("_orig", expr("S{}")), r#"let _orig = S{};"#)] + fn call_given_fixture( + #[case] arg_str: &str, + #[case] rule: (&str, Expr), + #[case] expected: &str, + ) { + let arg = arg_str.ast(); + let mut resolver = std::collections::HashMap::new(); + resolver.insert(rule.0.to_owned(), &rule.1); + + let injected = ArgumentResolver::new(&resolver, &[]).resolve(&arg).unwrap(); + + assert_eq!(injected, expected.ast()); + } + + fn _mock_conversion_code(fixture: Cow, arg_type: &Type) -> Expr { + parse_quote! { + #fixture as #arg_type + } + } + + #[rstest] + #[case::implement_it( + "fn test(arg: MyType){}", + 0, + r#"let arg = "value to convert" as MyType;"# + )] + #[case::discard_impl( + "fn test(arg: impl AsRef){}", + 0, + r#"let arg = "value to convert";"# + )] + #[case::discard_generic_type( + "fn test>(arg: S){}", + 0, + r#"let arg = "value to convert";"# + )] + fn handle_magic_conversion(#[case] fn_str: &str, #[case] n_arg: usize, #[case] expected: &str) { + let function = fn_str.ast(); + let arg = fn_args(&function).nth(n_arg).unwrap(); + let generics = function + .sig + .generics + .type_params() + .map(|tp| &tp.ident) + .cloned() + .collect::>(); + + let mut resolver = std::collections::HashMap::new(); + let expr = expr(r#""value to convert""#); + resolver.insert(arg.maybe_ident().unwrap().to_string(), &expr); + + let ag = ArgumentResolver { + resolver: &resolver, + generic_types_names: &generics, + magic_conversion: &_mock_conversion_code, + }; + + let injected = ag.resolve(&arg).unwrap(); + + assert_eq!(injected, expected.ast()); + } +} diff --git a/src/render/mod.rs b/src/render/mod.rs new file mode 100644 index 0000000..404efc8 --- /dev/null +++ b/src/render/mod.rs @@ -0,0 +1,434 @@ +pub(crate) mod fixture; +mod test; +mod wrapper; + +use std::collections::HashMap; + +use syn::token::Async; + +use proc_macro2::{Span, TokenStream}; +use syn::{parse_quote, Attribute, Expr, FnArg, Ident, ItemFn, Path, ReturnType, Stmt}; + +use quote::{format_ident, quote, ToTokens}; +use unicode_ident::is_xid_continue; + +use crate::utils::attr_ends_with; +use crate::{ + parse::{ + rstest::{RsTestAttributes, RsTestData, RsTestInfo}, + testcase::TestCase, + vlist::ValueList, + }, + utils::attr_is, +}; +use crate::{ + refident::MaybeIdent, + resolver::{self, Resolver}, +}; +use wrapper::WrapByModule; + +pub(crate) use fixture::render as fixture; + +use self::apply_argumets::ApplyArgumets; +pub(crate) mod apply_argumets; +pub(crate) mod inject; + +pub(crate) fn single(mut test: ItemFn, info: RsTestInfo) -> TokenStream { + test.apply_argumets(&info.arguments); + let resolver = resolver::fixtures::get(info.data.fixtures()); + let args = test.sig.inputs.iter().cloned().collect::>(); + let attrs = std::mem::take(&mut test.attrs); + let asyncness = test.sig.asyncness; + let generic_types = test + .sig + .generics + .type_params() + .map(|tp| &tp.ident) + .cloned() + .collect::>(); + + single_test_case( + &test.sig.ident, + &test.sig.ident, + &args, + &attrs, + &test.sig.output, + asyncness, + Some(&test), + resolver, + &info.attributes, + &generic_types, + ) +} + +pub(crate) fn parametrize(mut test: ItemFn, info: RsTestInfo) -> TokenStream { + let RsTestInfo { + data, + attributes, + arguments, + } = info; + test.apply_argumets(&arguments); + let resolver_fixtures = resolver::fixtures::get(data.fixtures()); + + let rendered_cases = cases_data(&data, test.sig.ident.span()) + .map(|(name, attrs, resolver)| { + TestCaseRender::new(name, attrs, (resolver, &resolver_fixtures)) + }) + .map(|case| case.render(&test, &attributes)) + .collect(); + + test_group(test, rendered_cases) +} + +impl ValueList { + fn render( + &self, + test: &ItemFn, + resolver: &dyn Resolver, + attrs: &[syn::Attribute], + attributes: &RsTestAttributes, + ) -> TokenStream { + let span = test.sig.ident.span(); + let test_cases = self + .argument_data(resolver) + .map(|(name, r)| TestCaseRender::new(Ident::new(&name, span), attrs, r)) + .map(|test_case| test_case.render(test, attributes)); + + quote! { #(#test_cases)* } + } + + fn argument_data<'a>( + &'a self, + resolver: &'a dyn Resolver, + ) -> impl Iterator)> + 'a { + let max_len = self.values.len(); + self.values.iter().enumerate().map(move |(index, expr)| { + let sanitized_expr = sanitize_ident(expr); + let name = format!( + "{}_{:0len$}_{sanitized_expr:.64}", + self.arg, + index + 1, + len = max_len.display_len() + ); + let resolver_this = (self.arg.to_string(), expr.clone()); + (name, Box::new((resolver, resolver_this))) + }) + } +} + +fn _matrix_recursive<'a>( + test: &ItemFn, + list_values: &'a [&'a ValueList], + resolver: &dyn Resolver, + attrs: &'a [syn::Attribute], + attributes: &RsTestAttributes, +) -> TokenStream { + if list_values.is_empty() { + return Default::default(); + } + let vlist = list_values[0]; + let list_values = &list_values[1..]; + + if list_values.is_empty() { + let mut attrs = attrs.to_vec(); + attrs.push(parse_quote!( + #[allow(non_snake_case)] + )); + vlist.render(test, resolver, &attrs, attributes) + } else { + let span = test.sig.ident.span(); + let modules = vlist.argument_data(resolver).map(move |(name, resolver)| { + _matrix_recursive(test, list_values, &resolver, attrs, attributes) + .wrap_by_mod(&Ident::new(&name, span)) + }); + + quote! { #( + #[allow(non_snake_case)] + #modules + )* } + } +} + +pub(crate) fn matrix(mut test: ItemFn, info: RsTestInfo) -> TokenStream { + let RsTestInfo { + data, + attributes, + arguments, + } = info; + test.apply_argumets(&arguments); + let span = test.sig.ident.span(); + + let cases = cases_data(&data, span).collect::>(); + + let resolver = resolver::fixtures::get(data.fixtures()); + let rendered_cases = if cases.is_empty() { + let list_values = data.list_values().collect::>(); + _matrix_recursive(&test, &list_values, &resolver, &[], &attributes) + } else { + cases + .into_iter() + .map(|(case_name, attrs, case_resolver)| { + let list_values = data.list_values().collect::>(); + _matrix_recursive( + &test, + &list_values, + &(case_resolver, &resolver), + attrs, + &attributes, + ) + .wrap_by_mod(&case_name) + }) + .collect() + }; + + test_group(test, rendered_cases) +} + +fn resolve_default_test_attr(is_async: bool) -> TokenStream { + if is_async { + quote! { #[async_std::test] } + } else { + quote! { #[test] } + } +} + +fn render_exec_call(fn_path: Path, args: &[Ident], is_async: bool) -> TokenStream { + if is_async { + quote! {#fn_path(#(#args),*).await} + } else { + quote! {#fn_path(#(#args),*)} + } +} + +fn render_test_call( + fn_path: Path, + args: &[Ident], + timeout: Option, + is_async: bool, +) -> TokenStream { + match (timeout, is_async) { + (Some(to_expr), true) => quote! { + use rstest::timeout::*; + execute_with_timeout_async(move || #fn_path(#(#args),*), #to_expr).await + }, + (Some(to_expr), false) => quote! { + use rstest::timeout::*; + execute_with_timeout_sync(move || #fn_path(#(#args),*), #to_expr) + }, + _ => render_exec_call(fn_path, args, is_async), + } +} + +/// Render a single test case: +/// +/// * `name` - Test case name +/// * `testfn_name` - The name of test function to call +/// * `args` - The arguments of the test function +/// * `attrs` - The expected test attributes +/// * `output` - The expected test return type +/// * `asyncness` - The `async` fn token +/// * `test_impl` - If you want embed test function (should be the one called by `testfn_name`) +/// * `resolver` - The resolver used to resolve injected values +/// * `attributes` - Test attributes to select test behaviour +/// * `generic_types` - The genrics type used in signature +/// +// Ok I need some refactoring here but now that not a real issue +#[allow(clippy::too_many_arguments)] +fn single_test_case( + name: &Ident, + testfn_name: &Ident, + args: &[FnArg], + attrs: &[Attribute], + output: &ReturnType, + asyncness: Option, + test_impl: Option<&ItemFn>, + resolver: impl Resolver, + attributes: &RsTestAttributes, + generic_types: &[Ident], +) -> TokenStream { + let (attrs, trace_me): (Vec<_>, Vec<_>) = + attrs.iter().cloned().partition(|a| !attr_is(a, "trace")); + let mut attributes = attributes.clone(); + if !trace_me.is_empty() { + attributes.add_trace(format_ident!("trace")); + } + let inject = inject::resolve_aruments(args.iter(), &resolver, generic_types); + let args = args + .iter() + .filter_map(MaybeIdent::maybe_ident) + .cloned() + .collect::>(); + let trace_args = trace_arguments(args.iter(), &attributes); + + let is_async = asyncness.is_some(); + let (attrs, timeouts): (Vec<_>, Vec<_>) = + attrs.iter().cloned().partition(|a| !attr_is(a, "timeout")); + + let timeout = timeouts + .into_iter() + .last() + .map(|attribute| attribute.parse_args::().unwrap()); + + // If no injected attribut provided use the default one + let test_attr = if attrs + .iter() + .any(|a| attr_ends_with(a, &parse_quote! {test})) + { + None + } else { + Some(resolve_default_test_attr(is_async)) + }; + let execute = render_test_call(testfn_name.clone().into(), &args, timeout, is_async); + + quote! { + #test_attr + #(#attrs)* + #asyncness fn #name() #output { + #test_impl + #inject + #trace_args + #execute + } + } +} + +fn trace_arguments<'a>( + args: impl Iterator, + attributes: &RsTestAttributes, +) -> Option { + let mut statements = args + .filter(|&arg| attributes.trace_me(arg)) + .map(|arg| { + let s: Stmt = parse_quote! { + println!("{} = {:?}", stringify!(#arg), #arg); + }; + s + }) + .peekable(); + if statements.peek().is_some() { + Some(quote! { + println!("{:-^40}", " TEST ARGUMENTS "); + #(#statements)* + println!("{:-^40}", " TEST START "); + }) + } else { + None + } +} + +struct TestCaseRender<'a> { + name: Ident, + attrs: &'a [syn::Attribute], + resolver: Box, +} + +impl<'a> TestCaseRender<'a> { + pub fn new(name: Ident, attrs: &'a [syn::Attribute], resolver: R) -> Self { + TestCaseRender { + name, + attrs, + resolver: Box::new(resolver), + } + } + + fn render(self, testfn: &ItemFn, attributes: &RsTestAttributes) -> TokenStream { + let args = testfn.sig.inputs.iter().cloned().collect::>(); + let mut attrs = testfn.attrs.clone(); + attrs.extend(self.attrs.iter().cloned()); + let asyncness = testfn.sig.asyncness; + let generic_types = testfn + .sig + .generics + .type_params() + .map(|tp| &tp.ident) + .cloned() + .collect::>(); + + single_test_case( + &self.name, + &testfn.sig.ident, + &args, + &attrs, + &testfn.sig.output, + asyncness, + None, + self.resolver, + attributes, + &generic_types, + ) + } +} + +fn test_group(mut test: ItemFn, rendered_cases: TokenStream) -> TokenStream { + let fname = &test.sig.ident; + test.attrs = vec![]; + + quote! { + #[cfg(test)] + #test + + #[cfg(test)] + mod #fname { + use super::*; + + #rendered_cases + } + } +} + +trait DisplayLen { + fn display_len(&self) -> usize; +} + +impl DisplayLen for D { + fn display_len(&self) -> usize { + format!("{self}").len() + } +} + +fn format_case_name(case: &TestCase, index: usize, display_len: usize) -> String { + let description = case + .description + .as_ref() + .map(|d| format!("_{d}")) + .unwrap_or_default(); + format!("case_{index:0display_len$}{description}") +} + +fn cases_data( + data: &RsTestData, + name_span: Span, +) -> impl Iterator)> { + let display_len = data.cases().count().display_len(); + data.cases().enumerate().map({ + move |(n, case)| { + let resolver_case = data + .case_args() + .map(|a| a.to_string()) + .zip(case.args.iter()) + .collect::>(); + ( + Ident::new(&format_case_name(case, n + 1, display_len), name_span), + case.attrs.as_slice(), + resolver_case, + ) + } + }) +} + +fn sanitize_ident(expr: &Expr) -> String { + expr.to_token_stream() + .to_string() + .chars() + .filter(|c| !c.is_whitespace()) + .map(|c| match c { + '"' | '\'' => "__".to_owned(), + ':' | '(' | ')' | '{' | '}' | '[' | ']' | ',' | '.' | '*' | '+' | '/' | '-' | '%' + | '^' | '!' | '&' | '|' => "_".to_owned(), + _ => c.to_string(), + }) + .collect::() + .chars() + .filter(|&c| is_xid_continue(c)) + .collect() +} diff --git a/src/render/test.rs b/src/render/test.rs new file mode 100644 index 0000000..cefd98a --- /dev/null +++ b/src/render/test.rs @@ -0,0 +1,1855 @@ +#![cfg(test)] + +use syn::{ + parse::{Parse, ParseStream, Result}, + parse2, parse_str, + visit::Visit, + ItemFn, ItemMod, +}; + +use super::*; +use crate::test::{assert_eq, fixture, *}; +use crate::utils::*; + +trait SetAsync { + fn set_async(&mut self, is_async: bool); +} + +impl SetAsync for ItemFn { + fn set_async(&mut self, is_async: bool) { + self.sig.asyncness = if is_async { + Some(parse_quote! { async }) + } else { + None + }; + } +} + +fn trace_argument_code_string(arg_name: &str) -> String { + let arg_name = ident(arg_name); + let statment: Stmt = parse_quote! { + println!("{} = {:?}", stringify!(#arg_name) ,#arg_name); + }; + statment.display_code() +} + +#[rstest] +#[case("1", "1")] +#[case(r#""1""#, "__1__")] +#[case(r#"Some::SomeElse"#, "Some__SomeElse")] +#[case(r#""minnie".to_owned()"#, "__minnie___to_owned__")] +#[case( + r#"vec![1 , 2, + 3]"#, + "vec__1_2_3_" +)] +#[case( + r#"some_macro!("first", {second}, [third])"#, + "some_macro____first____second___third__" +)] +#[case(r#"'x'"#, "__x__")] +#[case::ops(r#"a*b+c/d-e%f^g"#, "a_b_c_d_e_f_g")] +fn sanitaze_ident_name(#[case] expression: impl AsRef, #[case] expected: impl AsRef) { + let expression: Expr = expression.as_ref().ast(); + + assert_eq!(expected.as_ref(), sanitize_ident(&expression)); +} + +mod single_test_should { + use rstest_test::{assert_in, assert_not_in}; + + use crate::{ + parse::arguments::{ArgumentsInfo, FutureArg}, + test::{assert_eq, *}, + }; + + use super::*; + + #[test] + fn add_return_type_if_any() { + let input_fn: ItemFn = "fn function(fix: String) -> Result { Ok(42) }".ast(); + + let result: ItemFn = single(input_fn.clone(), Default::default()).ast(); + + assert_eq!(result.sig.output, input_fn.sig.output); + } + + fn extract_inner_test_function(outer: &ItemFn) -> ItemFn { + let first_stmt = outer.block.stmts.get(0).unwrap(); + + parse_quote! { + #first_stmt + } + } + + #[test] + fn include_given_function() { + let input_fn: ItemFn = r#" + pub fn test, B>(mut s: String, v: &u32, a: &mut [i32], r: R) -> (u32, B, String, &str) + where B: Borrow + { + let some = 42; + assert_eq!(42, some); + } + "#.ast(); + + let result: ItemFn = single(input_fn.clone(), Default::default()).ast(); + + let inner_fn = extract_inner_test_function(&result); + let inner_fn_impl: Stmt = inner_fn.block.stmts.last().cloned().unwrap(); + + assert_eq!(inner_fn.sig, input_fn.sig); + assert_eq!(inner_fn_impl.display_code(), input_fn.block.display_code()); + } + + #[rstest] + fn not_copy_any_attributes( + #[values( + "#[test]", + "#[very::complicated::path]", + "#[test]#[should_panic]", + "#[should_panic]#[test]", + "#[a]#[b]#[c]" + )] + attributes: &str, + ) { + let attributes = attrs(attributes); + let mut input_fn: ItemFn = r#"pub fn test(_s: String){}"#.ast(); + input_fn.attrs = attributes; + + let result: ItemFn = single(input_fn.clone(), Default::default()).ast(); + let first_stmt = result.block.stmts.get(0).unwrap(); + + let inner_fn: ItemFn = parse_quote! { + #first_stmt + }; + + assert!(inner_fn.attrs.is_empty()); + } + + #[rstest] + #[case::sync(false)] + #[case::async_fn(true)] + fn use_injected_test_attribute_to_mark_test_functions_if_any( + #[case] is_async: bool, + #[values( + "#[test]", + "#[other::test]", + "#[very::complicated::path::test]", + "#[prev]#[test]", + "#[test]#[after]", + "#[prev]#[other::test]" + )] + attributes: &str, + ) { + let attributes = attrs(attributes); + let mut input_fn: ItemFn = r#"fn test(_s: String) {} "#.ast(); + input_fn.set_async(is_async); + input_fn.attrs = attributes.clone(); + + let result: ItemFn = single(input_fn.clone(), Default::default()).ast(); + + assert_eq!(result.attrs, attributes); + } + + #[test] + fn use_global_await() { + let input_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut info: RsTestInfo = Default::default(); + info.arguments.set_global_await(true); + info.arguments.add_future(ident("a")); + info.arguments.add_future(ident("b")); + + let item_fn: ItemFn = single(input_fn.clone(), info).ast(); + + assert_in!( + item_fn.block.display_code(), + await_argument_code_string("a") + ); + assert_in!( + item_fn.block.display_code(), + await_argument_code_string("b") + ); + assert_not_in!( + item_fn.block.display_code(), + await_argument_code_string("c") + ); + } + + #[test] + fn use_selective_await() { + let input_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {} "#.ast(); + let mut info: RsTestInfo = Default::default(); + info.arguments.set_future(ident("a"), FutureArg::Define); + info.arguments.set_future(ident("b"), FutureArg::Await); + + let item_fn: ItemFn = single(input_fn.clone(), info).ast(); + + assert_not_in!( + item_fn.block.display_code(), + await_argument_code_string("a",) + ); + assert_in!( + item_fn.block.display_code(), + await_argument_code_string("b") + ); + assert_not_in!( + item_fn.block.display_code(), + await_argument_code_string("c") + ); + } + + #[test] + fn trace_arguments_values() { + let input_fn: ItemFn = r#"#[trace]fn test(s: String, a:i32) {} "#.ast(); + + let item_fn: ItemFn = single(input_fn.clone(), Default::default()).ast(); + + assert_in!( + item_fn.block.display_code(), + trace_argument_code_string("s") + ); + assert_in!( + item_fn.block.display_code(), + trace_argument_code_string("a") + ); + } + + #[test] + fn trace_not_all_arguments_values() { + let input_fn: ItemFn = + r#"#[trace] fn test(a_trace: i32, b_no_trace:i32, c_no_trace:i32, d_trace:i32) {} "# + .ast(); + + let mut attributes = RsTestAttributes::default(); + attributes.add_notraces(vec![ident("b_no_trace"), ident("c_no_trace")]); + + let item_fn: ItemFn = single( + input_fn.clone(), + RsTestInfo { + attributes, + ..Default::default() + }, + ) + .ast(); + + assert_in!( + item_fn.block.display_code(), + trace_argument_code_string("a_trace") + ); + assert_not_in!( + item_fn.block.display_code(), + trace_argument_code_string("b_no_trace") + ); + assert_not_in!( + item_fn.block.display_code(), + trace_argument_code_string("c_no_trace") + ); + assert_in!( + item_fn.block.display_code(), + trace_argument_code_string("d_trace") + ); + } + + #[rstest] + #[case::sync("", parse_quote! { #[test] })] + #[case::async_fn("async", parse_quote! { #[async_std::test] })] + fn add_default_test_attribute( + #[case] prefix: &str, + #[case] test_attribute: Attribute, + #[values( + "", + "#[no_one]", + "#[should_panic]", + "#[should_panic]#[other]", + "#[a::b::c]#[should_panic]" + )] + attributes: &str, + ) { + let attributes = attrs(attributes); + let mut input_fn: ItemFn = format!(r#"{} fn test(_s: String) {{}} "#, prefix).ast(); + input_fn.attrs = attributes.clone(); + + let result: ItemFn = single(input_fn.clone(), Default::default()).ast(); + + assert_eq!(result.attrs[0], test_attribute); + assert_eq!(&result.attrs[1..], attributes.as_slice()); + } + + #[rstest] + #[case::sync(false, false)] + #[case::async_fn(true, true)] + fn use_await_for_no_async_test_function(#[case] is_async: bool, #[case] use_await: bool) { + let mut input_fn: ItemFn = r#"fn test(_s: String) {} "#.ast(); + input_fn.set_async(is_async); + + let result: ItemFn = single(input_fn.clone(), Default::default()).ast(); + + let last_stmt = result.block.stmts.last().unwrap(); + + assert_eq!(use_await, last_stmt.is_await()); + } + #[test] + fn add_future_boilerplate_if_requested() { + let item_fn: ItemFn = r#" + async fn test(async_ref_u32: &u32, async_u32: u32,simple: u32) + { } + "# + .ast(); + + let mut arguments = ArgumentsInfo::default(); + arguments.add_future(ident("async_ref_u32")); + arguments.add_future(ident("async_u32")); + + let info = RsTestInfo { + arguments, + ..Default::default() + }; + + let result: ItemFn = single(item_fn.clone(), info).ast(); + let inner_fn = extract_inner_test_function(&result); + + let expected = parse_str::( + r#"async fn test<'_async_ref_u32>( + async_ref_u32: impl std::future::Future, + async_u32: impl std::future::Future, + simple: u32 + ) + { } + "#, + ) + .unwrap(); + + assert_eq!(inner_fn.sig, expected.sig); + } +} + +struct TestsGroup { + requested_test: ItemFn, + module: ItemMod, +} + +impl Parse for TestsGroup { + fn parse(input: ParseStream) -> Result { + Ok(Self { + requested_test: input.parse()?, + module: input.parse()?, + }) + } +} + +trait QueryAttrs { + fn has_attr(&self, attr: &syn::Path) -> bool; + fn has_attr_that_ends_with(&self, attr: &syn::PathSegment) -> bool; +} + +impl QueryAttrs for ItemFn { + fn has_attr(&self, attr: &syn::Path) -> bool { + self.attrs.iter().find(|a| &a.path == attr).is_some() + } + + fn has_attr_that_ends_with(&self, name: &syn::PathSegment) -> bool { + self.attrs + .iter() + .find(|a| attr_ends_with(a, name)) + .is_some() + } +} + +/// To extract all test functions +struct TestFunctions(Vec); + +fn is_test_fn(item_fn: &ItemFn) -> bool { + item_fn.has_attr_that_ends_with(&parse_quote! { test }) +} + +impl TestFunctions { + fn is_test_fn(item_fn: &ItemFn) -> bool { + is_test_fn(item_fn) + } +} + +impl<'ast> Visit<'ast> for TestFunctions { + //noinspection RsTypeCheck + fn visit_item_fn(&mut self, item_fn: &'ast ItemFn) { + if Self::is_test_fn(item_fn) { + self.0.push(item_fn.clone()) + } + } +} + +trait Named { + fn name(&self) -> String; +} + +impl Named for Ident { + fn name(&self) -> String { + self.to_string() + } +} + +impl Named for ItemFn { + fn name(&self) -> String { + self.sig.ident.name() + } +} + +impl Named for ItemMod { + fn name(&self) -> String { + self.ident.name() + } +} + +trait Names { + fn names(&self) -> Vec; +} + +impl Names for Vec { + fn names(&self) -> Vec { + self.iter().map(Named::name).collect() + } +} + +trait ModuleInspector { + fn get_all_tests(&self) -> Vec; + fn get_tests(&self) -> Vec; + fn get_modules(&self) -> Vec; +} + +impl ModuleInspector for ItemMod { + fn get_tests(&self) -> Vec { + self.content + .as_ref() + .map(|(_, items)| { + items + .iter() + .filter_map(|it| match it { + syn::Item::Fn(item_fn) if is_test_fn(item_fn) => Some(item_fn.clone()), + _ => None, + }) + .collect() + }) + .unwrap_or_default() + } + + fn get_all_tests(&self) -> Vec { + let mut f = TestFunctions(vec![]); + f.visit_item_mod(&self); + f.0 + } + + fn get_modules(&self) -> Vec { + self.content + .as_ref() + .map(|(_, items)| { + items + .iter() + .filter_map(|it| match it { + syn::Item::Mod(item_mod) => Some(item_mod.clone()), + _ => None, + }) + .collect() + }) + .unwrap_or_default() + } +} + +impl ModuleInspector for TestsGroup { + fn get_all_tests(&self) -> Vec { + self.module.get_all_tests() + } + + fn get_tests(&self) -> Vec { + self.module.get_tests() + } + + fn get_modules(&self) -> Vec { + self.module.get_modules() + } +} + +#[derive(Default, Debug)] +struct Assignments(HashMap); + +impl<'ast> Visit<'ast> for Assignments { + //noinspection RsTypeCheck + fn visit_local(&mut self, assign: &syn::Local) { + match &assign { + syn::Local { + pat: syn::Pat::Ident(pat), + init: Some((_, expr)), + .. + } => { + self.0.insert(pat.ident.to_string(), expr.as_ref().clone()); + } + _ => {} + } + } +} + +impl Assignments { + pub fn collect_assignments(item_fn: &ItemFn) -> Self { + let mut collect = Self::default(); + collect.visit_item_fn(item_fn); + collect + } +} + +impl From for TestsGroup { + fn from(tokens: TokenStream) -> Self { + syn::parse2::(tokens).unwrap() + } +} + +mod cases_should { + use std::iter::FromIterator; + + use rstest_test::{assert_in, assert_not_in}; + + use crate::parse::{ + arguments::{ArgumentsInfo, FutureArg}, + rstest::{RsTestData, RsTestInfo, RsTestItem}, + testcase::TestCase, + }; + + use super::{assert_eq, *}; + + fn into_rstest_data(item_fn: &ItemFn) -> RsTestData { + RsTestData { + items: fn_args_idents(item_fn) + .cloned() + .map(RsTestItem::CaseArgName) + .collect(), + } + } + + struct TestCaseBuilder { + item_fn: ItemFn, + info: RsTestInfo, + } + + impl TestCaseBuilder { + fn new(item_fn: ItemFn) -> Self { + let info: RsTestInfo = into_rstest_data(&item_fn).into(); + Self { item_fn, info } + } + + fn from>(s: S) -> Self { + Self::new(s.as_ref().ast()) + } + + fn set_async(mut self, is_async: bool) -> Self { + self.item_fn.set_async(is_async); + self + } + + fn push_case>(mut self, case: T) -> Self { + self.info.push_case(case.into()); + self + } + + fn extend>(mut self, cases: impl Iterator) -> Self { + self.info.extend(cases.map(Into::into)); + self + } + + fn take(self) -> (ItemFn, RsTestInfo) { + (self.item_fn, self.info) + } + + fn add_notrace(mut self, idents: Vec) -> Self { + self.info.attributes.add_notraces(idents); + self + } + } + + fn one_simple_case() -> (ItemFn, RsTestInfo) { + TestCaseBuilder::from(r#"fn test(mut fix: String) { println!("user code") }"#) + .push_case(r#"String::from("3")"#) + .take() + } + + fn some_simple_cases(cases: i32) -> (ItemFn, RsTestInfo) { + TestCaseBuilder::from(r#"fn test(mut fix: String) { println!("user code") }"#) + .extend((0..cases).map(|_| r#"String::from("3")"#)) + .take() + } + + #[test] + fn create_a_module_named_as_test_function() { + let (item_fn, info) = + TestCaseBuilder::from("fn should_be_the_module_name(mut fix: String) {}").take(); + + let tokens = parametrize(item_fn, info); + + let output = TestsGroup::from(tokens); + + assert_eq!(output.module.ident, "should_be_the_module_name"); + } + + #[test] + fn copy_user_function() { + let (item_fn, info) = TestCaseBuilder::from( + r#"fn should_be_the_module_name(mut fix: String) { println!("user code") }"#, + ) + .take(); + + let tokens = parametrize(item_fn.clone(), info); + + let mut output = TestsGroup::from(tokens); + let test_impl: Stmt = output.requested_test.block.stmts.last().cloned().unwrap(); + + output.requested_test.attrs = vec![]; + assert_eq!(output.requested_test.sig, item_fn.sig); + assert_eq!(test_impl.display_code(), item_fn.block.display_code()); + } + + #[test] + fn should_not_copy_should_panic_attribute() { + let (item_fn, info) = TestCaseBuilder::from( + r#"#[should_panic] fn with_should_panic(mut fix: String) { println!("user code") }"#, + ) + .take(); + + let tokens = parametrize(item_fn.clone(), info); + + let output = TestsGroup::from(tokens); + + assert!(!format!("{:?}", output.requested_test.attrs).contains("should_panic")); + } + + #[test] + fn should_mark_test_with_given_attributes() { + let (item_fn, info) = + TestCaseBuilder::from(r#"#[should_panic] #[other(value)] fn test(s: String){}"#) + .push_case(r#"String::from("3")"#) + .take(); + + let tokens = parametrize(item_fn.clone(), info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + // Sanity check + assert!(tests.len() > 0); + + for t in tests { + assert_eq!(item_fn.attrs, &t.attrs[1..]); + } + } + + #[rstest] + #[case::empty("")] + #[case::some_attrs("#[a]#[b::c]#[should_panic]")] + fn should_add_attributes_given_in_the_test_case( + #[case] fnattrs: &str, + #[values("", "#[should_panic]", "#[first]#[second(arg)]")] case_attrs: &str, + ) { + let given_attrs = attrs(fnattrs); + let case_attrs = attrs(case_attrs); + let (mut item_fn, info) = TestCaseBuilder::from(r#"fn test(v: i32){}"#) + .push_case(TestCase::from("42").with_attrs(case_attrs.clone())) + .take(); + + item_fn.attrs = given_attrs.clone(); + + let tokens = parametrize(item_fn, info); + + let test_attrs = &TestsGroup::from(tokens).get_all_tests()[0].attrs[1..]; + + let l = given_attrs.len(); + + assert_eq!(case_attrs.as_slice(), &test_attrs[l..]); + assert_eq!(given_attrs.as_slice(), &test_attrs[..l]); + } + + #[test] + fn mark_user_function_as_test() { + let (item_fn, info) = TestCaseBuilder::from( + r#"fn should_be_the_module_name(mut fix: String) { println!("user code") }"#, + ) + .take(); + let tokens = parametrize(item_fn, info); + + let output = TestsGroup::from(tokens); + + assert_eq!( + output.requested_test.attrs, + vec![parse_quote! {#[cfg(test)]}] + ); + } + + #[test] + fn mark_module_as_test() { + let (item_fn, info) = TestCaseBuilder::from( + r#"fn should_be_the_module_name(mut fix: String) { println!("user code") }"#, + ) + .take(); + let tokens = parametrize(item_fn, info); + + let output = TestsGroup::from(tokens); + + assert_eq!(output.module.attrs, vec![parse_quote! {#[cfg(test)]}]); + } + + #[test] + fn add_a_test_case() { + let (item_fn, info) = one_simple_case(); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert_eq!(1, tests.len()); + assert!(&tests[0].sig.ident.to_string().starts_with("case_")) + } + + #[test] + fn add_return_type_if_any() { + let (item_fn, info) = + TestCaseBuilder::from("fn function(fix: String) -> Result { Ok(42) }") + .push_case(r#"String::from("3")"#) + .take(); + + let tokens = parametrize(item_fn.clone(), info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert_eq!(tests[0].sig.output, item_fn.sig.output); + } + + #[test] + fn not_copy_user_function() { + let t_name = "test_name"; + let (item_fn, info) = TestCaseBuilder::from(format!( + "fn {}(fix: String) -> Result {{ Ok(42) }}", + t_name + )) + .push_case(r#"String::from("3")"#) + .take(); + + let tokens = parametrize(item_fn, info); + + let test = &TestsGroup::from(tokens).get_all_tests()[0]; + let inner_functions = extract_inner_functions(&test.block); + + assert_eq!(0, inner_functions.filter(|f| f.sig.ident == t_name).count()); + } + + #[test] + fn starts_case_number_from_1() { + let (item_fn, info) = one_simple_case(); + + let tokens = parametrize(item_fn.clone(), info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert!( + &tests[0].sig.ident.to_string().starts_with("case_1"), + "Should starts with case_1 but is {}", + tests[0].sig.ident.to_string() + ) + } + + #[test] + fn add_all_test_cases() { + let (item_fn, info) = some_simple_cases(5); + + let tokens = parametrize(item_fn.clone(), info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + let valid_names = tests + .iter() + .filter(|it| it.sig.ident.to_string().starts_with("case_")); + assert_eq!(5, valid_names.count()) + } + + #[test] + fn left_pad_case_number_by_zeros() { + let (item_fn, info) = some_simple_cases(1000); + + let tokens = parametrize(item_fn.clone(), info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + let first_name = tests[0].sig.ident.to_string(); + let last_name = tests[999].sig.ident.to_string(); + + assert!( + first_name.ends_with("_0001"), + "Should ends by _0001 but is {}", + first_name + ); + assert!( + last_name.ends_with("_1000"), + "Should ends by _1000 but is {}", + last_name + ); + + let valid_names = tests + .iter() + .filter(|it| it.sig.ident.to_string().len() == first_name.len()); + assert_eq!(1000, valid_names.count()) + } + + #[test] + fn use_description_if_any() { + let (item_fn, mut info) = one_simple_case(); + let description = "show_this_description"; + + if let &mut RsTestItem::TestCase(ref mut case) = &mut info.data.items[1] { + case.description = Some(parse_str::(description).unwrap()); + } else { + panic!("Test case should be the second one"); + } + + let tokens = parametrize(item_fn.clone(), info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert!(tests[0] + .sig + .ident + .to_string() + .ends_with(&format!("_{}", description))); + } + + #[rstest] + #[case::sync(false)] + #[case::async_fn(true)] + fn use_injected_test_attribute_to_mark_test_functions_if_any( + #[case] is_async: bool, + #[values( + "#[test]", + "#[other::test]", + "#[very::complicated::path::test]", + "#[prev]#[test]", + "#[test]#[after]", + "#[prev]#[other::test]" + )] + attributes: &str, + ) { + let attributes = attrs(attributes); + let (mut item_fn, info) = TestCaseBuilder::from(r#"fn test(s: String){}"#) + .push_case(r#"String::from("3")"#) + .set_async(is_async) + .take(); + item_fn.attrs = attributes.clone(); + item_fn.set_async(is_async); + + let tokens = parametrize(item_fn.clone(), info); + + let test = &TestsGroup::from(tokens).get_all_tests()[0]; + + assert_eq!(attributes, test.attrs); + } + + #[rstest] + #[case::sync(false, parse_quote! { #[test] })] + #[case::async_fn(true, parse_quote! { #[async_std::test] })] + fn add_default_test_attribute( + #[case] is_async: bool, + #[case] test_attribute: Attribute, + #[values( + "", + "#[no_one]", + "#[should_panic]", + "#[should_panic]#[other]", + "#[a::b::c]#[should_panic]" + )] + attributes: &str, + ) { + let attributes = attrs(attributes); + let (mut item_fn, info) = TestCaseBuilder::from( + r#"fn should_be_the_module_name(mut fix: String) { println!("user code") }"#, + ) + .push_case("42") + .set_async(is_async) + .take(); + item_fn.attrs = attributes.clone(); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert_eq!(tests[0].attrs[0], test_attribute); + assert_eq!(&tests[0].attrs[1..], attributes.as_slice()); + } + + #[test] + fn add_future_boilerplate_if_requested() { + let (item_fn, mut info) = TestCaseBuilder::from( + r#"async fn test(async_ref_u32: &u32, async_u32: u32,simple: u32) { }"#, + ) + .take(); + + let mut arguments = ArgumentsInfo::default(); + arguments.add_future(ident("async_ref_u32")); + arguments.add_future(ident("async_u32")); + + info.arguments = arguments; + + let tokens = parametrize(item_fn.clone(), info); + let test_function = TestsGroup::from(tokens).requested_test; + + let expected = parse_str::( + r#"async fn test<'_async_ref_u32>( + async_ref_u32: impl std::future::Future, + async_u32: impl std::future::Future, + simple: u32 + ) + { } + "#, + ) + .unwrap(); + + assert_eq!(test_function.sig, expected.sig); + } + + #[rstest] + #[case::sync(false, false)] + #[case::async_fn(true, true)] + fn use_await_for_async_test_function(#[case] is_async: bool, #[case] use_await: bool) { + let (item_fn, info) = + TestCaseBuilder::from(r#"fn test(mut fix: String) { println!("user code") }"#) + .set_async(is_async) + .push_case(r#"String::from("3")"#) + .take(); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + let last_stmt = tests[0].block.stmts.last().unwrap(); + + assert_eq!(use_await, last_stmt.is_await()); + } + + #[test] + fn trace_arguments_value() { + let (item_fn, info) = + TestCaseBuilder::from(r#"#[trace] fn test(a_trace_me: i32, b_trace_me: i32) {}"#) + .push_case(TestCase::from_iter(vec!["1", "2"])) + .push_case(TestCase::from_iter(vec!["3", "4"])) + .take(); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert!(tests.len() > 0); + for test in tests { + for name in &["a_trace_me", "b_trace_me"] { + assert_in!(test.block.display_code(), trace_argument_code_string(name)); + } + } + } + + #[test] + fn trace_just_some_arguments_value() { + let (item_fn, info) = + TestCaseBuilder::from(r#"#[trace] fn test(a_trace_me: i32, b_no_trace_me: i32, c_no_trace_me: i32, d_trace_me: i32) {}"#) + .push_case(TestCase::from_iter(vec!["1", "2", "1", "2"])) + .push_case(TestCase::from_iter(vec!["3", "4", "3", "4"])) + .add_notrace(to_idents!(["b_no_trace_me", "c_no_trace_me"])) + .take(); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert!(tests.len() > 0); + for test in tests { + for should_be_present in &["a_trace_me", "d_trace_me"] { + assert_in!( + test.block.display_code(), + trace_argument_code_string(should_be_present) + ); + } + for should_not_be_present in &["b_trace_me", "c_trace_me"] { + assert_not_in!( + test.block.display_code(), + trace_argument_code_string(should_not_be_present) + ); + } + } + } + + #[test] + fn trace_just_one_case() { + let (item_fn, info) = + TestCaseBuilder::from(r#"fn test(a_no_trace_me: i32, b_trace_me: i32) {}"#) + .push_case(TestCase::from_iter(vec!["1", "2"])) + .push_case(TestCase::from_iter(vec!["3", "4"]).with_attrs(attrs("#[trace]"))) + .add_notrace(to_idents!(["a_no_trace_me"])) + .take(); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert_not_in!( + tests[0].block.display_code(), + trace_argument_code_string("b_trace_me") + ); + assert_in!( + tests[1].block.display_code(), + trace_argument_code_string("b_trace_me") + ); + assert_not_in!( + tests[1].block.display_code(), + trace_argument_code_string("a_no_trace_me") + ); + } + + #[test] + fn use_global_await() { + let (item_fn, mut info) = TestCaseBuilder::from(r#"fn test(a: i32, b:i32, c:i32) {}"#) + .push_case(TestCase::from_iter(vec!["1", "2", "3"])) + .push_case(TestCase::from_iter(vec!["1", "2", "3"])) + .take(); + info.arguments.set_global_await(true); + info.arguments.add_future(ident("a")); + info.arguments.add_future(ident("b")); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens); + + let code = tests.requested_test.block.display_code(); + + assert_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + #[test] + fn use_selective_await() { + let (item_fn, mut info) = TestCaseBuilder::from(r#"fn test(a: i32, b:i32, c:i32) {}"#) + .push_case(TestCase::from_iter(vec!["1", "2", "3"])) + .push_case(TestCase::from_iter(vec!["1", "2", "3"])) + .take(); + info.arguments.set_future(ident("a"), FutureArg::Define); + info.arguments.set_future(ident("b"), FutureArg::Await); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens); + + let code = tests.requested_test.block.display_code(); + + assert_not_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } +} + +mod matrix_cases_should { + use rstest_test::{assert_in, assert_not_in}; + + use crate::parse::{ + arguments::{ArgumentsInfo, FutureArg}, + vlist::ValueList, + }; + + /// Should test matrix tests render without take in account MatrixInfo to RsTestInfo + /// transformation + use super::{assert_eq, *}; + + fn into_rstest_data(item_fn: &ItemFn) -> RsTestData { + RsTestData { + items: fn_args_idents(item_fn) + .cloned() + .map(|it| { + ValueList { + arg: it, + values: vec![], + } + .into() + }) + .collect(), + } + } + + #[test] + fn create_a_module_named_as_test_function() { + let item_fn = "fn should_be_the_module_name(mut fix: String) {}".ast(); + let data = into_rstest_data(&item_fn); + + let tokens = matrix(item_fn.clone(), data.into()); + + let output = TestsGroup::from(tokens); + + assert_eq!(output.module.ident, "should_be_the_module_name"); + } + + #[test] + fn copy_user_function() { + let item_fn = + r#"fn should_be_the_module_name(mut fix: String) { println!("user code") }"#.ast(); + let data = into_rstest_data(&item_fn); + + let tokens = matrix(item_fn.clone(), data.into()); + + let mut output = TestsGroup::from(tokens); + let test_impl: Stmt = output.requested_test.block.stmts.last().cloned().unwrap(); + + output.requested_test.attrs = vec![]; + assert_eq!(output.requested_test.sig, item_fn.sig); + assert_eq!(test_impl.display_code(), item_fn.block.display_code()); + } + + #[test] + fn not_copy_user_function() { + let t_name = "test_name"; + let item_fn: ItemFn = format!( + "fn {}(fix: String) -> Result {{ Ok(42) }}", + t_name + ) + .ast(); + let info = RsTestInfo { + data: RsTestData { + items: vec![values_list("fix", &["1"]).into()].into(), + }, + ..Default::default() + }; + + let tokens = matrix(item_fn, info); + + let test = &TestsGroup::from(tokens).get_all_tests()[0]; + let inner_functions = extract_inner_functions(&test.block); + + assert_eq!(0, inner_functions.filter(|f| f.sig.ident == t_name).count()); + } + + #[test] + fn not_copy_should_panic_attribute() { + let item_fn = + r#"#[should_panic] fn with_should_panic(mut fix: String) { println!("user code") }"# + .ast(); + let info = RsTestInfo { + data: RsTestData { + items: vec![values_list("fix", &["1"]).into()].into(), + }, + ..Default::default() + }; + + let tokens = matrix(item_fn, info); + + let output = TestsGroup::from(tokens); + + assert!(!format!("{:?}", output.requested_test.attrs).contains("should_panic")); + } + + #[test] + fn should_mark_test_with_given_attributes() { + let item_fn: ItemFn = r#"#[should_panic] #[other(value)] fn test(_s: String){}"#.ast(); + + let info = RsTestInfo { + data: RsTestData { + items: vec![values_list("fix", &["1"]).into()].into(), + }, + ..Default::default() + }; + let tokens = matrix(item_fn.clone(), info); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + // Sanity check + assert!(tests.len() > 0); + + for t in tests { + let end = t.attrs.len() - 1; + assert_eq!(item_fn.attrs, &t.attrs[1..end]); + } + } + + #[test] + fn add_return_type_if_any() { + let item_fn: ItemFn = "fn function(fix: String) -> Result { Ok(42) }".ast(); + let info = RsTestInfo { + data: RsTestData { + items: vec![values_list("fix", &["1", "2", "3"]).into()].into(), + }, + ..Default::default() + }; + + let tokens = matrix(item_fn.clone(), info); + + let tests = TestsGroup::from(tokens).get_tests(); + + assert_eq!(tests[0].sig.output, item_fn.sig.output); + assert_eq!(tests[1].sig.output, item_fn.sig.output); + assert_eq!(tests[2].sig.output, item_fn.sig.output); + } + + #[test] + fn mark_user_function_as_test() { + let item_fn = + r#"fn should_be_the_module_name(mut fix: String) { println!("user code") }"#.ast(); + let data = into_rstest_data(&item_fn); + + let tokens = matrix(item_fn.clone(), data.into()); + + let output = TestsGroup::from(tokens); + + let expected = parse2::(quote! { + #[cfg(test)] + fn some() {} + }) + .unwrap() + .attrs; + + assert_eq!(expected, output.requested_test.attrs); + } + + #[test] + fn mark_module_as_test() { + let item_fn = + r#"fn should_be_the_module_name(mut fix: String) { println!("user code") }"#.ast(); + let data = into_rstest_data(&item_fn); + + let tokens = matrix(item_fn.clone(), data.into()); + + let output = TestsGroup::from(tokens); + + let expected = parse2::(quote! { + #[cfg(test)] + mod some {} + }) + .unwrap() + .attrs; + + assert_eq!(expected, output.module.attrs); + } + + #[test] + fn with_just_one_arg() { + let arg_name = "fix"; + let info = RsTestInfo { + data: RsTestData { + items: vec![values_list(arg_name, &["1", "2", "3"]).into()].into(), + }, + ..Default::default() + }; + + let item_fn = format!(r#"fn test({}: u32) {{ println!("user code") }}"#, arg_name).ast(); + + let tokens = matrix(item_fn, info); + + let tests = TestsGroup::from(tokens).get_tests(); + + assert_eq!(3, tests.len()); + assert!(&tests[0].sig.ident.to_string().starts_with("fix_")) + } + + #[rstest] + #[case::sync(false)] + #[case::async_fn(true)] + fn use_injected_test_attribute_to_mark_test_functions_if_any( + #[case] is_async: bool, + #[values( + "#[test]", + "#[other::test]", + "#[very::complicated::path::test]", + "#[prev]#[test]", + "#[test]#[after]", + "#[prev]#[other::test]" + )] + attributes: &str, + ) { + let attributes = attrs(attributes); + let filter = attrs("#[allow(non_snake_case)]"); + let data = RsTestData { + items: vec![values_list("v", &["1", "2", "3"]).into()].into(), + }; + let mut item_fn: ItemFn = r#"fn test(v: u32) {{ println!("user code") }}"#.ast(); + item_fn.set_async(is_async); + item_fn.attrs = attributes.clone(); + + let tokens = matrix(item_fn, data.into()); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + // Sanity check + assert!(tests.len() > 0); + + for test in tests { + let filterd: Vec<_> = test + .attrs + .into_iter() + .filter(|a| !filter.contains(a)) + .collect(); + assert_eq!(attributes, filterd); + } + } + + #[rstest] + #[case::sync(false, parse_quote! { #[test] })] + #[case::async_fn(true, parse_quote! { #[async_std::test] })] + fn add_default_test_attribute( + #[case] is_async: bool, + #[case] test_attribute: Attribute, + #[values( + "", + "#[no_one]", + "#[should_panic]", + "#[should_panic]#[other]", + "#[a::b::c]#[should_panic]" + )] + attributes: &str, + ) { + let attributes = attrs(attributes); + let data = RsTestData { + items: vec![values_list("v", &["1", "2", "3"]).into()].into(), + }; + + let mut item_fn: ItemFn = r#"fn test(v: u32) {{ println!("user code") }}"#.ast(); + item_fn.set_async(is_async); + item_fn.attrs = attributes.clone(); + + let tokens = matrix(item_fn, data.into()); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + // Sanity check + assert!(tests.len() > 0); + + for test in tests { + assert_eq!(test.attrs[0], test_attribute); + assert_eq!(&test.attrs[1..test.attrs.len() - 1], attributes.as_slice()); + } + } + + #[test] + fn add_future_boilerplate_if_requested() { + let item_fn = r#"async fn test(async_ref_u32: &u32, async_u32: u32,simple: u32) { }"#.ast(); + + let mut arguments = ArgumentsInfo::default(); + arguments.add_future(ident("async_ref_u32")); + arguments.add_future(ident("async_u32")); + + let info = RsTestInfo { + arguments, + ..Default::default() + }; + + let tokens = matrix(item_fn, info); + + let test_function = TestsGroup::from(tokens).requested_test; + + let expected = parse_str::( + r#"async fn test<'_async_ref_u32>( + async_ref_u32: impl std::future::Future, + async_u32: impl std::future::Future, + simple: u32 + ) + { } + "#, + ) + .unwrap(); + + assert_eq!(test_function.sig, expected.sig); + } + + #[rstest] + fn add_allow_non_snake_case( + #[values( + "", + "#[no_one]", + "#[should_panic]", + "#[should_panic]#[other]", + "#[a::b::c]#[should_panic]" + )] + attributes: &str, + ) { + let attributes = attrs(attributes); + let non_snake_case = &attrs("#[allow(non_snake_case)]")[0]; + let data = RsTestData { + items: vec![values_list("v", &["1", "2", "3"]).into()].into(), + }; + + let mut item_fn: ItemFn = r#"fn test(v: u32) {{ println!("user code") }}"#.ast(); + item_fn.attrs = attributes.clone(); + + let tokens = matrix(item_fn, data.into()); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + // Sanity check + assert!(tests.len() > 0); + + for test in tests { + assert_eq!(test.attrs.last().unwrap(), non_snake_case); + assert_eq!(&test.attrs[1..test.attrs.len() - 1], attributes.as_slice()); + } + } + + #[rstest] + #[case::sync(false, false)] + #[case::async_fn(true, true)] + fn use_await_for_async_test_function(#[case] is_async: bool, #[case] use_await: bool) { + let data = RsTestData { + items: vec![values_list("v", &["1", "2", "3"]).into()].into(), + }; + + let mut item_fn: ItemFn = r#"fn test(v: u32) {{ println!("user code") }}"#.ast(); + item_fn.set_async(is_async); + + let tokens = matrix(item_fn, data.into()); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + // Sanity check + assert!(tests.len() > 0); + + for test in tests { + let last_stmt = test.block.stmts.last().unwrap(); + assert_eq!(use_await, last_stmt.is_await()); + } + } + + #[test] + fn trace_arguments_value() { + let data = RsTestData { + items: vec![ + values_list("a_trace_me", &["1", "2"]).into(), + values_list("b_trace_me", &["3", "4"]).into(), + ] + .into(), + }; + let item_fn: ItemFn = r#"#[trace] fn test(a_trace_me: u32, b_trace_me: u32) {}"#.ast(); + + let tokens = matrix(item_fn, data.into()); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert!(tests.len() > 0); + for test in tests { + for name in &["a_trace_me", "b_trace_me"] { + assert_in!(test.block.display_code(), trace_argument_code_string(name)); + } + } + } + + #[test] + fn trace_just_some_arguments_value() { + let data = RsTestData { + items: vec![ + values_list("a_trace_me", &["1", "2"]).into(), + values_list("b_no_trace_me", &["3", "4"]).into(), + values_list("c_no_trace_me", &["5", "6"]).into(), + values_list("d_trace_me", &["7", "8"]).into(), + ] + .into(), + }; + let mut attributes: RsTestAttributes = Default::default(); + attributes.add_notraces(vec![ident("b_no_trace_me"), ident("c_no_trace_me")]); + let item_fn: ItemFn = r#"#[trace] fn test(a_trace_me: u32, b_no_trace_me: u32, c_no_trace_me: u32, d_trace_me: u32) {}"#.ast(); + + let tokens = matrix( + item_fn, + RsTestInfo { + data, + attributes, + ..Default::default() + }, + ); + + let tests = TestsGroup::from(tokens).get_all_tests(); + + assert!(tests.len() > 0); + for test in tests { + for should_be_present in &["a_trace_me", "d_trace_me"] { + assert_in!( + test.block.display_code(), + trace_argument_code_string(should_be_present) + ); + } + for should_not_be_present in &["b_no_trace_me", "c_no_trace_me"] { + assert_not_in!( + test.block.display_code(), + trace_argument_code_string(should_not_be_present) + ); + } + } + } + + #[test] + fn use_global_await() { + let item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {}"#.ast(); + let data = RsTestData { + items: vec![ + values_list("a", &["1"]).into(), + values_list("b", &["2"]).into(), + values_list("c", &["3"]).into(), + ] + .into(), + }; + let mut info = RsTestInfo { + data, + attributes: Default::default(), + arguments: Default::default(), + }; + info.arguments.set_global_await(true); + info.arguments.add_future(ident("a")); + info.arguments.add_future(ident("b")); + + let tokens = matrix(item_fn, info); + + let tests = TestsGroup::from(tokens); + + let code = tests.requested_test.block.display_code(); + + assert_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + #[test] + fn use_selective_await() { + let item_fn: ItemFn = r#"fn test(a: i32, b:i32, c:i32) {}"#.ast(); + let data = RsTestData { + items: vec![ + values_list("a", &["1"]).into(), + values_list("b", &["2"]).into(), + values_list("c", &["3"]).into(), + ] + .into(), + }; + let mut info = RsTestInfo { + data, + attributes: Default::default(), + arguments: Default::default(), + }; + + info.arguments.set_future(ident("a"), FutureArg::Define); + info.arguments.set_future(ident("b"), FutureArg::Await); + + let tokens = matrix(item_fn, info); + + let tests = TestsGroup::from(tokens); + + let code = tests.requested_test.block.display_code(); + + assert_not_in!(code, await_argument_code_string("a")); + assert_in!(code, await_argument_code_string("b")); + assert_not_in!(code, await_argument_code_string("c")); + } + + mod two_args_should { + /// Should test matrix tests render without take in account MatrixInfo to RsTestInfo + /// transformation + use super::{assert_eq, *}; + + fn fixture<'a>() -> (Vec<&'a str>, ItemFn, RsTestInfo) { + let names = vec!["first", "second"]; + ( + names.clone(), + format!( + r#"fn test({}: u32, {}: u32) {{ println!("user code") }}"#, + names[0], names[1] + ) + .ast(), + RsTestInfo { + data: RsTestData { + items: vec![ + values_list(names[0], &["1", "2", "3"]).into(), + values_list(names[1], &["1", "2"]).into(), + ], + }, + ..Default::default() + }, + ) + } + + #[test] + fn contain_a_module_for_each_first_arg() { + let (names, item_fn, info) = fixture(); + + let tokens = matrix(item_fn, info); + + let modules = TestsGroup::from(tokens).module.get_modules().names(); + + let expected = (1..=3) + .map(|i| format!("{}_{}", names[0], i)) + .collect::>(); + + assert_eq!(expected.len(), modules.len()); + for (e, m) in expected.into_iter().zip(modules.into_iter()) { + assert_in!(m, e); + } + } + + #[test] + fn annotate_modules_with_allow_non_snake_name() { + let (_, item_fn, info) = fixture(); + let non_snake_case = &attrs("#[allow(non_snake_case)]")[0]; + + let tokens = matrix(item_fn, info); + + let modules = TestsGroup::from(tokens).module.get_modules(); + + for module in modules { + assert!(module.attrs.contains(&non_snake_case)); + } + } + + #[test] + fn create_all_tests() { + let (_, item_fn, info) = fixture(); + + let tokens = matrix(item_fn, info); + + let tests = TestsGroup::from(tokens).module.get_all_tests().names(); + + assert_eq!(6, tests.len()); + } + + #[test] + fn create_all_modules_with_the_same_functions() { + let (_, item_fn, info) = fixture(); + + let tokens = matrix(item_fn, info); + + let tests = TestsGroup::from(tokens) + .module + .get_modules() + .into_iter() + .map(|m| m.get_tests().names()) + .collect::>(); + + assert_eq!(tests[0], tests[1]); + assert_eq!(tests[1], tests[2]); + } + + #[test] + fn test_name_should_contain_argument_name() { + let (names, item_fn, info) = fixture(); + + let tokens = matrix(item_fn, info); + + let tests = TestsGroup::from(tokens).module.get_modules()[0] + .get_tests() + .names(); + + let expected = (1..=2) + .map(|i| format!("{}_{}", names[1], i)) + .collect::>(); + + assert_eq!(expected.len(), tests.len()); + for (e, m) in expected.into_iter().zip(tests.into_iter()) { + assert_in!(m, e); + } + } + } + + #[test] + fn three_args_should_create_all_function_4_mods_at_the_first_level_and_3_at_the_second() { + let (first, second, third) = ("first", "second", "third"); + let info = RsTestInfo { + data: RsTestData { + items: vec![ + values_list(first, &["1", "2", "3", "4"]).into(), + values_list(second, &["1", "2", "3"]).into(), + values_list(third, &["1", "2"]).into(), + ], + }, + ..Default::default() + }; + let item_fn = format!( + r#"fn test({}: u32, {}: u32, {}: u32) {{ println!("user code") }}"#, + first, second, third + ) + .ast(); + + let tokens = matrix(item_fn, info); + + let tg = TestsGroup::from(tokens); + + assert_eq!(24, tg.module.get_all_tests().len()); + assert_eq!(4, tg.module.get_modules().len()); + assert_eq!(3, tg.module.get_modules()[0].get_modules().len()); + assert_eq!(3, tg.module.get_modules()[3].get_modules().len()); + assert_eq!( + 2, + tg.module.get_modules()[0].get_modules()[0] + .get_tests() + .len() + ); + assert_eq!( + 2, + tg.module.get_modules()[3].get_modules()[1] + .get_tests() + .len() + ); + } + + #[test] + fn pad_case_index() { + let item_fn: ItemFn = + r#"fn test(first: u32, second: u32, third: u32) { println!("user code") }"#.ast(); + let values = (1..=100).map(|i| i.to_string()).collect::>(); + let info = RsTestInfo { + data: RsTestData { + items: vec![ + values_list("first", values.as_ref()).into(), + values_list("second", values[..10].as_ref()).into(), + values_list("third", values[..2].as_ref()).into(), + ], + }, + ..Default::default() + }; + + let tokens = matrix(item_fn.clone(), info); + + let tg = TestsGroup::from(tokens); + + let mods = tg.get_modules().names(); + + assert_in!(mods[0], "first_001"); + assert_in!(mods[99], "first_100"); + + let mods = tg.get_modules()[0].get_modules().names(); + + assert_in!(mods[0], "second_01"); + assert_in!(mods[9], "second_10"); + + let functions = tg.get_modules()[0].get_modules()[1].get_tests().names(); + + assert_in!(functions[0], "third_1"); + assert_in!(functions[1], "third_2"); + } +} + +mod complete_should { + use super::{assert_eq, *}; + + fn rendered_case(fn_name: &str) -> TestsGroup { + let item_fn: ItemFn = format!( + r#" #[first] + #[second(arg)] + fn {}( + fix: u32, + a: f64, b: f32, + x: i32, y: i32) {{}}"#, + fn_name + ) + .ast(); + let data = RsTestData { + items: vec![ + fixture("fix", &["2"]).into(), + ident("a").into(), + ident("b").into(), + vec!["1f64", "2f32"] + .into_iter() + .collect::() + .into(), + TestCase { + description: Some(ident("description")), + ..vec!["3f64", "4f32"].into_iter().collect::() + } + .with_attrs(attrs("#[third]#[forth(other)]")) + .into(), + values_list("x", &["12", "-2"]).into(), + values_list("y", &["-3", "42"]).into(), + ], + }; + + matrix(item_fn.clone(), data.into()).into() + } + + fn test_case() -> TestsGroup { + rendered_case("test_function") + } + + #[test] + fn use_function_name_as_outer_module() { + let rendered = rendered_case("should_be_the_outer_module_name"); + + assert_eq!(rendered.module.ident, "should_be_the_outer_module_name") + } + + #[test] + fn have_one_module_for_each_parametrized_case() { + let rendered = test_case(); + + assert_eq!( + vec!["case_1", "case_2_description"], + rendered + .get_modules() + .iter() + .map(|m| m.ident.to_string()) + .collect::>() + ); + } + + #[test] + fn implement_exactly_8_tests() { + let rendered = test_case(); + + assert_eq!(8, rendered.get_all_tests().len()); + } + + #[test] + fn implement_exactly_4_tests_in_each_module() { + let modules = test_case().module.get_modules(); + + assert_eq!(4, modules[0].get_all_tests().len()); + assert_eq!(4, modules[1].get_all_tests().len()); + } + + #[test] + fn assign_same_case_value_for_each_test() { + let modules = test_case().module.get_modules(); + + for f in modules[0].get_all_tests() { + let assignments = Assignments::collect_assignments(&f); + assert_eq!(assignments.0["a"], expr("1f64")); + assert_eq!(assignments.0["b"], expr("2f32")); + } + + for f in modules[1].get_all_tests() { + let assignments = Assignments::collect_assignments(&f); + assert_eq!(assignments.0["a"], expr("3f64")); + assert_eq!(assignments.0["b"], expr("4f32")); + } + } + + #[test] + fn assign_all_case_combination_in_tests() { + let modules = test_case().module.get_modules(); + + let cases = vec![("12", "-3"), ("12", "42"), ("-2", "-3"), ("-2", "42")]; + for module in modules { + for ((x, y), f) in cases.iter().zip(module.get_all_tests().iter()) { + let assignments = Assignments::collect_assignments(f); + assert_eq!(assignments.0["x"], expr(x)); + assert_eq!(assignments.0["y"], expr(y)); + } + } + } + + #[test] + fn mark_test_with_given_attributes() { + let modules = test_case().module.get_modules(); + let attrs = attrs("#[first]#[second(arg)]"); + + for f in modules[0].get_all_tests() { + let end = f.attrs.len() - 1; + assert_eq!(attrs, &f.attrs[1..end]); + } + for f in modules[1].get_all_tests() { + assert_eq!(attrs, &f.attrs[1..3]); + } + } + #[test] + fn should_add_attributes_given_in_the_test_case() { + let modules = test_case().module.get_modules(); + let attrs = attrs("#[third]#[forth(other)]"); + + for f in modules[1].get_all_tests() { + assert_eq!(attrs, &f.attrs[3..5]); + } + } +} diff --git a/src/render/wrapper.rs b/src/render/wrapper.rs new file mode 100644 index 0000000..a513a7c --- /dev/null +++ b/src/render/wrapper.rs @@ -0,0 +1,19 @@ +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::Ident; + +pub(crate) trait WrapByModule { + fn wrap_by_mod(&self, mod_name: &Ident) -> TokenStream; +} + +impl WrapByModule for T { + fn wrap_by_mod(&self, mod_name: &Ident) -> TokenStream { + quote! { + mod #mod_name { + use super::*; + + #self + } + } + } +} diff --git a/src/resolver.rs b/src/resolver.rs new file mode 100644 index 0000000..5111050 --- /dev/null +++ b/src/resolver.rs @@ -0,0 +1,174 @@ +/// Define `Resolver` trait and implement it on some hashmaps and also define the `Resolver` tuple +/// composition. Provide also some utility functions related to how to create a `Resolver` and +/// resolving render. +/// +use std::borrow::Cow; +use std::collections::HashMap; + +use proc_macro2::Ident; +use syn::{parse_quote, Expr}; + +use crate::parse::Fixture; + +pub(crate) mod fixtures { + use quote::format_ident; + + use super::*; + + pub(crate) fn get<'a>(fixtures: impl Iterator) -> impl Resolver + 'a { + fixtures + .map(|f| (f.name.to_string(), extract_resolve_expression(f))) + .collect::>() + } + + fn extract_resolve_expression(fixture: &Fixture) -> syn::Expr { + let resolve = fixture.resolve.as_ref().unwrap_or(&fixture.name); + let positional = &fixture.positional.0; + let f_name = match positional.len() { + 0 => format_ident!("default"), + l => format_ident!("partial_{}", l), + }; + parse_quote! { #resolve::#f_name(#(#positional), *) } + } + + #[cfg(test)] + mod should { + use super::*; + use crate::test::{assert_eq, *}; + + #[rstest] + #[case(&[], "default()")] + #[case(&["my_expression"], "partial_1(my_expression)")] + #[case(&["first", "other"], "partial_2(first, other)")] + fn resolve_by_use_the_given_name(#[case] args: &[&str], #[case] expected: &str) { + let data = vec![fixture("pippo", args)]; + let resolver = get(data.iter()); + + let resolved = resolver.resolve(&ident("pippo")).unwrap().into_owned(); + + assert_eq!(resolved, format!("pippo::{}", expected).ast()); + } + + #[rstest] + #[case(&[], "default()")] + #[case(&["my_expression"], "partial_1(my_expression)")] + #[case(&["first", "other"], "partial_2(first, other)")] + fn resolve_by_use_the_resolve_field(#[case] args: &[&str], #[case] expected: &str) { + let data = vec![fixture("pippo", args).with_resolve("pluto")]; + let resolver = get(data.iter()); + + let resolved = resolver.resolve(&ident("pippo")).unwrap().into_owned(); + + assert_eq!(resolved, format!("pluto::{}", expected).ast()); + } + } +} + +pub(crate) mod values { + use super::*; + use crate::parse::fixture::ArgumentValue; + + pub(crate) fn get<'a>(values: impl Iterator) -> impl Resolver + 'a { + values + .map(|av| (av.name.to_string(), &av.expr)) + .collect::>() + } + + #[cfg(test)] + mod should { + use super::*; + use crate::test::{assert_eq, *}; + + #[test] + fn resolve_by_use_the_given_name() { + let data = vec![ + arg_value("pippo", "42"), + arg_value("donaldduck", "vec![1,2]"), + ]; + let resolver = get(data.iter()); + + assert_eq!( + resolver.resolve(&ident("pippo")).unwrap().into_owned(), + "42".ast() + ); + assert_eq!( + resolver.resolve(&ident("donaldduck")).unwrap().into_owned(), + "vec![1,2]".ast() + ); + } + } +} + +/// A trait that `resolve` the given ident to expression code to assign the value. +pub(crate) trait Resolver { + fn resolve(&self, ident: &Ident) -> Option>; +} + +impl<'a> Resolver for HashMap { + fn resolve(&self, ident: &Ident) -> Option> { + let ident = ident.to_string(); + self.get(&ident).map(|&c| Cow::Borrowed(c)) + } +} + +impl Resolver for HashMap { + fn resolve(&self, ident: &Ident) -> Option> { + let ident = ident.to_string(); + self.get(&ident).map(Cow::Borrowed) + } +} + +impl Resolver for (R1, R2) { + fn resolve(&self, ident: &Ident) -> Option> { + self.0.resolve(ident).or_else(|| self.1.resolve(ident)) + } +} + +impl Resolver for &R { + fn resolve(&self, ident: &Ident) -> Option> { + (*self).resolve(ident) + } +} + +impl Resolver for Box { + fn resolve(&self, ident: &Ident) -> Option> { + (**self).resolve(ident) + } +} + +impl Resolver for (String, Expr) { + fn resolve(&self, ident: &Ident) -> Option> { + if *ident == self.0 { + Some(Cow::Borrowed(&self.1)) + } else { + None + } + } +} + +#[cfg(test)] +mod should { + use super::*; + use crate::test::{assert_eq, *}; + use syn::parse_str; + + #[test] + fn return_the_given_expression() { + let ast = parse_str("fn function(mut foo: String) {}").unwrap(); + let arg = first_arg_ident(&ast); + let expected = expr("bar()"); + let mut resolver = HashMap::new(); + + resolver.insert("foo".to_string(), &expected); + + assert_eq!(expected, (&resolver).resolve(&arg).unwrap().into_owned()) + } + + #[test] + fn return_none_for_unknown_argument() { + let ast = "fn function(mut fix: String) {}".ast(); + let arg = first_arg_ident(&ast); + + assert!(EmptyResolver.resolve(&arg).is_none()) + } +} diff --git a/src/test.rs b/src/test.rs new file mode 100644 index 0000000..c4a25ca --- /dev/null +++ b/src/test.rs @@ -0,0 +1,328 @@ +#![macro_use] + +/// Unit testing utility module. Collect a bunch of functions¯o and impls to simplify unit +/// testing bolilerplate. +/// +use std::borrow::Cow; +use std::iter::FromIterator; + +pub(crate) use pretty_assertions::assert_eq; +use proc_macro2::TokenTree; +use quote::quote; +pub(crate) use rstest::{fixture, rstest}; +use syn::{parse::Parse, parse2, parse_quote, parse_str, Error, Expr, Ident, ItemFn, Stmt}; + +use super::*; +use crate::parse::{ + fixture::{FixtureData, FixtureItem}, + rstest::{RsTestData, RsTestItem}, + testcase::TestCase, + vlist::ValueList, + Attribute, Fixture, Positional, +}; +use crate::resolver::Resolver; +use crate::utils::fn_args_idents; +use parse::fixture::ArgumentValue; + +macro_rules! to_args { + ($e:expr) => {{ + $e.iter() + .map(|s| s as &dyn AsRef) + .map(expr) + .collect::>() + }}; +} + +macro_rules! to_exprs { + ($e:expr) => { + $e.iter().map(|s| expr(s)).collect::>() + }; +} + +macro_rules! to_strs { + ($e:expr) => { + $e.iter().map(ToString::to_string).collect::>() + }; +} + +macro_rules! to_idents { + ($e:expr) => { + $e.iter().map(|s| ident(s)).collect::>() + }; +} + +struct Outer(T); +impl Parse for Outer { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let outer: Ident = input.parse()?; + if outer == "outer" { + let content; + let _ = syn::parenthesized!(content in input); + content.parse().map(Outer) + } else { + Err(Error::new(outer.span(), "Expected 'outer'")) + } + } +} + +pub(crate) fn parse_meta>(test_case: S) -> T { + let to_parse = format!( + r#" + #[outer({})] + fn to_parse() {{}} + "#, + test_case.as_ref() + ); + + let item_fn = parse_str::(&to_parse).expect(&format!("Cannot parse '{}'", to_parse)); + + let tokens = quote!( + #item_fn + ); + + let tt = tokens.into_iter().skip(1).next().unwrap(); + + if let TokenTree::Group(g) = tt { + let ts = g.stream(); + parse2::>(ts).unwrap().0 + } else { + panic!("Cannot find group in {:#?}", tt) + } +} + +pub(crate) trait ToAst { + fn ast(self) -> T; +} + +impl ToAst for &str { + fn ast(self) -> T { + parse_str(self).unwrap() + } +} + +impl ToAst for String { + fn ast(self) -> T { + parse_str(&self).unwrap() + } +} + +impl ToAst for proc_macro2::TokenStream { + fn ast(self) -> T { + parse2(self).unwrap() + } +} + +pub(crate) fn ident(s: impl AsRef) -> Ident { + s.as_ref().ast() +} + +pub(crate) fn expr(s: impl AsRef) -> syn::Expr { + s.as_ref().ast() +} + +pub(crate) fn attrs(s: impl AsRef) -> Vec { + parse_str::(&format!( + r#"{} + fn _no_name_() {{}} + "#, + s.as_ref() + )) + .unwrap() + .attrs +} + +pub(crate) fn fixture(name: impl AsRef, args: &[&str]) -> Fixture { + Fixture::new(ident(name), None, Positional(to_exprs!(args))) +} + +pub(crate) fn arg_value(name: impl AsRef, value: impl AsRef) -> ArgumentValue { + ArgumentValue::new(ident(name), expr(value)) +} + +pub(crate) fn values_list>(arg: &str, values: &[S]) -> ValueList { + ValueList { + arg: ident(arg), + values: values.into_iter().map(|s| expr(s)).collect(), + } +} + +pub(crate) fn first_arg_ident(ast: &ItemFn) -> &Ident { + fn_args_idents(&ast).next().unwrap() +} + +pub(crate) fn extract_inner_functions(block: &syn::Block) -> impl Iterator { + block.stmts.iter().filter_map(|s| match s { + syn::Stmt::Item(syn::Item::Fn(f)) => Some(f), + _ => None, + }) +} + +pub(crate) fn literal_expressions_str() -> Vec<&'static str> { + vec![ + "42", + "42isize", + "1.0", + "-1", + "-1.0", + "true", + "1_000_000u64", + "0b10100101u8", + r#""42""#, + "b'H'", + ] +} + +pub(crate) trait ExtractArgs { + fn args(&self) -> Vec; +} + +impl ExtractArgs for TestCase { + fn args(&self) -> Vec { + self.args.iter().cloned().collect() + } +} + +impl ExtractArgs for ValueList { + fn args(&self) -> Vec { + self.values.iter().cloned().collect() + } +} + +impl Attribute { + pub fn attr>(s: S) -> Self { + Attribute::Attr(ident(s)) + } + + pub fn tagged, SA: AsRef>(tag: SI, attrs: Vec) -> Self { + Attribute::Tagged(ident(tag), attrs.into_iter().map(|a| ident(a)).collect()) + } + + pub fn typed, T: AsRef>(tag: S, inner: T) -> Self { + Attribute::Type(ident(tag), parse_str(inner.as_ref()).unwrap()) + } +} + +impl RsTestInfo { + pub fn push_case(&mut self, case: TestCase) { + self.data.items.push(RsTestItem::TestCase(case)); + } + + pub fn extend(&mut self, cases: impl Iterator) { + self.data.items.extend(cases.map(RsTestItem::TestCase)); + } +} + +impl Fixture { + pub fn with_resolve(mut self, resolve_ident: &str) -> Self { + self.resolve = Some(ident(resolve_ident)); + self + } +} + +impl TestCase { + pub fn with_description(mut self, description: &str) -> Self { + self.description = Some(ident(description)); + self + } + + pub fn with_attrs(mut self, attrs: Vec) -> Self { + self.attrs = attrs; + self + } +} + +impl> FromIterator for TestCase { + fn from_iter>(iter: T) -> Self { + TestCase { + args: iter.into_iter().map(expr).collect(), + attrs: Default::default(), + description: None, + } + } +} + +impl<'a> From<&'a str> for TestCase { + fn from(argument: &'a str) -> Self { + std::iter::once(argument).collect() + } +} + +impl From> for RsTestData { + fn from(items: Vec) -> Self { + Self { items } + } +} + +impl From for RsTestInfo { + fn from(data: RsTestData) -> Self { + Self { + data, + ..Default::default() + } + } +} + +impl From> for Positional { + fn from(data: Vec) -> Self { + Positional(data) + } +} + +impl From> for FixtureData { + fn from(fixtures: Vec) -> Self { + Self { items: fixtures } + } +} + +pub(crate) struct EmptyResolver; + +impl<'a> Resolver for EmptyResolver { + fn resolve(&self, _ident: &Ident) -> Option> { + None + } +} + +pub(crate) trait IsAwait { + fn is_await(&self) -> bool; +} + +impl IsAwait for Stmt { + fn is_await(&self) -> bool { + match self { + Stmt::Expr(Expr::Await(_)) => true, + _ => false, + } + } +} + +pub(crate) trait DisplayCode { + fn display_code(&self) -> String; +} + +impl DisplayCode for T { + fn display_code(&self) -> String { + self.to_token_stream().to_string() + } +} + +impl crate::parse::fixture::FixtureInfo { + pub(crate) fn with_once(mut self) -> Self { + self.attributes = self.attributes.with_once(); + self + } +} + +impl crate::parse::fixture::FixtureModifiers { + pub(crate) fn with_once(mut self) -> Self { + self.append(Attribute::attr("once")); + self + } +} + +pub(crate) fn await_argument_code_string(arg_name: &str) -> String { + let arg_name = ident(arg_name); + let statment: Stmt = parse_quote! { + let #arg_name = #arg_name.await; + }; + statment.display_code() +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..9471c7a --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,390 @@ +/// Contains some unsorted functions used across others modules +/// +use quote::format_ident; +use std::collections::{HashMap, HashSet}; + +use crate::refident::MaybeIdent; +use syn::{Attribute, Expr, FnArg, Generics, Ident, ItemFn, ReturnType, Type, WherePredicate}; + +/// Return an iterator over fn arguments items. +/// +pub(crate) fn fn_args_idents(test: &ItemFn) -> impl Iterator { + fn_args(test).filter_map(MaybeIdent::maybe_ident) +} + +/// Return if function declaration has an ident +/// +pub(crate) fn fn_args_has_ident(fn_decl: &ItemFn, ident: &Ident) -> bool { + fn_args_idents(fn_decl).any(|id| id == ident) +} + +/// Return an iterator over fn arguments. +/// +pub(crate) fn fn_args(item_fn: &ItemFn) -> impl Iterator { + item_fn.sig.inputs.iter() +} + +pub(crate) fn attr_ends_with(attr: &Attribute, segment: &syn::PathSegment) -> bool { + attr.path.segments.iter().last() == Some(segment) +} + +pub(crate) fn attr_starts_with(attr: &Attribute, segment: &syn::PathSegment) -> bool { + attr.path.segments.iter().next() == Some(segment) +} + +pub(crate) fn attr_is(attr: &Attribute, name: &str) -> bool { + attr.path.is_ident(&format_ident!("{}", name)) +} + +pub(crate) fn attr_in(attr: &Attribute, names: &[&str]) -> bool { + names + .iter() + .any(|name| attr.path.is_ident(&format_ident!("{}", name))) +} + +pub(crate) trait IsLiteralExpression { + fn is_literal(&self) -> bool; +} + +impl> IsLiteralExpression for E { + fn is_literal(&self) -> bool { + matches!( + self.as_ref(), + Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(_), + .. + }) + ) + } +} + +// Recoursive search id by reference till find one in ends +fn _is_used( + visited: &mut HashSet, + id: &Ident, + references: &HashMap>, + ends: &HashSet, +) -> bool { + if visited.contains(id) { + return false; + } + visited.insert(id.clone()); + if ends.contains(id) { + return true; + } + if references.contains_key(id) { + for refered in references.get(id).unwrap() { + if _is_used(visited, refered, references, ends) { + return true; + } + } + } + false +} + +// Recoursive search id by reference till find one in ends +fn is_used(id: &Ident, references: &HashMap>, ends: &HashSet) -> bool { + let mut visited = Default::default(); + _is_used(&mut visited, id, references, ends) +} + +impl MaybeIdent for syn::WherePredicate { + fn maybe_ident(&self) -> Option<&Ident> { + match self { + WherePredicate::Type(syn::PredicateType { bounded_ty: t, .. }) => { + first_type_path_segment_ident(t) + } + WherePredicate::Lifetime(syn::PredicateLifetime { lifetime, .. }) => { + Some(&lifetime.ident) + } + WherePredicate::Eq(_) => None, + } + } +} + +#[derive(Default)] +struct SearchSimpleTypeName(HashSet); + +impl SearchSimpleTypeName { + fn take(self) -> HashSet { + self.0 + } + + fn visit_inputs<'a>(&mut self, inputs: impl Iterator) { + use syn::visit::Visit; + inputs.for_each(|fn_arg| self.visit_fn_arg(fn_arg)); + } + fn visit_output(&mut self, output: &ReturnType) { + use syn::visit::Visit; + self.visit_return_type(output); + } + + fn collect_from_type_param(tp: &syn::TypeParam) -> Self { + let mut s: Self = Default::default(); + use syn::visit::Visit; + s.visit_type_param(tp); + s + } + + fn collect_from_where_predicate(wp: &syn::WherePredicate) -> Self { + let mut s: Self = Default::default(); + use syn::visit::Visit; + s.visit_where_predicate(wp); + s + } +} + +impl<'ast> syn::visit::Visit<'ast> for SearchSimpleTypeName { + fn visit_path(&mut self, p: &'ast syn::Path) { + if let Some(id) = p.get_ident() { + self.0.insert(id.clone()); + } + syn::visit::visit_path(self, p) + } + + fn visit_lifetime(&mut self, i: &'ast syn::Lifetime) { + self.0.insert(i.ident.clone()); + syn::visit::visit_lifetime(self, i) + } +} + +// Take generics definitions and where clauses and return the +// a map from simple types (lifetime names or type with just names) +// to a set of all simple types that use it as some costrain. +fn extract_references_map(generics: &Generics) -> HashMap> { + let mut references = HashMap::>::default(); + // Extracts references from types param + generics.type_params().for_each(|tp| { + SearchSimpleTypeName::collect_from_type_param(tp) + .take() + .into_iter() + .for_each(|id| { + references.entry(id).or_default().insert(tp.ident.clone()); + }); + }); + // Extracts references from where clauses + generics + .where_clause + .iter() + .flat_map(|wc| wc.predicates.iter()) + .filter_map(|wp| wp.maybe_ident().map(|id| (id, wp))) + .for_each(|(ref_ident, wp)| { + SearchSimpleTypeName::collect_from_where_predicate(wp) + .take() + .into_iter() + .for_each(|id| { + references.entry(id).or_default().insert(ref_ident.clone()); + }); + }); + references +} + +// Return a hash set that contains all types and lifetimes referenced +// in input/output expressed by a single ident. +fn references_ident_types<'a>( + generics: &Generics, + inputs: impl Iterator, + output: &ReturnType, +) -> HashSet { + let mut used: SearchSimpleTypeName = Default::default(); + used.visit_output(output); + used.visit_inputs(inputs); + let references = extract_references_map(generics); + let mut used = used.take(); + let input_output = used.clone(); + // Extend the input output collected ref with the transitive ones: + used.extend( + generics + .params + .iter() + .filter_map(MaybeIdent::maybe_ident) + .filter(|&id| is_used(id, &references, &input_output)) + .cloned(), + ); + used +} + +fn filtered_predicates(mut wc: syn::WhereClause, valids: &HashSet) -> syn::WhereClause { + wc.predicates = wc + .predicates + .clone() + .into_iter() + .filter(|wp| { + wp.maybe_ident() + .map(|t| valids.contains(t)) + .unwrap_or_default() + }) + .collect(); + wc +} + +fn filtered_generics<'a>( + params: impl Iterator + 'a, + valids: &'a HashSet, +) -> impl Iterator + 'a { + params.filter(move |p| match p.maybe_ident() { + Some(id) => valids.contains(id), + None => false, + }) +} + +//noinspection RsTypeCheck +pub(crate) fn generics_clean_up<'a>( + original: &Generics, + inputs: impl Iterator, + output: &ReturnType, +) -> syn::Generics { + let used = references_ident_types(original, inputs, output); + let mut result: Generics = original.clone(); + result.params = filtered_generics(result.params.into_iter(), &used).collect(); + result.where_clause = result.where_clause.map(|wc| filtered_predicates(wc, &used)); + result +} + +// If type is not self and doesn't starts with :: return the first ident +// of its path segment: only if is a simple path. +// If type is a simple ident just return the this ident. That is useful to +// find the base type for associate type indication +fn first_type_path_segment_ident(t: &Type) -> Option<&Ident> { + match t { + Type::Path(tp) if tp.qself.is_none() && tp.path.leading_colon.is_none() => tp + .path + .segments + .iter() + .next() + .and_then(|ps| match ps.arguments { + syn::PathArguments::None => Some(&ps.ident), + _ => None, + }), + _ => None, + } +} + +pub(crate) fn fn_arg_mutability(arg: &FnArg) -> Option { + match arg { + FnArg::Typed(syn::PatType { pat, .. }) => match pat.as_ref() { + syn::Pat::Ident(syn::PatIdent { mutability, .. }) => *mutability, + _ => None, + }, + _ => None, + } +} + +#[cfg(test)] +mod test { + use syn::parse_quote; + + use super::*; + use crate::test::{assert_eq, *}; + + #[test] + fn fn_args_idents_should() { + let item_fn = parse_quote! { + fn the_functon(first: u32, second: u32) {} + }; + + let mut args = fn_args_idents(&item_fn); + + assert_eq!("first", args.next().unwrap().to_string()); + assert_eq!("second", args.next().unwrap().to_string()); + } + + #[test] + fn fn_args_has_ident_should() { + let item_fn = parse_quote! { + fn the_functon(first: u32, second: u32) {} + }; + + assert!(fn_args_has_ident(&item_fn, &ident("first"))); + assert!(!fn_args_has_ident(&item_fn, &ident("third"))); + } + + #[rstest] + #[case::base("fn foo(a: A) -> B {}", &["A", "B"])] + #[case::use_const_in_array("fn foo(a: A) -> [u32; B] {}", &["A", "B", "u32"])] + #[case::in_type_args("fn foo(a: A) -> SomeType {}", &["A", "B"])] + #[case::in_type_args("fn foo(a: SomeType, b: SomeType) {}", &["A", "B"])] + #[case::pointers("fn foo(a: *const A, b: &B) {}", &["A", "B"])] + #[case::lifetime("fn foo<'a, A, B, C>(a: A, b: &'a B) {}", &["a", "A", "B"])] + #[case::transitive_lifetime("fn foo<'a, A, B, C>(a: A, b: B) where B: Iterator + 'a {}", &["a", "A", "B"])] + #[case::associated("fn foo<'a, A:Copy, C>(b: impl Iterator + 'a) {}", &["a", "A"])] + #[case::transitive_in_defs("fn foo>(b: B) {}", &["A", "B"])] + #[case::transitive_in_where("fn foo(b: B) where B: Iterator {}", &["A", "B"])] + #[case::transitive_const("fn foo(b: B) where B: Some {}", &["A", "B"])] + #[case::transitive_lifetime("fn foo<'a, A, B, C>(a: A, b: B) where B: Iterator + 'a {}", &["a", "A", "B"])] + #[case::transitive_lifetime(r#"fn foo<'a, 'b, 'c, 'd, A, B, C> + (a: A, b: B) + where B: Iterator + 'c, + 'c: 'a + 'b {}"#, &["a", "b", "c", "A", "B"])] + fn references_ident_types_should(#[case] f: &str, #[case] expected: &[&str]) { + let f: ItemFn = f.ast(); + let used = references_ident_types(&f.sig.generics, f.sig.inputs.iter(), &f.sig.output); + + let expected = to_idents!(expected) + .into_iter() + .collect::>(); + + assert_eq!(expected, used); + } + + #[rstest] + #[case::remove_not_in_output( + r#"fn test, B, F, H: Iterator>() -> (H, B, String, &str) + where F: ToString, + B: Borrow + {}"#, + r#"fn test>() -> (H, B, String, &str) + where B: Borrow + {}"# + )] + #[case::not_remove_used_in_arguments( + r#"fn test, B, F, H: Iterator> + (h: H, it: impl Iterator, j: &[B]) + where F: ToString, + B: Borrow + {}"#, + r#"fn test, B, H: Iterator> + (h: H, it: impl Iterator, j: &[B]) + where + B: Borrow + {}"# + )] + #[case::dont_remove_transitive( + r#"fn test(a: A) where + B: AsRef, + A: Iterator, + D: ArsRef {}"#, + r#"fn test(a: A) where + B: AsRef, + A: Iterator {}"# + )] + #[case::remove_unused_lifetime( + "fn test<'a, 'b, 'c, 'd, 'e, 'f, 'g, A>(a: &'a uint32, b: impl AsRef + 'b) where 'b: 'c + 'd, A: Copy + 'e, 'f: 'g {}", + "fn test<'a, 'b, 'c, 'd, 'e, A>(a: &'a uint32, b: impl AsRef + 'b) where 'b: 'c + 'd, A: Copy + 'e {}" + )] + #[case::remove_unused_const( + r#"fn test + (a: [u32; A], b: SomeType, c: T) where + T: Iterator, + O: AsRef + {}"#, + r#"fn test + (a: [u32; A], b: SomeType, c: T) where + T: Iterator + {}"# + )] + fn generics_cleaner(#[case] code: &str, #[case] expected: &str) { + // Should remove all generics parameters that are not present in output + let item_fn: ItemFn = code.ast(); + + let expected: ItemFn = expected.ast(); + + let cleaned = generics_clean_up( + &item_fn.sig.generics, + item_fn.sig.inputs.iter(), + &item_fn.sig.output, + ); + + assert_eq!(expected.sig.generics, cleaned); + } +} -- 2.7.4