[Rust] Unify types between bindings and pure Rust impl (#2616)
authorNick Hynes <nhynes@berkeley.edu>
Wed, 3 Apr 2019 00:24:21 +0000 (17:24 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 3 Apr 2019 00:24:21 +0000 (17:24 -0700)
47 files changed:
rust/.gitignore [new file with mode: 0644]
rust/common/.gitignore [deleted file]
rust/common/Cargo.toml
rust/common/build.rs [new file with mode: 0644]
rust/common/src/array.rs [new file with mode: 0644]
rust/common/src/c_runtime_api.rs [deleted file]
rust/common/src/errors.rs
rust/common/src/lib.rs
rust/common/src/packed_func.rs [new file with mode: 0644]
rust/common/src/ty.rs
rust/common/src/value.rs
rust/common/tvm-sys/Cargo.toml [deleted file]
rust/common/tvm-sys/build.rs [deleted file]
rust/common/tvm-sys/src/lib.rs [deleted file]
rust/frontend/Cargo.toml
rust/frontend/examples/resnet/src/main.rs
rust/frontend/src/bytearray.rs
rust/frontend/src/context.rs
rust/frontend/src/errors.rs
rust/frontend/src/function.rs
rust/frontend/src/lib.rs
rust/frontend/src/module.rs
rust/frontend/src/ndarray.rs
rust/frontend/src/ty.rs [deleted file]
rust/frontend/src/value.rs
rust/frontend/tests/basics/src/main.rs
rust/frontend/tests/callback/src/bin/array.rs
rust/frontend/tests/callback/src/bin/error.rs
rust/frontend/tests/callback/src/bin/float.rs
rust/frontend/tests/callback/src/bin/int.rs
rust/frontend/tests/callback/src/bin/string.rs
rust/runtime/.gitignore [deleted file]
rust/runtime/Cargo.toml
rust/runtime/src/allocator.rs
rust/runtime/src/array.rs
rust/runtime/src/errors.rs
rust/runtime/src/graph.rs
rust/runtime/src/lib.rs
rust/runtime/src/module.rs
rust/runtime/src/packed_func.rs [deleted file]
rust/runtime/src/sgx.rs
rust/runtime/src/threading.rs
rust/runtime/src/workspace.rs
rust/runtime/tests/test_nnvm/Cargo.toml
rust/runtime/tests/test_tvm_basic/Cargo.toml
rust/runtime/tests/test_tvm_basic/src/main.rs
tests/scripts/task_rust.sh

diff --git a/rust/.gitignore b/rust/.gitignore
new file mode 100644 (file)
index 0000000..0cc6606
--- /dev/null
@@ -0,0 +1,4 @@
+target/
+*.rs.bk
+Cargo.lock
+c_runtime_api.rs
diff --git a/rust/common/.gitignore b/rust/common/.gitignore
deleted file mode 100644 (file)
index 84c2ae9..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-target
-**/*.rs.bk
-Cargo.lock
-/tvm-sys/src/bindgen.rs
index bcba5ad..5d21ee5 100644 (file)
@@ -5,9 +5,11 @@ authors = ["TVM Contributors"]
 license = "Apache-2.0"
 
 [features]
-runtime = []
-frontend = ["tvm-sys"]
+bindings = []
 
 [dependencies]
-error-chain = { version = "0.12.0", default-features = false }
-tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
+failure = "0.1.5"
+ndarray = "0.12.1"
+
+[build-dependencies]
+bindgen = "0.37.4"
diff --git a/rust/common/build.rs b/rust/common/build.rs
new file mode 100644 (file)
index 0000000..f07e71f
--- /dev/null
@@ -0,0 +1,31 @@
+extern crate bindgen;
+
+use std::path::PathBuf;
+
+fn main() {
+    if cfg!(feature = "bindings") {
+        println!("cargo:rerun-if-env-changed=TVM_HOME");
+        println!("cargo:rustc-link-lib=dylib=tvm_runtime");
+        println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
+    }
+
+    // @see rust-bindgen#550 for `blacklist_type`
+    bindgen::Builder::default()
+        .header(format!(
+            "{}/include/tvm/runtime/c_runtime_api.h",
+            env!("TVM_HOME")
+        ))
+        .header(format!(
+            "{}/include/tvm/runtime/c_backend_api.h",
+            env!("TVM_HOME")
+        ))
+        .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
+        .blacklist_type("max_align_t")
+        .layout_tests(false)
+        .derive_partialeq(true)
+        .derive_eq(true)
+        .generate()
+        .expect("unable to generate bindings")
+        .write_to_file(PathBuf::from("src/c_runtime_api.rs"))
+        .expect("can not write the bindings!");
+}
diff --git a/rust/common/src/array.rs b/rust/common/src/array.rs
new file mode 100644 (file)
index 0000000..e7b7585
--- /dev/null
@@ -0,0 +1,128 @@
+use std::{
+    any::TypeId,
+    mem,
+    os::raw::{c_int, c_void},
+};
+
+use crate::ffi::{
+    DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
+    DLDeviceType_kDLCPU, DLTensor,
+};
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct DataType {
+    pub code: usize,
+    pub bits: usize,
+    pub lanes: usize,
+}
+
+impl DataType {
+    /// Returns the number of bytes occupied by an element of this `DataType`.
+    pub fn itemsize(&self) -> usize {
+        (self.bits * self.lanes) >> 3
+    }
+
+    /// Returns whether this `DataType` represents primitive type `T`.
+    pub fn is_type<T: 'static>(&self) -> bool {
+        if self.lanes != 1 {
+            return false;
+        }
+        let typ = TypeId::of::<T>();
+        (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
+            || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
+            || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
+            || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
+            || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
+            || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
+    }
+
+    pub fn code(&self) -> usize {
+        self.code
+    }
+
+    pub fn bits(&self) -> usize {
+        self.bits
+    }
+
+    pub fn lanes(&self) -> usize {
+        self.lanes
+    }
+}
+
+impl<'a> From<&'a DataType> for DLDataType {
+    fn from(dtype: &'a DataType) -> Self {
+        Self {
+            code: dtype.code as u8,
+            bits: dtype.bits as u8,
+            lanes: dtype.lanes as u16,
+        }
+    }
+}
+
+impl From<DLDataType> for DataType {
+    fn from(dtype: DLDataType) -> Self {
+        Self {
+            code: dtype.code as usize,
+            bits: dtype.bits as usize,
+            lanes: dtype.lanes as usize,
+        }
+    }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub struct TVMContext {
+    pub device_type: usize,
+    pub device_id: usize,
+}
+
+impl<'a> From<&'a TVMContext> for DLContext {
+    fn from(ctx: &'a TVMContext) -> Self {
+        Self {
+            device_type: ctx.device_type as u32,
+            device_id: ctx.device_id as i32,
+        }
+    }
+}
+
+impl Default for TVMContext {
+    fn default() -> Self {
+        Self {
+            device_type: DLDeviceType_kDLCPU as usize,
+            device_id: 0,
+        }
+    }
+}
+
+/// `From` conversions to `DLTensor` for `ndarray::Array`.
+/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
+macro_rules! impl_dltensor_from_ndarray {
+    ($type:ty, $typecode:expr) => {
+        impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
+            fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
+                DLTensor {
+                    data: arr.as_mut_ptr() as *mut c_void,
+                    ctx: DLContext {
+                        device_type: DLDeviceType_kDLCPU,
+                        device_id: 0,
+                    },
+                    ndim: arr.ndim() as c_int,
+                    dtype: DLDataType {
+                        code: $typecode as u8,
+                        bits: 8 * mem::size_of::<$type>() as u8,
+                        lanes: 1,
+                    },
+                    shape: arr.shape().as_ptr() as *const i64 as *mut i64,
+                    strides: arr.strides().as_ptr() as *const isize as *mut i64,
+                    byte_offset: 0,
+                }
+            }
+        }
+    };
+}
+
+impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
+impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
diff --git a/rust/common/src/c_runtime_api.rs b/rust/common/src/c_runtime_api.rs
deleted file mode 100644 (file)
index 6facf9c..0000000
+++ /dev/null
@@ -1,770 +0,0 @@
-/* automatically generated by rust-bindgen for TVM revision 6292c78 */
-
-pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
-pub const DLPACK_VERSION: u32 = 8;
-pub const _STDINT_H: u32 = 1;
-pub const _FEATURES_H: u32 = 1;
-pub const _DEFAULT_SOURCE: u32 = 1;
-pub const __USE_ISOC11: u32 = 1;
-pub const __USE_ISOC99: u32 = 1;
-pub const __USE_ISOC95: u32 = 1;
-pub const __USE_POSIX_IMPLICITLY: u32 = 1;
-pub const _POSIX_SOURCE: u32 = 1;
-pub const _POSIX_C_SOURCE: u32 = 200809;
-pub const __USE_POSIX: u32 = 1;
-pub const __USE_POSIX2: u32 = 1;
-pub const __USE_POSIX199309: u32 = 1;
-pub const __USE_POSIX199506: u32 = 1;
-pub const __USE_XOPEN2K: u32 = 1;
-pub const __USE_XOPEN2K8: u32 = 1;
-pub const _ATFILE_SOURCE: u32 = 1;
-pub const __USE_MISC: u32 = 1;
-pub const __USE_ATFILE: u32 = 1;
-pub const __USE_FORTIFY_LEVEL: u32 = 0;
-pub const _STDC_PREDEF_H: u32 = 1;
-pub const __STDC_IEC_559__: u32 = 1;
-pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
-pub const __STDC_ISO_10646__: u32 = 201505;
-pub const __STDC_NO_THREADS__: u32 = 1;
-pub const __GNU_LIBRARY__: u32 = 6;
-pub const __GLIBC__: u32 = 2;
-pub const __GLIBC_MINOR__: u32 = 23;
-pub const _SYS_CDEFS_H: u32 = 1;
-pub const __WORDSIZE: u32 = 64;
-pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
-pub const __SYSCALL_WORDSIZE: u32 = 64;
-pub const _BITS_WCHAR_H: u32 = 1;
-pub const INT8_MIN: i32 = -128;
-pub const INT16_MIN: i32 = -32768;
-pub const INT32_MIN: i32 = -2147483648;
-pub const INT8_MAX: u32 = 127;
-pub const INT16_MAX: u32 = 32767;
-pub const INT32_MAX: u32 = 2147483647;
-pub const UINT8_MAX: u32 = 255;
-pub const UINT16_MAX: u32 = 65535;
-pub const UINT32_MAX: u32 = 4294967295;
-pub const INT_LEAST8_MIN: i32 = -128;
-pub const INT_LEAST16_MIN: i32 = -32768;
-pub const INT_LEAST32_MIN: i32 = -2147483648;
-pub const INT_LEAST8_MAX: u32 = 127;
-pub const INT_LEAST16_MAX: u32 = 32767;
-pub const INT_LEAST32_MAX: u32 = 2147483647;
-pub const UINT_LEAST8_MAX: u32 = 255;
-pub const UINT_LEAST16_MAX: u32 = 65535;
-pub const UINT_LEAST32_MAX: u32 = 4294967295;
-pub const INT_FAST8_MIN: i32 = -128;
-pub const INT_FAST16_MIN: i64 = -9223372036854775808;
-pub const INT_FAST32_MIN: i64 = -9223372036854775808;
-pub const INT_FAST8_MAX: u32 = 127;
-pub const INT_FAST16_MAX: u64 = 9223372036854775807;
-pub const INT_FAST32_MAX: u64 = 9223372036854775807;
-pub const UINT_FAST8_MAX: u32 = 255;
-pub const UINT_FAST16_MAX: i32 = -1;
-pub const UINT_FAST32_MAX: i32 = -1;
-pub const INTPTR_MIN: i64 = -9223372036854775808;
-pub const INTPTR_MAX: u64 = 9223372036854775807;
-pub const UINTPTR_MAX: i32 = -1;
-pub const PTRDIFF_MIN: i64 = -9223372036854775808;
-pub const PTRDIFF_MAX: u64 = 9223372036854775807;
-pub const SIG_ATOMIC_MIN: i32 = -2147483648;
-pub const SIG_ATOMIC_MAX: u32 = 2147483647;
-pub const SIZE_MAX: i32 = -1;
-pub const WINT_MIN: u32 = 0;
-pub const WINT_MAX: u32 = 4294967295;
-pub type int_least8_t = ::std::os::raw::c_schar;
-pub type int_least16_t = ::std::os::raw::c_short;
-pub type int_least32_t = ::std::os::raw::c_int;
-pub type int_least64_t = ::std::os::raw::c_long;
-pub type uint_least8_t = ::std::os::raw::c_uchar;
-pub type uint_least16_t = ::std::os::raw::c_ushort;
-pub type uint_least32_t = ::std::os::raw::c_uint;
-pub type uint_least64_t = ::std::os::raw::c_ulong;
-pub type int_fast8_t = ::std::os::raw::c_schar;
-pub type int_fast16_t = ::std::os::raw::c_long;
-pub type int_fast32_t = ::std::os::raw::c_long;
-pub type int_fast64_t = ::std::os::raw::c_long;
-pub type uint_fast8_t = ::std::os::raw::c_uchar;
-pub type uint_fast16_t = ::std::os::raw::c_ulong;
-pub type uint_fast32_t = ::std::os::raw::c_ulong;
-pub type uint_fast64_t = ::std::os::raw::c_ulong;
-pub type intmax_t = ::std::os::raw::c_long;
-pub type uintmax_t = ::std::os::raw::c_ulong;
-pub type wchar_t = ::std::os::raw::c_int;
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct max_align_t {
-  pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
-  pub __bindgen_padding_0: u64,
-  pub __clang_max_align_nonce2: f64,
-}
-pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
-pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
-pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
-pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
-pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
-pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
-pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
-/// \brief The device type in DLContext.
-pub type DLDeviceType = u32;
-/// \brief A Device context for Tensor and operator.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLContext {
-  /// \brief The device type used in the device.
-  pub device_type: DLDeviceType,
-  /// \brief The device index
-  pub device_id: ::std::os::raw::c_int,
-}
-pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
-pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
-pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
-/// \brief The type code options DLDataType.
-pub type DLDataTypeCode = u32;
-/// \brief The data type the tensor can hold.
-///
-/// Examples
-/// - float: type_code = 2, bits = 32, lanes=1
-/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
-/// - int8: type_code = 0, bits = 8, lanes=1
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLDataType {
-  /// \brief Type code of base types.
-  /// We keep it uint8_t instead of DLDataTypeCode for minimal memory
-  /// footprint, but the value should be one of DLDataTypeCode enum values.
-  ///
-  pub code: u8,
-  /// \brief Number of bits, common choices are 8, 16, 32.
-  pub bits: u8,
-  /// \brief Number of lanes in the type, used for vector types.
-  pub lanes: u16,
-}
-/// \brief Plain C Tensor object, does not manage memory.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLTensor {
-  /// \brief The opaque data pointer points to the allocated data.
-  /// This will be CUDA device pointer or cl_mem handle in OpenCL.
-  /// This pointer is always aligns to 256 bytes as in CUDA.
-  pub data: *mut ::std::os::raw::c_void,
-  /// \brief The device context of the tensor
-  pub ctx: DLContext,
-  /// \brief Number of dimensions
-  pub ndim: ::std::os::raw::c_int,
-  /// \brief The data type of the pointer
-  pub dtype: DLDataType,
-  /// \brief The shape of the tensor
-  pub shape: *mut i64,
-  /// \brief strides of the tensor,
-  /// can be NULL, indicating tensor is compact.
-  pub strides: *mut i64,
-  /// \brief The offset in bytes to the beginning pointer to data
-  pub byte_offset: u64,
-}
-/// \brief C Tensor object, manage memory of DLTensor. This data structure is
-/// intended to faciliate the borrowing of DLTensor by another framework. It is
-/// not meant to transfer the tensor. When the borrowing framework doesn't need
-/// the tensor, it should call the deleter to notify the host that the resource
-/// is no longer needed.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLManagedTensor {
-  /// \brief DLTensor which is being memory managed
-  pub dl_tensor: DLTensor,
-  /// \brief the context of the original host framework of DLManagedTensor in
-  /// which DLManagedTensor is used in the framework. It can also be NULL.
-  pub manager_ctx: *mut ::std::os::raw::c_void,
-  /// \brief Destructor signature void (*)(void*) - this should be called
-  /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
-  /// if there is no way for the caller to provide a reasonable destructor.
-  pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
-}
-/// \brief type of array index.
-pub type tvm_index_t = i64;
-pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
-pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
-pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
-pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
-pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
-/// \brief Extension device types in TVM
-pub type TVMDeviceExtType = u32;
-pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
-pub const TVMTypeCode_kNull: TVMTypeCode = 4;
-pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
-pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
-pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
-pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
-pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
-pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
-pub const TVMTypeCode_kStr: TVMTypeCode = 11;
-pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
-pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
-pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
-pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
-pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
-pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
-pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
-/// \brief The type code in TVMType
-/// \note TVMType is used in two places.
-pub type TVMTypeCode = u32;
-/// \brief The data type used in TVM Runtime.
-///
-/// Examples
-/// - float: type_code = 2, bits = 32, lanes=1
-/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
-/// - int8: type_code = 0, bits = 8, lanes=1
-///
-/// \note Arguments TVM API function always takes bits=64 and lanes=1
-pub type TVMType = DLDataType;
-/// \brief The Device information, abstract away common device types.
-pub type TVMContext = DLContext;
-/// \brief The tensor array stucture to TVM API.
-pub type TVMArray = DLTensor;
-/// \brief the array handle
-pub type TVMArrayHandle = *mut TVMArray;
-/// \brief Union type of values
-/// being passed through API and function calls.
-#[repr(C)]
-#[derive(Copy, Clone)]
-pub union TVMValue {
-  pub v_int64: i64,
-  pub v_float64: f64,
-  pub v_handle: *mut ::std::os::raw::c_void,
-  pub v_str: *const ::std::os::raw::c_char,
-  pub v_type: TVMType,
-  pub v_ctx: TVMContext,
-  _bindgen_union_align: u64,
-}
-/// \brief Byte array type used to pass in byte array
-/// When kBytes is used as data type.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct TVMByteArray {
-  pub data: *const ::std::os::raw::c_char,
-  pub size: usize,
-}
-/// \brief Handle to TVM runtime modules.
-pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
-/// \brief Handle to packed function handle.
-pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
-/// \brief Handle to hold return value.
-pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
-/// \brief The stream that is specific to device
-/// can be NULL, which indicates the default one.
-pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
-extern "C" {
-  /// \brief Used for implementing C API function.
-  /// Set last error message before return.
-  /// \param msg The error message to be set.
-  pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
-}
-extern "C" {
-  /// \brief return str message of the last error
-  /// all function in this file will return 0 when success
-  /// and -1 when an error occured,
-  /// TVMGetLastError can be called to retrieve the error
-  ///
-  /// this function is threadsafe and can be called by different thread
-  /// \return error info
-  pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
-}
-extern "C" {
-  /// \brief Load module from file.
-  /// \param file_name The file name to load the module from.
-  /// \param format The format of the module.
-  /// \param out The result module
-  ///
-  /// \return 0 when success, -1 when failure happens
-  /// \note The resulting module do not contain import relation.
-  /// It can be reconstructed by TVMModImport.
-  pub fn TVMModLoadFromFile(
-    file_name: *const ::std::os::raw::c_char,
-    format: *const ::std::os::raw::c_char,
-    out: *mut TVMModuleHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Add dep to mod's dependency.
-  /// This allows functions in this module to use modules.
-  ///
-  /// \param mod The module handle.
-  /// \param dep The dependent module to be imported.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Get function from the module.
-  /// \param mod The module handle.
-  /// \param func_name The name of the function.
-  /// \param query_imports Whether to query imported modules
-  /// \param out The result function, can be NULL if it is not available.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMModGetFunction(
-    mod_: TVMModuleHandle,
-    func_name: *const ::std::os::raw::c_char,
-    query_imports: ::std::os::raw::c_int,
-    out: *mut TVMFunctionHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free front-end extension type resource.
-  /// \param handle The extension handle.
-  /// \param type_code The type of of the extension type.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMExtTypeFree(
-    handle: *mut ::std::os::raw::c_void,
-    type_code: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free the Module
-  /// \param mod The module to be freed.
-  ///
-  /// \note This may not free up the module's resources.
-  /// If there is active TVMFunctionHandle uses the module
-  /// Or if this module is imported by another active module.
-  ///
-  /// The all functions remains valid until TVMFuncFree is called.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free the function when it is no longer needed.
-  /// \param func The function handle
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Call a Packed TVM Function.
-  ///
-  /// \param func node handle of the function.
-  /// \param arg_values The arguments
-  /// \param type_codes The type codes of the arguments
-  /// \param num_args Number of arguments.
-  ///
-  /// \param ret_val The return value.
-  /// \param ret_type_code the type code of return value.
-  ///
-  /// \return 0 when success, -1 when failure happens
-  /// \note TVM calls always exchanges with type bits=64, lanes=1
-  ///
-  /// \note API calls always exchanges with type bits=64, lanes=1
-  /// If API call returns container handles (e.g. FunctionHandle)
-  /// these handles should be managed by the front-end.
-  /// The front-end need to call free function (e.g. TVMFuncFree)
-  /// to free these handles.
-  pub fn TVMFuncCall(
-    func: TVMFunctionHandle,
-    arg_values: *mut TVMValue,
-    type_codes: *mut ::std::os::raw::c_int,
-    num_args: ::std::os::raw::c_int,
-    ret_val: *mut TVMValue,
-    ret_type_code: *mut ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Set the return value of TVMPackedCFunc.
-  ///
-  /// This function is called by TVMPackedCFunc to set the return value.
-  /// When this function is not called, the function returns null by default.
-  ///
-  /// \param ret The return value handle, pass by ret in TVMPackedCFunc
-  /// \param value The value to be returned.
-  /// \param type_code The type of the value to be returned.
-  /// \param num_ret Number of return values, for now only 1 is supported.
-  pub fn TVMCFuncSetReturn(
-    ret: TVMRetValueHandle,
-    value: *mut TVMValue,
-    type_code: *mut ::std::os::raw::c_int,
-    num_ret: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Inplace translate callback argument value to return value.
-  /// This is only needed for non-POD arguments.
-  ///
-  /// \param value The value to be translated.
-  /// \param code The type code to be translated.
-  /// \note This function will do a shallow copy when necessary.
-  ///
-  /// \return 0 when success, -1 when failure happens.
-  pub fn TVMCbArgToReturn(
-    value: *mut TVMValue,
-    code: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-/// \brief C type of packed function.
-///
-/// \param args The arguments
-/// \param type_codes The type codes of the arguments
-/// \param num_args Number of arguments.
-/// \param ret The return value handle.
-/// \param resource_handle The handle additional resouce handle from fron-end.
-/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
-/// \sa TVMCFuncSetReturn
-pub type TVMPackedCFunc = ::std::option::Option<
-  unsafe extern "C" fn(
-    args: *mut TVMValue,
-    type_codes: *mut ::std::os::raw::c_int,
-    num_args: ::std::os::raw::c_int,
-    ret: TVMRetValueHandle,
-    resource_handle: *mut ::std::os::raw::c_void,
-  ) -> ::std::os::raw::c_int,
->;
-/// \brief C callback to free the resource handle in C packed function.
-/// \param resource_handle The handle additional resouce handle from fron-end.
-pub type TVMPackedCFuncFinalizer =
-  ::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
-/// \brief Signature for extension function declarer.
-///
-/// TVM call this function to get the extension functions
-/// The declarer will call register_func to register function and their name.
-///
-/// \param register_func_handle The register function
-/// \return 0 if success, -1 if failure happens
-pub type TVMExtensionFuncDeclarer = ::std::option::Option<
-  unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
->;
-extern "C" {
-  /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
-  ///
-  /// The resource_handle will be managed by TVM API, until the function is no longer used.
-  ///
-  /// \param func The packed C function.
-  /// \param resource_handle The resource handle from front-end, can be NULL.
-  /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
-  /// \param out the result function handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMFuncCreateFromCFunc(
-    func: TVMPackedCFunc,
-    resource_handle: *mut ::std::os::raw::c_void,
-    fin: TVMPackedCFuncFinalizer,
-    out: *mut TVMFunctionHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Register the function to runtime's global table.
-  ///
-  /// The registered function then can be pulled by the backend by the name.
-  ///
-  /// \param name The name of the function.
-  /// \param f The function to be registered.
-  /// \param override Whether allow override already registered function.
-  pub fn TVMFuncRegisterGlobal(
-    name: *const ::std::os::raw::c_char,
-    f: TVMFunctionHandle,
-    override_: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Get a global function.
-  ///
-  /// \param name The name of the function.
-  /// \param out the result function pointer, NULL if it does not exist.
-  ///
-  /// \note The function handle of global function is managed by TVM runtime,
-  /// So TVMFuncFree is should not be called when it get deleted.
-  pub fn TVMFuncGetGlobal(
-    name: *const ::std::os::raw::c_char,
-    out: *mut TVMFunctionHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief List all the globally registered function name
-  /// \param out_size The number of functions
-  /// \param out_array The array of function names.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMFuncListGlobalNames(
-    out_size: *mut ::std::os::raw::c_int,
-    out_array: *mut *mut *const ::std::os::raw::c_char,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Allocate a nd-array's memory,
-  /// including space of shape, of given spec.
-  ///
-  /// \param shape The shape of the array, the data content will be copied to out
-  /// \param ndim The number of dimension of the array.
-  /// \param dtype_code The type code of the dtype
-  /// \param dtype_bits The number of bits of dtype
-  /// \param dtype_lanes The number of lanes in the dtype.
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context.
-  /// \param out The output handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayAlloc(
-    shape: *const tvm_index_t,
-    ndim: ::std::os::raw::c_int,
-    dtype_code: ::std::os::raw::c_int,
-    dtype_bits: ::std::os::raw::c_int,
-    dtype_lanes: ::std::os::raw::c_int,
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    out: *mut TVMArrayHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free the TVM Array.
-  /// \param handle The array handle to be freed.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Copy array data from CPU byte array.
-  /// \param handle The array handle.
-  /// \param data the data pointer
-  /// \param nbytes The number of bytes to copy.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayCopyFromBytes(
-    handle: TVMArrayHandle,
-    data: *mut ::std::os::raw::c_void,
-    nbytes: usize,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Copy array data to CPU byte array.
-  /// \param handle The array handle.
-  /// \param data the data pointer
-  /// \param nbytes The number of bytes to copy.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayCopyToBytes(
-    handle: TVMArrayHandle,
-    data: *mut ::std::os::raw::c_void,
-    nbytes: usize,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Copy the array, both from and to must be valid during the copy.
-  /// \param from The array to be copied from.
-  /// \param to The target space.
-  /// \param stream The stream where the copy happens, can be NULL.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayCopyFromTo(
-    from: TVMArrayHandle,
-    to: TVMArrayHandle,
-    stream: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Produce an array from the DLManagedTensor that shares data memory
-  /// with the DLManagedTensor.
-  /// \param from The source DLManagedTensor.
-  /// \param out The output array handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayFromDLPack(
-    from: *mut DLManagedTensor,
-    out: *mut TVMArrayHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Produce a DLMangedTensor from the array that shares data memory with
-  /// the array.
-  /// \param from The source array.
-  /// \param out The DLManagedTensor handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayToDLPack(
-    from: TVMArrayHandle,
-    out: *mut *mut DLManagedTensor,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Delete (free) a DLManagedTensor's data.
-  /// \param dltensor Pointer to the DLManagedTensor.
-  pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
-}
-extern "C" {
-  /// \brief Create a new runtime stream.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context
-  /// \param out The new stream handle
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMStreamCreate(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    out: *mut TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free a created stream handle.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context
-  /// \param stream The stream to be freed
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMStreamFree(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    stream: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Set the runtime stream of current thread to be stream.
-  /// The subsequent calls to the same device_type
-  /// will use the setted stream handle.
-  /// The specific type of stream is runtime device dependent.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context.
-  /// \param handle The stream handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMSetStream(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    handle: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Wait until all computations on stream completes.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context.
-  /// \param stream The stream to be synchronized.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMSynchronize(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    stream: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Synchronize two streams of execution.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context
-  /// \param src The source stream to synchronize.
-  /// \param dst The destination stream to synchronize.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMStreamStreamSynchronize(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    src: TVMStreamHandle,
-    dst: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Backend function for modules to get function
-  /// from its environment mod_node (its imports and global function).
-  /// The user do should not call TVMFuncFree on func.
-  ///
-  /// \param mod_node The module handle.
-  /// \param func_name The name of the function.
-  /// \param out The result function.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendGetFuncFromEnv(
-    mod_node: *mut ::std::os::raw::c_void,
-    func_name: *const ::std::os::raw::c_char,
-    out: *mut TVMFunctionHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Backend function to register system-wide library symbol.
-  ///
-  /// \param name The name of the symbol
-  /// \param ptr The symbol address.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendRegisterSystemLibSymbol(
-    name: *const ::std::os::raw::c_char,
-    ptr: *mut ::std::os::raw::c_void,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Backend function to allocate temporal workspace.
-  ///
-  /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
-  ///
-  /// \param nbytes The size of the space requested.
-  /// \param device_type The device type which the space will be allocated.
-  /// \param device_id The device id which the space will be allocated.
-  /// \param dtype_code_hint The type code of the array elements. Only used in
-  /// certain backends such as OpenGL.
-  /// \param dtype_bits_hint The type bits of the array elements. Only used in
-  /// certain backends such as OpenGL.
-  /// \return nullptr when error is thrown, a valid ptr if success
-  pub fn TVMBackendAllocWorkspace(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    nbytes: u64,
-    dtype_code_hint: ::std::os::raw::c_int,
-    dtype_bits_hint: ::std::os::raw::c_int,
-  ) -> *mut ::std::os::raw::c_void;
-}
-extern "C" {
-  /// \brief Backend function to free temporal workspace.
-  ///
-  /// \param ptr The result allocated space pointer.
-  /// \param device_type The device type which the space will be allocated.
-  /// \param device_id The device id which the space will be allocated.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  ///
-  /// \sa TVMBackendAllocWorkspace
-  pub fn TVMBackendFreeWorkspace(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    ptr: *mut ::std::os::raw::c_void,
-  ) -> ::std::os::raw::c_int;
-}
-/// \brief Environment for TVM parallel task.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct TVMParallelGroupEnv {
-  /// \brief Auxiliary used for synchronization
-  pub sync_handle: *mut ::std::os::raw::c_void,
-  /// \brief total amount of task
-  pub num_task: i32,
-}
-/// \brief The callback function to execute a parallel lambda
-/// \param task_id the task id of the function.
-/// \param penv The parallel environment backs the execution.
-/// \param cdata The supporting closure data.
-pub type FTVMParallelLambda = ::std::option::Option<
-  unsafe extern "C" fn(
-    task_id: ::std::os::raw::c_int,
-    penv: *mut TVMParallelGroupEnv,
-    cdata: *mut ::std::os::raw::c_void,
-  ) -> ::std::os::raw::c_int,
->;
-extern "C" {
-  /// \brief Backend function for running parallel jobs.
-  ///
-  /// \param flambda The parallel function to be launched.
-  /// \param cdata The closure data.
-  /// \param num_task Number of tasks to launch, can be 0, means launch
-  /// with all available threads.
-  ///
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendParallelLaunch(
-    flambda: FTVMParallelLambda,
-    cdata: *mut ::std::os::raw::c_void,
-    num_task: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief BSP barrrier between parallel threads
-  /// \param task_id the task id of the function.
-  /// \param penv The parallel environment backs the execution.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendParallelBarrier(
-    task_id: ::std::os::raw::c_int,
-    penv: *mut TVMParallelGroupEnv,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Simple static initialization function.
-  /// Run f once and set handle to be not null.
-  /// This function is mainly used for test purpose.
-  ///
-  /// \param handle An global address to indicate f
-  /// \param f The function to be ran
-  /// \param cdata The closure data to pass to the function.
-  /// \param nbytes Number of bytes in the closure data.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendRunOnce(
-    handle: *mut *mut ::std::os::raw::c_void,
-    f: ::std::option::Option<
-      unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
-    >,
-    cdata: *mut ::std::os::raw::c_void,
-    nbytes: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
index a81fab9..ad72f36 100644 (file)
@@ -1,15 +1,79 @@
-//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
+use std::fmt;
 
-error_chain! {
-    errors {
-        TryFromTVMArgValueError(expected: String, actual: String) {
-              description("mismatched types while converting from TVMArgValue")
-              display("expected `{}` but given `{}`", expected, actual)
+static TYPE_CODE_STRS: [&str; 15] = [
+    "int",
+    "uint",
+    "float",
+    "handle",
+    "null",
+    "TVMType",
+    "TVMContext",
+    "ArrayHandle",
+    "NodeHandle",
+    "ModuleHandle",
+    "FuncHandle",
+    "str",
+    "bytes",
+    "NDArrayContainer",
+    "ExtBegin",
+];
+
+#[derive(Debug, Fail)]
+pub struct ValueDowncastError {
+    actual_type_code: i64,
+    expected_type_code: i64,
+}
+
+impl ValueDowncastError {
+    pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self {
+        Self {
+            actual_type_code,
+            expected_type_code,
         }
+    }
+}
 
-        TryFromTVMRetValueError(expected: String, actual: String) {
-              description("mismatched types while downcasting TVMRetValue")
-              display("invalid downcast: expected `{}` but given `{}`", expected, actual)
+impl fmt::Display for ValueDowncastError {
+    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+        write!(
+            formatter,
+            "Could not downcast TVMValue: expected `{}` but was {}",
+            TYPE_CODE_STRS[self.actual_type_code as usize],
+            TYPE_CODE_STRS[self.expected_type_code as usize]
+        )
+    }
+}
+
+#[derive(Debug, Fail)]
+#[fail(display = "Function call `{}` returned error: {}", context, message)]
+pub struct FuncCallError {
+    context: String,
+    message: String,
+}
+
+impl FuncCallError {
+    pub fn get_with_context(context: String) -> Self {
+        Self {
+            context,
+            message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
+                .to_str()
+                .expect("double fault")
+                .to_owned(),
         }
     }
 }
+
+// error_chain! {
+//     errors {
+//         TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
+//             description("mismatched types while downcasting TVMRetValue")
+//             display("invalid downcast: expected `{}` but was `{}`",
+//                     expected_type, type_code_to_string(actual_type_code))
+//         }
+//     }
+//     foreign_links {
+//         IntoString(std::ffi::IntoStringError);
+//         ParseInt(std::num::ParseIntError);
+//         Utf8(std::str::Utf8Error);
+//     }
+// }
index ad4c4f2..966655e 100644 (file)
@@ -1,39 +1,29 @@
 //! This crate contains the refactored basic components required
 //! for `runtime` and `frontend` TVM crates.
 
-#![crate_name = "tvm_common"]
-#![recursion_limit = "1024"]
-#![allow(non_camel_case_types, unused_imports)]
-#![feature(box_syntax, try_from)]
+#![feature(box_syntax, trait_alias)]
 
 #[macro_use]
-extern crate error_chain;
+extern crate failure;
 
 /// Unified ffi module for both runtime and frontend crates.
 pub mod ffi {
     #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
 
-    #[cfg(feature = "frontend")]
-    pub extern crate tvm_sys as ts;
+    use std::os::raw::{c_char, c_int, c_void};
 
-    #[cfg(feature = "runtime")]
-    pub mod runtime {
-        use std::os::raw::{c_char, c_int, c_void};
+    include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
 
-        include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
-
-        pub type BackendPackedCFunc = extern "C" fn(
-            args: *const TVMValue,
-            type_codes: *const c_int,
-            num_args: c_int,
-        ) -> c_int;
-    }
+    pub type BackendPackedCFunc =
+        extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
 }
 
+pub mod array;
 pub mod errors;
-pub mod ty;
+#[macro_use]
+pub mod packed_func;
 pub mod value;
 
 pub use errors::*;
-pub use ty::TVMTypeCode;
-pub use value::{TVMArgValue, TVMRetValue, TVMValue};
+pub use ffi::{TVMContext, TVMType};
+pub use packed_func::{TVMArgValue, TVMRetValue};
diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs
new file mode 100644 (file)
index 0000000..a564fe6
--- /dev/null
@@ -0,0 +1,312 @@
+use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
+
+use failure::Error;
+
+pub use crate::ffi::TVMValue;
+use crate::ffi::*;
+
+pub trait PackedFunc =
+    Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync;
+
+/// Calls a packed function and returns a `TVMRetValue`.
+///
+/// # Example
+///
+/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
+#[macro_export]
+macro_rules! call_packed {
+  ($fn:expr, $($args:expr),+) => {
+    $fn(&[$($args.into(),)+])
+  };
+  ($fn:expr) => {
+    $fn(&Vec::new())
+  };
+}
+
+/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
+/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
+#[derive(Clone, Copy)]
+pub struct TVMArgValue<'a> {
+    pub _lifetime: PhantomData<&'a ()>,
+    pub value: TVMValue,
+    pub type_code: i64,
+}
+
+impl<'a> TVMArgValue<'a> {
+    pub fn new(value: TVMValue, type_code: i64) -> Self {
+        TVMArgValue {
+            _lifetime: PhantomData,
+            value: value,
+            type_code: type_code,
+        }
+    }
+}
+
+#[macro_export]
+macro_rules! ensure_type {
+    ($val:ident, $expected_type_code:expr) => {
+        ensure!(
+            $val.type_code == $expected_type_code as i64,
+            $crate::errors::ValueDowncastError::new(
+                $val.type_code as i64,
+                $expected_type_code as i64
+            )
+        );
+    };
+}
+
+/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
+macro_rules! impl_prim_tvm_arg {
+    ($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => {
+        $(
+            impl From<$type> for TVMArgValue<'static> {
+                fn from(val: $type) -> Self {
+                    TVMArgValue {
+                        value: TVMValue { $field: val as $field_type },
+                        type_code: $type_code as i64,
+                        _lifetime: PhantomData,
+                    }
+                }
+            }
+            impl<'a> From<&'a $type> for TVMArgValue<'a> {
+                fn from(val: &'a $type) -> Self {
+                    TVMArgValue {
+                        value: TVMValue {
+                            $field: val.to_owned() as $field_type,
+                        },
+                        type_code: $type_code as i64,
+                        _lifetime: PhantomData,
+                    }
+                }
+            }
+            impl<'a> TryFrom<TVMArgValue<'a>> for $type {
+              type Error = Error;
+                fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
+                    ensure_type!(val, $type_code);
+                    Ok(unsafe { val.value.$field as $type })
+                }
+            }
+
+            impl<'a> TryFrom<&TVMArgValue<'a>> for $type {
+              type Error = Error;
+                fn try_from(val: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
+                    ensure_type!(val, $type_code);
+                    Ok(unsafe { val.value.$field as $type })
+                }
+            }
+        )+
+    };
+}
+
+impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]);
+impl_prim_tvm_arg!(
+    DLDataTypeCode_kDLInt,
+    v_int64,
+    i64,
+    [i8, i16, i32, i64, isize]
+);
+impl_prim_tvm_arg!(
+    DLDataTypeCode_kDLUInt,
+    v_int64,
+    i64,
+    [u8, u16, u32, u64, usize]
+);
+
+#[cfg(feature = "bindings")]
+// only allow this in bindings because pure-rust can't take ownership of leaked CString
+impl<'a> From<&String> for TVMArgValue<'a> {
+    fn from(string: &String) -> Self {
+        TVMArgValue {
+            value: TVMValue {
+                v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(),
+            },
+            type_code: TVMTypeCode_kStr as i64,
+            _lifetime: PhantomData,
+        }
+    }
+}
+
+impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> {
+    fn from(string: &std::ffi::CString) -> Self {
+        TVMArgValue {
+            value: TVMValue {
+                v_str: string.as_ptr(),
+            },
+            type_code: TVMTypeCode_kStr as i64,
+            _lifetime: PhantomData,
+        }
+    }
+}
+
+impl<'a> TryFrom<TVMArgValue<'a>> for &str {
+    type Error = Error;
+    fn try_from(arg: TVMArgValue<'a>) -> Result<Self, Self::Error> {
+        ensure_type!(arg, TVMTypeCode_kStr);
+        Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
+    }
+}
+
+impl<'a> TryFrom<&TVMArgValue<'a>> for &str {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
+        ensure_type!(arg, TVMTypeCode_kStr);
+        Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
+    }
+}
+
+/// Creates a conversion to a `TVMArgValue` for an object handle.
+impl<'a, T> From<*const T> for TVMArgValue<'a> {
+    fn from(ptr: *const T) -> Self {
+        TVMArgValue {
+            value: TVMValue {
+                v_handle: ptr as *mut T as *mut c_void,
+            },
+            type_code: TVMTypeCode_kArrayHandle as i64,
+            _lifetime: PhantomData,
+        }
+    }
+}
+
+/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
+impl<'a, T> From<*mut T> for TVMArgValue<'a> {
+    fn from(ptr: *mut T) -> Self {
+        TVMArgValue {
+            value: TVMValue {
+                v_handle: ptr as *mut c_void,
+            },
+            type_code: TVMTypeCode_kHandle as i64,
+            _lifetime: PhantomData,
+        }
+    }
+}
+
+impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
+    fn from(arr: &'a mut DLTensor) -> Self {
+        TVMArgValue {
+            value: TVMValue {
+                v_handle: arr as *mut _ as *mut c_void,
+            },
+            type_code: TVMTypeCode_kArrayHandle as i64,
+            _lifetime: PhantomData,
+        }
+    }
+}
+
+impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
+    fn from(arr: &'a DLTensor) -> Self {
+        TVMArgValue {
+            value: TVMValue {
+                v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
+            },
+            type_code: TVMTypeCode_kArrayHandle as i64,
+            _lifetime: PhantomData,
+        }
+    }
+}
+
+impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType {
+    type Error = Error;
+    fn try_from(arg: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
+        ensure_type!(arg, TVMTypeCode_kTVMType);
+        Ok(unsafe { arg.value.v_type.into() })
+    }
+}
+
+/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
+/// Can be downcasted using `try_from` if it contains the desired type.
+///
+/// # Example
+///
+/// ```
+/// let a = 42u32;
+/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
+///
+/// let s = "hello, world!";
+/// let t: TVMRetValue = s.into();
+/// assert_eq!(String::try_from(t).unwrap(), s);
+/// ```
+pub struct TVMRetValue {
+    pub value: TVMValue,
+    pub box_value: Box<Any>,
+    pub type_code: i64,
+}
+
+impl TVMRetValue {
+    pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
+        Self {
+            value,
+            type_code,
+            box_value: box (),
+        }
+    }
+
+    pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
+        (self.value, self.type_code as TVMTypeCode)
+    }
+}
+
+impl Default for TVMRetValue {
+    fn default() -> Self {
+        TVMRetValue {
+            value: TVMValue { v_int64: 0 as i64 },
+            type_code: 0,
+            box_value: box (),
+        }
+    }
+}
+
+macro_rules! impl_pod_ret_value {
+    ($code:expr, [ $( $ty:ty ),+ ] ) => {
+        $(
+            impl From<$ty> for TVMRetValue {
+                fn from(val: $ty) -> Self {
+                    Self {
+                        value: val.into(),
+                        type_code: $code as i64,
+                        box_value: box (),
+                    }
+                }
+            }
+
+            impl TryFrom<TVMRetValue> for $ty {
+              type Error = Error;
+                fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
+                    ensure_type!(ret, $code);
+                    Ok(ret.value.into())
+                }
+            }
+        )+
+    };
+}
+
+impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]);
+impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]);
+impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]);
+impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]);
+impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]);
+
+impl TryFrom<TVMRetValue> for String {
+    type Error = Error;
+    fn try_from(ret: TVMRetValue) -> Result<String, Self::Error> {
+        ensure_type!(ret, TVMTypeCode_kStr);
+        let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) };
+        let ret_str = cs.clone().into_string();
+        if cfg!(feature = "bindings") {
+            std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall)
+        }
+        Ok(ret_str?)
+    }
+}
+
+impl From<String> for TVMRetValue {
+    fn from(s: String) -> Self {
+        let cs = std::ffi::CString::new(s).unwrap();
+        Self {
+            value: TVMValue {
+                v_str: cs.into_raw() as *mut i8,
+            },
+            box_value: box (),
+            type_code: TVMTypeCode_kStr as i64,
+        }
+    }
+}
index 126bcd4..e69de29 100644 (file)
@@ -1,144 +0,0 @@
-//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
-//!
-//! # Example
-//!
-//! ```
-//! let dtype = TVMType::from("float");
-//! println!("dtype is: {}", dtype);
-//! ```
-
-use std::{
-    ffi::{CStr, CString},
-    fmt::{self, Display, Formatter},
-};
-
-/// TVM type codes.
-#[repr(u32)]
-#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
-pub enum TVMTypeCode {
-    kDLInt = 0,
-    kDLUInt = 1,
-    kDLFloat = 2,
-    kHandle = 3,
-    kNull = 4,
-    kTVMType = 5,
-    kTVMContext = 6,
-    kArrayHandle = 7,
-    kNodeHandle = 8,
-    kModuleHandle = 9,
-    kFuncHandle = 10,
-    kStr = 11,
-    kBytes = 12,
-    kNDArrayContainer = 13,
-}
-
-impl Default for TVMTypeCode {
-    fn default() -> Self {
-        TVMTypeCode::kDLInt
-    }
-}
-
-impl From<TVMTypeCode> for i64 {
-    fn from(arg: TVMTypeCode) -> i64 {
-        match arg {
-            TVMTypeCode::kDLInt => 0,
-            TVMTypeCode::kDLUInt => 1,
-            TVMTypeCode::kDLFloat => 2,
-            TVMTypeCode::kHandle => 3,
-            TVMTypeCode::kNull => 4,
-            TVMTypeCode::kTVMType => 5,
-            TVMTypeCode::kTVMContext => 6,
-            TVMTypeCode::kArrayHandle => 7,
-            TVMTypeCode::kNodeHandle => 8,
-            TVMTypeCode::kModuleHandle => 9,
-            TVMTypeCode::kFuncHandle => 10,
-            TVMTypeCode::kStr => 11,
-            TVMTypeCode::kBytes => 12,
-            TVMTypeCode::kNDArrayContainer => 13,
-        }
-    }
-}
-
-impl Into<TVMTypeCode> for i64 {
-    fn into(self) -> TVMTypeCode {
-        match self {
-            0 => TVMTypeCode::kDLInt,
-            1 => TVMTypeCode::kDLUInt,
-            2 => TVMTypeCode::kDLFloat,
-            3 => TVMTypeCode::kHandle,
-            4 => TVMTypeCode::kNull,
-            5 => TVMTypeCode::kTVMType,
-            6 => TVMTypeCode::kTVMContext,
-            7 => TVMTypeCode::kArrayHandle,
-            8 => TVMTypeCode::kNodeHandle,
-            9 => TVMTypeCode::kModuleHandle,
-            10 => TVMTypeCode::kFuncHandle,
-            11 => TVMTypeCode::kStr,
-            12 => TVMTypeCode::kBytes,
-            13 => TVMTypeCode::kNDArrayContainer,
-            _ => unreachable!(),
-        }
-    }
-}
-
-impl Display for TVMTypeCode {
-    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
-        write!(
-            f,
-            "{}",
-            match self {
-                TVMTypeCode::kDLInt => "int",
-                TVMTypeCode::kDLUInt => "uint",
-                TVMTypeCode::kDLFloat => "float",
-                TVMTypeCode::kHandle => "handle",
-                TVMTypeCode::kNull => "null",
-                TVMTypeCode::kTVMType => "TVM type",
-                TVMTypeCode::kTVMContext => "TVM context",
-                TVMTypeCode::kArrayHandle => "Array handle",
-                TVMTypeCode::kNodeHandle => "Node handle",
-                TVMTypeCode::kModuleHandle => "Module handle",
-                TVMTypeCode::kFuncHandle => "Function handle",
-                TVMTypeCode::kStr => "string",
-                TVMTypeCode::kBytes => "bytes",
-                TVMTypeCode::kNDArrayContainer => "ndarray container",
-            }
-        )
-    }
-}
-
-macro_rules! impl_prim_type {
-    ($type:ty, $variant:ident) => {
-        impl<'a> From<&'a $type> for TVMTypeCode {
-            fn from(_arg: &$type) -> Self {
-                TVMTypeCode::$variant
-            }
-        }
-
-        impl<'a> From<&'a mut $type> for TVMTypeCode {
-            fn from(_arg: &mut $type) -> Self {
-                TVMTypeCode::$variant
-            }
-        }
-    };
-}
-
-impl_prim_type!(usize, kDLInt);
-impl_prim_type!(i64, kDLInt);
-impl_prim_type!(i32, kDLInt);
-impl_prim_type!(i16, kDLInt);
-impl_prim_type!(i8, kDLInt);
-
-impl_prim_type!(u64, kDLUInt);
-impl_prim_type!(u32, kDLUInt);
-impl_prim_type!(u16, kDLUInt);
-impl_prim_type!(u8, kDLUInt);
-
-impl_prim_type!(f64, kDLFloat);
-impl_prim_type!(f32, kDLFloat);
-
-impl_prim_type!(str, kStr);
-impl_prim_type!(CStr, kStr);
-impl_prim_type!(String, kStr);
-impl_prim_type!(CString, kStr);
-
-impl_prim_type!([u8], kBytes);
index 6da8b27..c7c040b 100644 (file)
-//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue`
-//! required for using TVM functions.
+use std::str::FromStr;
 
