[Rust] Clean up conversions between TVM and Rust functions (#6114)
authorMax Willsey <me@mwillsey.com>
Thu, 23 Jul 2020 08:14:17 +0000 (01:14 -0700)
committerGitHub <noreply@github.com>
Thu, 23 Jul 2020 08:14:17 +0000 (01:14 -0700)
* Replace ToBoxedFn with From

* Compact and improve Typed and ToFunction impls

- Clone one less time
- Don't panic if number of args is wrong, return an error
- Actually drop functions/closures on the rust side

* Retry

rust/tvm-macros/src/external.rs
rust/tvm-rt/README.md
rust/tvm-rt/src/function.rs
rust/tvm-rt/src/lib.rs
rust/tvm-rt/src/to_boxed_fn.rs [deleted file]
rust/tvm-rt/src/to_function.rs

index 2fcee49..4359db9 100644 (file)
@@ -144,7 +144,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
         let wrapper = quote! {
             pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
                 let func_ref: #tvm_rt_crate::Function = #global_name.clone();
-                let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.to_boxed_fn();
+                let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into();
                 let res: #ret_type = func_ref(#(#args),*)?;
                 Ok(res)
             }
index 7c87939..662687e 100644 (file)
@@ -53,7 +53,7 @@ fn sum(x: i64, y: i64, z: i64) -> i64 {
 fn main() {
     register(sum, "mysum".to_owned()).unwrap();
     let func = Function::get("mysum").unwrap();
-    let boxed_fn = func.to_boxed_fn::<dyn Fn(i64, i64, i64) -> Result<i64>>();
+    let boxed_fn: Box<dyn Fn(i64, i64, i64) -> Result<i64>> = func.into();
     let ret = boxed_fn(10, 20, 30).unwrap();
     assert_eq!(ret, 60);
 }
index 591b5cc..94a20ac 100644 (file)
@@ -25,7 +25,7 @@
 //!
 //! See the tests and examples repository for more examples.
 
-use std::convert::TryFrom;
+use std::convert::{TryFrom, TryInto};
 use std::{
     ffi::CString,
     os::raw::{c_char, c_int},
@@ -34,8 +34,6 @@ use std::{
 
 use crate::errors::Error;
 
-use super::to_boxed_fn::ToBoxedFn;
-
 pub use super::to_function::{ToFunction, Typed};
 pub use tvm_sys::{ffi, ArgValue, RetValue};
 
@@ -94,11 +92,13 @@ impl Function {
         }
     }
 
-    pub fn get_boxed<F: ?Sized, S: AsRef<str>>(name: S) -> Option<Box<F>>
+    pub fn get_boxed<F, S>(name: S) -> Option<Box<F>>
     where
-        F: ToBoxedFn,
+        S: AsRef<str>,
+        F: ?Sized,
+        Self: Into<Box<F>>,
     {
-        Self::get(name).map(|f| f.to_boxed_fn::<F>())
+        Self::get(name).map(|f| f.into())
     }
 
     /// Returns the underlying TVM function handle.
@@ -141,15 +141,31 @@ impl Function {
 
         Ok(rv)
     }
+}
 
-    pub fn to_boxed_fn<F: ?Sized>(self) -> Box<F>
-    where
-        F: ToBoxedFn,
-    {
-        F::to_boxed_fn(self)
-    }
+macro_rules! impl_to_fn {
+    () => { impl_to_fn!(@impl); };
+    ($t:ident, $($ts:ident,)*) => { impl_to_fn!(@impl $t, $($ts,)*); impl_to_fn!($($ts,)*); };
+    (@impl $($t:ident,)*) => {
+        impl<Err, Out, $($t,)*> From<Function> for Box<dyn Fn($($t,)*) -> Result<Out>>
+        where
+            Error: From<Err>,
+            Out: TryFrom<RetValue, Error = Err>,
+            $($t: Into<ArgValue<'static>>),*
+        {
+            fn from(func: Function) -> Self {
+                #[allow(non_snake_case)]
+                Box::new(move |$($t : $t),*| {
+                    let args = vec![ $($t.into()),* ];
+                    Ok(func.invoke(args)?.try_into()?)
+                })
+            }
+        }
+    };
 }
 
+impl_to_fn!(T1, T2, T3, T4, T5, T6,);
+
 impl Clone for Function {
     fn clone(&self) -> Function {
         Self {
@@ -248,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
 ///
 /// register(sum, "mysum".to_owned()).unwrap();
 /// let func = Function::get("mysum").unwrap();
-/// let boxed_fn = func.to_boxed_fn::<dyn Fn(i64, i64, i64) -> Result<i64>>();
+/// let boxed_fn: Box<dyn Fn(i64, i64, i64) -> Result<i64>> = func.into();
 /// let ret = boxed_fn(10, 20, 30).unwrap();
 /// assert_eq!(ret, 60);
 /// ```
index a56a25b..ad4c1ca 100644 (file)
@@ -97,7 +97,6 @@ pub mod errors;
 pub mod function;
 pub mod module;
 pub mod ndarray;
-pub mod to_boxed_fn;
 mod to_function;
 pub mod value;
 
diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs
deleted file mode 100644 (file)
index 8416f2c..0000000
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-//! This module provides a method for converting type erased TVM functions
-//! into a boxed Rust closure.
-//!
-//! To call a registered function check the [`ToBoxedFn::to_boxed_fn`] method.
-//!
-//! See the tests and examples repository for more examples.
-
-pub use tvm_sys::{ffi, ArgValue, RetValue};
-
-use crate::errors;
-
-use super::function::{Function, Result};
-
-pub trait ToBoxedFn {
-    fn to_boxed_fn(func: Function) -> Box<Self>;
-}
-
-use std::convert::{TryFrom, TryInto};
-
-impl<E, O> ToBoxedFn for dyn Fn() -> Result<O>
-where
-    errors::Error: From<E>,
-    O: TryFrom<RetValue, Error = E>,
-{
-    fn to_boxed_fn(func: Function) -> Box<Self> {
-        Box::new(move || {
-            let res = func.invoke(vec![])?;
-            let res = res.try_into()?;
-            Ok(res)
-        })
-    }
-}
-
-impl<E, A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
-where
-    errors::Error: From<E>,
-    A: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = E>,
-{
-    fn to_boxed_fn(func: Function) -> Box<Self> {
-        Box::new(move |a: A| {
-            let args = vec![a.into()];
-            let res = func.invoke(args)?;
-            let res = res.try_into()?;
-            Ok(res)
-        })
-    }
-}
-
-impl<E, A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
-where
-    errors::Error: From<E>,
-    A: Into<ArgValue<'static>>,
-    B: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = E>,
-{
-    fn to_boxed_fn(func: Function) -> Box<Self> {
-        Box::new(move |a: A, b: B| {
-            let args = vec![a.into(), b.into()];
-            let res = func.invoke(args)?;
-            let res = res.try_into()?;
-            Ok(res)
-        })
-    }
-}
-
-impl<E, A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
-where
-    errors::Error: From<E>,
-    A: Into<ArgValue<'static>>,
-    B: Into<ArgValue<'static>>,
-    C: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = E>,
-{
-    fn to_boxed_fn(func: Function) -> Box<Self> {
-        Box::new(move |a: A, b: B, c: C| {
-            let args = vec![a.into(), b.into(), c.into()];
-            let res = func.invoke(args)?;
-            let res = res.try_into()?;
-            Ok(res)
-        })
-    }
-}
-
-impl<E, A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
-where
-    errors::Error: From<E>,
-    A: Into<ArgValue<'static>>,
-    B: Into<ArgValue<'static>>,
-    C: Into<ArgValue<'static>>,
-    D: Into<ArgValue<'static>>,
-    O: TryFrom<RetValue, Error = E>,
-{
-    fn to_boxed_fn(func: Function) -> Box<Self> {
-        Box::new(move |a: A, b: B, c: C, d: D| {
-            let args = vec![a.into(), b.into(), c.into(), d.into()];
-            let res = func.invoke(args)?;
-            let res = res.try_into()?;
-            Ok(res)
-        })
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use crate::function::{self, Function, Result};
-
-    #[test]
-    fn to_boxed_fn0() {
-        fn boxed0() -> i64 {
-            return 10;
-        }
-
-        function::register_override(boxed0, "boxed0".to_owned(), true).unwrap();
-        let func = Function::get("boxed0").unwrap();
-        let typed_func: Box<dyn Fn() -> Result<i64>> = func.to_boxed_fn();
-        assert_eq!(typed_func().unwrap(), 10);
-    }
-}
index 445c99e..a89652b 100644 (file)
@@ -49,85 +49,6 @@ pub trait Typed<I, O> {
     fn ret(o: O) -> Result<RetValue>;
 }
 
-impl<F, O, E> Typed<(), O> for F
-where
-    F: Fn() -> O,
-    Error: From<E>,
-    O: TryInto<RetValue, Error = E>,
-{
-    fn args(_args: Vec<ArgValue<'static>>) -> Result<()> {
-        debug_assert!(_args.len() == 0);
-        Ok(())
-    }
-
-    fn ret(o: O) -> Result<RetValue> {
-        o.try_into().map_err(|e| e.into())
-    }
-}
-
-impl<F, A, O, E1, E2> Typed<(A,), O> for F
-where
-    F: Fn(A) -> O,
-    Error: From<E1>,
-    Error: From<E2>,
-    A: TryFrom<ArgValue<'static>, Error = E1>,
-    O: TryInto<RetValue, Error = E2>,
-{
-    fn args(args: Vec<ArgValue<'static>>) -> Result<(A,)> {
-        debug_assert!(args.len() == 1);
-        let a: A = args[0].clone().try_into()?;
-        Ok((a,))
-    }
-
-    fn ret(o: O) -> Result<RetValue> {
-        o.try_into().map_err(|e| e.into())
-    }
-}
-
-impl<F, A, B, O, E1, E2> Typed<(A, B), O> for F
-where
-    F: Fn(A, B) -> O,
-    Error: From<E1>,
-    Error: From<E2>,
-    A: TryFrom<ArgValue<'static>, Error = E1>,
-    B: TryFrom<ArgValue<'static>, Error = E1>,
-    O: TryInto<RetValue, Error = E2>,
-{
-    fn args(args: Vec<ArgValue<'static>>) -> Result<(A, B)> {
-        debug_assert!(args.len() == 2);
-        let a: A = args[0].clone().try_into()?;
-        let b: B = args[1].clone().try_into()?;
-        Ok((a, b))
-    }
-
-    fn ret(o: O) -> Result<RetValue> {
-        o.try_into().map_err(|e| e.into())
-    }
-}
-
-impl<F, A, B, C, O, E1, E2> Typed<(A, B, C), O> for F
-where
-    F: Fn(A, B, C) -> O,
-    Error: From<E1>,
-    Error: From<E2>,
-    A: TryFrom<ArgValue<'static>, Error = E1>,
-    B: TryFrom<ArgValue<'static>, Error = E1>,
-    C: TryFrom<ArgValue<'static>, Error = E1>,
-    O: TryInto<RetValue, Error = E2>,
-{
-    fn args(args: Vec<ArgValue<'static>>) -> Result<(A, B, C)> {
-        debug_assert!(args.len() == 3);
-        let a: A = args[0].clone().try_into()?;
-        let b: B = args[1].clone().try_into()?;
-        let c: C = args[2].clone().try_into()?;
-        Ok((a, b, c))
-    }
-
-    fn ret(o: O) -> Result<RetValue> {
-        o.try_into().map_err(|e| e.into())
-    }
-}
-
 pub trait ToFunction<I, O>: Sized {
     type Handle;
 
@@ -269,95 +190,100 @@ impl ToFunction<Vec<ArgValue<'static>>, RetValue>
     fn drop(_: *mut Self::Handle) {}
 }
 
-impl<O, F> ToFunction<(), O> for F
-where
-    F: Fn() -> O + 'static,
-{
-    type Handle = Box<dyn Fn() -> O + 'static>;
-
-    fn into_raw(self) -> *mut Self::Handle {
-        let ptr: Box<Self::Handle> = Box::new(Box::new(self));
-        Box::into_raw(ptr)
-    }
+macro_rules! impl_typed_and_to_function {
+    ($len:literal; $($t:ident),*) => {
+        impl<F, Out, $($t),*> Typed<($($t,)*), Out> for F
+        where
+            F: Fn($($t),*) -> Out,
+            Out: TryInto<RetValue>,
+            Error: From<Out::Error>,
+            $( $t: TryFrom<ArgValue<'static>>,
+               Error: From<$t::Error>, )*
+        {
+            #[allow(non_snake_case, unused_variables, unused_mut)]
+            fn args(args: Vec<ArgValue<'static>>) -> Result<($($t,)*)> {
+                if args.len() != $len {
+                    return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n",
+                                                         std::any::type_name::<Self>(),
+                                                         $len, args.len())))
+                }
+                let mut args = args.into_iter();
+                $(let $t = args.next().unwrap().try_into()?;)*
+                Ok(($($t,)*))
+            }
 
-    fn call(handle: *mut Self::Handle, _: Vec<ArgValue<'static>>) -> Result<RetValue>
-    where
-        F: Typed<(), O>,
-    {
-        // Ideally we shouldn't need to clone, probably doesn't really matter.
-        let out = unsafe { (*handle)() };
-        F::ret(out)
-    }
+            fn ret(out: Out) -> Result<RetValue> {
+                out.try_into().map_err(|e| e.into())
+            }
+        }
 
-    fn drop(_: *mut Self::Handle) {}
-}
 
-macro_rules! to_function_instance {
-    ($(($param:ident,$index:tt),)+) => {
-        impl<F, $($param,)+ O> ToFunction<($($param,)+), O> for
-        F where F: Fn($($param,)+) -> O + 'static {
-            type Handle = Box<dyn Fn($($param,)+) -> O + 'static>;
+        impl<F, $($t,)* Out> ToFunction<($($t,)*), Out> for F
+        where
+            F: Fn($($t,)*) -> Out + 'static
+        {
+            type Handle = Box<dyn Fn($($t,)*) -> Out + 'static>;
 
             fn into_raw(self) -> *mut Self::Handle {
                 let ptr: Box<Self::Handle> = Box::new(Box::new(self));
                 Box::into_raw(ptr)
             }
 
-            fn call(handle: *mut Self::Handle, args: Vec<ArgValue<'static>>) -> Result<RetValue> where F: Typed<($($param,)+), O> {
-                // Ideally we shouldn't need to clone, probably doesn't really matter.
-                let args = F::args(args)?;
-                let out = unsafe {
-                    (*handle)($(args.$index),+)
-                };
+            #[allow(non_snake_case)]
+            fn call(handle: *mut Self::Handle, args: Vec<ArgValue<'static>>) -> Result<RetValue>
+            where
+                F: Typed<($($t,)*), Out>
+            {
+                let ($($t,)*) = F::args(args)?;
+                let out = unsafe { (*handle)($($t),*) };
                 F::ret(out)
             }
 
-            fn drop(_: *mut Self::Handle) {}
+            fn drop(ptr: *mut Self::Handle) {
+                let bx = unsafe { Box::from_raw(ptr) };
+                std::mem::drop(bx)
+            }
         }
     }
 }
 
-to_function_instance!((A, 0),);
-to_function_instance!((A, 0), (B, 1),);
-to_function_instance!((A, 0), (B, 1), (C, 2),);
-to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),);
+impl_typed_and_to_function!(0;);
+impl_typed_and_to_function!(1; A);
+impl_typed_and_to_function!(2; A, B);
+impl_typed_and_to_function!(3; A, B, C);
+impl_typed_and_to_function!(4; A, B, C, D);
+impl_typed_and_to_function!(5; A, B, C, D, E);
 
 #[cfg(test)]
 mod tests {
-    use super::{Function, ToFunction, Typed};
-
-    fn zero() -> i32 {
-        10
-    }
+    use super::*;
 
-    fn helper<F, I, O>(f: F) -> Function
+    fn call<F, I, O>(f: F, args: Vec<ArgValue<'static>>) -> Result<RetValue>
     where
         F: ToFunction<I, O>,
         F: Typed<I, O>,
     {
-        f.to_function()
+        F::call(f.into_raw(), args)
     }
 
     #[test]
     fn test_to_function0() {
-        helper(zero);
-    }
-
-    fn one_arg(i: i32) -> i32 {
-        i
-    }
-
-    #[test]
-    fn test_to_function1() {
-        helper(one_arg);
-    }
-
-    fn two_arg(i: i32, j: i32) -> i32 {
-        i + j
+        fn zero() -> i32 {
+            10
+        }
+        let _ = zero.to_function();
+        let good = call(zero, vec![]).unwrap();
+        assert_eq!(i32::try_from(good).unwrap(), 10);
+        let bad = call(zero, vec![1.into()]).unwrap_err();
+        assert!(matches!(bad, Error::CallFailed(..)));
     }
 
     #[test]
     fn test_to_function2() {
-        helper(two_arg);
+        fn two_arg(i: i32, j: i32) -> i32 {
+            i + j
+        }
+        let good = call(two_arg, vec![3.into(), 4.into()]).unwrap();
+        assert_eq!(i32::try_from(good).unwrap(), 7);
     }
 }