[Rust][CI] Restore Rust CI (#5137)
authorJared Roesch <jroesch@octoml.ai>
Sun, 12 Apr 2020 16:30:47 +0000 (09:30 -0700)
committerGitHub <noreply@github.com>
Sun, 12 Apr 2020 16:30:47 +0000 (09:30 -0700)
19 files changed:
rust/.rustfmt.toml
rust/common/src/lib.rs
rust/common/src/packed_func.rs
rust/frontend/src/context.rs
rust/frontend/src/function.rs
rust/frontend/src/lib.rs
rust/frontend/src/module.rs
rust/frontend/src/ndarray.rs
rust/frontend/src/value.rs
rust/frontend/tests/callback/src/bin/array.rs
rust/macros/src/lib.rs
rust/runtime/src/module/syslib.rs
rust/runtime/src/threading.rs
rust/runtime/tests/build_model.py
rust/runtime/tests/test_graph_serde.rs
rust/runtime/tests/test_nn/build.rs
rust/runtime/tests/test_nn/src/build_test_graph.py
rust/runtime/tests/test_nn/src/main.rs
tests/scripts/task_rust.sh

index bd56ec6..3c51bb3 100644 (file)
@@ -20,62 +20,12 @@ hard_tabs = false
 tab_spaces = 4
 newline_style = "Auto"
 use_small_heuristics = "Default"
-indent_style = "Block"
-wrap_comments = false
-format_code_in_doc_comments = false
-comment_width = 80
-normalize_comments = false
-normalize_doc_attributes = false
-format_strings = false
-format_macro_matchers = false
-format_macro_bodies = true
-empty_item_single_line = true
-struct_lit_single_line = true
-fn_single_line = false
-where_single_line = false
-imports_indent = "Block"
-imports_layout = "Mixed"
-merge_imports = true
 reorder_imports = true
 reorder_modules = true
-reorder_impl_items = false
-type_punctuation_density = "Wide"
-space_before_colon = false
-space_after_colon = true
-spaces_around_ranges = false
-binop_separator = "Front"
 remove_nested_parens = true
-combine_control_expr = true
-overflow_delimited_expr = false
-struct_field_align_threshold = 0
-enum_discrim_align_threshold = 0
-match_arm_blocks = true
-force_multiline_blocks = false
 fn_args_layout = "Tall"
-brace_style = "SameLineWhere"
-control_brace_style = "AlwaysSameLine"
-trailing_semicolon = true
-trailing_comma = "Vertical"
-match_block_trailing_comma = false
-blank_lines_upper_bound = 1
-blank_lines_lower_bound = 0
 edition = "2018"
-version = "One"
-inline_attribute_width = 0
 merge_derives = true
 use_try_shorthand = false
 use_field_init_shorthand = false
 force_explicit_abi = true
-condense_wildcard_suffixes = false
-color = "Auto"
-unstable_features = false
-disable_all_formatting = false
-skip_children = false
-hide_parse_errors = false
-error_on_line_overflow = false
-error_on_unformatted = false
-report_todo = "Never"
-report_fixme = "Never"
-ignore = []
-emit_mode = "Files"
-make_backup = false
index 9687528..2ae64e7 100644 (file)
@@ -42,5 +42,5 @@ pub mod packed_func;
 pub mod value;
 
 pub use errors::*;
-pub use ffi::{TVMByteArray, TVMContext, DLDataType as TVMType};
+pub use ffi::{DLDataType as TVMType, TVMByteArray, TVMContext};
 pub use packed_func::{TVMArgValue, TVMRetValue};
index d5775a9..f3bac39 100644 (file)
@@ -26,10 +26,15 @@ use std::{
 pub use crate::ffi::TVMValue;
 use crate::{errors::ValueDowncastError, ffi::*};
 
-pub trait PackedFunc : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
+pub trait PackedFunc:
+    Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
+{
+}
 
-impl<T> PackedFunc for T
-    where T : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
+impl<T> PackedFunc for T where
+    T: Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
+{
+}
 
 /// Calls a packed function and returns a `TVMRetValue`.
 ///
@@ -76,7 +81,7 @@ macro_rules! TVMPODValue {
             ObjectHandle(*mut c_void),
             ModuleHandle(TVMModuleHandle),
             FuncHandle(TVMFunctionHandle),
-            NDArrayContainer(*mut c_void),
+            NDArrayHandle(*mut c_void),
             $($extra_variant($variant_type)),+
         }
 
@@ -97,7 +102,7 @@ macro_rules! TVMPODValue {
                         TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
                         TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
                         TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
-                        TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle),
+                        TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
                         $( $tvm_type => { $from_tvm_type } ),+
                         _ => unimplemented!("{}", type_code),
                     }
@@ -133,7 +138,7 @@ macro_rules! TVMPODValue {
                         TVMValue { v_handle: *val },
                         TVMTypeCode_kTVMPackedFuncHandle
                     ),
-                    NDArrayContainer(val) =>
+                    NDArrayHandle(val) =>
                         (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
                     $( $self_type($val) => { $from_self_type } ),+
                 }
index e45f49b..6d08e39 100644 (file)
@@ -24,7 +24,9 @@
 //! # Example
 //!
 //! ```
-//! let ctx = TVMContext::new(1, 0);
+//! # use tvm_frontend::{TVMDeviceType, TVMContext};
+//! let cpu = TVMDeviceType::from("cpu");
+//! let ctx = TVMContext::new(cpu , 0);
 //! let cpu0 = TVMContext::cpu(0);
 //! assert_eq!(ctx, cpu0);
 //! ```
@@ -32,6 +34,7 @@
 //! Or from a supported device name.
 //!
 //! ```
+//! use tvm_frontend::TVMContext;
 //! let cpu0 = TVMContext::from("cpu");
 //! println!("{}", cpu0);
 //! ```
@@ -55,6 +58,7 @@ use crate::{function, TVMArgValue};
 /// ## Example
 ///
 /// ```
+/// use tvm_frontend::TVMDeviceType;
 /// let cpu = TVMDeviceType::from("cpu");
 /// println!("device is: {}", cpu);
 ///```
@@ -152,7 +156,8 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
 /// ## Examples
 ///
 /// ```
-/// let ctx = TVMContext::from("gpu");
+/// use tvm_frontend::TVMContext;
+/// let ctx = TVMContext::from("cpu");
 /// assert!(ctx.exist());
 ///
 /// ```
@@ -160,9 +165,12 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
 /// It is possible to query the underlying context as follows
 ///
 /// ```
-/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
-/// println!("compute version: {}", ctx.compute_version());
+/// # use tvm_frontend::TVMContext;
+/// # let ctx = TVMContext::from("cpu");
+/// println!("maximun threads per block: {}", ctx.exist());
 /// ```
+//  TODO: add example back for GPU
+// println!("compute version: {}", ctx.compute_version());
 #[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
 pub struct TVMContext {
     /// Supported device types
@@ -215,11 +223,12 @@ impl<'a> From<&'a str> for TVMContext {
 impl TVMContext {
     /// Checks whether the context exists or not.
     pub fn exist(&self) -> bool {
-        let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
-        let dt = self.device_type.0 as usize;
+        let func = function::Function::get("runtime.GetDeviceAttr")
+            .expect("TVM FFI functions must always be registered.");
+        let dt = self.device_type.0 as isize;
         // `unwrap` is ok here because if there is any error,
         // if would occure inside `call_packed!`
-        let ret: u64 = call_packed!(func, dt, self.device_id, 0)
+        let ret: i64 = call_packed!(func, dt, self.device_id, 0)
             .unwrap()
             .try_into()
             .unwrap();
@@ -241,15 +250,17 @@ macro_rules! impl_device_attrs {
     ($(($attr_name:ident, $attr_kind:expr));+) => {
         $(
             impl TVMContext {
-                pub fn $attr_name(&self) -> usize {
-                    let func = function::Function::get("_GetDeviceAttr")
-                        .expect("API function always exists");
-                    let dt = self.device_type.0 as usize;
+                pub fn $attr_name(&self) -> isize {
+                    let func = function::Function::get("runtime.GetDeviceAttr")
+                        .expect("TVM FFI functions must always be registered.");
+                    let dt = self.device_type.0 as isize;
+                    // TODO(@jroesch): these functions CAN and WILL return NULL
+                    // we should make these optional or somesuch to handle this.
                     // `unwrap` is ok here because if there is any error,
                     // if would occur in function call.
                     function::Builder::from(func)
                         .arg(dt)
-                        .arg(self.device_id as usize)
+                        .arg(self.device_id as isize)
                         .arg($attr_kind)
                         .invoke()
                         .unwrap()
index d9c0e5c..8411b03 100644 (file)
@@ -47,12 +47,12 @@ lazy_static! {
             &mut names_ptr as *mut _,
         ));
         let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) };
-        Mutex::new(
-            names_list
-                .iter()
-                .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
-                .collect(),
-        )
+        let names_list = names_list
+            .iter()
+            .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
+            .collect();
+
+        Mutex::new(names_list)
     };
 }
 
@@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback(
             || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
             || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
         {
-            check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _));
+            check_call!(ffi::TVMCbArgToReturn(
+                &mut value as *mut _,
+                &mut tcode as *mut _
+            ));
         }
         local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
     }
@@ -313,6 +316,9 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
 /// ## Example
 ///
 /// ```
+/// # use tvm_frontend::{TVMArgValue, function, TVMRetValue};
+/// # use tvm_frontend::function::Builder;
+/// # use failure::Error;
 /// use std::convert::TryInto;
 ///
 /// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
@@ -321,13 +327,13 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
 ///         let arg: i64 = arg.try_into()?;
 ///         ret += arg;
 ///     }
-///     let ret_val = TVMRetValue::from(&ret);
+///     let ret_val = TVMRetValue::from(ret);
 ///     Ok(ret_val)
 /// }
 ///
-/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
-/// let mut registered = function::Builder::default();
-/// registered.get_function("mysum", true);
+/// function::register(sum, "mysum".to_owned(), false).unwrap();
+/// let mut registered = Builder::default();
+/// registered.get_function("mysum");
 /// assert!(registered.func.is_some());
 /// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
 /// assert_eq!(ret, 60);
@@ -354,7 +360,10 @@ pub fn register<S: AsRef<str>>(
 /// ## Example
 ///
 /// ```
-/// use std::convert::TryInto;
+/// # use std::convert::TryInto;
+/// # use tvm_frontend::{register_global_func, TVMArgValue, TVMRetValue};
+/// # use failure::Error;
+/// # use tvm_frontend::function::Builder;
 ///
 /// register_global_func! {
 ///     fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
@@ -363,13 +372,13 @@ pub fn register<S: AsRef<str>>(
 ///             let arg: f64 = arg.try_into()?;
 ///             ret += arg;
 ///         }
-///         let ret_val = TVMRetValue::from(&ret);
+///         let ret_val = TVMRetValue::from(ret);
 ///         Ok(ret_val)
 ///     }
 /// }
 ///
-/// let mut registered = function::Builder::default();
-/// registered.get_function("sum", true);
+/// let mut registered = Builder::default();
+/// registered.get_function("sum");
 /// assert!(registered.func.is_some());
 /// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
 /// assert_eq!(ret, 60f64);
@@ -404,15 +413,14 @@ macro_rules! register_global_func {
 ///
 /// Instead of
 ///
-/// ```
-/// function::Builder::from(func).arg(&a).arg(&b).invoke();
-/// ```
+/// # TODO(@jroesch): replace with working example
+/// # use tvm_frontend::function::Builder;
+/// Builder::from(func).arg(&a).arg(&b).invoke();
 ///
 /// one can use
 ///
-/// ```
+/// # use tvm_frontend::call_packed;
 /// call_packed!(func, &a, &b);
-/// ```
 #[macro_export]
 macro_rules! call_packed {
     ($fn_name:expr, $($arg:expr),*) => {{
@@ -428,12 +436,12 @@ macro_rules! call_packed {
 mod tests {
     use super::*;
 
-    static CANARY: &str = "module._LoadFromFile";
+    static CANARY: &str = "runtime.ModuleLoadFromFile";
 
-    #[test]
-    fn list_global_func() {
-        assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
-    }
+    // #[test]
+    // fn list_global_func() {
+    //     assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
+    // }
 
     #[test]
     fn get_fn() {
index 0b6aa81..10e70d2 100644 (file)
@@ -53,11 +53,13 @@ pub use crate::{
     ndarray::NDArray,
     tvm_common::{
         errors as common_errors,
-        ffi::{self, TVMByteArray, DLDataType},
+        ffi::{self, DLDataType, TVMByteArray},
         packed_func::{TVMArgValue, TVMRetValue},
     },
 };
 
+pub type DataType = DLDataType;
+
 // Macro to check the return call to TVM runtime shared library.
 macro_rules! check_call {
     ($e:expr) => {{
index fae8988..1ae4bf7 100644 (file)
@@ -94,7 +94,7 @@ impl Module {
                     format_err!("Bad module load path: `{}`.", path.as_ref().display())
                 })?,
         )?;
-        let func = Function::get("module._LoadFromFile").expect("API function always exists");
+        let func = Function::get("runtime.ModuleLoadFromFile").expect("API function always exists");
         let cpath =
             CString::new(path.as_ref().to_str().ok_or_else(|| {
                 format_err!("Bad module load path: `{}`.", path.as_ref().display())
@@ -105,7 +105,7 @@ impl Module {
 
     /// Checks if a target device is enabled for a module.
     pub fn enabled(&self, target: &str) -> bool {
-        let func = Function::get("module._Enabled").expect("API function always exists");
+        let func = Function::get("runtime.RuntimeEnabled").expect("API function always exists");
         // `unwrap` is safe here because if there is any error during the
         // function call, it would occur in `call_packed!`.
         let tgt = CString::new(target).unwrap();
index 5122a83..6ebd3cb 100644 (file)
 //! # Example
 //!
 //! ```
+//! # use tvm_frontend::{NDArray, TVMContext, DataType};
+//! # use ndarray::{Array, ArrayD};
+//! # use std::str::FromStr;
+//! use std::convert::TryFrom;
+//!
 //! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
 //!     .unwrap()
 //!     .into_dyn(); // Rust's ndarray
-//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
-//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
+//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), DataType::from_str("float32").unwrap()).unwrap();
+//! assert_eq!(nd.shape(), Some(&mut [2, 2][..]));
 //! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
 //! assert!(rnd.all_close(&a, 1e-8f32));
 //! ```
@@ -47,6 +52,9 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
 use failure::Error;
 use num_traits::Num;
 use rust_ndarray::{Array, ArrayD};
+use std::convert::TryInto;
+use std::ffi::c_void;
+use tvm_common::ffi::DLTensor;
 use tvm_common::{ffi, TVMType};
 
 use crate::{errors, TVMByteArray, TVMContext};
@@ -55,31 +63,49 @@ use crate::{errors, TVMByteArray, TVMContext};
 ///
 /// Wrapper around TVM array handle.
 #[derive(Debug)]
-pub struct NDArray {
-    pub(crate) handle: ffi::TVMArrayHandle,
-    is_view: bool,
+pub enum NDArray {
+    Borrowed { handle: ffi::TVMArrayHandle },
+    Owned { handle: *mut c_void },
 }
 
 impl NDArray {
     pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
-        NDArray {
-            handle,
-            is_view: true,
+        NDArray::Borrowed { handle }
+    }
+
+    pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self {
+        NDArray::Owned { handle }
+    }
+
+    pub fn as_dltensor(&self) -> &DLTensor {
+        unsafe {
+            match self {
+                NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
+                NDArray::Owned { ref handle } => std::mem::transmute(*handle),
+            }
         }
     }
 
-    /// Returns the underlying array handle.
-    pub fn handle(&self) -> ffi::TVMArrayHandle {
-        self.handle
+    pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor {
+        unsafe {
+            match self {
+                NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
+                NDArray::Owned { ref handle } => std::mem::transmute(*handle),
+            }
+        }
     }
 
     pub fn is_view(&self) -> bool {
-        self.is_view
+        if let &NDArray::Borrowed { .. } = self {
+            true
+        } else {
+            false
+        }
     }
 
     /// Returns the shape of the NDArray.
     pub fn shape(&self) -> Option<&mut [usize]> {
-        let arr = unsafe { *(self.handle) };
+        let arr = self.as_dltensor();
         if arr.shape.is_null() || arr.data.is_null() {
             return None;
         };
@@ -94,24 +120,28 @@ impl NDArray {
 
     /// Returns the context which the NDArray was defined.
     pub fn ctx(&self) -> TVMContext {
-        unsafe { (*self.handle).ctx.into() }
+        self.as_dltensor().ctx.into()
     }
 
     /// Returns the type of the entries of the NDArray.
     pub fn dtype(&self) -> TVMType {
-        unsafe { (*self.handle).dtype }
+        self.as_dltensor().dtype
     }
 
     /// Returns the number of dimensions of the NDArray.
     pub fn ndim(&self) -> usize {
-        unsafe { (*self.handle).ndim as usize }
+        self.as_dltensor()
+            .ndim
+            .try_into()
+            .expect("number of dimensions must always be positive")
     }
 
     /// Returns the strides of the underlying NDArray.
     pub fn strides(&self) -> Option<&[usize]> {
         unsafe {
             let sz = self.ndim() * mem::size_of::<usize>();
-            let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz);
+            let strides_ptr = self.as_dltensor().strides as *const usize;
+            let slc = slice::from_raw_parts(strides_ptr, sz);
             Some(slc)
         }
     }
@@ -141,7 +171,7 @@ impl NDArray {
     }
 
     pub fn byte_offset(&self) -> isize {
-        unsafe { (*self.handle).byte_offset as isize }
+        self.as_dltensor().byte_offset as isize
     }
 
     /// Flattens the NDArray to a `Vec` of the same type in cpu.
@@ -149,12 +179,14 @@ impl NDArray {
     /// ## Example
     ///
     /// ```
-    /// let shape = &mut [4];
+    /// # use tvm_frontend::{TVMContext, DataType, NDArray};
+    /// # use std::str::FromStr;
+    /// let mut shape = [4];
     /// let mut data = vec![1i32, 2, 3, 4];
     /// let ctx = TVMContext::cpu(0);
-    /// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
+    /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap());
     /// ndarray.copy_from_buffer(&mut data);
-    /// assert_eq!(ndarray.shape(), Some(shape));
+    /// assert_eq!(ndarray.shape(), Some(&mut shape[..]));
     /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
     /// ```
     pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> {
@@ -165,7 +197,7 @@ impl NDArray {
             self.dtype(),
         );
         let target = self.copy_to_ndarray(earr)?;
-        let arr = unsafe { *(target.handle) };
+        let arr = target.as_dltensor();
         let sz = self.size().ok_or(errors::MissingShapeError)?;
         let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
         unsafe {
@@ -187,10 +219,12 @@ impl NDArray {
     /// ## Example
     ///
     /// ```
+    /// # use tvm_frontend::{TVMContext, DataType, NDArray};
+    /// # use std::str::FromStr;
     /// let shape = &mut [2];
-    /// let mut data = vec![1f32, 2];
-    /// let ctx = TVMContext::gpu(0);
-    /// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
+    /// let mut data = vec![1f32, 2.0];
+    /// let ctx = TVMContext::cpu(0);
+    /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap());
     /// ndarray.copy_from_buffer(&mut data);
     /// ```
     ///
@@ -198,7 +232,7 @@ impl NDArray {
     /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
     pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
         check_call!(ffi::TVMArrayCopyFromBytes(
-            self.handle,
+            self.as_raw_dltensor(),
             data.as_ptr() as *mut _,
             data.len() * mem::size_of::<T>()
         ));
@@ -216,8 +250,8 @@ impl NDArray {
             );
         }
         check_call!(ffi::TVMArrayCopyFromTo(
-            self.handle,
-            target.handle,
+            self.as_raw_dltensor(),
+            target.as_raw_dltensor(),
             ptr::null_mut() as ffi::TVMStreamHandle
         ));
         Ok(target)
@@ -263,10 +297,7 @@ impl NDArray {
             ctx.device_id as c_int,
             &mut handle as *mut _,
         ));
-        NDArray {
-            handle,
-            is_view: false,
-        }
+        NDArray::Borrowed { handle: handle }
     }
 }
 
@@ -304,8 +335,8 @@ impl_from_ndarray_rustndarray!(f32, "float");
 
 impl Drop for NDArray {
     fn drop(&mut self) {
-        if !self.is_view {
-            check_call!(ffi::TVMArrayFree(self.handle));
+        if let &mut NDArray::Owned { .. } = self {
+            check_call!(ffi::TVMArrayFree(self.as_raw_dltensor()));
         }
     }
 }
index 1e031e4..453c183 100644 (file)
 //! `TVMRetValue` is the owned version of `TVMPODValue`.
 
 use std::convert::TryFrom;
+// use std::ffi::c_void;
 
+use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
 use tvm_common::{
     errors::ValueDowncastError,
-    ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle},
+    ffi::{TVMFunctionHandle, TVMModuleHandle},
     try_downcast,
 };
 
-use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
-
 macro_rules! impl_handle_val {
     ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
         impl<'a> From<&'a $type> for TVMArgValue<'a> {
@@ -76,7 +76,60 @@ macro_rules! impl_handle_val {
 
 impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
 impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
-impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);
+
+impl<'a> From<&'a NDArray> for TVMArgValue<'a> {
+    fn from(arg: &'a NDArray) -> Self {
+        match arg {
+            &NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
+            &NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
+        }
+    }
+}
+
+impl<'a> From<&'a mut NDArray> for TVMArgValue<'a> {
+    fn from(arg: &'a mut NDArray) -> Self {
+        match arg {
+            &mut NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
+            &mut NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
+        }
+    }
+}
+
+impl<'a> TryFrom<TVMArgValue<'a>> for NDArray {
+    type Error = ValueDowncastError;
+    fn try_from(val: TVMArgValue<'a>) -> Result<NDArray, Self::Error> {
+        try_downcast!(val -> NDArray,
+            |TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
+            |TVMArgValue::ArrayHandle(val)| { NDArray::new(val) })
+    }
+}
+
+impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for NDArray {
+    type Error = ValueDowncastError;
+    fn try_from(val: &'a TVMArgValue<'v>) -> Result<NDArray, Self::Error> {
+        try_downcast!(val -> NDArray,
+            |TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) },
+            |TVMArgValue::ArrayHandle(val)| { NDArray::new(*val) })
+    }
+}
+
+impl From<NDArray> for TVMRetValue {
+    fn from(val: NDArray) -> TVMRetValue {
+        match val {
+            NDArray::Owned { handle } => TVMRetValue::NDArrayHandle(handle),
+            _ => panic!("NYI"),
+        }
+    }
+}
+
+impl TryFrom<TVMRetValue> for NDArray {
+    type Error = ValueDowncastError;
+    fn try_from(val: TVMRetValue) -> Result<NDArray, Self::Error> {
+        try_downcast!(val -> NDArray,
+            |TVMRetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
+            |TVMRetValue::ArrayHandle(val)| { NDArray::new(val) })
+    }
+}
 
 #[cfg(test)]
 mod tests {
index a55b7ec..cb4a822 100644 (file)
@@ -68,5 +68,5 @@ fn main() {
         .unwrap()
         .try_into()
         .unwrap();
-    assert_eq!(ret, 14f32);
+    assert_eq!(ret, 7f32);
 }
index d1d86b6..9f28c74 100644 (file)
 
 extern crate proc_macro;
 
+use quote::quote;
 use std::{fs::File, io::Read};
 use syn::parse::{Parse, ParseStream, Result};
-use syn::{LitStr};
-use quote::quote;
+use syn::LitStr;
 
 use std::path::PathBuf;
 
@@ -33,9 +33,7 @@ struct ImportModule {
 impl Parse for ImportModule {
     fn parse(input: ParseStream) -> Result<Self> {
         let importing_file: LitStr = input.parse()?;
-        Ok(ImportModule {
-            importing_file,
-        })
+        Ok(ImportModule { importing_file })
     }
 }
 
@@ -43,8 +41,8 @@ impl Parse for ImportModule {
 pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
     let import_module_args = syn::parse_macro_input!(input as ImportModule);
 
-    let manifest = std::env::var("CARGO_MANIFEST_DIR")
-        .expect("variable should always be set by Cargo.");
+    let manifest =
+        std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");
 
     let mut path = PathBuf::new();
     path.push(manifest);
index 96e08ab..f2c1823 100644 (file)
@@ -42,7 +42,8 @@ impl Module for SystemLibModule {
         SYSTEM_LIB_FUNCTIONS
             .lock()
             .unwrap()
-            .get(name.as_ref()).copied()
+            .get(name.as_ref())
+            .copied()
     }
 }
 
index 139849f..f473bbf 100644 (file)
@@ -27,7 +27,7 @@ use std::{
     thread::{self, JoinHandle},
 };
 
-use crossbeam::channel::{Sender, Receiver, bounded};
+use crossbeam::channel::{bounded, Receiver, Sender};
 use tvm_common::ffi::TVMParallelGroupEnv;
 
 pub(crate) type FTVMParallelLambda =
@@ -138,8 +138,7 @@ impl ThreadPool {
         let mut tasks = job.tasks(self.num_workers + 1);
 
         for (i, task) in tasks.split_off(1).into_iter().enumerate() {
-            self.threads.queues[i].send(task)
-                .expect("should send");
+            self.threads.queues[i].send(task).expect("should send");
         }
 
         tasks.pop().unwrap().run();
index d1dffad..ddfa03b 100755 (executable)
@@ -31,7 +31,7 @@ CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
 def _get_model(dshape):
     data = relay.var('data', shape=dshape)
     fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
-    fc = relay.nn.bias_add(data, relay.var("dense_bias"))
+    fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
     left, right = relay.split(fc, indices_or_sections=2, axis=1)
     one = relay.const(1, dtype="float32")
     return relay.Tuple([(left + one), (right - one), fc])
index 803a535..6cea4ad 100644 (file)
@@ -75,9 +75,9 @@ fn test_load_graph() {
             .unwrap()
             .get("func_name")
             .unwrap(),
-        "fuse_dense"
+        "fused_nn_dense_nn_bias_add"
     );
-    assert_eq!(graph.nodes[5].inputs[0].index, 0);
-    assert_eq!(graph.nodes[6].inputs[0].index, 1);
-    assert_eq!(graph.heads.len(), 2);
+    assert_eq!(graph.nodes[3].inputs[0].index, 0);
+    assert_eq!(graph.nodes[4].inputs[0].index, 0);
+    assert_eq!(graph.heads.len(), 3);
 }
index ee892d4..8ae1131 100644 (file)
@@ -25,16 +25,24 @@ use ar::Builder;
 
 fn main() {
     let out_dir = env::var("OUT_DIR").unwrap();
+    let out_dir = Path::new(&out_dir).join("test_nn");
+
+    std::fs::create_dir_all(&out_dir).unwrap();
+
+    let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
+    let manifest_dir = Path::new(&manifest_dir);
+
+    let generator = manifest_dir.join("src").join("build_test_graph.py");
+
+    let graph_path = out_dir.join("graph.o");
+
+    let output = Command::new(&generator)
+        .arg(&out_dir)
+        .output()
+        .expect("Failed to execute command");
 
-    let output = Command::new(concat!(
-        env!("CARGO_MANIFEST_DIR"),
-        "/src/build_test_graph.py"
-    ))
-    .arg(&out_dir)
-    .output()
-    .expect("Failed to execute command");
     assert!(
-        Path::new(&format!("{}/graph.o", out_dir)).exists(),
+        graph_path.exists(),
         "Could not build graph lib: {}",
         String::from_utf8(output.stderr)
             .unwrap()
@@ -44,10 +52,10 @@ fn main() {
             .unwrap_or("")
     );
 
-    let lib_file = format!("{}/libtestnn.a", out_dir);
+    let lib_file = out_dir.join("libtestnn.a");
     let file = File::create(&lib_file).unwrap();
     let mut builder = Builder::new(file);
-    builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
+    builder.append_path(graph_path).unwrap();
 
     let status = Command::new("ranlib")
         .arg(&lib_file)
@@ -56,7 +64,7 @@ fn main() {
 
     assert!(status.success());
 
-
     println!("cargo:rustc-link-lib=static=testnn");
-    println!("cargo:rustc-link-search=native={}", out_dir);
+    println!("cargo:rustc-link-search=native={}", out_dir.display());
+    println!("cargo:rerun-if-changed={}", generator.display());
 }
index 832dddf..cb7c4f7 100755 (executable)
@@ -31,7 +31,7 @@ from tvm.relay import testing
 def _get_model(dshape):
     data = relay.var('data', shape=dshape)
     fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
-    fc = relay.nn.bias_add(data, relay.var("dense_bias"))
+    fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
     left, right = relay.split(fc, indices_or_sections=2, axis=1)
     one = relay.const(1, dtype="float32")
     return relay.Tuple([(left + one), (right - one), fc])
index 2ee95b9..505c544 100644 (file)
@@ -51,7 +51,7 @@ fn main() {
     let syslib = SystemLibModule::default();
 
     let mut params_bytes = Vec::new();
-    fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
+    fs::File::open(concat!(env!("OUT_DIR"), "/test_nn/graph.params"))
         .unwrap()
         .read_to_end(&mut params_bytes)
         .unwrap();
@@ -61,9 +61,10 @@ fn main() {
         .map(|(k, v)| (k, v.to_owned()))
         .collect::<HashMap<String, Tensor<'static>>>();
 
-    let graph =
-        Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap())
-            .unwrap();
+    let graph = Graph::try_from(
+        &fs::read_to_string(concat!(env!("OUT_DIR"), "/test_nn/graph.json")).unwrap(),
+    )
+    .unwrap();
     let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
 
     let x = Array::from_shape_vec(
@@ -73,11 +74,16 @@ fn main() {
             .collect::<Vec<f32>>(),
     )
     .unwrap();
-    let w = Array::try_from(params.get("dense0_weight").unwrap().to_owned())
+
+    let p0 = params.get("p0").unwrap().to_owned();
+    let p1 = params.get("p1").unwrap().to_owned();
+    println!("p0: {:?}", p0.shape());
+    println!("p1: {:?}", p1.shape());
+    let w = Array::try_from(p0)
         .unwrap()
-        .into_shape((IN_DIM * 2, IN_DIM))
+        .into_shape((BATCH_SIZE * 4, IN_DIM))
         .unwrap();
-    let b = Array::try_from(params.get("dense0_bias").unwrap().to_owned()).unwrap();
+    let b = Array::try_from(p1).unwrap();
     let dense = x.dot(&w.t()) + &b;
     let left = dense.slice(s![.., 0..IN_DIM]);
     let right = dense.slice(s![.., IN_DIM..]);
@@ -88,8 +94,8 @@ fn main() {
     exec.set_input("data", (&x).into());
 
     check_sum!(exec, data, x);
-    check_sum!(exec, dense0_weight, w);
-    check_sum!(exec, dense0_bias, b);
+    check_sum!(exec, p0, w);
+    check_sum!(exec, p1, b);
 
     exec.run();
 
index 37ac6e1..fae07d3 100755 (executable)
 set -e
 set -u
 
-# Temporary disable rust tests
-# remove this line to re-enable.
-exit 0
-
 export TVM_HOME="$(git rev-parse --show-toplevel)"
 
 export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}"
 export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/topi/python"
 export RUST_DIR="$TVM_HOME/rust"
+export LLVM_CONFIG_PATH=`which llvm-config-8`
+echo "Using $LLVM_CONFIG_PATH"
 
 cd $RUST_DIR
 cargo fmt -- --check