-use std::{
-    any::Any,
-    convert::TryFrom,
-    ffi::{CStr, CString},
-    fmt::{self, Debug, Formatter},
-    marker::PhantomData,
-    mem,
-    ops::Deref,
-    os::raw::{c_char, c_void},
-};
+use failure::Error;
 
-#[cfg(feature = "runtime")]
-use ffi::runtime::TVMValue as _TVMValue;
+use crate::ffi::*;
 
-#[cfg(feature = "frontend")]
-use ffi::ts::TVMValue as _TVMValue;
-
-use errors::*;
-
-use ty::TVMTypeCode;
-
-/// Wrapped TVMValue type.
-#[derive(Clone, Copy)]
-pub struct TVMValue {
-    pub inner: _TVMValue,
-}
-
-impl TVMValue {
-    /// Creates TVMValue from the raw part.
-    pub fn new(inner: _TVMValue) -> Self {
-        TVMValue { inner }
-    }
-
-    pub(crate) fn into_raw(self) -> _TVMValue {
-        self.inner
-    }
-}
-
-impl Debug for TVMValue {
-    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
-        unsafe {
-            write!(
-                f,
-                "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\
-                 [v_str: {:?}]",
-                self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str
-            )
+impl TVMType {
+    fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
+        Self {
+            code: type_code,
+            bits,
+            lanes,
         }
     }
 }
 
-impl Deref for TVMValue {
-    type Target = _TVMValue;
-    fn deref(&self) -> &Self::Target {
-        &self.inner
-    }
-}
-
-macro_rules! impl_prim_val {
-    ($type:ty, $field:ident, $cast:ty) => {
-        impl From<$type> for TVMValue {
-            fn from(arg: $type) -> Self {
-                let inner = _TVMValue {
-                    $field: arg as $cast,
-                };
-                Self::new(inner)
-            }
+/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
+/// such as "int32", "float32" or with lane "float32x1".
+impl FromStr for TVMType {
+    type Err = Error;
+    fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+        if type_str == "bool" {
+            return Ok(TVMType::new(1, 1, 1));
         }
 
-        impl<'a> From<&'a $type> for TVMValue {
-            fn from(arg: &$type) -> Self {
-                let inner = _TVMValue {
-                    $field: *arg as $cast,
-                };
-                Self::new(inner)
+        let mut type_lanes = type_str.split("x");
+        let typ = type_lanes.next().expect("Missing dtype");
+        let lanes = type_lanes
+            .next()
+            .map(|l| <u16>::from_str_radix(l, 10))
+            .unwrap_or(Ok(1))?;
+        let (type_name, bits) = match typ.find(char::is_numeric) {
+            Some(idx) => {
+                let (name, bits_str) = typ.split_at(idx);
+                (name, u8::from_str_radix(bits_str, 10)?)
             }
-        }
-
-        impl<'a> From<&'a mut $type> for TVMValue {
-            fn from(arg: &mut $type) -> Self {
-                let inner = _TVMValue {
-                    $field: *arg as $cast,
-                };
-                Self::new(inner)
-            }
-        }
-
-        impl TryFrom<TVMValue> for $type {
-            type Error = Error;
-            fn try_from(val: TVMValue) -> Result<Self> {
-                Ok(unsafe { val.inner.$field as $type })
-            }
-        }
-
-        impl<'a> TryFrom<&'a TVMValue> for $type {
-            type Error = Error;
-            fn try_from(val: &TVMValue) -> Result<Self> {
-                Ok(unsafe { val.into_raw().$field as $type })
-            }
-        }
-
-        impl<'a> TryFrom<&'a mut TVMValue> for $type {
-            type Error = Error;
-            fn try_from(val: &mut TVMValue) -> Result<Self> {
-                Ok(unsafe { val.into_raw().$field as $type })
-            }
-        }
-    };
-}
-
-impl_prim_val!(isize, v_int64, i64);
-impl_prim_val!(i64, v_int64, i64);
-impl_prim_val!(i32, v_int64, i64);
-impl_prim_val!(i16, v_int64, i64);
-impl_prim_val!(i8, v_int64, i64);
-impl_prim_val!(usize, v_int64, i64);
-impl_prim_val!(u64, v_int64, i64);
-impl_prim_val!(u32, v_int64, i64);
-impl_prim_val!(u16, v_int64, i64);
-impl_prim_val!(u8, v_int64, i64);
-
-impl_prim_val!(f64, v_float64, f64);
-impl_prim_val!(f32, v_float64, f64);
-
-impl<'a> From<&'a str> for TVMValue {
-    fn from(arg: &str) -> TVMValue {
-        let arg = CString::new(arg).unwrap();
-        let inner = _TVMValue {
-            v_str: arg.as_ptr() as *const c_char,
+            None => (typ, 32),
         };
-        mem::forget(arg);
-        Self::new(inner)
-    }
-}
-
-impl<'a> From<&'a String> for TVMValue {
-    fn from(arg: &String) -> TVMValue {
-        let arg = CString::new(arg.as_bytes()).unwrap();
-        let inner = _TVMValue {
-            v_str: arg.as_ptr() as *const c_char,
-        };
-        mem::forget(arg);
-        Self::new(inner)
-    }
-}
-
-impl<'a> From<&'a CString> for TVMValue {
-    fn from(arg: &CString) -> TVMValue {
-        let arg = arg.to_owned();
-        let inner = _TVMValue {
-            v_str: arg.as_ptr() as *const c_char,
-        };
-        mem::forget(arg);
-        Self::new(inner)
-    }
-}
 
-impl<'a> From<&'a [u8]> for TVMValue {
-    fn from(arg: &[u8]) -> TVMValue {
-        let arg = arg.to_owned();
-        let inner = _TVMValue {
-            v_handle: &arg as *const _ as *mut c_void,
+        let type_code = match type_name {
+            "int" => 0,
+            "uint" => 1,
+            "float" => 2,
+            "handle" => 3,
+            _ => return Err(format_err!("Unknown type {}", type_name)),
         };
-        mem::forget(arg);
-        Self::new(inner)
-    }
-}
-
-/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function.
-/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`.
-/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions.
-///
-/// ## Example
-///
-/// ```
-/// let s = "hello".to_string();
-/// let arg = TVMArgValue::from(&s);
-/// let tvm: String = arg.try_into().unwrap();
-/// assert_eq!(arg, s);
-/// ```
-#[derive(Debug, Clone, Copy)]
-pub struct TVMArgValue<'a> {
-    /// The wrapped TVMValue
-    pub value: TVMValue,
-    /// The matching type code.
-    pub type_code: TVMTypeCode,
-    /// This is only exposed to runtime and frontend crates and is not meant to be used directly.
-    pub lifetime: PhantomData<&'a ()>,
-}
-
-impl<'a> TVMArgValue<'a> {
-    pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self {
-        TVMArgValue {
-            value: value,
-            type_code: type_code,
-            lifetime: PhantomData,
-        }
-    }
-}
-
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 {
-    type Error = Error;
-    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
-        if (arg.type_code == TVMTypeCode::kDLInt)
-            | (arg.type_code == TVMTypeCode::kDLUInt)
-            | (arg.type_code == TVMTypeCode::kNull)
-        {
-            Ok(unsafe { arg.value.inner.v_int64 })
-        } else {
-            bail!(ErrorKind::TryFromTVMArgValueError(
-                stringify!(i64).to_string(),
-                arg.type_code.to_string()
-            ))
-        }
-    }
-}
-
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 {
-    type Error = Error;
-    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
-        if arg.type_code == TVMTypeCode::kDLFloat {
-            Ok(unsafe { arg.value.inner.v_float64 })
-        } else {
-            bail!(ErrorKind::TryFromTVMArgValueError(
-                stringify!(f64).to_string(),
-                arg.type_code.to_string()
-            ))
-        }
-    }
-}
-
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String {
-    type Error = Error;
-    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
-        if arg.type_code == TVMTypeCode::kStr {
-            let ret_str = unsafe {
-                match CStr::from_ptr(arg.value.inner.v_str).to_str() {
-                    Ok(s) => s,
-                    Err(_) => "Invalid UTF-8 message",
-                }
-            };
-            Ok(ret_str.to_string())
-        } else {
-            bail!(ErrorKind::TryFromTVMArgValueError(
-                stringify!(String).to_string(),
-                arg.type_code.to_string()
-            ))
-        }
-    }
-}
-
-/// Main way to create a TVMArgValue from suported Rust values.
-impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a>
-where
-    TVMValue: From<&'b T>,
-    TVMTypeCode: From<&'b T>,
-{
-    fn from(arg: &'b T) -> Self {
-        TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg))
-    }
-}
-
-/// Creates a conversion to a `TVMArgValue` for an object handle.
-impl<'a, T> From<*const T> for TVMArgValue<'a> {
-    fn from(ptr: *const T) -> Self {
-        let value = TVMValue::new(_TVMValue {
-            v_handle: ptr as *mut T as *mut c_void,
-        });
 
-        TVMArgValue::new(value, TVMTypeCode::kArrayHandle)
+        Ok(TVMType::new(type_code, bits, lanes))
     }
 }
 
-/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
-impl<'a, T> From<*mut T> for TVMArgValue<'a> {
-    fn from(ptr: *mut T) -> Self {
-        let value = TVMValue::new(_TVMValue {
-            v_handle: ptr as *mut c_void,
-        });
-
-        TVMArgValue::new(value, TVMTypeCode::kHandle)
-    }
-}
-
-/// An owned version of TVMPODValue. It can be converted from varieties of
-/// primitive and object types.
-/// It can be downcasted using `try_from` if it contains the desired type.
-///
-/// # Example
-///
-/// ```
-/// let a = 42u32;
-/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
-///
-/// let s = "hello, world!";
-/// let t: TVMRetValue = s.into();
-/// assert_eq!(String::try_from(t).unwrap(), s);
-/// ```
-pub struct TVMRetValue {
-    /// A primitive return value, if any.
-    pub prim_value: usize,
-    /// An object return value, if any.
-    pub box_value: Box<Any>,
-    pub type_code: TVMTypeCode,
-}
-
-impl TVMRetValue {
-    fn new(prim_value: usize, box_value: Box<Any>, type_code: TVMTypeCode) -> Self {
-        Self {
-            prim_value,
-            box_value,
-            type_code,
+impl std::fmt::Display for TVMType {
+    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+        if self.bits == 1 && self.lanes == 1 {
+            return write!(f, "bool");
         }
-    }
-
-    /// unsafe function to create `TVMRetValue` from `TVMValue` and
-    /// its matching `TVMTypeCode`.
-    pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self {
-        let value = value.into_raw();
-        match type_code {
-            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => {
-                Self::new(value.v_int64 as usize, box (), type_code)
-            }
-            TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code),
-            TVMTypeCode::kHandle
-            | TVMTypeCode::kArrayHandle
-            | TVMTypeCode::kNodeHandle
-            | TVMTypeCode::kModuleHandle
-            | TVMTypeCode::kFuncHandle => {
-                Self::new(value.v_handle as usize, box value.v_handle, type_code)
-            }
-            TVMTypeCode::kStr | TVMTypeCode::kBytes => {
-                Self::new(value.v_str as usize, box (value.v_str), type_code)
-            }
-            _ => Self::new(0usize, box (), type_code),
+        let mut type_str = match self.code {
+            0 => "int",
+            1 => "uint",
+            2 => "float",
+            4 => "handle",
+            _ => "unknown",
         }
-    }
+        .to_string();
 
-    /// Returns the underlying `TVMValue` and `TVMTypeCode`.
-    pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
-        let val = match self.type_code {
-            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue {
-                v_int64: self.prim_value as i64,
-            }),
-            TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue {
-                v_float64: self.prim_value as f64,
-            }),
-            TVMTypeCode::kHandle
-            | TVMTypeCode::kArrayHandle
-            | TVMTypeCode::kNodeHandle
-            | TVMTypeCode::kModuleHandle
-            | TVMTypeCode::kFuncHandle
-            | TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue {
-                v_handle: self.prim_value as *const c_void as *mut c_void,
-            }),
-            TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue {
-                v_str: self.prim_value as *const c_char,
-            }),
-            _ => unreachable!(),
-        };
-        (val, self.type_code)
-    }
-}
-
-impl Default for TVMRetValue {
-    fn default() -> Self {
-        TVMRetValue {
-            prim_value: 0usize,
-            box_value: box (),
-            type_code: TVMTypeCode::default(),
-        }
-    }
-}
-
-impl Clone for TVMRetValue {
-    fn clone(&self) -> Self {
-        match self.type_code {
-            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => {
-                Self::new(self.prim_value.clone(), box (), self.type_code.clone())
-            }
-            TVMTypeCode::kHandle
-            | TVMTypeCode::kArrayHandle
-            | TVMTypeCode::kNodeHandle
-            | TVMTypeCode::kModuleHandle
-            | TVMTypeCode::kFuncHandle
-            | TVMTypeCode::kNDArrayContainer => Self::new(
-                self.prim_value.clone(),
-                box (self.prim_value.clone() as *const c_void as *mut c_void),
-                self.type_code.clone(),
-            ),
-            TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new(
-                self.prim_value.clone(),
-                box (self.prim_value.clone() as *const c_char),
-                self.type_code.clone(),
-            ),
-            _ => unreachable!(),
+        type_str += &self.bits.to_string();
+        if self.lanes > 1 {
+            type_str += &format!("x{}", self.lanes);
         }
+        f.write_str(&type_str)
     }
 }
 
-impl Debug for TVMRetValue {
-    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
-        write!(
-            f,
-            "prim_value: {:?}, box_value: {:?}, type_code: {:?}",
-            self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code
-        )
-    }
-}
-
-macro_rules! impl_prim_ret_value {
-    ($type:ty, $code:expr) => {
-        impl From<$type> for TVMRetValue {
-            fn from(val: $type) -> Self {
-                TVMRetValue {
-                    prim_value: val as usize,
-                    box_value: box (),
-                    type_code: $code,
-                }
-            }
-        }
-
-        impl<'a> From<&'a $type> for TVMRetValue {
-            fn from(val: &$type) -> Self {
-                TVMRetValue {
-                    prim_value: *val as usize,
-                    box_value: box (),
-                    type_code: $code,
+macro_rules! impl_pod_tvm_value {
+    ($field:ident, $field_ty:ty, $( $ty:ty ),+) => {
+        $(
+            impl From<$ty> for TVMValue {
+                fn from(val: $ty) -> Self {
+                    TVMValue { $field: val as $field_ty }
                 }
             }
-        }
 
-        impl<'a> From<&'a mut $type> for TVMRetValue {
-            fn from(val: &mut $type) -> Self {
-                TVMRetValue {
-                    prim_value: *val as usize,
-                    box_value: box (),
-                    type_code: $code,
+            impl From<TVMValue> for $ty {
+                fn from(val: TVMValue) -> Self {
+                    unsafe { val.$field as $ty }
                 }
             }
-        }
-
-        impl TryFrom<TVMRetValue> for $type {
-            type Error = Error;
-            fn try_from(ret: TVMRetValue) -> Result<$type> {
-                if ret.type_code == $code {
-                    Ok(ret.prim_value as $type)
-                } else {
-                    bail!(ErrorKind::TryFromTVMRetValueError(
-                        stringify!($type).to_string(),
-                        ret.type_code.to_string(),
-                    ))
-                }
-            }
-        }
+        )+
     };
+    ($field:ident, $ty:ty) => {
+        impl_pod_tvm_value!($field, $ty, $ty);
+    }
 }
 
-impl_prim_ret_value!(i8, TVMTypeCode::kDLInt);
-impl_prim_ret_value!(i16, TVMTypeCode::kDLInt);
-impl_prim_ret_value!(i32, TVMTypeCode::kDLInt);
-impl_prim_ret_value!(i64, TVMTypeCode::kDLInt);
-impl_prim_ret_value!(isize, TVMTypeCode::kDLInt);
-
-impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt);
-impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt);
-impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt);
-impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt);
-impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt);
-
-impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat);
-impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat);
+impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize);
+impl_pod_tvm_value!(v_float64, f64, f32, f64);
+impl_pod_tvm_value!(v_type, TVMType);
+impl_pod_tvm_value!(v_ctx, TVMContext);
 
-macro_rules! impl_ptr_ret_value {
-    ($type:ty) => {
-        impl From<$type> for TVMRetValue {
-            fn from(ptr: $type) -> Self {
-                TVMRetValue {
-                    prim_value: ptr as usize,
-                    box_value: box (),
-                    type_code: TVMTypeCode::kHandle,
-                }
-            }
-        }
-
-        impl TryFrom<TVMRetValue> for $type {
-            type Error = Error;
-            fn try_from(ret: TVMRetValue) -> Result<$type> {
-                if ret.type_code == TVMTypeCode::kHandle {
-                    Ok(ret.prim_value as $type)
-                } else {
-                    bail!(ErrorKind::TryFromTVMRetValueError(
-                        stringify!($type).to_string(),
-                        ret.type_code.to_string(),
-                    ))
-                }
+macro_rules! impl_tvm_context {
+    ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
+        /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev")
+        impl FromStr for TVMContext {
+            type Err = Error;
+            fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+                Ok(Self {
+                    device_type: match type_str {
+                         $( $(  stringify!($dev_name)  )|+ => $dev_type ),+,
+                        _ => return Err(format_err!("device {} not supported", type_str).into()),
+                    },
+                    device_id: 0,
+                })
             }
         }
-    };
-}
-
-impl_ptr_ret_value!(*const c_void);
-impl_ptr_ret_value!(*mut c_void);
 
-impl From<String> for TVMRetValue {
-    fn from(val: String) -> Self {
-        let pval = val.as_ptr() as *const c_char as usize;
-        let bval = box (val.as_ptr() as *const c_char);
-        mem::forget(val);
-        TVMRetValue::new(pval, bval, TVMTypeCode::kStr)
-    }
-}
-
-impl TryFrom<TVMRetValue> for String {
-    type Error = Error;
-    fn try_from(ret: TVMRetValue) -> Result<String> {
-        // Note: simple downcast doesn't work for function call return values
-        let ret_str = unsafe {
-            match CStr::from_ptr(ret.prim_value as *const c_char).to_str() {
-                Ok(s) => s,
-                Err(_) => "Invalid UTF-8 message",
-            }
-        };
-
-        Ok(ret_str.to_string())
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use std::convert::TryInto;
-
-    #[test]
-    fn numeric() {
-        macro_rules! arg_ret_tests {
-            ($v:expr; $($ty:ty),+) => {{
+        impl TVMContext {
+            $(
                 $(
-                    let v = $v as $ty;
-                    let b = TVMRetValue::from(&v);
-                    let b: $ty = b.try_into().unwrap();
-                    assert_eq!(b, v);
+                    pub fn $dev_name(device_id: usize) -> Self {
+                        Self {
+                            device_type: $dev_type,
+                            device_id: device_id as i32,
+                        }
+                    }
                 )+
-            }};
+            )+
         }
-
-        arg_ret_tests!(42; i8, i16, i32, i64, f32, f64);
-    }
-
-    #[test]
-    fn string() {
-        let s = "hello".to_string();
-        let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap();
-        assert_eq!(tvm_arg, s);
-    }
+    };
 }
+
+impl_tvm_context!(
+    DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
+    DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+    DLDeviceType_kDLOpenCL: [cl],
+    DLDeviceType_kDLMetal: [metal],
+    DLDeviceType_kDLVPI: [vpi],
+    DLDeviceType_kDLROCM: [rocm],
+    DLDeviceType_kDLExtDev: [ext_dev]
+);
diff --git a/rust/common/tvm-sys/Cargo.toml b/rust/common/tvm-sys/Cargo.toml
deleted file mode 100644 (file)
index 117d174..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-[package]
-name = "tvm-sys"
-version = "0.1.0"
-authors = ["TVM Contributors"]
-license = "Apache-2.0"
-description = "Raw C API"
-
-[build-dependencies]
-bindgen = "0.37.4"
diff --git a/rust/common/tvm-sys/build.rs b/rust/common/tvm-sys/build.rs
deleted file mode 100644 (file)
index f842043..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-extern crate bindgen;
-
-use std::path::PathBuf;
-
-fn main() {
-    println!("cargo:rerun-if-env-changed=TVM_HOME");
-    println!("cargo:rustc-link-lib=dylib=tvm_runtime");
-    println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
-    let bindings = bindgen::Builder::default()
-        .header(format!(
-            "{}/include/tvm/runtime/c_runtime_api.h",
-            env!("TVM_HOME")
-        ))
-        .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
-        .blacklist_type("max_align_t") // @see rust-bindgen#550
-        .layout_tests(false)
-        .derive_partialeq(true)
-        .derive_eq(true)
-        .generate()
-        .expect("unable to generate bindings");
-
-    bindings
-        .write_to_file(PathBuf::from("src/bindgen.rs"))
-        .expect("can not write the bindings!");
-}
diff --git a/rust/common/tvm-sys/src/lib.rs b/rust/common/tvm-sys/src/lib.rs
deleted file mode 100644 (file)
index 15f1ea3..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-#![allow(
-    non_camel_case_types,
-    non_snake_case,
-    non_upper_case_globals,
-    dead_code,
-    improper_ctypes
-)]
-
-include!("bindgen.rs");
index db26155..eb1f5b8 100644 (file)
@@ -15,11 +15,11 @@ name = "tvm_frontend"
 crate-type = ["dylib"]
 
 [dependencies]
-error-chain = "0.12.0"
+failure = "0.1.5"
 lazy_static = "1.1.0"
 ndarray = "0.12.1"
 num-traits = "0.2"
-tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] }
+tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] }
 
 [features]
 blas = ["ndarray/blas"]
index 869a35b..2ad3efa 100644 (file)
@@ -1,5 +1,3 @@
-#![feature(try_from)]
-
 extern crate csv;
 extern crate image;
 extern crate ndarray;
@@ -10,6 +8,7 @@ use std::{
     convert::TryInto,
     fs::{self, File},
     path::Path,
+    str::FromStr,
 };
 
 use image::{FilterType, GenericImageView};
@@ -44,8 +43,12 @@ fn main() {
     // make arr shape as [1, 3, 224, 224] acceptable to resnet
     let arr = arr.insert_axis(Axis(0));
     // create input tensor from rust's ndarray
-    let input =
-        NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+    let input = NDArray::from_rust_ndarray(
+        &arr,
+        TVMContext::cpu(0),
+        TVMType::from_str("float32").unwrap(),
+    )
+    .unwrap();
     println!(
         "input size is {:?}",
         input.shape().expect("cannot get the input shape")
@@ -59,7 +62,7 @@ fn main() {
     )))
     .unwrap();
     // get the global TVM graph runtime function
-    let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
+    let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
     let runtime_create_fn_ret = call_packed!(
         runtime_create_fn,
         &graph,
@@ -85,14 +88,19 @@ fn main() {
         .get_function("set_input", false)
         .unwrap();
 
-    call_packed!(set_input_fn, "data", &input).unwrap();
+    let data_str = "data".to_string();
+    call_packed!(set_input_fn, &data_str, &input).unwrap();
     // get `run` function from runtime module
     let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
     // execute the run function. Note that it has no argument
     call_packed!(run_fn,).unwrap();
     // prepare to get the output
     let output_shape = &mut [1, 1000];
-    let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
+    let output = NDArray::empty(
+        output_shape,
+        TVMContext::cpu(0),
+        TVMType::from_str("float32").unwrap(),
+    );
     // get the `get_output` function from runtime module
     let ref get_output_fn = graph_runtime_module
         .get_function("get_output", false)
index 395f34c..9274dba 100644 (file)
@@ -3,9 +3,9 @@
 //!
 //! For more detail, please see the example `resnet` in `examples` repository.
 
-use std::os::raw::c_char;
+use std::os::raw::{c_char, c_void};
 
-use crate::ts;
+use tvm_common::{ffi, TVMArgValue};
 
 /// A struct holding TVM byte-array.
 ///
@@ -19,11 +19,11 @@ use crate::ts;
 /// ```
 #[derive(Debug, Clone)]
 pub struct TVMByteArray {
-    pub(crate) inner: ts::TVMByteArray,
+    pub(crate) inner: ffi::TVMByteArray,
 }
 
 impl TVMByteArray {
-    pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray {
+    pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray {
         TVMByteArray { inner: barr }
     }
 
@@ -46,7 +46,7 @@ impl TVMByteArray {
 
 impl<'a> From<&'a Vec<u8>> for TVMByteArray {
     fn from(arg: &Vec<u8>) -> Self {
-        let barr = ts::TVMByteArray {
+        let barr = ffi::TVMByteArray {
             data: arg.as_ptr() as *const c_char,
             size: arg.len(),
         };
@@ -54,6 +54,18 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
     }
 }
 
+impl<'a> From<&TVMByteArray> for TVMArgValue<'a> {
+    fn from(arr: &TVMByteArray) -> Self {
+        Self {
+            value: ffi::TVMValue {
+                v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void,
+            },
+            type_code: ffi::TVMTypeCode_kBytes as i64,
+            _lifetime: std::marker::PhantomData,
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
index 65e11d8..5d800a8 100644 (file)
 //! ```
 
 use std::{
+    convert::TryInto,
     fmt::{self, Display, Formatter},
     os::raw::c_void,
     ptr,
 };
 
-use crate::{function, ts, Result};
+use failure::Error;
+
+use tvm_common::{
+    ffi::{self, TVMValue},
+    TVMArgValue,
+};
+
+use crate::function;
 
 /// Device type can be from a supported device name. See the supported devices
 /// in [TVM](https://github.com/dmlc/tvm).
@@ -45,35 +53,35 @@ impl Default for TVMDeviceType {
     }
 }
 
-impl From<TVMDeviceType> for ts::DLDeviceType {
+impl From<TVMDeviceType> for ffi::DLDeviceType {
     fn from(device_type: TVMDeviceType) -> Self {
         match device_type.0 {
-            1 => ts::DLDeviceType_kDLCPU,
-            2 => ts::DLDeviceType_kDLGPU,
-            3 => ts::DLDeviceType_kDLCPUPinned,
-            4 => ts::DLDeviceType_kDLOpenCL,
-            7 => ts::DLDeviceType_kDLVulkan,
-            8 => ts::DLDeviceType_kDLMetal,
-            9 => ts::DLDeviceType_kDLVPI,
-            10 => ts::DLDeviceType_kDLROCM,
-            12 => ts::DLDeviceType_kDLExtDev,
+            1 => ffi::DLDeviceType_kDLCPU,
+            2 => ffi::DLDeviceType_kDLGPU,
+            3 => ffi::DLDeviceType_kDLCPUPinned,
+            4 => ffi::DLDeviceType_kDLOpenCL,
+            7 => ffi::DLDeviceType_kDLVulkan,
+            8 => ffi::DLDeviceType_kDLMetal,
+            9 => ffi::DLDeviceType_kDLVPI,
+            10 => ffi::DLDeviceType_kDLROCM,
+            12 => ffi::DLDeviceType_kDLExtDev,
             _ => panic!("device type not found!"),
         }
     }
 }
 
-impl From<ts::DLDeviceType> for TVMDeviceType {
-    fn from(device_type: ts::DLDeviceType) -> Self {
+impl From<ffi::DLDeviceType> for TVMDeviceType {
+    fn from(device_type: ffi::DLDeviceType) -> Self {
         match device_type {
-            ts::DLDeviceType_kDLCPU => TVMDeviceType(1),
-            ts::DLDeviceType_kDLGPU => TVMDeviceType(2),
-            ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
-            ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
-            ts::DLDeviceType_kDLVulkan => TVMDeviceType(7),
-            ts::DLDeviceType_kDLMetal => TVMDeviceType(8),
-            ts::DLDeviceType_kDLVPI => TVMDeviceType(9),
-            ts::DLDeviceType_kDLROCM => TVMDeviceType(10),
-            ts::DLDeviceType_kDLExtDev => TVMDeviceType(12),
+            ffi::DLDeviceType_kDLCPU => TVMDeviceType(1),
+            ffi::DLDeviceType_kDLGPU => TVMDeviceType(2),
+            ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
+            ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
+            ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7),
+            ffi::DLDeviceType_kDLMetal => TVMDeviceType(8),
+            ffi::DLDeviceType_kDLVPI => TVMDeviceType(9),
+            ffi::DLDeviceType_kDLROCM => TVMDeviceType(10),
+            ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12),
             _ => panic!("device type not found!"),
         }
     }
@@ -117,6 +125,18 @@ impl<'a> From<&'a str> for TVMDeviceType {
     }
 }
 
+impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> {
+    fn from(dev_type: &'a TVMDeviceType) -> Self {
+        Self {
+            value: TVMValue {
+                v_int64: dev_type.0 as i64,
+            },
+            type_code: ffi::DLDataTypeCode_kDLInt as i64,
+            _lifetime: std::marker::PhantomData,
+        }
+    }
+}
+
 /// Represents the underlying device context. Default is cpu.
 ///
 /// ## Examples
@@ -138,15 +158,15 @@ pub struct TVMContext {
     /// Supported device types
     pub device_type: TVMDeviceType,
     /// Device id
-    pub device_id: usize,
+    pub device_id: i32,
 }
 
 impl TVMContext {
     /// Creates context from device type and id.
-    pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
+    pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self {
         TVMContext {
-            device_type: device_type,
-            device_id: device_id,
+            device_type,
+            device_id,
         }
     }
 }
@@ -155,7 +175,7 @@ macro_rules! impl_ctxs {
     ($(($ctx:ident, $dldevt:expr));+) => {
         $(
             impl TVMContext {
-                pub fn $ctx(device_id: usize) -> Self {
+                pub fn $ctx(device_id: i32) -> Self {
                     Self::new(TVMDeviceType($dldevt), device_id)
                 }
             }
@@ -185,20 +205,20 @@ impl<'a> From<&'a str> for TVMContext {
 impl TVMContext {
     /// Checks whether the context exists or not.
     pub fn exist(&self) -> bool {
-        let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
-            .expect("API function always exists");
+        let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
         let dt = self.device_type.0 as usize;
         // `unwrap` is ok here because if there is any error,
         // if would occure inside `call_packed!`
-        let ret = call_packed!(func, &dt, &self.device_id, &0)
+        let ret: u64 = call_packed!(func, &dt, &self.device_id, &0)
             .unwrap()
-            .prim_value;
+            .try_into()
+            .unwrap();
         ret != 0
     }
 
     /// Synchronize the context stream.
-    pub fn sync(&self) -> Result<()> {
-        check_call!(ts::TVMSynchronize(
+    pub fn sync(&self) -> Result<(), Error> {
+        check_call!(ffi::TVMSynchronize(
             self.device_type.0 as i32,
             self.device_id as i32,
             ptr::null_mut() as *mut c_void
@@ -212,16 +232,17 @@ macro_rules! impl_device_attrs {
         $(
             impl TVMContext {
                 pub fn $attr_name(&self) -> usize {
-                    let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
+                    let func = function::Function::get("_GetDeviceAttr")
                         .expect("API function always exists");
                     let dt = self.device_type.0 as usize;
                     // `unwrap` is ok here because if there is any error,
                     // if would occur in function call.
-                    let ret = function::Builder::from(func)
-                        .args(&[dt, self.device_id, $attr_kind])
+                    function::Builder::from(func)
+                        .args(&[dt, self.device_id as usize, $attr_kind])
                         .invoke()
-                        .unwrap();
-                    ret.prim_value as usize
+                        .unwrap()
+                        .try_into()
+                        .unwrap()
                 }
             }
         )+
@@ -237,18 +258,18 @@ impl_device_attrs!((max_threads_per_block, 1);
                 (multi_processor_count, 7);
                 (max_thread_dimensions, 8));
 
-impl From<ts::DLContext> for TVMContext {
-    fn from(ctx: ts::DLContext) -> Self {
+impl From<ffi::DLContext> for TVMContext {
+    fn from(ctx: ffi::DLContext) -> Self {
         TVMContext {
             device_type: TVMDeviceType::from(ctx.device_type),
-            device_id: ctx.device_id as usize,
+            device_id: ctx.device_id,
         }
     }
 }
 
-impl From<TVMContext> for ts::DLContext {
+impl From<TVMContext> for ffi::DLContext {
     fn from(ctx: TVMContext) -> Self {
-        ts::DLContext {
+        ffi::DLContext {
             device_type: ctx.device_type.into(),
             device_id: ctx.device_id as i32,
         }
index a10f83c..96a70ca 100644 (file)
@@ -1,51 +1,26 @@
-//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types.
+pub use failure::Error;
 
-use std::{ffi, option};
+#[derive(Debug, Fail)]
+#[fail(display = "Cannot convert from an empty array.")]
+pub struct EmptyArrayError;
 
-use crate::{common_errors, rust_ndarray};
-
-error_chain! {
-    errors {
-        EmptyArray {
-            description("cannot convert from an empty array")
-        }
-
-        NullHandle(name: String) {
-            description("null handle")
-            display("requested `{}` handle is null", name)
-        }
-
-        FunctionNotFound {
-            description("function not found")
-            display("function was not set in `function::Builder`")
-        }
-
-        TypeMismatch(expected: String, found: String) {
-            description("type mismatch!")
-            display("expected type `{}`, but found `{}`", expected, found)
-        }
-
-        MissingShapeError {
-            description("ndarray `shape()` returns `None`")
-            display("called `Option::unwrap()` on a `None` value")
-        }
-
-        AtMostOneReturn {
-            description("TVM functions accept at most one return value")
-        }
+#[derive(Debug, Fail)]
+#[fail(display = "Handle `{}` is null.", name)]
+pub struct NullHandleError {
+    pub name: String,
+}
 
-    }
+#[derive(Debug, Fail)]
+#[fail(display = "Function was not set in `function::Builder`")]
+pub struct FunctionNotFoundError;
 
-    foreign_links {
-        ShapeError(rust_ndarray::ShapeError);
-        NulError(ffi::NulError);
-        IntoStringError(ffi::IntoStringError);
-        CommonError(common_errors::Error);
-    }
+#[derive(Debug, Fail)]
+#[fail(display = "Expected type `{}` but found `{}`", expected, actual)]
+pub struct TypeMismatchError {
+    pub expected: String,
+    pub actual: String,
 }
 
-impl From<option::NoneError> for Error {
-    fn from(_err: option::NoneError) -> Self {
-        ErrorKind::MissingShapeError.into()
-    }
-}
+#[derive(Debug, Fail)]
+#[fail(display = "Missing NDArray shape.")]
+pub struct MissingShapeError;
index fa6bed1..f0fbcbe 100644 (file)
@@ -15,14 +15,20 @@ use std::{
     sync::Mutex,
 };
 
-use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue};
+use failure::Error;
+
+use crate::{
+    errors,
+    ffi::{self, TVMValue},
+    Module, TVMArgValue, TVMRetValue,
+};
 
 lazy_static! {
     static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
         let mut out_size = 0 as c_int;
         let name = ptr::null_mut() as *mut c_char;
         let mut out_array = name as *mut _;
-        check_call!(ts::TVMFuncListGlobalNames(
+        check_call!(ffi::TVMFuncListGlobalNames(
             &mut out_size as *mut _,
             &mut out_array
         ));
@@ -37,17 +43,14 @@ lazy_static! {
 }
 
 /// Wrapper around TVM function handle which includes `is_global`
-/// indicating whether the function is global or not, `is_released`
-/// to hint dropping the function handle and `is_cloned` showing
+/// indicating whether the function is global or not, and `is_cloned` showing
 /// not to drop a cloned function from Rust side.
 /// The value of these fields can be accessed through their respective methods.
 #[derive(Debug, Hash)]
 pub struct Function {
-    pub(crate) handle: ts::TVMFunctionHandle,
+    pub(crate) handle: ffi::TVMFunctionHandle,
     // whether the registered function is global or not.
     is_global: bool,
-    // whether the function has been dropped from frontend or not.
-    is_released: bool,
     // whether the function has been cloned from frontend or not.
     is_cloned: bool,
 }
@@ -56,29 +59,30 @@ unsafe impl Send for Function {}
 unsafe impl Sync for Function {}
 
 impl Function {
-    pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self {
+    pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
         Function {
             handle: handle,
-            is_global: is_global,
-            is_released: is_released,
+            is_global: false,
             is_cloned: false,
         }
     }
 
     /// For a given function, it returns a function by name.
-    pub fn get<S: AsRef<str>>(name: S, is_global: bool) -> Option<&'static Function> {
+    pub fn get<S: AsRef<str>>(name: S) -> Option<&'static Function> {
         let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
         globals.get_mut(name.as_ref()).and_then(|maybe_func| {
             if maybe_func.is_none() {
                 let name = CString::new(name.as_ref()).unwrap();
-                let mut handle = ptr::null_mut() as ts::TVMFunctionHandle;
-                check_call!(ts::TVMFuncGetGlobal(
+                let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
+                check_call!(ffi::TVMFuncGetGlobal(
                     name.as_ptr() as *const c_char,
                     &mut handle as *mut _
                 ));
-                maybe_func.replace(Function::new(
-                    handle, is_global, false, /* is_released */
-                ));
+                maybe_func.replace(Function {
+                    handle: handle,
+                    is_global: true,
+                    is_cloned: false,
+                });
             }
             unsafe {
                 std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
@@ -89,7 +93,7 @@ impl Function {
     }
 
     /// Returns the underlying TVM function handle.
-    pub fn handle(&self) -> ts::TVMFunctionHandle {
+    pub fn handle(&self) -> ffi::TVMFunctionHandle {
         self.handle
     }
 
@@ -98,12 +102,6 @@ impl Function {
         self.is_global
     }
 
-    /// Returns `true` if the underlying TVM function has been released
-    /// from the frontend and `false` otherwise.
-    pub fn is_released(&self) -> bool {
-        self.is_released
-    }
-
     /// Returns `true` if the underlying TVM function has been cloned
     /// from the frontend and `false` otherwise.
     pub fn is_cloned(&self) -> bool {
@@ -113,24 +111,18 @@ impl Function {
 
 impl Clone for Function {
     fn clone(&self) -> Function {
-        if !self.is_released && !self.is_cloned {
-            Self {
-                handle: self.handle,
-                is_global: self.is_global,
-                is_released: self.is_released,
-                is_cloned: true,
-            }
-        } else {
-            Function::new(self.handle, self.is_global, self.is_released)
+        Self {
+            handle: self.handle,
+            is_global: self.is_global,
+            is_cloned: true,
         }
     }
 }
 
 impl Drop for Function {
     fn drop(&mut self) {
-        if !self.is_released && !self.is_global && !self.is_cloned {
-            check_call!(ts::TVMFuncFree(self.handle));
-            self.is_released = true;
+        if !self.is_global && !self.is_cloned {
+            check_call!(ffi::TVMFuncFree(self.handle));
         }
     }
 }
@@ -138,17 +130,17 @@ impl Drop for Function {
 /// Function builder in order to create and call functions.
 ///
 /// *Note:* Currently TVM functions accept *at most* one return value.
-#[derive(Debug, Clone, Default)]
+#[derive(Default)]
 pub struct Builder<'a, 'm> {
     pub func: Option<&'m Function>,
-    pub arg_buf: Option<Box<[TVMArgValue<'a>]>>,
+    pub arg_buf: Vec<TVMArgValue<'a>>,
     pub ret_buf: Option<TVMRetValue>,
 }
 
 impl<'a, 'm> Builder<'a, 'm> {
     pub fn new(
         func: Option<&'m Function>,
-        arg_buf: Option<Box<[TVMArgValue<'a>]>>,
+        arg_buf: Vec<TVMArgValue<'a>>,
         ret_buf: Option<TVMRetValue>,
     ) -> Self {
         Self {
@@ -158,123 +150,66 @@ impl<'a, 'm> Builder<'a, 'm> {
         }
     }
 
-    pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self {
-        self.func = Function::get(name, is_global);
+    pub fn get_function(&mut self, name: &'m str) -> &mut Self {
+        self.func = Function::get(name);
         self
     }
 
     /// Pushes a [`TVMArgValue`] into the function argument buffer.
-    pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self
+    pub fn arg<T: 'a>(&mut self, arg: &'a T) -> &mut Self
     where
-        TVMValue: From<&'b T>,
-        TVMTypeCode: From<&'b T>,
+        TVMArgValue<'a>: From<&'a T>,
     {
-        let tvm_arg = TVMArgValue::from(arg);
-        if self.arg_buf.is_none() {
-            self.arg_buf = Some(Box::new([tvm_arg]));
-        } else {
-            let new_arg_buf = self.arg_buf.take().map(|bbuf| {
-                let mut new_arg_buf = Vec::from(bbuf);
-                new_arg_buf.push(tvm_arg);
-                let new_len = new_arg_buf.len();
-                new_arg_buf.truncate(new_len);
-                new_arg_buf.into_boxed_slice()
-            });
-            self.arg_buf = new_arg_buf;
-        }
+        self.arg_buf.push(arg.into());
         self
     }
 
     /// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
-    pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self
+    pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self
     where
-        I: IntoIterator<Item = &'b T>,
-        TVMValue: From<&'b T>,
-        TVMTypeCode: From<&'b T>,
+        I: IntoIterator<Item = &'a T>,
+        TVMArgValue<'a>: From<&'a T>,
     {
-        for arg in args {
+        args.into_iter().for_each(|arg| {
             self.arg(&arg);
-        }
+        });
         self
     }
 
     /// Sets an output for a function that requirs a mutable output to be provided.
     /// See the `basics` in tests for an example.
-    pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self>
+    pub fn set_output<T>(&mut self, ret: T) -> &mut Self
     where
-        TVMValue: From<&'b T>,
-        TVMTypeCode: From<&'b T>,
+        TVMRetValue: From<T>,
     {
-        if self.ret_buf.is_none() {
-            let tvm_ret =
-                unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) };
-            self.ret_buf = Some(tvm_ret);
-        } else {
-            bail!(ErrorKind::AtMostOneReturn)
-        }
-        Ok(self)
+        self.ret_buf = Some(ret.into());
+        self
     }
 
     /// Calls the function that created from `Builder`.
-    pub fn invoke(&mut self) -> Result<TVMRetValue> {
-        self.clone()(())
-    }
-}
-
-impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
-    type Output = Result<TVMRetValue>;
-    extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output {
-        if self.func.is_none() {
-            bail!("{}", ErrorKind::FunctionNotFound);
-        }
-
-        let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
-        let mut ret_type_code = 0 as c_int;
-        if self.arg_buf.is_some() {
-            let arg_buf = self.arg_buf?;
-            let mut num_args = arg_buf.len();
-            let mut values = arg_buf
-                .iter()
-                .map(|tav| tav.value.inner)
-                .collect::<Vec<ts::TVMValue>>();
-            let mut tcodes = arg_buf
-                .iter()
-                .map(|tav| tav.type_code as c_int)
-                .collect::<Vec<_>>();
-
-            if self.ret_buf.is_some() {
-                num_args = num_args + 1;
-                let ret_buf = self.ret_buf?;
-                let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf);
-                values.append(&mut vec![ret_val.inner]);
-                tcodes.append(&mut vec![ret_type_code as c_int]);
-            }
-
-            values.truncate(num_args);
-            tcodes.truncate(num_args);
-            check_call!(ts::TVMFuncCall(
-                self.func?.handle,
-                values.as_mut_ptr(),
-                tcodes.as_mut_ptr(),
-                num_args as c_int,
-                &mut ret_val as *mut _,
-                &mut ret_type_code as *mut _
-            ));
-        } else {
-            check_call!(ts::TVMFuncCall(
-                self.func?.handle,
-                ptr::null_mut(),
-                ptr::null_mut(),
-                0 as c_int,
-                &mut ret_val as *mut _,
-                &mut ret_type_code as *mut _
-            ));
-        }
+    pub fn invoke(&mut self) -> Result<TVMRetValue, Error> {
+        #![allow(unused_unsafe)]
+        ensure!(self.func.is_some(), errors::FunctionNotFoundError);
+
+        let num_args = self.arg_buf.len();
+        let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = self
+            .arg_buf
+            .iter()
+            .map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode))
+            .unzip();
+
+        let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() };
+        let mut ret_type_code = 0;
+        check_call!(ffi::TVMFuncCall(
+            self.func.ok_or(errors::FunctionNotFoundError)?.handle,
+            values.as_mut_ptr(),
+            type_codes.as_mut_ptr() as *mut i32,
+            num_args as c_int,
+            &mut ret_val as *mut _,
+            &mut ret_type_code as *mut _
+        ));
 
-        let ret = unsafe {
-            TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into())
-        };
-        Ok(ret)
+        Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64) })
     }
 }
 
@@ -282,46 +217,44 @@ impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
 /// TVM functions.
 impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
     fn from(func: &'m Function) -> Self {
-        Builder::new(Some(func), None, None)
+        Builder::new(Some(func), Vec::new(), None)
     }
 }
 
 /// Converts a mutable reference of a [`Module`] to [`Builder`].
 impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
     fn from(module: &'m mut Module) -> Self {
-        Builder::new(module.entry(), None, None)
+        Builder::new(module.entry(), Vec::new(), None)
     }
 }
 
 unsafe extern "C" fn tvm_callback(
-    args: *mut ts::TVMValue,
+    args: *mut ffi::TVMValue,
     type_codes: *mut c_int,
     num_args: c_int,
-    ret: ts::TVMRetValueHandle,
+    ret: ffi::TVMRetValueHandle,
     fhandle: *mut c_void,
 ) -> c_int {
     // turning off the incorrect linter complaints
-    #![allow(unused_assignments)]
+    #![allow(unused_assignments, unused_unsafe)]
     let len = num_args as usize;
     let args_list = slice::from_raw_parts_mut(args, len);
     let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
     let mut local_args: Vec<TVMArgValue> = Vec::new();
-    let mut value = mem::uninitialized::<ts::TVMValue>();
+    let mut value = mem::uninitialized::<ffi::TVMValue>();
     let mut tcode = mem::uninitialized::<c_int>();
-    let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
+    let rust_fn =
+        mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
     for i in 0..len {
         value = args_list[i];
         tcode = type_codes_list[i];
-        if tcode == ts::TVMTypeCode_kNodeHandle as c_int
-            || tcode == ts::TVMTypeCode_kFuncHandle as c_int
-            || tcode == ts::TVMTypeCode_kModuleHandle as c_int
+        if tcode == ffi::TVMTypeCode_kNodeHandle as c_int
+            || tcode == ffi::TVMTypeCode_kFuncHandle as c_int
+            || tcode == ffi::TVMTypeCode_kModuleHandle as c_int
         {
-            check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
+            check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
         }
-        local_args.push(TVMArgValue::new(
-            TVMValue::new(value),
-            (tcode as i64).into(),
-        ));
+        local_args.push(TVMArgValue::new(value.into(), (tcode as i64).into()));
     }
 
     let rv = match rust_fn(local_args.as_slice()) {
@@ -332,10 +265,9 @@ unsafe extern "C" fn tvm_callback(
         }
     };
 
-    let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv);
-    let mut ret_val = ret_val.inner;
+    let (mut ret_val, ret_tcode) = rv.into_tvm_value();
     let mut ret_type_code = ret_tcode as c_int;
-    check_call!(ts::TVMCFuncSetReturn(
+    check_call!(ffi::TVMCFuncSetReturn(
         ret,
         &mut ret_val as *mut _,
         &mut ret_type_code as *mut _,
@@ -345,24 +277,25 @@ unsafe extern "C" fn tvm_callback(
 }
 
 unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
-    let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
+    let rust_fn =
+        mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
     mem::drop(rust_fn);
 }
 
-fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function {
-    let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
-    let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue>;
-    check_call!(ts::TVMFuncCreateFromCFunc(
+fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function {
+    let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
+    let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>;
+    check_call!(ffi::TVMFuncCreateFromCFunc(
         Some(tvm_callback),
         resource_handle as *mut c_void,
         Some(tvm_callback_finalizer),
         &mut fhandle as *mut _
     ));
-    Function::new(fhandle, false, false)
+    Function::new(fhandle)
 }
 
 /// Registers a Rust function with signature
-/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>`
+/// `fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>`
 /// as a **global TVM packed function** from frontend to TVM backend.
 ///
 /// Use [`register_global_func`] if overriding an existing global TVM function
@@ -373,7 +306,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function
 /// ```
 /// use std::convert::TryInto;
 ///
-/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
 ///     let mut ret = 0i64;
 ///     for arg in args.iter() {
 ///         let arg: i64 = arg.try_into()?;
@@ -391,18 +324,17 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function
 /// assert_eq!(ret, 60);
 /// ```
 pub fn register<S: AsRef<str>>(
-    f: fn(&[TVMArgValue]) -> Result<TVMRetValue>,
+    f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>,
     name: S,
     override_: bool,
-) -> Result<()> {
+) -> Result<(), Error> {
     let func = convert_to_tvm_func(f);
     let name = CString::new(name.as_ref())?;
-    check_call!(ts::TVMFuncRegisterGlobal(
-        name.as_ref().as_ptr() as *const c_char,
+    check_call!(ffi::TVMFuncRegisterGlobal(
+        name.into_raw(),
         func.handle(),
         override_ as c_int
     ));
-    mem::forget(name);
     Ok(())
 }
 
@@ -416,7 +348,7 @@ pub fn register<S: AsRef<str>>(
 /// use std::convert::TryInto;
 ///
 /// register_global_func! {
-///     fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+///     fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
 ///         let mut ret = 0f64;
 ///         for arg in args.iter() {
 ///             let arg: f64 = arg.try_into()?;
@@ -437,12 +369,12 @@ pub fn register<S: AsRef<str>>(
 macro_rules! register_global_func {
     {
         $(#[$m:meta])*
-        fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue> {
+        fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue, Error> {
             $($code:tt)*
         }
     } => {{
         $(#[$m])*
-        fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue> {
+        fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
             $($code)*
         }
 
@@ -496,17 +428,17 @@ mod tests {
 
     #[test]
     fn get_fn() {
-        assert!(Function::get(CANARY, true).is_some());
-        assert!(Function::get("does not exists!", false).is_none());
+        assert!(Function::get(CANARY).is_some());
+        assert!(Function::get("does not exists!").is_none());
     }
 
     #[test]
     fn provide_args() {
+        let str_arg = CString::new("test").unwrap();
         let mut func = Builder::default();
-        func.get_function("tvm.graph_runtime.remote_create", true)
+        func.get_function("tvm.graph_runtime.remote_create")
             .args(&[10, 20])
-            .arg(&"test".to_owned());
-        assert!(func.arg_buf.is_some());
-        assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3));
+            .arg(&str_arg);
+        assert_eq!(func.arg_buf.len(), 3);
     }
 }
index 6e15e4f..a773b27 100644 (file)
 //!
 //! Checkout the `examples` repository for more details.
 
-#![crate_name = "tvm_frontend"]
-#![recursion_limit = "1024"]
-#![allow(non_camel_case_types, unused_unsafe)]
-#![feature(
-    try_from,
-    try_trait,
-    fn_traits,
-    unboxed_closures,
-    box_syntax,
-    option_replace
-)]
+#![feature(box_syntax)]
 
 #[macro_use]
-extern crate error_chain;
-extern crate tvm_common as common;
+extern crate failure;
 #[macro_use]
 extern crate lazy_static;
 extern crate ndarray as rust_ndarray;
 extern crate num_traits;
+extern crate tvm_common;
 
 use std::{
     ffi::{CStr, CString},
     str,
 };
 
-use crate::common::ffi::ts;
+use failure::Error;
+
+pub use crate::{
+    bytearray::TVMByteArray,
+    context::{TVMContext, TVMDeviceType},
+    errors::*,
+    function::Function,
+    module::Module,
+    ndarray::NDArray,
+    tvm_common::{
+        errors as common_errors,
+        ffi::{self, TVMType},
+        packed_func::{TVMArgValue, TVMRetValue},
+    },
+};
 
 // Macro to check the return call to TVM runtime shared library.
 macro_rules! check_call {
@@ -50,7 +54,7 @@ macro_rules! check_call {
 /// Gets the last error message.
 pub fn get_last_error() -> &'static str {
     unsafe {
-        match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
+        match CStr::from_ptr(ffi::TVMGetLastError()).to_str() {
             Ok(s) => s,
             Err(_) => "Invalid UTF-8 message",
         }
@@ -60,7 +64,7 @@ pub fn get_last_error() -> &'static str {
 pub(crate) fn set_last_error(err: &Error) {
     let c_string = CString::new(err.to_string()).unwrap();
     unsafe {
-        ts::TVMAPISetLastError(c_string.as_ptr());
+        ffi::TVMAPISetLastError(c_string.as_ptr());
     }
 }
 
@@ -71,27 +75,11 @@ pub mod context;
 pub mod errors;
 pub mod module;
 pub mod ndarray;
-pub mod ty;
 pub mod value;
 
-pub use crate::{
-    bytearray::TVMByteArray,
-    common::{
-        errors as common_errors,
-        ty::TVMTypeCode,
-        value::{TVMArgValue, TVMRetValue, TVMValue},
-    },
-    context::{TVMContext, TVMDeviceType},
-    errors::*,
-    function::Function,
-    module::Module,
-    ndarray::NDArray,
-    ty::TVMType,
-};
-
 /// Outputs the current TVM version.
 pub fn version() -> &'static str {
-    match str::from_utf8(ts::TVM_VERSION) {
+    match str::from_utf8(ffi::TVM_VERSION) {
         Ok(s) => s,
         Err(_) => "Invalid UTF-8 string",
     }
@@ -108,8 +96,8 @@ mod tests {
 
     #[test]
     fn set_error() {
-        let err = ErrorKind::EmptyArray;
+        let err = errors::EmptyArrayError;
         set_last_error(&err.into());
-        assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
+        assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string());
     }
 }
index c12d9d4..9c27387 100644 (file)
@@ -8,30 +8,27 @@ use std::{
     ptr,
 };
 
-use crate::ts;
+use failure::Error;
+use tvm_common::ffi;
 
-use crate::{function::Function, ErrorKind, Result};
+use crate::{errors, function::Function};
 
 const ENTRY_FUNC: &'static str = "__tvm_main__";
 
 /// Wrapper around TVM module handle which contains an entry function.
 /// The entry function can be applied to an imported module through [`entry_func`].
-/// Also [`is_released`] shows whether the module is dropped or not.
 ///
 /// [`entry_func`]:struct.Module.html#method.entry_func
-/// [`is_released`]:struct.Module.html#method.is_released
 #[derive(Debug, Clone)]
 pub struct Module {
-    pub(crate) handle: ts::TVMModuleHandle,
-    is_released: bool,
+    pub(crate) handle: ffi::TVMModuleHandle,
     entry_func: Option<Function>,
 }
 
 impl Module {
-    pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self {
+    pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
         Self {
             handle,
-            is_released,
             entry_func: None,
         }
     }
@@ -44,62 +41,67 @@ impl Module {
     }
 
     /// Gets a function by name from a registered module.
-    pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
+    pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> {
         let name = CString::new(name)?;
-        let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
-        check_call!(ts::TVMModGetFunction(
+        let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
+        check_call!(ffi::TVMModGetFunction(
             self.handle,
             name.as_ptr() as *const c_char,
             query_import as c_int,
             &mut fhandle as *mut _
         ));
-        if fhandle.is_null() {
-            bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
-        } else {
-            Ok(Function::new(fhandle, false, false))
-        }
+        ensure!(
+            !fhandle.is_null(),
+            errors::NullHandleError {
+                name: format!("{}", name.into_string()?)
+            }
+        );
+        Ok(Function::new(fhandle))
     }
 
     /// Imports a dependent module such as `.ptx` for gpu.
     pub fn import_module(&self, dependent_module: Module) {
-        check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
+        check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
     }
 
     /// Loads a module shared library from path.
-    pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> {
-        let ext = path.as_ref().extension()?.to_str()?;
-        let func = Function::get("module._LoadFromFile", true /* is_global */)
-            .expect("API function always exists");
-        let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?;
+    pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> {
+        let ext = CString::new(
+            path.as_ref()
+                .extension()
+                .unwrap_or(std::ffi::OsStr::new(""))
+                .to_str()
+                .ok_or_else(|| {
+                    format_err!("Bad module load path: `{}`.", path.as_ref().display())
+                })?,
+        )?;
+        let func = Function::get("module._LoadFromFile").expect("API function always exists");
+        let cpath =
+            CString::new(path.as_ref().to_str().ok_or_else(|| {
+                format_err!("Bad module load path: `{}`.", path.as_ref().display())
+            })?)?;
+        let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?;
         Ok(ret)
     }
 
     /// Checks if a target device is enabled for a module.
     pub fn enabled(&self, target: &str) -> bool {
-        let func = Function::get("module._Enabled", true /* is_global */)
-            .expect("API function always exists");
+        let func = Function::get("module._Enabled").expect("API function always exists");
         // `unwrap` is safe here because if there is any error during the
         // function call, it would occur in `call_packed!`.
-        let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap();
+        let tgt = CString::new(target).unwrap();
+        let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap();
         ret != 0
     }
 
     /// Returns the underlying module handle.
-    pub fn handle(&self) -> ts::TVMModuleHandle {
+    pub fn handle(&self) -> ffi::TVMModuleHandle {
         self.handle
     }
-
-    /// Returns true if the underlying module has been dropped and false otherwise.
-    pub fn is_released(&self) -> bool {
-        self.is_released
-    }
 }
 
 impl Drop for Module {
     fn drop(&mut self) {
-        if !self.is_released {
-            check_call!(ts::TVMModFree(self.handle));
-            self.is_released = true;
-        }
+        check_call!(ffi::TVMModFree(self.handle));
     }
 }
index 44dfcca..1939c92 100644 (file)
 //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
 //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
 
-use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice};
+use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
 
-use crate::rust_ndarray::{Array, ArrayD};
+use failure::Error;
 use num_traits::Num;
+use rust_ndarray::{Array, ArrayD};
+use tvm_common::{ffi, TVMType};
 
-use crate::ts;
-
-use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType};
+use crate::{errors, TVMByteArray, TVMContext};
 
 /// See the [`module-level documentation`](../ndarray/index.html) for more details.
 ///
 /// Wrapper around TVM array handle.
 #[derive(Debug)]
 pub struct NDArray {
-    pub(crate) handle: ts::TVMArrayHandle,
+    pub(crate) handle: ffi::TVMArrayHandle,
     is_view: bool,
 }
 
 impl NDArray {
-    pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self {
+    pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
         NDArray {
             handle: handle,
-            is_view: is_view,
+            is_view: true,
         }
     }
 
     /// Returns the underlying array handle.
-    pub fn handle(&self) -> ts::TVMArrayHandle {
+    pub fn handle(&self) -> ffi::TVMArrayHandle {
         self.handle
     }
 
@@ -99,12 +99,13 @@ impl NDArray {
     }
 
     /// Shows whether the underlying ndarray is contiguous in memory or not.
-    pub fn is_contiguous(&self) -> Result<bool> {
+    pub fn is_contiguous(&self) -> Result<bool, Error> {
         Ok(match self.strides() {
             None => true,
             Some(strides) => {
-                // MissingShapeError in case shape is not determined
-                self.shape()?
+                // errors::MissingShapeError in case shape is not determined
+                self.shape()
+                    .ok_or(errors::MissingShapeError)?
                     .iter()
                     .zip(strides)
                     .rfold(
@@ -138,14 +139,16 @@ impl NDArray {
     /// assert_eq!(ndarray.shape(), Some(shape));
     /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
     /// ```
-    pub fn to_vec<T>(&self) -> Result<Vec<T>> {
-        if self.shape().is_none() {
-            bail!("{}", ErrorKind::EmptyArray);
-        }
-        let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype());
+    pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> {
+        ensure!(self.shape().is_some(), errors::EmptyArrayError);
+        let earr = NDArray::empty(
+            self.shape().ok_or(errors::MissingShapeError)?,
+            TVMContext::cpu(0),
+            self.dtype(),
+        );
         let target = self.copy_to_ndarray(earr)?;
         let arr = unsafe { *(target.handle) };
-        let sz = self.size()? as usize;
+        let sz = self.size().ok_or(errors::MissingShapeError)?;
         let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
         unsafe {
             v.as_mut_ptr()
@@ -156,7 +159,7 @@ impl NDArray {
     }
 
     /// Converts the NDArray to [`TVMByteArray`].
-    pub fn to_bytearray(&self) -> Result<TVMByteArray> {
+    pub fn to_bytearray(&self) -> Result<TVMByteArray, Error> {
         let v = self.to_vec::<u8>()?;
         Ok(TVMByteArray::from(&v))
     }
@@ -176,7 +179,7 @@ impl NDArray {
     /// *Note*: if something goes wrong during the copy, it will panic
     /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
     pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
-        check_call!(ts::TVMArrayCopyFromBytes(
+        check_call!(ffi::TVMArrayCopyFromBytes(
             self.handle,
             data.as_ptr() as *mut _,
             data.len() * mem::size_of::<T>()
@@ -184,27 +187,31 @@ impl NDArray {
     }
 
     /// Copies the NDArray to another target NDArray.
-    pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
+    pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, Error> {
         if self.dtype() != target.dtype() {
             bail!(
                 "{}",
-                ErrorKind::TypeMismatch(
-                    format!("{}", self.dtype().to_string()),
-                    format!("{}", target.dtype().to_string()),
-                )
+                errors::TypeMismatchError {
+                    expected: format!("{}", self.dtype().to_string()),
+                    actual: format!("{}", target.dtype().to_string()),
+                }
             );
         }
-        check_call!(ts::TVMArrayCopyFromTo(
+        check_call!(ffi::TVMArrayCopyFromTo(
             self.handle,
             target.handle,
-            ptr::null_mut() as ts::TVMStreamHandle
+            ptr::null_mut() as ffi::TVMStreamHandle
         ));
         Ok(target)
     }
 
     /// Copies the NDArray to a target context.
-    pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray> {
-        let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype());
+    pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray, Error> {
+        let tmp = NDArray::empty(
+            self.shape().ok_or(errors::MissingShapeError)?,
+            target.clone(),
+            self.dtype(),
+        );
         let copy = self.copy_to_ndarray(tmp)?;
         Ok(copy)
     }
@@ -214,28 +221,34 @@ impl NDArray {
         rnd: &ArrayD<T>,
         ctx: TVMContext,
         dtype: TVMType,
-    ) -> Result<Self> {
+    ) -> Result<Self, Error> {
         let mut shape = rnd.shape().to_vec();
         let mut nd = NDArray::empty(&mut shape, ctx, dtype);
         let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
-        nd.copy_from_buffer(buf.as_slice_mut()?);
+        nd.copy_from_buffer(
+            buf.as_slice_mut()
+                .expect("Array from iter must be contiguous."),
+        );
         Ok(nd)
     }
 
     /// Allocates and creates an empty NDArray given the shape, context and dtype.
     pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray {
-        let mut handle = ptr::null_mut() as ts::TVMArrayHandle;
-        check_call!(ts::TVMArrayAlloc(
+        let mut handle = ptr::null_mut() as ffi::TVMArrayHandle;
+        check_call!(ffi::TVMArrayAlloc(
             shape.as_ptr() as *const i64,
             shape.len() as c_int,
-            dtype.inner.code as c_int,
-            dtype.inner.bits as c_int,
-            dtype.inner.lanes as c_int,
+            dtype.code as c_int,
+            dtype.bits as c_int,
+            dtype.lanes as c_int,
             ctx.device_type.0 as c_int,
             ctx.device_id as c_int,
             &mut handle as *mut _,
         ));
-        NDArray::new(handle, false)
+        NDArray {
+            handle,
+            is_view: false,
+        }
     }
 }
 
@@ -243,23 +256,25 @@ macro_rules! impl_from_ndarray_rustndarray {
     ($type:ty, $type_name:tt) => {
         impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
             type Error = Error;
-            fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
-                if nd.shape().is_none() {
-                    bail!("{}", ErrorKind::EmptyArray);
-                }
-                assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
-                Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
+            fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> {
+                ensure!(nd.shape().is_some(), errors::MissingShapeError);
+                assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch");
+                Ok(Array::from_shape_vec(
+                    &*nd.shape().ok_or(errors::MissingShapeError)?,
+                    nd.to_vec::<$type>()?,
+                )?)
             }
         }
 
         impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
             type Error = Error;
-            fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
-                if nd.shape().is_none() {
-                    bail!("{}", ErrorKind::EmptyArray);
-                }
-                assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
-                Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
+            fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> {
+                ensure!(nd.shape().is_some(), errors::MissingShapeError);
+                assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch");
+                Ok(Array::from_shape_vec(
+                    &*nd.shape().ok_or(errors::MissingShapeError)?,
+                    nd.to_vec::<$type>()?,
+                )?)
             }
         }
     };
@@ -272,7 +287,7 @@ impl_from_ndarray_rustndarray!(f32, "float");
 impl Drop for NDArray {
     fn drop(&mut self) {
         if !self.is_view {
-            check_call!(ts::TVMArrayFree(self.handle));
+            check_call!(ffi::TVMArrayFree(self.handle));
         }
     }
 }
@@ -306,7 +321,7 @@ mod tests {
     fn basics() {
         let shape = &mut [1, 2, 3];
         let ctx = TVMContext::cpu(0);
-        let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
+        let ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap());
         assert_eq!(ndarray.shape().unwrap(), shape);
         assert_eq!(
             ndarray.size().unwrap(),
@@ -322,7 +337,7 @@ mod tests {
         let shape = &mut [4];
         let mut data = vec![1i32, 2, 3, 4];
         let ctx = TVMContext::cpu(0);
-        let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
+        let mut ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap());
         assert!(ndarray.to_vec::<i32>().is_ok());
         ndarray.copy_from_buffer(&mut data);
         assert_eq!(ndarray.shape().unwrap(), shape);
@@ -331,7 +346,11 @@ mod tests {
         assert!(ndarray.is_contiguous().is_ok());
         assert_eq!(ndarray.byte_offset(), 0);
         let mut shape = vec![4];
-        let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32"));
+        let e = NDArray::empty(
+            &mut shape,
+            TVMContext::cpu(0),
+            TVMType::from_str("int32").unwrap(),
+        );
         let nd = ndarray.copy_to_ndarray(e);
         assert!(nd.is_ok());
         assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
@@ -343,9 +362,13 @@ mod tests {
         let mut shape = vec![4];
         let mut data = vec![1f32, 2., 3., 4.];
         let ctx = TVMContext::cpu(0);
-        let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32"));
+        let mut nd_float = NDArray::empty(
+            &mut shape,
+            ctx.clone(),
+            TVMType::from_str("float32").unwrap(),
+        );
         nd_float.copy_from_buffer(&mut data);
-        let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32"));
+        let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap());
         nd_float.copy_to_ndarray(empty_int).unwrap();
     }
 
@@ -354,8 +377,12 @@ mod tests {
         let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
             .unwrap()
             .into_dyn();
-        let nd =
-            NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+        let nd = NDArray::from_rust_ndarray(
+            &a,
+            TVMContext::cpu(0),
+            TVMType::from_str("float32").unwrap(),
+        )
+        .unwrap();
         assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
         let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
         assert!(rnd.all_close(&a, 1e-8f32));
diff --git a/rust/frontend/src/ty.rs b/rust/frontend/src/ty.rs
deleted file mode 100644 (file)
index 7e912a5..0000000
+++ /dev/null
@@ -1,150 +0,0 @@
-//! This module implements the required conversions from Rust types to TVM types.
-//!
-//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
-//! and 64-bits pointers are supported.
-
-use std::{
-    fmt::{self, Display, Formatter},
-    ops::{Deref, DerefMut},
-};
-
-use crate::ts;
-
-use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
-
-macro_rules! impl_prim_type {
-    ($type:ty, $variant:ident) => {
-        impl From<$type> for TVMTypeCode {
-            fn from(_arg: $type) -> Self {
-                TVMTypeCode::$variant
-            }
-        }
-
-        impl<'a> From<&'a $type> for TVMTypeCode {
-            fn from(_arg: &$type) -> Self {
-                TVMTypeCode::$variant
-            }
-        }
-
-        impl<'a> From<&'a mut $type> for TVMTypeCode {
-            fn from(_arg: &mut $type) -> Self {
-                TVMTypeCode::$variant
-            }
-        }
-    };
-}
-
-impl_prim_type!(TVMDeviceType, kDLInt);
-impl_prim_type!(TVMContext, kTVMContext);
-impl_prim_type!(TVMType, kTVMType);
-impl_prim_type!(Function, kFuncHandle);
-impl_prim_type!(Module, kModuleHandle);
-impl_prim_type!(NDArray, kArrayHandle);
-impl_prim_type!(TVMByteArray, kBytes);
-
-/// See the [module-level documentation](../ty/index.html) for more details.
-///
-/// Wrapper around underlying TVMType
-#[derive(Debug, Copy, Clone, PartialEq, Eq)]
-pub struct TVMType {
-    // inner fields are (code: u8, bits: u8, lanes: u16)
-    pub inner: ts::TVMType,
-}
-
-impl TVMType {
-    pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
-        TVMType {
-            inner: ts::TVMType {
-                code: type_code,
-                bits: bits,
-                lanes: lanes,
-            },
-        }
-    }
-}
-
-/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
-/// such as "int32", "float32" or with lane "float32x1".
-impl<'a> From<&'a str> for TVMType {
-    fn from(type_str: &'a str) -> Self {
-        if type_str == "bool" {
-            return TVMType::new(1, 1, 1);
-        }
-
-        let mut type_lanes = type_str.split("x");
-        let typ = type_lanes.next().expect("Missing dtype");
-        let lanes = type_lanes
-            .next()
-            .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
-            .unwrap_or(1);
-        let (type_name, bits) = match typ.find(char::is_numeric) {
-            Some(idx) => {
-                let (name, bits_str) = typ.split_at(idx);
-                (
-                    name,
-                    u8::from_str_radix(bits_str, 10)
-                        .expect(&format!("Bad dtype bits: {}", bits_str)),
-                )
-            }
-            None => (typ, 32),
-        };
-
-        let type_code = match type_name {
-            "int" => 0,
-            "uint" => 1,
-            "float" => 2,
-            "handle" => 3,
-            _ => unimplemented!(),
-        };
-
-        TVMType::new(type_code, bits, lanes)
-    }
-}
-
-impl Display for TVMType {
-    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
-        let ts::TVMType { code, bits, lanes } = self.inner;
-        if bits == 1 && lanes == 1 {
-            return write!(f, "bool");
-        }
-        let mut tcode_str = match code {
-            0 => "int",
-            1 => "uint",
-            2 => "float",
-            4 => "handle",
-            _ => "Unknown",
-        }
-        .to_string();
-
-        tcode_str += &bits.to_string();
-        if lanes > 1 {
-            tcode_str += &format!("x{}", lanes.to_string());
-        }
-        f.write_str(&tcode_str)
-    }
-}
-
-impl From<TVMType> for ts::DLDataType {
-    fn from(dtype: TVMType) -> Self {
-        dtype.inner
-    }
-}
-
-impl From<ts::DLDataType> for TVMType {
-    fn from(dtype: ts::DLDataType) -> Self {
-        Self::new(dtype.code, dtype.bits, dtype.lanes)
-    }
-}
-
-impl Deref for TVMType {
-    type Target = ts::TVMType;
-    fn deref(&self) -> &Self::Target {
-        &self.inner
-    }
-}
-
-impl DerefMut for TVMType {
-    fn deref_mut(&mut self) -> &mut Self::Target {
-        &mut self.inner
-    }
-}
index 9fad7de..eb62f10 100644 (file)
 //! and their conversions needed for the types used in frontend crate.
 //! `TVMRetValue` is the owned version of `TVMPODValue`.
 
-use std::{convert::TryFrom, mem, os::raw::c_void};
+use std::{convert::TryFrom, os::raw::c_void};
+
+use failure::Error;
+use tvm_common::{
+    ensure_type,
+    ffi::{self, TVMValue},
+};
 
 use crate::{
-    common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext,
-    TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue,
+    common_errors::*, context::TVMContext, Function, Module, NDArray, TVMArgValue, TVMByteArray,
+    TVMRetValue,
 };
 
 macro_rules! impl_tvm_val_from_handle {
-    ($($ty:ty),+) => {
-        $(
-            impl<'a> From<&'a $ty> for TVMValue {
-                fn from(arg: &$ty) -> Self {
-                    let inner = ts::TVMValue {
+    ($ty:ident, $type_code:expr, $handle:ty) => {
+        impl<'a> From<&'a $ty> for TVMArgValue<'a> {
+            fn from(arg: &$ty) -> Self {
+                TVMArgValue {
+                    value: TVMValue {
                         v_handle: arg.handle as *mut _ as *mut c_void,
-                    };
-                    Self::new(inner)
+                    },
+                    type_code: $type_code as i64,
+                    _lifetime: std::marker::PhantomData,
                 }
             }
-        )+
-    }
-}
-
-impl_tvm_val_from_handle!(Module, Function, NDArray);
-
-impl<'a> From<&'a TVMType> for TVMValue {
-    fn from(ty: &TVMType) -> Self {
-        let inner = ts::TVMValue { v_type: ty.inner };
-        Self::new(inner)
-    }
-}
-
-impl<'a> From<&'a TVMContext> for TVMValue {
-    fn from(ctx: &TVMContext) -> Self {
-        let inner = ts::TVMValue {
-            v_ctx: ctx.clone().into(),
-        };
-        Self::new(inner)
-    }
-}
-
-impl<'a> From<&'a TVMDeviceType> for TVMValue {
-    fn from(dev: &TVMDeviceType) -> Self {
-        let inner = ts::TVMValue {
-            v_int64: dev.0 as i64,
-        };
-        Self::new(inner)
-    }
-}
-
-impl<'a> From<&'a TVMByteArray> for TVMValue {
-    fn from(barr: &TVMByteArray) -> Self {
-        let inner = ts::TVMValue {
-            v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void,
-        };
-        Self::new(inner)
-    }
-}
+        }
 
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray {
-    type Error = Error;
-    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
-        if arg.type_code == TVMTypeCode::kArrayHandle {
-            let handle = unsafe { arg.value.inner.v_handle };
-            let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) };
-            Ok(Self::new(arr_handle, true))
-        } else {
-            bail!(ErrorKind::TryFromTVMArgValueError(
-                stringify!(NDArray).to_string(),
-                arg.type_code.to_string()
-            ))
+        impl<'a> From<&'a mut $ty> for TVMArgValue<'a> {
+            fn from(arg: &mut $ty) -> Self {
+                TVMArgValue {
+                    value: TVMValue {
+                        v_handle: arg.handle as *mut _ as *mut c_void,
+                    },
+                    type_code: $type_code as i64,
+                    _lifetime: std::marker::PhantomData,
+                }
+            }
         }
-    }
-}
 
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module {
-    type Error = Error;
-    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
-        if arg.type_code == TVMTypeCode::kModuleHandle {
-            let handle = unsafe { arg.value.inner.v_handle };
-            Ok(Self::new(handle, false))
-        } else {
-            bail!(ErrorKind::TryFromTVMArgValueError(
-                stringify!(Module).to_string(),
-                arg.type_code.to_string()
-            ))
+        impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty {
+            type Error = Error;
+            fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty, Self::Error> {
+                ensure_type!(arg, $type_code);
+                Ok($ty::new(unsafe { arg.value.v_handle as $handle }))
+            }
         }
-    }
-}
 
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray {
-    type Error = Error;
-    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
-        if arg.type_code == TVMTypeCode::kBytes {
-            unsafe {
-                let barr_ptr =
-                    mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle);
-                Ok(Self::new(*barr_ptr))
+        impl From<$ty> for TVMRetValue {
+            fn from(val: $ty) -> TVMRetValue {
+                TVMRetValue {
+                    value: TVMValue {
+                        v_handle: val.handle() as *mut c_void,
+                    },
+                    box_value: box val,
+                    type_code: $type_code as i64,
+                }
             }
-        } else {
-            bail!(ErrorKind::TryFromTVMArgValueError(
-                stringify!(TVMByteArray).to_string(),
-                arg.type_code.to_string()
-            ))
         }
-    }
-}
 
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType {
-    type Error = Error;
-    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
-        if arg.type_code == TVMTypeCode::kTVMType {
-            let ty = unsafe { arg.value.inner.v_type };
-            Ok(TVMType::from(ty))
-        } else {
-            bail!(ErrorKind::TryFromTVMArgValueError(
-                stringify!(TVMType).to_string(),
-                arg.type_code.to_string()
-            ))
+        impl TryFrom<TVMRetValue> for $ty {
+            type Error = Error;
+            fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
+                ensure_type!(ret, $type_code);
+                Ok($ty::new(unsafe { ret.value.v_handle as $handle }))
+            }
         }
-    }
+    };
 }
 
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext {
-    type Error = Error;
-    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
-        if arg.type_code == TVMTypeCode::kTVMContext {
-            let ty = unsafe { arg.value.inner.v_ctx };
-            Ok(TVMContext::from(ty))
-        } else {
-            bail!(ErrorKind::TryFromTVMArgValueError(
-                stringify!(TVMContext).to_string(),
-                arg.type_code.to_string()
-            ))
+impl_tvm_val_from_handle!(
+    Function,
+    ffi::TVMTypeCode_kFuncHandle,
+    ffi::TVMFunctionHandle
+);
+impl_tvm_val_from_handle!(Module, ffi::TVMTypeCode_kModuleHandle, ffi::TVMModuleHandle);
+impl_tvm_val_from_handle!(NDArray, ffi::TVMTypeCode_kArrayHandle, ffi::TVMArrayHandle);
+
+impl<'a> From<&'a TVMByteArray> for TVMValue {
+    fn from(barr: &TVMByteArray) -> Self {
+        TVMValue {
+            v_handle: &barr.inner as *const ffi::TVMByteArray as *mut c_void,
         }
     }
 }
@@ -144,78 +92,43 @@ macro_rules! impl_boxed_ret_value {
         impl From<$type> for TVMRetValue {
             fn from(val: $type) -> Self {
                 TVMRetValue {
-                    prim_value: 0,
+                    value: TVMValue { v_int64: 0 },
                     box_value: box val,
-                    type_code: $code,
+                    type_code: $code as i64,
                 }
             }
         }
         impl TryFrom<TVMRetValue> for $type {
             type Error = Error;
-            fn try_from(ret: TVMRetValue) -> Result<$type> {
+            fn try_from(ret: TVMRetValue) -> Result<$type, Self::Error> {
                 if let Ok(val) = ret.box_value.downcast::<$type>() {
                     Ok(*val)
                 } else {
-                    bail!(ErrorKind::TryFromTVMRetValueError(
-                        stringify!($type).to_string(),
-                        ret.type_code.to_string()
-                    ))
+                    bail!(ValueDowncastError::new($code as i64, ret.type_code as i64))
                 }
             }
         }
     };
 }
 
-impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType);
-impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext);
-impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes);
-
-impl TryFrom<TVMRetValue> for Module {
-    type Error = Error;
-    fn try_from(ret: TVMRetValue) -> Result<Module> {
-        if let Ok(handle) = ret.box_value.downcast::<ts::TVMModuleHandle>() {
-            Ok(Module::new(*handle, false))
-        } else {
-            bail!(ErrorKind::TryFromTVMRetValueError(
-                stringify!(TVMTypeCode::kModuleHandle).to_string(),
-                ret.type_code.to_string()
-            ))
-        }
-    }
-}
-
-impl TryFrom<TVMRetValue> for Function {
-    type Error = Error;
-    fn try_from(ret: TVMRetValue) -> Result<Function> {
-        if let Ok(handle) = ret.box_value.downcast::<ts::TVMFunctionHandle>() {
-            Ok(Function::new(*handle, false, false))
-        } else {
-            bail!(ErrorKind::TryFromTVMRetValueError(
-                stringify!(TVMTypeCode::kFuncHandle).to_string(),
-                ret.type_code.to_string()
-            ))
-        }
-    }
-}
+impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext);
+impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes);
 
-impl TryFrom<TVMRetValue> for NDArray {
+impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray {
     type Error = Error;
-    fn try_from(ret: TVMRetValue) -> Result<NDArray> {
-        if let Ok(handle) = ret.box_value.downcast::<ts::TVMArrayHandle>() {
-            Ok(NDArray::new(*handle, false))
-        } else {
-            bail!(ErrorKind::TryFromTVMRetValueError(
-                stringify!(TVMTypeCode::kArrayHandle).to_string(),
-                ret.type_code.to_string()
-            ))
-        }
+    fn try_from(arg: &TVMArgValue<'v>) -> Result<Self, Self::Error> {
+        ensure_type!(arg, ffi::TVMTypeCode_kBytes);
+        Ok(TVMByteArray::new(unsafe {
+            *(arg.value.v_handle as *mut ffi::TVMByteArray)
+        }))
     }
 }
 
 #[cfg(test)]
 mod tests {
     use super::*;
-    use std::convert::TryInto;
+    use std::{convert::TryInto, str::FromStr};
+    use tvm_common::ffi::TVMType;
 
     #[test]
     fn bytearray() {
@@ -227,7 +140,7 @@ mod tests {
 
     #[test]
     fn ty() {
-        let t = TVMType::from("int32");
+        let t = TVMType::from_str("int32").unwrap();
         let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
         assert_eq!(tvm, t);
     }
index 69b948e..55c537b 100644 (file)
@@ -1,6 +1,8 @@
 extern crate ndarray as rust_ndarray;
 extern crate tvm_frontend as tvm;
 
+use std::str::FromStr;
+
 use tvm::*;
 
 fn main() {
@@ -12,7 +14,7 @@ fn main() {
     } else {
         (TVMContext::gpu(0), "gpu")
     };
-    let dtype = TVMType::from("float32");
+    let dtype = TVMType::from_str("float32").unwrap();
     let mut arr = NDArray::empty(shape, ctx, dtype);
     arr.copy_from_buffer(data.as_mut_slice());
     let mut ret = NDArray::empty(shape, ctx, dtype);
@@ -26,8 +28,7 @@ fn main() {
     function::Builder::from(&mut fadd)
         .arg(&arr)
         .arg(&arr)
-        .set_output(&mut ret)
-        .unwrap()
+        .arg(&mut ret)
         .invoke()
         .unwrap();
 
index 81dcadc..e77ea43 100644 (file)
@@ -1,4 +1,3 @@
-#![feature(extern_crate_item_prelude, try_from)]
 #![allow(unused_imports)]
 
 extern crate ndarray as rust_ndarray;
@@ -6,17 +5,23 @@ extern crate ndarray as rust_ndarray;
 extern crate tvm_frontend as tvm;
 
 use rust_ndarray::ArrayD;
-use std::convert::{TryFrom, TryInto};
+use std::{
+    convert::{TryFrom, TryInto},
+    str::FromStr,
+};
 
-use tvm::*;
+use tvm::{errors::Error, *};
 
 fn main() {
     register_global_func! {
-        fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+        fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
             let mut ret = 0f32;
             let shape = &mut [2];
             for arg in args.iter() {
-                let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+                let e = NDArray::empty(
+                    shape, TVMContext::cpu(0),
+                    TVMType::from_str("float32").unwrap()
+                );
                 let arg: NDArray = arg.try_into()?;
                 let arr = arg.copy_to_ndarray(e)?;
                 let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
@@ -28,12 +33,16 @@ fn main() {
 
     let shape = &mut [2];
     let mut data = vec![3f32, 4.0];
-    let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+    let mut arr = NDArray::empty(
+        shape,
+        TVMContext::cpu(0),
+        TVMType::from_str("float32").unwrap(),
+    );
     arr.copy_from_buffer(data.as_mut_slice());
 
     let mut registered = function::Builder::default();
     let ret: f32 = registered
-        .get_function("sum", true)
+        .get_function("sum")
         .arg(&arr)
         .arg(&arr)
         .invoke()
index f40f0f1..24a1f07 100644 (file)
@@ -1,4 +1,4 @@
-#![feature(extern_crate_item_prelude, panic_info_message)]
+#![feature(panic_info_message)]
 #![allow(unused_imports)]
 
 use std::panic;
@@ -6,20 +6,20 @@ use std::panic;
 #[macro_use]
 extern crate tvm_frontend as tvm;
 
-use tvm::*;
+use tvm::{errors::Error, *};
 
 fn main() {
     register_global_func! {
-        fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> {
-            Err(ErrorKind::TypeMismatch(
-                format!("{}", "i64".to_string()),
-                format!("{}", "f64".to_string()),
-            ).into())
+        fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
+            Err(errors::TypeMismatchError{
+                expected: "i64".to_string(),
+                actual: "f64".to_string(),
+            }.into())
         }
     }
 
     let mut registered = function::Builder::default();
-    registered.get_function("error", true);
+    registered.get_function("error");
     assert!(registered.func.is_some());
     registered.args(&[10, 20]);
 
index 3070552..a26487b 100644 (file)
@@ -1,26 +1,25 @@
-#![feature(extern_crate_item_prelude, try_from)]
 #![allow(unused_imports)]
 
 #[macro_use]
 extern crate tvm_frontend as tvm;
 
 use std::convert::TryInto;
-use tvm::*;
+use tvm::{errors::Error, *};
 
 fn main() {
     register_global_func! {
-        fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+        fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
             let mut ret = 0.0;
-            for arg in args.iter() {
+            for arg in args.into_iter() {
                 let val: f64 = arg.try_into()?;
                 ret += val;
             }
-            Ok(TVMRetValue::from(&ret))
+            Ok(TVMRetValue::from(ret))
         }
     }
 
     let mut registered = function::Builder::default();
-    registered.get_function("sum", true);
+    registered.get_function("sum");
     assert!(registered.func.is_some());
     let ret: f64 = registered
         .args(&[10.0f64, 20.0, 30.0])
index 3018822..591f95a 100644 (file)
@@ -1,25 +1,24 @@
-#![feature(extern_crate_item_prelude, try_from)]
 #![allow(unused_imports)]
 
 extern crate tvm_frontend as tvm;
 
 use std::convert::TryInto;
-use tvm::*;
+use tvm::{errors::Error, *};
 
 fn main() {
-    fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+    fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
         let mut ret = 0i64;
         for arg in args.iter() {
             let val: i64 = arg.try_into()?;
             ret += val;
         }
-        Ok(TVMRetValue::from(&ret))
+        Ok(TVMRetValue::from(ret))
     }
 
     tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
 
     let mut registered = function::Builder::default();
-    registered.get_function("mysum", true);
+    registered.get_function("mysum");
     assert!(registered.func.is_some());
     let ret: i64 = registered
         .args(&[10, 20, 30])
index eafee31..3b2ad65 100644 (file)
@@ -1,31 +1,32 @@
-#![feature(extern_crate_item_prelude, try_from)]
 #![allow(unused_imports)]
 
 #[macro_use]
 extern crate tvm_frontend as tvm;
 use std::convert::TryInto;
-use tvm::*;
+use tvm::{errors::Error, *};
 
 // FIXME
 fn main() {
     register_global_func! {
-        fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+        fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
             let mut ret = "".to_string();
             for arg in args.iter() {
-                let val: String = arg.try_into()?;
-                ret += val.as_str();
+                let val: &str = arg.try_into()?;
+                ret += val;
             }
             Ok(TVMRetValue::from(ret))
         }
     }
+    let a = std::ffi::CString::new("a").unwrap();
+    let b = std::ffi::CString::new("b").unwrap();
+    let c = std::ffi::CString::new("c").unwrap();
     let mut registered = function::Builder::default();
-    registered.get_function("concate_str", true);
+    registered.get_function("concate_str");
     assert!(registered.func.is_some());
-    let a = "a".to_string();
-    let b = "b".to_string();
-    let c = "c".to_string();
     let ret: String = registered
-        .args(&[a, b, c])
+        .arg(&a)
+        .arg(&b)
+        .arg(&c)
         .invoke()
         .unwrap()
         .try_into()
diff --git a/rust/runtime/.gitignore b/rust/runtime/.gitignore
deleted file mode 100644 (file)
index 230ab66..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-Cargo.lock
-target/
-**/*.rs.bk
index d48c0d9..ae73ae7 100644 (file)
@@ -15,15 +15,15 @@ sgx = ["nom/alloc"]
 
 [dependencies]
 bounded-spsc-queue = "0.4.0"
-error-chain = { version = "0.12.0", default-features = false }
+failure = "0.1.5"
 itertools = "0.7.8"
 lazy_static = "1.1.0"
-ndarray = "0.11.2"
+ndarray="0.12.1"
 nom = {version = "4.0.0", default-features = false }
 serde = "1.0.59"
 serde_derive = "1.0.79"
 serde_json = "1.0.17"
-tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] }
+tvm-common = { version = "0.1.0", path = "../common/" }
 
 [target.'cfg(not(target_env = "sgx"))'.dependencies]
 num_cpus = "1.8.0"
index 5f77037..0514dce 100644 (file)
@@ -1,9 +1,7 @@
 #[cfg(target_env = "sgx")]
-use alloc::alloc::{self, Layout};
+use alloc::alloc::{self, Layout, LayoutErr};
 #[cfg(not(target_env = "sgx"))]
-use std::alloc::{self, Layout};
-
-use crate::errors::*;
+use std::alloc::{self, Layout, LayoutErr};
 
 const DEFAULT_ALIGN_BYTES: usize = 4;
 
@@ -15,7 +13,7 @@ pub struct Allocation {
 
 impl Allocation {
     /// Allocates a chunk of memory of `size` bytes with optional alignment.
-    pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
         let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
         let layout = Layout::from_size_align(size, alignment)?;
         let ptr = unsafe { alloc::alloc(layout.clone()) };
index 5c49515..3bb02f1 100644 (file)
@@ -1,23 +1,17 @@
-use std::{
-    any::TypeId,
-    convert::TryFrom,
-    mem,
-    ops::{Deref, DerefMut},
-    os::raw::{c_int, c_void},
-    ptr, slice,
-};
+use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
 
+use failure::Error;
 use ndarray;
-
-use crate::{
-    allocator::Allocation,
-    errors::*,
-    ffi::runtime::{
+use tvm_common::{
+    array::{DataType, TVMContext},
+    ffi::{
         DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
-        DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor,
+        DLDataTypeCode_kDLUInt, DLTensor,
     },
 };
 
+use crate::allocator::Allocation;
+
 /// A `Storage` is a container which holds `Tensor` data.
 #[derive(PartialEq)]
 pub enum Storage<'a> {
@@ -29,7 +23,7 @@ pub enum Storage<'a> {
 }
 
 impl<'a> Storage<'a> {
-    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
         Ok(Storage::Owned(Allocation::new(size, align)?))
     }
 
@@ -237,6 +231,27 @@ impl<'a> Tensor<'a> {
             byte_offset: 0,
         }
     }
+
+    pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
+        assert!(!flatten || self.is_contiguous());
+        DLTensor {
+            data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
+            ctx: DLContext::from(&self.ctx),
+            ndim: if flatten { 1 } else { self.shape.len() } as i32,
+            dtype: DLDataType::from(&self.dtype),
+            shape: if flatten {
+                &self.size as *const _ as *mut i64
+            } else {
+                self.shape.as_ptr()
+            } as *mut i64,
+            strides: if flatten || self.is_contiguous() {
+                ptr::null_mut()
+            } else {
+                self.strides.as_ref().unwrap().as_ptr()
+            } as *mut i64,
+            byte_offset: 0,
+        }
+    }
 }
 
 /// Conversions to `ndarray::Array` from `Tensor`, if the types match.
@@ -244,7 +259,7 @@ macro_rules! impl_ndarray_try_from_tensor {
     ($type:ty, $dtype:expr) => {
         impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
             type Error = Error;
-            fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
+            fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
                 ensure!(
                     tensor.dtype == $dtype,
                     "Cannot convert Tensor with dtype {:?} to ndarray",
@@ -263,120 +278,9 @@ macro_rules! impl_ndarray_try_from_tensor {
     };
 }
 
-impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
-impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
-impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
-impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
-
-pub struct DLTensor {
-    pub(crate) inner: _DLTensor,
-}
-
-impl Deref for DLTensor {
-    type Target = _DLTensor;
-    fn deref(&self) -> &Self::Target {
-        &self.inner
-    }
-}
-
-impl DerefMut for DLTensor {
-    fn deref_mut(&mut self) -> &mut Self::Target {
-        &mut self.inner
-    }
-}
-
-impl DLTensor {
-    pub(crate) fn new(raw: _DLTensor) -> Self {
-        Self { inner: raw }
-    }
-
-    pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
-        assert!(!flatten || tensor.is_contiguous());
-        Self {
-            inner: _DLTensor {
-                data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
-                ctx: DLContext::from(&tensor.ctx),
-                ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
-                dtype: DLDataType::from(&tensor.dtype),
-                shape: if flatten {
-                    &tensor.size as *const _ as *mut i64
-                } else {
-                    tensor.shape.as_ptr()
-                } as *mut i64,
-                strides: if flatten || tensor.is_contiguous() {
-                    ptr::null_mut()
-                } else {
-                    tensor.strides.as_ref().unwrap().as_ptr()
-                } as *mut i64,
-                byte_offset: 0,
-            },
-        }
-    }
-}
-
-impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
-    fn from(tensor: &'a Tensor<'t>) -> Self {
-        DLTensor::from_tensor(tensor, false /* flatten */)
-    }
-}
-
-impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
-    fn from(tensor: &'a mut Tensor<'t>) -> Self {
-        DLTensor::from_tensor(tensor, false /* flatten */)
-    }
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq)]
-pub struct DataType {
-    pub(crate) code: usize,
-    pub(crate) bits: usize,
-    pub(crate) lanes: usize,
-}
-
-impl DataType {
-    /// Returns the number of bytes occupied by an element of this `DataType`.
-    pub fn itemsize(&self) -> usize {
-        (self.bits * self.lanes) >> 3
-    }
-
-    /// Returns whether this `DataType` represents primitive type `T`.
-    pub fn is_type<T: 'static>(&self) -> bool {
-        if self.lanes != 1 {
-            return false;
-        }
-        let typ = TypeId::of::<T>();
-        (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
-            || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
-            || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
-            || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
-            || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
-            || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
-    }
-}
-
-impl<'a> From<&'a DataType> for DLDataType {
-    fn from(dtype: &'a DataType) -> Self {
-        Self {
-            code: dtype.code as u8,
-            bits: dtype.bits as u8,
-            lanes: dtype.lanes as u16,
-        }
-    }
-}
-
-impl From<DLDataType> for DataType {
-    fn from(dtype: DLDataType) -> Self {
-        Self {
-            code: dtype.code as usize,
-            bits: dtype.bits as usize,
-            lanes: dtype.lanes as usize,
-        }
-    }
-}
-
 macro_rules! make_dtype_const {
     ($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
-        const $name: DataType = DataType {
+        pub const $name: DataType = DataType {
             code: $code as usize,
             bits: $bits,
             lanes: $lanes,
@@ -389,28 +293,20 @@ make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
 // make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
 make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
 make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
+impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
+impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
+impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
+impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
 
-#[derive(Debug, Clone, Copy, PartialEq)]
-pub struct TVMContext {
-    pub(crate) device_type: usize,
-    pub(crate) device_id: usize,
-}
-
-impl<'a> From<&'a TVMContext> for DLContext {
-    fn from(ctx: &'a TVMContext) -> Self {
-        Self {
-            device_type: ctx.device_type as u32,
-            device_id: ctx.device_id as i32,
-        }
+impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
+    fn from(tensor: &'a Tensor<'t>) -> Self {
+        Tensor::as_dltensor(tensor, false /* flatten */)
     }
 }
 
-impl Default for TVMContext {
-    fn default() -> Self {
-        Self {
-            device_type: DLDeviceType_kDLCPU as usize,
-            device_id: 0,
-        }
+impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
+    fn from(tensor: &'a mut Tensor<'t>) -> Self {
+        Tensor::as_dltensor(tensor, false /* flatten */)
     }
 }
 
@@ -463,42 +359,6 @@ macro_rules! impl_tensor_from_ndarray {
     };
 }
 
-/// `From` conversions to `DLTensor` for `ndarray::Array`.
-/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
-macro_rules! impl_dltensor_from_ndarray {
-    ($type:ty, $typecode:expr) => {
-        impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
-            fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
-                DLTensor {
-                    inner: _DLTensor {
-                        data: arr.as_mut_ptr() as *mut c_void,
-                        ctx: DLContext {
-                            device_type: DLDeviceType_kDLCPU,
-                            device_id: 0,
-                        },
-                        ndim: arr.ndim() as c_int,
-                        dtype: DLDataType {
-                            code: $typecode as u8,
-                            bits: 8 * mem::size_of::<$type>() as u8,
-                            lanes: 1,
-                        },
-                        shape: arr.shape().as_ptr() as *const i64 as *mut i64,
-                        strides: arr.strides().as_ptr() as *const isize as *mut i64,
-                        byte_offset: 0,
-                    },
-                }
-            }
-        }
-    };
-}
-
-impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
-impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
-impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
-impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
-impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
-impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
-
 impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
 impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
 impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
index cf77230..26a8961 100644 (file)
@@ -1,36 +1,19 @@
-#[cfg(target_env = "sgx")]
-use alloc::alloc;
-#[cfg(not(target_env = "sgx"))]
-use std::alloc;
-use std::num;
-
-use crate::common::errors as common_errors;
-use ndarray;
-use serde_json;
-
-error_chain! {
-  errors {
-    GraphFormatError(msg: String) {
-      description("unable to load graph")
-      display("could not load graph json: {}", msg)
-    }
-
-    LoadGraphParamsError(msg: String) {
-      description("unable to load graph params")
-      display("could not load graph params: {}", msg)
-    }
-  }
-  foreign_links {
-    Alloc(alloc::AllocErr);
-    GraphDeserialize(serde_json::Error);
-    ParseInt(num::ParseIntError);
-    ShapeError(ndarray::ShapeError);
-    CommonError(common_errors::Error);
-  }
+#[derive(Debug, Fail)]
+pub enum GraphFormatError {
+    #[fail(display = "Could not parse graph json")]
+    Parse(#[fail(cause)] failure::Error),
+    #[fail(display = "Could not parse graph params")]
+    Params,
+    #[fail(display = "{} is missing attr: {}", 0, 1)]
+    MissingAttr(String, String),
+    #[fail(display = "Missing field: {}", 0)]
+    MissingField(&'static str),
+    #[fail(display = "Invalid DLType: {}", 0)]
+    InvalidDLType(String),
 }
 
-impl From<alloc::LayoutErr> for Error {
-    fn from(_err: alloc::LayoutErr) -> Error {
-        Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
-    }
+#[derive(Debug, Fail)]
+#[fail(display = "SGX error: 0x{:x}", code)]
+pub struct SgxError {
+    pub code: u32,
 }
index 0d5e281..6e00d9c 100644 (file)
@@ -1,16 +1,17 @@
 use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
 
+use failure::Error;
 use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
 use serde;
 use serde_json;
-
-use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor};
-use crate::{
-    common::value::TVMArgValue,
-    errors::{Error, ErrorKind, Result},
-    ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt},
+use tvm_common::{
+    array::{DataType, TVMContext},
+    ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor},
+    TVMArgValue,
 };
 
+use crate::{errors::GraphFormatError, Module, Storage, Tensor};
+
 // @see `kTVMNDArrayMagic` in `ndarray.h`
 const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
 // @see `kTVMNDArrayListMagic` in `graph_runtime.h`
@@ -41,28 +42,26 @@ pub struct Entry {
 }
 
 impl Graph {
-    fn entry_index(&self, entry: &Entry) -> Result<usize> {
+    fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
         self.node_row_ptr
             .as_ref()
             .map(|nrp| nrp[entry.id] + entry.index)
-            .ok_or("Missing node_row_ptr.".into())
+            .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
     }
 
     /// Attempt to deserialize a JSON attribute to a type `T`.
-    fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
+    fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
         Ok(serde_json::from_value::<T>(
             self.attrs
                 .as_ref()
-                .ok_or(ErrorKind::GraphFormatError(
-                    "Missing graph attrs".to_string(),
-                ))?
+                .ok_or(GraphFormatError::MissingField("attrs"))?
                 .get(attr)
-                .ok_or(ErrorKind::GraphFormatError(format!(
-                    "Missing {} attr",
-                    attr
-                )))?
+                .ok_or_else(|| {
+                    GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
+                })?
                 .to_owned(),
-        )?)
+        )
+        .map_err(|err| GraphFormatError::Parse(err.into()))?)
     }
 }
 
@@ -81,39 +80,31 @@ struct NodeAttrs {
     flatten_data: bool,
 }
 
+macro_rules! get_node_attr {
+    ($node:expr, $attrs:ident, $attr:literal) => {
+        $attrs
+            .get($attr)
+            .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
+    };
+}
+
 impl Node {
-    fn parse_attrs(&self) -> Result<NodeAttrs> {
+    fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
         let attrs = self
             .attrs
             .as_ref()
-            .ok_or(format!("Missing node.attrs for `{}`", self.name))?;
-        let func_name = attrs
-            .get("func_name")
-            .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
-            .to_string();
-        let num_outputs = attrs
-            .get("num_outputs")
-            .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
-            .parse::<usize>()?;
-        let flatten_data = attrs
-            .get("flatten_data")
-            .ok_or(format!(
-                "Node `{}` is missing attrs.flatten_data",
-                self.name
-            ))?
-            .parse::<u8>()?
-            == 1;
+            .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
         Ok(NodeAttrs {
-            func_name,
-            num_outputs,
-            flatten_data,
+            func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
+            num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
+            flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
         })
     }
 }
 
 impl<'a> TryFrom<&'a String> for Graph {
     type Error = Error;
-    fn try_from(graph_json: &String) -> Result<Self> {
+    fn try_from(graph_json: &String) -> Result<Self, self::Error> {
         let graph = serde_json::from_str(graph_json)?;
         Ok(graph)
     }
@@ -121,7 +112,7 @@ impl<'a> TryFrom<&'a String> for Graph {
 
 impl<'a> TryFrom<&'a str> for Graph {
     type Error = Error;
-    fn try_from(graph_json: &'a str) -> Result<Self> {
+    fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
         let graph = serde_json::from_str(graph_json)?;
         Ok(graph)
     }
@@ -161,7 +152,7 @@ pub struct GraphExecutor<'m, 't> {
 unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
 
 impl<'m, 't> GraphExecutor<'m, 't> {
-    pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
+    pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
         let tensors = Self::setup_storages(&graph)?;
         Ok(GraphExecutor {
             op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
@@ -178,7 +169,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
     }
 
     /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
-    fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
+    fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
         let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
         let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
         let dtypes = graph
@@ -189,18 +180,15 @@ impl<'m, 't> GraphExecutor<'m, 't> {
                 if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
                     Ok(dtype)
                 } else {
-                    Err(ErrorKind::GraphFormatError(
-                        format!("Invalid dltype: {}", dltype).to_string(),
-                    )
-                    .into())
+                    Err(GraphFormatError::InvalidDLType(dltype.to_string()))
                 }
             })
-            .collect::<Result<Vec<DataType>>>()?;
+            .collect::<Result<Vec<DataType>, GraphFormatError>>()?;
 
-        let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
+        let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max();
         let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
         for (i, &storage_id) in storage_ids.iter().enumerate() {
-            let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
+            let dtype_size = dtypes[i].bits() * dtypes[i].lanes() >> 3;
             let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
             storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
         }
@@ -208,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
         let mut storages: Vec<Storage> = storage_num_bytes
             .into_iter()
             .map(|nbytes| Storage::new(nbytes, align))
-            .collect::<Result<Vec<Storage>>>()?;
+            .collect::<Result<Vec<Storage>, Error>>()?;
 
         let tensors = izip!(storage_ids, shapes, dtypes)
             .map(|(storage_id, shape, dtype)| {
@@ -233,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
         graph: &Graph,
         lib: &'m M,
         tensors: &Vec<Tensor<'t>>,
-    ) -> Result<Vec<Box<Fn() + 'm>>> {
+    ) -> Result<Vec<Box<Fn() + 'm>>, Error> {
         ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
         let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
 
@@ -251,9 +239,10 @@ impl<'m, 't> GraphExecutor<'m, 't> {
                 continue;
             }
 
-            let func = lib
-                .get_function(&attrs.func_name)
-                .ok_or(format!("Missing function {}", attrs.func_name))?;
+            let func = lib.get_function(&attrs.func_name).ok_or(format_err!(
+                "Library is missing function {}",
+                attrs.func_name
+            ))?;
             let arg_indices = node
                 .inputs
                 .iter()
@@ -264,19 +253,19 @@ impl<'m, 't> GraphExecutor<'m, 't> {
                 .map(|idx| {
                     let tensor = &tensors[idx?];
                     Ok(if attrs.flatten_data {
-                        DLTensor::from_tensor(tensor, true /* flatten */)
+                        Tensor::as_dltensor(tensor, true /* flatten */)
                     } else {
                         DLTensor::from(tensor)
                     })
                 })
-                .collect::<Result<Vec<DLTensor>>>()
+                .collect::<Result<Vec<DLTensor>, Error>>()
                 .unwrap();
             let op: Box<Fn()> = box move || {
                 let args = dl_tensors
                     .iter()
                     .map(|t| t.into())
                     .collect::<Vec<TVMArgValue>>();
-                func(args.as_slice());
+                func(args.as_slice()).unwrap();
             };
             op_execs.push(op);
         }
@@ -344,7 +333,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
     }
 }
 
-/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
+// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
 named!(
   tvm_str_to_type<CompleteStr, DataType>,
   do_parse!(
@@ -367,7 +356,7 @@ named!(
   )
 );
 
-/// Converts a bytes to String.
+// Converts a bytes to String.
 named!(
     name<String>,
     map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
@@ -375,7 +364,7 @@ named!(
     ))
 );
 
-/// Parses a TVMContext
+// Parses a TVMContext
 named!(
   tvm_ctx<&[u8], TVMContext>,
   do_parse!(
@@ -385,7 +374,7 @@ named!(
   )
 );
 
-/// Parses a DataType
+// Parses a DataType
 named!(
   data_type<&[u8], DataType>,
   do_parse!(
@@ -396,7 +385,7 @@ named!(
   )
 );
 
-/// Parses a Tensor from a TVM array file.
+// Parses a Tensor from a TVM array file.
 named!(
     tensor<Tensor>,
     do_parse!(
@@ -420,7 +409,7 @@ named!(
     )
 );
 
-/// Parses a graph params dict from a params binary file.
+// Parses a graph params dict from a params binary file.
 named!(
     parse_param_dict<HashMap<String, Tensor>>,
     do_parse!(
@@ -433,17 +422,15 @@ named!(
 );
 
 /// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
-pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
+pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
     if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
-        if remaining_bytes.len() > 0 {
-            bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
-        } else {
+        if remaining_bytes.len() == 0 {
             Ok(param_dict)
+        } else {
+            Err(GraphFormatError::Params)
         }
     } else {
-        bail!(ErrorKind::LoadGraphParamsError(
-            "invalid parameters file".to_string()
-        ))
+        Err(GraphFormatError::Params)
     }
 }
 
index da030bc..848db27 100644 (file)
@@ -14,7 +14,6 @@
     allocator_api,
     box_syntax,
     fn_traits,
-    try_from,
     unboxed_closures,
     vec_remove_item
 )]
@@ -25,7 +24,7 @@ extern crate bounded_spsc_queue;
 #[cfg(target_env = "sgx")]
 extern crate core;
 #[macro_use]
-extern crate error_chain;
+extern crate failure;
 #[macro_use]
 extern crate itertools;
 #[macro_use]
@@ -39,36 +38,45 @@ extern crate serde;
 #[macro_use]
 extern crate serde_derive;
 extern crate serde_json;
-extern crate tvm_common as common;
+extern crate tvm_common;
 
 mod allocator;
 mod array;
 pub mod errors;
-mod module;
-#[macro_use]
-mod packed_func;
 mod graph;
+mod module;
 #[cfg(target_env = "sgx")]
 #[macro_use]
 pub mod sgx;
 mod threading;
 mod workspace;
 
-pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue};
-
-pub use self::{
-    array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*,
+pub use tvm_common::{
+    call_packed,
+    errors::*,
+    ffi::{self, DLTensor},
+    packed_func::{self, *},
+    TVMArgValue, TVMRetValue,
 };
 
-#[cfg(target_env = "sgx")]
-use self::sgx::ocall_packed_func;
+pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*};
+
+lazy_static! {
+    static ref LAST_ERROR: std::sync::RwLock<Option<&'static std::ffi::CStr>> =
+        std::sync::RwLock::new(None);
+}
 
 #[no_mangle]
 pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
-    #[cfg(not(target_env = "sgx"))]
-    unsafe {
-        panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
-    }
+    *LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) });
     #[cfg(target_env = "sgx")]
     ocall_packed!("__sgx_set_last_error__", cmsg);
 }
+
+#[no_mangle]
+pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char {
+    match *LAST_ERROR.read().unwrap() {
+        Some(err) => err.as_ptr(),
+        None => std::ptr::null(),
+    }
+}
index 8e6f7d6..636c4e8 100644 (file)
@@ -2,29 +2,29 @@ use std::{
     collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
 };
 
-use crate::{
-    ffi::runtime::BackendPackedCFunc,
-    packed_func::{wrap_backend_packed_func, PackedFunc},
+use tvm_common::{
+    ffi::BackendPackedCFunc,
+    packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
 };
 
 pub trait Module {
-    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
 }
 
 pub struct SystemLibModule;
 
 lazy_static! {
-    static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
+    static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
         Mutex::new(HashMap::new());
 }
 
 impl Module for SystemLibModule {
-    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
         SYSTEM_LIB_FUNCTIONS
             .lock()
             .unwrap()
             .get(name.as_ref())
-            .map(|func| wrap_backend_packed_func(func.to_owned()))
+            .map(|f| *f)
     }
 }
 
@@ -34,15 +34,42 @@ impl Default for SystemLibModule {
     }
 }
 
+// @see `WrapPackedFunc` in `llvm_module.cc`.
+pub(super) fn wrap_backend_packed_func(
+    func_name: String,
+    func: BackendPackedCFunc,
+) -> Box<dyn PackedFunc> {
+    box move |args: &[TVMArgValue]| {
+        let exit_code = func(
+            args.iter()
+                .map(|ref arg| arg.value)
+                .collect::<Vec<TVMValue>>()
+                .as_ptr(),
+            args.iter()
+                .map(|ref arg| arg.type_code as i32)
+                .collect::<Vec<i32>>()
+                .as_ptr() as *const i32,
+            args.len() as i32,
+        );
+        if exit_code == 0 {
+            Ok(TVMRetValue::default())
+        } else {
+            Err(tvm_common::errors::FuncCallError::get_with_context(
+                func_name.clone(),
+            ))
+        }
+    }
+}
+
 #[no_mangle]
 pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
     cname: *const c_char,
     func: BackendPackedCFunc,
 ) -> i32 {
     let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
-    SYSTEM_LIB_FUNCTIONS
-        .lock()
-        .unwrap()
-        .insert(name.to_string(), func);
+    SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
+        name.to_string(),
+        &*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
+    );
     return 0;
 }
diff --git a/rust/runtime/src/packed_func.rs b/rust/runtime/src/packed_func.rs
deleted file mode 100644 (file)
index 2fe0086..0000000
+++ /dev/null
@@ -1,118 +0,0 @@
-use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
-
-use super::Tensor;
-use crate::ffi::runtime::{
-    BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
-    TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
-};
-
-use super::DLTensor;
-use crate::{
-    common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
-    errors::*,
-};
-
-pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
-
-/// Calls a packed function and returns a `TVMRetValue`.
-///
-/// # Example
-///
-/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
-#[macro_export]
-macro_rules! call_packed {
-  ($fn:expr, $($args:expr),+) => {
-    $fn(&[$($args.into(),)+])
-  };
-  ($fn:expr) => {
-    $fn(&Vec::new())
-  };
-}
-
-impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
-    fn from(arr: &'a DLTensor) -> Self {
-        let raw = _TVMValue {
-            v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
-        };
-        TVMArgValue {
-            value: TVMValue::new(raw),
-            type_code: TVMTypeCode::kArrayHandle,
-            lifetime: PhantomData,
-        }
-    }
-}
-
-impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
-    fn from(arr: &'a mut DLTensor) -> Self {
-        let raw = _TVMValue {
-            v_handle: arr as *mut _ as *mut c_void,
-        };
-        TVMArgValue {
-            value: TVMValue::new(raw),
-            type_code: TVMTypeCode::kArrayHandle,
-            lifetime: PhantomData,
-        }
-    }
-}
-
-impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
-    type Error = Error;
-    fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
-        ensure!(
-            val.type_code == TVMTypeCode::kArrayHandle
-                || val.type_code == TVMTypeCode::kNDArrayContainer,
-            "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
-            TVMTypeCode::kArrayHandle,
-            TVMTypeCode::kNDArrayContainer,
-            val.type_code,
-        );
-
-        let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
-        Ok(DLTensor::new(dlt).into())
-    }
-}
-
-impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
-    fn from(val: &'t Tensor<'a>) -> Self {
-        TVMRetValue {
-            prim_value: 0,
-            box_value: box DLTensor::from(val),
-            type_code: TVMTypeCode::kNDArrayContainer,
-        }
-    }
-}
-
-impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
-    type Error = Error;
-    fn try_from(ret: TVMRetValue) -> Result<Self> {
-        ensure!(
-            ret.type_code == TVMTypeCode::kArrayHandle
-                || ret.type_code == TVMTypeCode::kNDArrayContainer,
-            "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
-            TVMTypeCode_kArrayHandle,
-            TVMTypeCode_kNDArrayContainer,
-            ret.type_code,
-        );
-
-        let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
-        Ok(DLTensor::new(dlt).into())
-    }
-}
-
-// @see `WrapPackedFunc` in `llvm_module.cc`.
-pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
-    box move |args: &[TVMArgValue]| {
-        func(
-            args.iter()
-                .map(|ref arg| arg.value.inner)
-                .collect::<Vec<_TVMValue>>()
-                .as_ptr(),
-            args.iter()
-                .map(|ref arg| arg.type_code as i32)
-                .collect::<Vec<i32>>()
-                .as_ptr() as *const i32,
-            args.len() as i32,
-        );
-        TVMRetValue::default()
-    }
-}
index 1edf3ef..42d3aa4 100644 (file)
@@ -3,18 +3,17 @@ use std::{
     os::raw::{c_char, c_int},
 };
 
-use errors::Result;
-use ffi::runtime::TVMValue;
-use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
-
-pub use runtime::threading::tvm_run_worker as run_worker;
+pub use crate::threading::tvm_run_worker as run_worker;
+use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
+use errors::SgxError;
+use ffi::TVMValue;
 
 #[macro_export]
 macro_rules! tvm_ocall {
     ($func: expr) => {
         match $func {
             0 => Ok(()),
-            err => Err(format!("SGX error: {}", err)),
+            code => Err(SgxError { code }),
         }
     };
 }
@@ -33,7 +32,10 @@ extern "C" {
     ) -> SgxStatus;
 }
 
-pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
+pub fn ocall_packed_func<S: AsRef<str>>(
+    fn_name: S,
+    args: &[TVMArgValue],
+) -> Result<TVMRetValue, SgxError> {
     let mut ret_val = TVMValue { v_int64: 0 };
     let ret_type_code = 0i64;
     unsafe {
@@ -58,11 +60,11 @@ pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Res
 #[macro_export]
 macro_rules! ocall_packed {
   ($fn_name:expr, $($args:expr),+) => {
-    ocall_packed_func($fn_name, &[$($args.into(),)+])
+    $crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
       .expect(concat!("Error calling `", $fn_name, "`"))
   };
   ($fn_name:expr) => {
-    ocall_packed_func($fn_name, &Vec::new())
+    $crate::sgx::ocall_packed_func($fn_name, &Vec::new())
       .expect(concat!("Error calling `", $fn_name, "`"))
   }
 }
index 38f4b7d..408c0b4 100644 (file)
@@ -1,7 +1,7 @@
 use std::{
     os::raw::{c_int, c_void},
     sync::{
-        atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
+        atomic::{AtomicUsize, Ordering},
         Arc, Barrier,
     },
 };
@@ -18,11 +18,10 @@ use std::{
 use std::{collections::VecDeque, ptr, sync::Mutex};
 
 use bounded_spsc_queue::{self, Producer};
-
-use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv};
+use tvm_common::ffi::TVMParallelGroupEnv;
 
 #[cfg(target_env = "sgx")]
-use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
+use super::{TVMArgValue, TVMRetValue};
 
 type FTVMParallelLambda =
     extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
@@ -62,12 +61,11 @@ impl Job {
     }
 
     /// Waits for all tasks in this `Job` to be completed.
-    fn wait(&self) -> Result<()> {
+    fn wait(&self) {
         while self.pending.load(Ordering::Acquire) > 0 {
             #[cfg(not(target_env = "sgx"))]
             thread::yield_now();
         }
-        Ok(())
     }
 }
 
@@ -161,7 +159,7 @@ impl ThreadPool {
         }
 
         tasks.pop().unwrap()();
-        job.wait().unwrap();
+        job.wait();
     }
 
     fn run_worker(queue: Consumer<Task>) {
@@ -251,7 +249,7 @@ pub extern "C" fn TVMBackendParallelLaunch(
                 cb: cb,
                 cdata: cdata,
                 req_num_tasks: num_task,
-                pending: Arc::new(ATOMIC_USIZE_INIT),
+                pending: Arc::new(AtomicUsize::new(0)),
             });
         });
     }
@@ -273,7 +271,7 @@ pub(crate) fn sgx_join_threads() {
             cb: poison_pill,
             cdata: ptr::null(),
             req_num_tasks: 0,
-            pending: Arc::new(ATOMIC_USIZE_INIT),
+            pending: Arc::new(AtomicUsize::new(0)),
         });
     });
     ocall_packed!("__sgx_thread_group_join__", 0);
@@ -322,8 +320,8 @@ mod tests {
     #[test]
     fn test_parallel_launch() {
         TVMBackendParallelLaunch(flambda, ptr::null(), 6);
-        let counter = ATOMIC_USIZE_INIT;
-        let task_ids_sum = ATOMIC_USIZE_INIT;
+        let counter = AtomicUsize::new(0);
+        let task_ids_sum = AtomicUsize::new(0);
         let cdata = (counter, task_ids_sum);
         let num_tasks = 3;
         TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
index a12a27e..1e29ec1 100644 (file)
@@ -4,8 +4,9 @@ use std::{
     ptr,
 };
 
-use super::allocator::Allocation;
-use crate::errors::*;
+use failure::Error;
+
+use crate::allocator::Allocation;
 
 const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
 
@@ -24,13 +25,13 @@ impl WorkspacePool {
         }
     }
 
-    fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
+    fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
         self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
         self.in_use.push(self.workspaces.len() - 1);
         Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
     }
 
-    fn alloc(&mut self, size: usize) -> Result<*mut u8> {
+    fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
         if self.free.len() == 0 {
             return self.alloc_new(size);
         }
@@ -60,7 +61,7 @@ impl WorkspacePool {
         }
     }
 
-    fn free(&mut self, ptr: *mut u8) -> Result<()> {
+    fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
         let mut ws_idx = None;
         for i in 0..self.in_use.len() {
             let idx = self.in_use[i];
@@ -72,7 +73,7 @@ impl WorkspacePool {
         }
         Ok(self
             .free
-            .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?))
+            .push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?))
     }
 }
 
index 14d0b39..259af23 100644 (file)
@@ -5,7 +5,7 @@ license = "Apache-2.0"
 authors = ["TVM Contributors"]
 
 [dependencies]
-ndarray = "0.11.2"
+ndarray="0.12.1"
 serde = "1.0.59"
 serde_json = "1.0.17"
 tvm-runtime = { path = "../../" }
index 2a753b4..561215d 100644 (file)
@@ -5,7 +5,7 @@ license = "Apache-2.0"
 authors = ["TVM Contributors"]
 
 [dependencies]
-ndarray = "0.11.2"
+ndarray="0.12.1"
 tvm-runtime = { path = "../../" }
 
 [build-dependencies]
index f14fbec..621315d 100644 (file)
@@ -17,6 +17,6 @@ fn main() {
     let mut a_dl: DLTensor = (&mut a).into();
     let mut b_dl: DLTensor = (&mut b).into();
     let mut c_dl: DLTensor = (&mut c).into();
-    call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
+    call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
     assert!(c.all_close(&e, 1e-8f32));
 }
index be0181b..6e17420 100755 (executable)
@@ -14,11 +14,11 @@ cargo fmt -- --check
 
 # test common
 cd $RUST_DIR/common
-cargo build --features runtime
-cargo test --features runtime --tests
+cargo build
+cargo test --tests
 
-cargo build --features frontend
-cargo test --features frontend --tests
+cargo build --features bindings
+cargo test --features bindings --tests
 
 # test runtime
 cd $RUST_DIR/runtime