[RUST][FRONTEND] Add rust frontend v0.1 (#2292)
authorEhsan M. Kermani <ehsanmo1367@gmail.com>
Sun, 3 Feb 2019 03:56:11 +0000 (19:56 -0800)
committerNick Hynes <nhynes@berkeley.edu>
Sun, 3 Feb 2019 03:56:11 +0000 (19:56 -0800)
91 files changed:
rust/.gitignore [deleted file]
rust/.rustfmt.toml
rust/.travis.yml [deleted file]
rust/Cargo.toml
rust/common/.gitignore [new file with mode: 0644]
rust/common/Cargo.toml [new file with mode: 0644]
rust/common/src/c_runtime_api.rs [new file with mode: 0644]
rust/common/src/errors.rs [new file with mode: 0644]
rust/common/src/lib.rs [new file with mode: 0644]
rust/common/src/ty.rs [new file with mode: 0644]
rust/common/src/value.rs [new file with mode: 0644]
rust/common/tvm-sys/Cargo.toml [new file with mode: 0644]
rust/common/tvm-sys/build.rs [new file with mode: 0644]
rust/common/tvm-sys/src/lib.rs [new file with mode: 0644]
rust/frontend/.gitignore [new file with mode: 0644]
rust/frontend/.travis.yml [new file with mode: 0644]
rust/frontend/Cargo.toml [new file with mode: 0644]
rust/frontend/README.md [new file with mode: 0644]
rust/frontend/examples/resnet/Cargo.toml [new file with mode: 0644]
rust/frontend/examples/resnet/README.md [new file with mode: 0644]
rust/frontend/examples/resnet/build.rs [new file with mode: 0644]
rust/frontend/examples/resnet/src/build_resnet.py [new file with mode: 0755]
rust/frontend/examples/resnet/src/main.rs [new file with mode: 0644]
rust/frontend/src/bytearray.rs [new file with mode: 0644]
rust/frontend/src/context.rs [new file with mode: 0644]
rust/frontend/src/errors.rs [new file with mode: 0644]
rust/frontend/src/function.rs [new file with mode: 0644]
rust/frontend/src/lib.rs [new file with mode: 0644]
rust/frontend/src/module.rs [new file with mode: 0644]
rust/frontend/src/ndarray.rs [new file with mode: 0644]
rust/frontend/src/ty.rs [new file with mode: 0644]
rust/frontend/src/value.rs [new file with mode: 0644]
rust/frontend/tests/basics/.gitignore [new file with mode: 0644]
rust/frontend/tests/basics/Cargo.toml [new file with mode: 0644]
rust/frontend/tests/basics/build.rs [new file with mode: 0644]
rust/frontend/tests/basics/src/main.rs [new file with mode: 0644]
rust/frontend/tests/basics/src/tvm_add.py [new file with mode: 0755]
rust/frontend/tests/callback/Cargo.toml [new file with mode: 0644]
rust/frontend/tests/callback/src/bin/array.rs [new file with mode: 0644]
rust/frontend/tests/callback/src/bin/error.rs [new file with mode: 0644]
rust/frontend/tests/callback/src/bin/float.rs [new file with mode: 0644]
rust/frontend/tests/callback/src/bin/int.rs [new file with mode: 0644]
rust/frontend/tests/callback/src/bin/string.rs [new file with mode: 0644]
rust/runtime/.gitignore [new file with mode: 0644]
rust/runtime/.travis.yml [new file with mode: 0644]
rust/runtime/Cargo.toml [new file with mode: 0644]
rust/runtime/src/allocator.rs [new file with mode: 0644]
rust/runtime/src/array.rs [new file with mode: 0644]
rust/runtime/src/errors.rs [new file with mode: 0644]
rust/runtime/src/graph.rs [new file with mode: 0644]
rust/runtime/src/lib.rs [new file with mode: 0644]
rust/runtime/src/module.rs [new file with mode: 0644]
rust/runtime/src/packed_func.rs [new file with mode: 0644]
rust/runtime/src/sgx.rs [new file with mode: 0644]
rust/runtime/src/threading.rs [new file with mode: 0644]
rust/runtime/src/workspace.rs [new file with mode: 0644]
rust/runtime/tests/.gitignore [new file with mode: 0644]
rust/runtime/tests/build_model.py [new file with mode: 0755]
rust/runtime/tests/test_graph_serde.rs [new file with mode: 0644]
rust/runtime/tests/test_nnvm/Cargo.toml [new file with mode: 0644]
rust/runtime/tests/test_nnvm/build.rs [new file with mode: 0644]
rust/runtime/tests/test_nnvm/src/build_test_graph.py [new file with mode: 0755]
rust/runtime/tests/test_nnvm/src/main.rs [new file with mode: 0644]
rust/runtime/tests/test_tvm_basic/Cargo.toml [new file with mode: 0644]
rust/runtime/tests/test_tvm_basic/build.rs [new file with mode: 0644]
rust/runtime/tests/test_tvm_basic/src/build_test_lib.py [new file with mode: 0755]
rust/runtime/tests/test_tvm_basic/src/main.rs [new file with mode: 0644]
rust/src/errors.rs [deleted file]
rust/src/lib.rs [deleted file]
rust/src/runtime/allocator.rs [deleted file]
rust/src/runtime/array.rs [deleted file]
rust/src/runtime/c_runtime_api.rs [deleted file]
rust/src/runtime/graph.rs [deleted file]
rust/src/runtime/mod.rs [deleted file]
rust/src/runtime/module.rs [deleted file]
rust/src/runtime/packed_func.rs [deleted file]
rust/src/runtime/sgx.rs [deleted file]
rust/src/runtime/threading.rs [deleted file]
rust/src/runtime/workspace.rs [deleted file]
rust/tests/.gitignore [deleted file]
rust/tests/build_model.py [deleted file]
rust/tests/test_graph_serde.rs [deleted file]
rust/tests/test_nnvm/Cargo.toml [deleted file]
rust/tests/test_nnvm/build.rs [deleted file]
rust/tests/test_nnvm/src/build_test_graph.py [deleted file]
rust/tests/test_nnvm/src/main.rs [deleted file]
rust/tests/test_tvm_basic/Cargo.toml [deleted file]
rust/tests/test_tvm_basic/build.rs [deleted file]
rust/tests/test_tvm_basic/src/build_test_lib.py [deleted file]
rust/tests/test_tvm_basic/src/main.rs [deleted file]
tests/scripts/task_rust.sh

diff --git a/rust/.gitignore b/rust/.gitignore
deleted file mode 100644 (file)
index 230ab66..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-Cargo.lock
-target/
-**/*.rs.bk
index 51e3cbfa76f145549cf5245894847d158f74a0ed..9e52f9efacc8ebbaa7ef2d0878362d28779894dc 100644 (file)
@@ -1,6 +1,6 @@
 max_width = 100
 hard_tabs = false
-tab_spaces = 2
+tab_spaces = 4
 newline_style = "Auto"
 use_small_heuristics = "Default"
 indent_style = "Block"
@@ -38,7 +38,7 @@ trailing_comma = "Vertical"
 match_block_trailing_comma = false
 blank_lines_upper_bound = 1
 blank_lines_lower_bound = 0
-edition = "2015"
+edition = "2018"
 merge_derives = true
 use_try_shorthand = true
 use_field_init_shorthand = false
@@ -50,8 +50,8 @@ unstable_features = false
 disable_all_formatting = false
 skip_children = false
 hide_parse_errors = false
-error_on_line_overflow = false
-error_on_unformatted = false
+error_on_line_overflow = true
+error_on_unformatted = true
 report_todo = "Never"
 report_fixme = "Never"
 ignore = []
diff --git a/rust/.travis.yml b/rust/.travis.yml
deleted file mode 100644 (file)
index 63a3d02..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-language: rust
-rust:
-  - nightly
-matrix:
-  fast_finish: true
index 4dd793e415a89db66f9c62dd25cb7606d9e45adf..448cbfe30d1e51635ceb77247a312d7d50e52b28 100644 (file)
@@ -1,28 +1,11 @@
-[package]
-name = "tvm"
-version = "0.1.0"
-license = "Apache-2.0"
-description = "TVM Rust runtime"
-repository = "https://github.com/dmlc/tvm"
-readme = "README.md"
-keywords = ["tvm", "nnvm"]
-categories = ["api-bindings", "science"]
-authors = ["TVM Contributors"]
-
-[features]
-default = ["nom/std"]
-sgx = ["nom/alloc"]
-
-[dependencies]
-bounded-spsc-queue = "0.4.0"
-error-chain = { version = "0.12.0", default-features = false }
-itertools = "0.7.8"
-lazy_static = "1.1.0"
-ndarray = "0.11.2"
-nom = {version = "4.0.0", default-features = false }
-serde = "1.0.59"
-serde_derive = "1.0.79"
-serde_json = "1.0.17"
-
-[target.'cfg(not(target_env = "sgx"))'.dependencies]
-num_cpus = "1.8.0"
+[workspace]
+members = [
+       "common",
+       "runtime",
+       "runtime/tests/test_tvm_basic",
+       "runtime/tests/test_nnvm",
+       "frontend",
+       "frontend/tests/basics",
+       "frontend/tests/callback",
+       "frontend/examples/resnet"
+]
diff --git a/rust/common/.gitignore b/rust/common/.gitignore
new file mode 100644 (file)
index 0000000..84c2ae9
--- /dev/null
@@ -0,0 +1,4 @@
+target
+**/*.rs.bk
+Cargo.lock
+/tvm-sys/src/bindgen.rs
diff --git a/rust/common/Cargo.toml b/rust/common/Cargo.toml
new file mode 100644 (file)
index 0000000..bcba5ad
--- /dev/null
@@ -0,0 +1,13 @@
+[package]
+name = "tvm-common"
+version = "0.1.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+
+[features]
+runtime = []
+frontend = ["tvm-sys"]
+
+[dependencies]
+error-chain = { version = "0.12.0", default-features = false }
+tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
diff --git a/rust/common/src/c_runtime_api.rs b/rust/common/src/c_runtime_api.rs
new file mode 100644 (file)
index 0000000..6facf9c
--- /dev/null
@@ -0,0 +1,770 @@
+/* automatically generated by rust-bindgen for TVM revision 6292c78 */
+
+pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
+pub const DLPACK_VERSION: u32 = 8;
+pub const _STDINT_H: u32 = 1;
+pub const _FEATURES_H: u32 = 1;
+pub const _DEFAULT_SOURCE: u32 = 1;
+pub const __USE_ISOC11: u32 = 1;
+pub const __USE_ISOC99: u32 = 1;
+pub const __USE_ISOC95: u32 = 1;
+pub const __USE_POSIX_IMPLICITLY: u32 = 1;
+pub const _POSIX_SOURCE: u32 = 1;
+pub const _POSIX_C_SOURCE: u32 = 200809;
+pub const __USE_POSIX: u32 = 1;
+pub const __USE_POSIX2: u32 = 1;
+pub const __USE_POSIX199309: u32 = 1;
+pub const __USE_POSIX199506: u32 = 1;
+pub const __USE_XOPEN2K: u32 = 1;
+pub const __USE_XOPEN2K8: u32 = 1;
+pub const _ATFILE_SOURCE: u32 = 1;
+pub const __USE_MISC: u32 = 1;
+pub const __USE_ATFILE: u32 = 1;
+pub const __USE_FORTIFY_LEVEL: u32 = 0;
+pub const _STDC_PREDEF_H: u32 = 1;
+pub const __STDC_IEC_559__: u32 = 1;
+pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
+pub const __STDC_ISO_10646__: u32 = 201505;
+pub const __STDC_NO_THREADS__: u32 = 1;
+pub const __GNU_LIBRARY__: u32 = 6;
+pub const __GLIBC__: u32 = 2;
+pub const __GLIBC_MINOR__: u32 = 23;
+pub const _SYS_CDEFS_H: u32 = 1;
+pub const __WORDSIZE: u32 = 64;
+pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
+pub const __SYSCALL_WORDSIZE: u32 = 64;
+pub const _BITS_WCHAR_H: u32 = 1;
+pub const INT8_MIN: i32 = -128;
+pub const INT16_MIN: i32 = -32768;
+pub const INT32_MIN: i32 = -2147483648;
+pub const INT8_MAX: u32 = 127;
+pub const INT16_MAX: u32 = 32767;
+pub const INT32_MAX: u32 = 2147483647;
+pub const UINT8_MAX: u32 = 255;
+pub const UINT16_MAX: u32 = 65535;
+pub const UINT32_MAX: u32 = 4294967295;
+pub const INT_LEAST8_MIN: i32 = -128;
+pub const INT_LEAST16_MIN: i32 = -32768;
+pub const INT_LEAST32_MIN: i32 = -2147483648;
+pub const INT_LEAST8_MAX: u32 = 127;
+pub const INT_LEAST16_MAX: u32 = 32767;
+pub const INT_LEAST32_MAX: u32 = 2147483647;
+pub const UINT_LEAST8_MAX: u32 = 255;
+pub const UINT_LEAST16_MAX: u32 = 65535;
+pub const UINT_LEAST32_MAX: u32 = 4294967295;
+pub const INT_FAST8_MIN: i32 = -128;
+pub const INT_FAST16_MIN: i64 = -9223372036854775808;
+pub const INT_FAST32_MIN: i64 = -9223372036854775808;
+pub const INT_FAST8_MAX: u32 = 127;
+pub const INT_FAST16_MAX: u64 = 9223372036854775807;
+pub const INT_FAST32_MAX: u64 = 9223372036854775807;
+pub const UINT_FAST8_MAX: u32 = 255;
+pub const UINT_FAST16_MAX: i32 = -1;
+pub const UINT_FAST32_MAX: i32 = -1;
+pub const INTPTR_MIN: i64 = -9223372036854775808;
+pub const INTPTR_MAX: u64 = 9223372036854775807;
+pub const UINTPTR_MAX: i32 = -1;
+pub const PTRDIFF_MIN: i64 = -9223372036854775808;
+pub const PTRDIFF_MAX: u64 = 9223372036854775807;
+pub const SIG_ATOMIC_MIN: i32 = -2147483648;
+pub const SIG_ATOMIC_MAX: u32 = 2147483647;
+pub const SIZE_MAX: i32 = -1;
+pub const WINT_MIN: u32 = 0;
+pub const WINT_MAX: u32 = 4294967295;
+pub type int_least8_t = ::std::os::raw::c_schar;
+pub type int_least16_t = ::std::os::raw::c_short;
+pub type int_least32_t = ::std::os::raw::c_int;
+pub type int_least64_t = ::std::os::raw::c_long;
+pub type uint_least8_t = ::std::os::raw::c_uchar;
+pub type uint_least16_t = ::std::os::raw::c_ushort;
+pub type uint_least32_t = ::std::os::raw::c_uint;
+pub type uint_least64_t = ::std::os::raw::c_ulong;
+pub type int_fast8_t = ::std::os::raw::c_schar;
+pub type int_fast16_t = ::std::os::raw::c_long;
+pub type int_fast32_t = ::std::os::raw::c_long;
+pub type int_fast64_t = ::std::os::raw::c_long;
+pub type uint_fast8_t = ::std::os::raw::c_uchar;
+pub type uint_fast16_t = ::std::os::raw::c_ulong;
+pub type uint_fast32_t = ::std::os::raw::c_ulong;
+pub type uint_fast64_t = ::std::os::raw::c_ulong;
+pub type intmax_t = ::std::os::raw::c_long;
+pub type uintmax_t = ::std::os::raw::c_ulong;
+pub type wchar_t = ::std::os::raw::c_int;
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct max_align_t {
+  pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
+  pub __bindgen_padding_0: u64,
+  pub __clang_max_align_nonce2: f64,
+}
+pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
+pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
+pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
+pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
+pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
+pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
+pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
+/// \brief The device type in DLContext.
+pub type DLDeviceType = u32;
+/// \brief A Device context for Tensor and operator.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct DLContext {
+  /// \brief The device type used in the device.
+  pub device_type: DLDeviceType,
+  /// \brief The device index
+  pub device_id: ::std::os::raw::c_int,
+}
+pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
+pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
+pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
+/// \brief The type code options DLDataType.
+pub type DLDataTypeCode = u32;
+/// \brief The data type the tensor can hold.
+///
+/// Examples
+/// - float: type_code = 2, bits = 32, lanes=1
+/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
+/// - int8: type_code = 0, bits = 8, lanes=1
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct DLDataType {
+  /// \brief Type code of base types.
+  /// We keep it uint8_t instead of DLDataTypeCode for minimal memory
+  /// footprint, but the value should be one of DLDataTypeCode enum values.
+  ///
+  pub code: u8,
+  /// \brief Number of bits, common choices are 8, 16, 32.
+  pub bits: u8,
+  /// \brief Number of lanes in the type, used for vector types.
+  pub lanes: u16,
+}
+/// \brief Plain C Tensor object, does not manage memory.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct DLTensor {
+  /// \brief The opaque data pointer points to the allocated data.
+  /// This will be CUDA device pointer or cl_mem handle in OpenCL.
+  /// This pointer is always aligns to 256 bytes as in CUDA.
+  pub data: *mut ::std::os::raw::c_void,
+  /// \brief The device context of the tensor
+  pub ctx: DLContext,
+  /// \brief Number of dimensions
+  pub ndim: ::std::os::raw::c_int,
+  /// \brief The data type of the pointer
+  pub dtype: DLDataType,
+  /// \brief The shape of the tensor
+  pub shape: *mut i64,
+  /// \brief strides of the tensor,
+  /// can be NULL, indicating tensor is compact.
+  pub strides: *mut i64,
+  /// \brief The offset in bytes to the beginning pointer to data
+  pub byte_offset: u64,
+}
+/// \brief C Tensor object, manage memory of DLTensor. This data structure is
+/// intended to faciliate the borrowing of DLTensor by another framework. It is
+/// not meant to transfer the tensor. When the borrowing framework doesn't need
+/// the tensor, it should call the deleter to notify the host that the resource
+/// is no longer needed.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct DLManagedTensor {
+  /// \brief DLTensor which is being memory managed
+  pub dl_tensor: DLTensor,
+  /// \brief the context of the original host framework of DLManagedTensor in
+  /// which DLManagedTensor is used in the framework. It can also be NULL.
+  pub manager_ctx: *mut ::std::os::raw::c_void,
+  /// \brief Destructor signature void (*)(void*) - this should be called
+  /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
+  /// if there is no way for the caller to provide a reasonable destructor.
+  pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
+}
+/// \brief type of array index.
+pub type tvm_index_t = i64;
+pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
+pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
+pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
+pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
+pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
+/// \brief Extension device types in TVM
+pub type TVMDeviceExtType = u32;
+pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
+pub const TVMTypeCode_kNull: TVMTypeCode = 4;
+pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
+pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
+pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
+pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
+pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
+pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
+pub const TVMTypeCode_kStr: TVMTypeCode = 11;
+pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
+pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
+pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
+pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
+pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
+pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
+pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
+/// \brief The type code in TVMType
+/// \note TVMType is used in two places.
+pub type TVMTypeCode = u32;
+/// \brief The data type used in TVM Runtime.
+///
+/// Examples
+/// - float: type_code = 2, bits = 32, lanes=1
+/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
+/// - int8: type_code = 0, bits = 8, lanes=1
+///
+/// \note Arguments TVM API function always takes bits=64 and lanes=1
+pub type TVMType = DLDataType;
+/// \brief The Device information, abstract away common device types.
+pub type TVMContext = DLContext;
+/// \brief The tensor array stucture to TVM API.
+pub type TVMArray = DLTensor;
+/// \brief the array handle
+pub type TVMArrayHandle = *mut TVMArray;
+/// \brief Union type of values
+/// being passed through API and function calls.
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub union TVMValue {
+  pub v_int64: i64,
+  pub v_float64: f64,
+  pub v_handle: *mut ::std::os::raw::c_void,
+  pub v_str: *const ::std::os::raw::c_char,
+  pub v_type: TVMType,
+  pub v_ctx: TVMContext,
+  _bindgen_union_align: u64,
+}
+/// \brief Byte array type used to pass in byte array
+/// When kBytes is used as data type.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct TVMByteArray {
+  pub data: *const ::std::os::raw::c_char,
+  pub size: usize,
+}
+/// \brief Handle to TVM runtime modules.
+pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
+/// \brief Handle to packed function handle.
+pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
+/// \brief Handle to hold return value.
+pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
+/// \brief The stream that is specific to device
+/// can be NULL, which indicates the default one.
+pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
+extern "C" {
+  /// \brief Used for implementing C API function.
+  /// Set last error message before return.
+  /// \param msg The error message to be set.
+  pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
+}
+extern "C" {
+  /// \brief return str message of the last error
+  /// all function in this file will return 0 when success
+  /// and -1 when an error occured,
+  /// TVMGetLastError can be called to retrieve the error
+  ///
+  /// this function is threadsafe and can be called by different thread
+  /// \return error info
+  pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
+}
+extern "C" {
+  /// \brief Load module from file.
+  /// \param file_name The file name to load the module from.
+  /// \param format The format of the module.
+  /// \param out The result module
+  ///
+  /// \return 0 when success, -1 when failure happens
+  /// \note The resulting module do not contain import relation.
+  /// It can be reconstructed by TVMModImport.
+  pub fn TVMModLoadFromFile(
+    file_name: *const ::std::os::raw::c_char,
+    format: *const ::std::os::raw::c_char,
+    out: *mut TVMModuleHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Add dep to mod's dependency.
+  /// This allows functions in this module to use modules.
+  ///
+  /// \param mod The module handle.
+  /// \param dep The dependent module to be imported.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Get function from the module.
+  /// \param mod The module handle.
+  /// \param func_name The name of the function.
+  /// \param query_imports Whether to query imported modules
+  /// \param out The result function, can be NULL if it is not available.
+  /// \return 0 when no error is thrown, -1 when failure happens
+  pub fn TVMModGetFunction(
+    mod_: TVMModuleHandle,
+    func_name: *const ::std::os::raw::c_char,
+    query_imports: ::std::os::raw::c_int,
+    out: *mut TVMFunctionHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Free front-end extension type resource.
+  /// \param handle The extension handle.
+  /// \param type_code The type of of the extension type.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMExtTypeFree(
+    handle: *mut ::std::os::raw::c_void,
+    type_code: ::std::os::raw::c_int,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Free the Module
+  /// \param mod The module to be freed.
+  ///
+  /// \note This may not free up the module's resources.
+  /// If there is active TVMFunctionHandle uses the module
+  /// Or if this module is imported by another active module.
+  ///
+  /// The all functions remains valid until TVMFuncFree is called.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Free the function when it is no longer needed.
+  /// \param func The function handle
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Call a Packed TVM Function.
+  ///
+  /// \param func node handle of the function.
+  /// \param arg_values The arguments
+  /// \param type_codes The type codes of the arguments
+  /// \param num_args Number of arguments.
+  ///
+  /// \param ret_val The return value.
+  /// \param ret_type_code the type code of return value.
+  ///
+  /// \return 0 when success, -1 when failure happens
+  /// \note TVM calls always exchanges with type bits=64, lanes=1
+  ///
+  /// \note API calls always exchanges with type bits=64, lanes=1
+  /// If API call returns container handles (e.g. FunctionHandle)
+  /// these handles should be managed by the front-end.
+  /// The front-end need to call free function (e.g. TVMFuncFree)
+  /// to free these handles.
+  pub fn TVMFuncCall(
+    func: TVMFunctionHandle,
+    arg_values: *mut TVMValue,
+    type_codes: *mut ::std::os::raw::c_int,
+    num_args: ::std::os::raw::c_int,
+    ret_val: *mut TVMValue,
+    ret_type_code: *mut ::std::os::raw::c_int,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Set the return value of TVMPackedCFunc.
+  ///
+  /// This function is called by TVMPackedCFunc to set the return value.
+  /// When this function is not called, the function returns null by default.
+  ///
+  /// \param ret The return value handle, pass by ret in TVMPackedCFunc
+  /// \param value The value to be returned.
+  /// \param type_code The type of the value to be returned.
+  /// \param num_ret Number of return values, for now only 1 is supported.
+  pub fn TVMCFuncSetReturn(
+    ret: TVMRetValueHandle,
+    value: *mut TVMValue,
+    type_code: *mut ::std::os::raw::c_int,
+    num_ret: ::std::os::raw::c_int,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Inplace translate callback argument value to return value.
+  /// This is only needed for non-POD arguments.
+  ///
+  /// \param value The value to be translated.
+  /// \param code The type code to be translated.
+  /// \note This function will do a shallow copy when necessary.
+  ///
+  /// \return 0 when success, -1 when failure happens.
+  pub fn TVMCbArgToReturn(
+    value: *mut TVMValue,
+    code: ::std::os::raw::c_int,
+  ) -> ::std::os::raw::c_int;
+}
+/// \brief C type of packed function.
+///
+/// \param args The arguments
+/// \param type_codes The type codes of the arguments
+/// \param num_args Number of arguments.
+/// \param ret The return value handle.
+/// \param resource_handle The handle additional resouce handle from fron-end.
+/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
+/// \sa TVMCFuncSetReturn
+pub type TVMPackedCFunc = ::std::option::Option<
+  unsafe extern "C" fn(
+    args: *mut TVMValue,
+    type_codes: *mut ::std::os::raw::c_int,
+    num_args: ::std::os::raw::c_int,
+    ret: TVMRetValueHandle,
+    resource_handle: *mut ::std::os::raw::c_void,
+  ) -> ::std::os::raw::c_int,
+>;
+/// \brief C callback to free the resource handle in C packed function.
+/// \param resource_handle The handle additional resouce handle from fron-end.
+pub type TVMPackedCFuncFinalizer =
+  ::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
+/// \brief Signature for extension function declarer.
+///
+/// TVM call this function to get the extension functions
+/// The declarer will call register_func to register function and their name.
+///
+/// \param register_func_handle The register function
+/// \return 0 if success, -1 if failure happens
+pub type TVMExtensionFuncDeclarer = ::std::option::Option<
+  unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
+>;
+extern "C" {
+  /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
+  ///
+  /// The resource_handle will be managed by TVM API, until the function is no longer used.
+  ///
+  /// \param func The packed C function.
+  /// \param resource_handle The resource handle from front-end, can be NULL.
+  /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
+  /// \param out the result function handle.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMFuncCreateFromCFunc(
+    func: TVMPackedCFunc,
+    resource_handle: *mut ::std::os::raw::c_void,
+    fin: TVMPackedCFuncFinalizer,
+    out: *mut TVMFunctionHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Register the function to runtime's global table.
+  ///
+  /// The registered function then can be pulled by the backend by the name.
+  ///
+  /// \param name The name of the function.
+  /// \param f The function to be registered.
+  /// \param override Whether allow override already registered function.
+  pub fn TVMFuncRegisterGlobal(
+    name: *const ::std::os::raw::c_char,
+    f: TVMFunctionHandle,
+    override_: ::std::os::raw::c_int,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Get a global function.
+  ///
+  /// \param name The name of the function.
+  /// \param out the result function pointer, NULL if it does not exist.
+  ///
+  /// \note The function handle of global function is managed by TVM runtime,
+  /// So TVMFuncFree is should not be called when it get deleted.
+  pub fn TVMFuncGetGlobal(
+    name: *const ::std::os::raw::c_char,
+    out: *mut TVMFunctionHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief List all the globally registered function name
+  /// \param out_size The number of functions
+  /// \param out_array The array of function names.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMFuncListGlobalNames(
+    out_size: *mut ::std::os::raw::c_int,
+    out_array: *mut *mut *const ::std::os::raw::c_char,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Allocate a nd-array's memory,
+  /// including space of shape, of given spec.
+  ///
+  /// \param shape The shape of the array, the data content will be copied to out
+  /// \param ndim The number of dimension of the array.
+  /// \param dtype_code The type code of the dtype
+  /// \param dtype_bits The number of bits of dtype
+  /// \param dtype_lanes The number of lanes in the dtype.
+  /// \param device_type The device type of context
+  /// \param device_id The device id of context.
+  /// \param out The output handle.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMArrayAlloc(
+    shape: *const tvm_index_t,
+    ndim: ::std::os::raw::c_int,
+    dtype_code: ::std::os::raw::c_int,
+    dtype_bits: ::std::os::raw::c_int,
+    dtype_lanes: ::std::os::raw::c_int,
+    device_type: ::std::os::raw::c_int,
+    device_id: ::std::os::raw::c_int,
+    out: *mut TVMArrayHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Free the TVM Array.
+  /// \param handle The array handle to be freed.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Copy array data from CPU byte array.
+  /// \param handle The array handle.
+  /// \param data the data pointer
+  /// \param nbytes The number of bytes to copy.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMArrayCopyFromBytes(
+    handle: TVMArrayHandle,
+    data: *mut ::std::os::raw::c_void,
+    nbytes: usize,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Copy array data to CPU byte array.
+  /// \param handle The array handle.
+  /// \param data the data pointer
+  /// \param nbytes The number of bytes to copy.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMArrayCopyToBytes(
+    handle: TVMArrayHandle,
+    data: *mut ::std::os::raw::c_void,
+    nbytes: usize,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Copy the array, both from and to must be valid during the copy.
+  /// \param from The array to be copied from.
+  /// \param to The target space.
+  /// \param stream The stream where the copy happens, can be NULL.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMArrayCopyFromTo(
+    from: TVMArrayHandle,
+    to: TVMArrayHandle,
+    stream: TVMStreamHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Produce an array from the DLManagedTensor that shares data memory
+  /// with the DLManagedTensor.
+  /// \param from The source DLManagedTensor.
+  /// \param out The output array handle.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMArrayFromDLPack(
+    from: *mut DLManagedTensor,
+    out: *mut TVMArrayHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Produce a DLMangedTensor from the array that shares data memory with
+  /// the array.
+  /// \param from The source array.
+  /// \param out The DLManagedTensor handle.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMArrayToDLPack(
+    from: TVMArrayHandle,
+    out: *mut *mut DLManagedTensor,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Delete (free) a DLManagedTensor's data.
+  /// \param dltensor Pointer to the DLManagedTensor.
+  pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
+}
+extern "C" {
+  /// \brief Create a new runtime stream.
+  ///
+  /// \param device_type The device type of context
+  /// \param device_id The device id of context
+  /// \param out The new stream handle
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMStreamCreate(
+    device_type: ::std::os::raw::c_int,
+    device_id: ::std::os::raw::c_int,
+    out: *mut TVMStreamHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Free a created stream handle.
+  ///
+  /// \param device_type The device type of context
+  /// \param device_id The device id of context
+  /// \param stream The stream to be freed
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMStreamFree(
+    device_type: ::std::os::raw::c_int,
+    device_id: ::std::os::raw::c_int,
+    stream: TVMStreamHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Set the runtime stream of current thread to be stream.
+  /// The subsequent calls to the same device_type
+  /// will use the setted stream handle.
+  /// The specific type of stream is runtime device dependent.
+  ///
+  /// \param device_type The device type of context
+  /// \param device_id The device id of context.
+  /// \param handle The stream handle.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMSetStream(
+    device_type: ::std::os::raw::c_int,
+    device_id: ::std::os::raw::c_int,
+    handle: TVMStreamHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Wait until all computations on stream completes.
+  ///
+  /// \param device_type The device type of context
+  /// \param device_id The device id of context.
+  /// \param stream The stream to be synchronized.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMSynchronize(
+    device_type: ::std::os::raw::c_int,
+    device_id: ::std::os::raw::c_int,
+    stream: TVMStreamHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Synchronize two streams of execution.
+  ///
+  /// \param device_type The device type of context
+  /// \param device_id The device id of context
+  /// \param src The source stream to synchronize.
+  /// \param dst The destination stream to synchronize.
+  /// \return 0 when success, -1 when failure happens
+  pub fn TVMStreamStreamSynchronize(
+    device_type: ::std::os::raw::c_int,
+    device_id: ::std::os::raw::c_int,
+    src: TVMStreamHandle,
+    dst: TVMStreamHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Backend function for modules to get function
+  /// from its environment mod_node (its imports and global function).
+  /// The user do should not call TVMFuncFree on func.
+  ///
+  /// \param mod_node The module handle.
+  /// \param func_name The name of the function.
+  /// \param out The result function.
+  /// \return 0 when no error is thrown, -1 when failure happens
+  pub fn TVMBackendGetFuncFromEnv(
+    mod_node: *mut ::std::os::raw::c_void,
+    func_name: *const ::std::os::raw::c_char,
+    out: *mut TVMFunctionHandle,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Backend function to register system-wide library symbol.
+  ///
+  /// \param name The name of the symbol
+  /// \param ptr The symbol address.
+  /// \return 0 when no error is thrown, -1 when failure happens
+  pub fn TVMBackendRegisterSystemLibSymbol(
+    name: *const ::std::os::raw::c_char,
+    ptr: *mut ::std::os::raw::c_void,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Backend function to allocate temporal workspace.
+  ///
+  /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
+  ///
+  /// \param nbytes The size of the space requested.
+  /// \param device_type The device type which the space will be allocated.
+  /// \param device_id The device id which the space will be allocated.
+  /// \param dtype_code_hint The type code of the array elements. Only used in
+  /// certain backends such as OpenGL.
+  /// \param dtype_bits_hint The type bits of the array elements. Only used in
+  /// certain backends such as OpenGL.
+  /// \return nullptr when error is thrown, a valid ptr if success
+  pub fn TVMBackendAllocWorkspace(
+    device_type: ::std::os::raw::c_int,
+    device_id: ::std::os::raw::c_int,
+    nbytes: u64,
+    dtype_code_hint: ::std::os::raw::c_int,
+    dtype_bits_hint: ::std::os::raw::c_int,
+  ) -> *mut ::std::os::raw::c_void;
+}
+extern "C" {
+  /// \brief Backend function to free temporal workspace.
+  ///
+  /// \param ptr The result allocated space pointer.
+  /// \param device_type The device type which the space will be allocated.
+  /// \param device_id The device id which the space will be allocated.
+  /// \return 0 when no error is thrown, -1 when failure happens
+  ///
+  /// \sa TVMBackendAllocWorkspace
+  pub fn TVMBackendFreeWorkspace(
+    device_type: ::std::os::raw::c_int,
+    device_id: ::std::os::raw::c_int,
+    ptr: *mut ::std::os::raw::c_void,
+  ) -> ::std::os::raw::c_int;
+}
+/// \brief Environment for TVM parallel task.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct TVMParallelGroupEnv {
+  /// \brief Auxiliary used for synchronization
+  pub sync_handle: *mut ::std::os::raw::c_void,
+  /// \brief total amount of task
+  pub num_task: i32,
+}
+/// \brief The callback function to execute a parallel lambda
+/// \param task_id the task id of the function.
+/// \param penv The parallel environment backs the execution.
+/// \param cdata The supporting closure data.
+pub type FTVMParallelLambda = ::std::option::Option<
+  unsafe extern "C" fn(
+    task_id: ::std::os::raw::c_int,
+    penv: *mut TVMParallelGroupEnv,
+    cdata: *mut ::std::os::raw::c_void,
+  ) -> ::std::os::raw::c_int,
+>;
+extern "C" {
+  /// \brief Backend function for running parallel jobs.
+  ///
+  /// \param flambda The parallel function to be launched.
+  /// \param cdata The closure data.
+  /// \param num_task Number of tasks to launch, can be 0, means launch
+  /// with all available threads.
+  ///
+  /// \return 0 when no error is thrown, -1 when failure happens
+  pub fn TVMBackendParallelLaunch(
+    flambda: FTVMParallelLambda,
+    cdata: *mut ::std::os::raw::c_void,
+    num_task: ::std::os::raw::c_int,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief BSP barrrier between parallel threads
+  /// \param task_id the task id of the function.
+  /// \param penv The parallel environment backs the execution.
+  /// \return 0 when no error is thrown, -1 when failure happens
+  pub fn TVMBackendParallelBarrier(
+    task_id: ::std::os::raw::c_int,
+    penv: *mut TVMParallelGroupEnv,
+  ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+  /// \brief Simple static initialization function.
+  /// Run f once and set handle to be not null.
+  /// This function is mainly used for test purpose.
+  ///
+  /// \param handle An global address to indicate f
+  /// \param f The function to be ran
+  /// \param cdata The closure data to pass to the function.
+  /// \param nbytes Number of bytes in the closure data.
+  /// \return 0 when no error is thrown, -1 when failure happens
+  pub fn TVMBackendRunOnce(
+    handle: *mut *mut ::std::os::raw::c_void,
+    f: ::std::option::Option<
+      unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
+    >,
+    cdata: *mut ::std::os::raw::c_void,
+    nbytes: ::std::os::raw::c_int,
+  ) -> ::std::os::raw::c_int;
+}
diff --git a/rust/common/src/errors.rs b/rust/common/src/errors.rs
new file mode 100644 (file)
index 0000000..a81fab9
--- /dev/null
@@ -0,0 +1,15 @@
+//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
+
+error_chain! {
+    errors {
+        TryFromTVMArgValueError(expected: String, actual: String) {
+              description("mismatched types while converting from TVMArgValue")
+              display("expected `{}` but given `{}`", expected, actual)
+        }
+
+        TryFromTVMRetValueError(expected: String, actual: String) {
+              description("mismatched types while downcasting TVMRetValue")
+              display("invalid downcast: expected `{}` but given `{}`", expected, actual)
+        }
+    }
+}
diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs
new file mode 100644 (file)
index 0000000..ad4c4f2
--- /dev/null
@@ -0,0 +1,39 @@
+//! This crate contains the refactored basic components required
+//! for `runtime` and `frontend` TVM crates.
+
+#![crate_name = "tvm_common"]
+#![recursion_limit = "1024"]
+#![allow(non_camel_case_types, unused_imports)]
+#![feature(box_syntax, try_from)]
+
+#[macro_use]
+extern crate error_chain;
+
+/// Unified ffi module for both runtime and frontend crates.
+pub mod ffi {
+    #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
+
+    #[cfg(feature = "frontend")]
+    pub extern crate tvm_sys as ts;
+
+    #[cfg(feature = "runtime")]
+    pub mod runtime {
+        use std::os::raw::{c_char, c_int, c_void};
+
+        include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
+
+        pub type BackendPackedCFunc = extern "C" fn(
+            args: *const TVMValue,
+            type_codes: *const c_int,
+            num_args: c_int,
+        ) -> c_int;
+    }
+}
+
+pub mod errors;
+pub mod ty;
+pub mod value;
+
+pub use errors::*;
+pub use ty::TVMTypeCode;
+pub use value::{TVMArgValue, TVMRetValue, TVMValue};
diff --git a/rust/common/src/ty.rs b/rust/common/src/ty.rs
new file mode 100644 (file)
index 0000000..126bcd4
--- /dev/null
@@ -0,0 +1,144 @@
+//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
+//!
+//! # Example
+//!
+//! ```
+//! let dtype = TVMType::from("float");
+//! println!("dtype is: {}", dtype);
+//! ```
+
+use std::{
+    ffi::{CStr, CString},
+    fmt::{self, Display, Formatter},
+};
+
+/// TVM type codes.
+#[repr(u32)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
+pub enum TVMTypeCode {
+    kDLInt = 0,
+    kDLUInt = 1,
+    kDLFloat = 2,
+    kHandle = 3,
+    kNull = 4,
+    kTVMType = 5,
+    kTVMContext = 6,
+    kArrayHandle = 7,
+    kNodeHandle = 8,
+    kModuleHandle = 9,
+    kFuncHandle = 10,
+    kStr = 11,
+    kBytes = 12,
+    kNDArrayContainer = 13,
+}
+
+impl Default for TVMTypeCode {
+    fn default() -> Self {
+        TVMTypeCode::kDLInt
+    }
+}
+
+impl From<TVMTypeCode> for i64 {
+    fn from(arg: TVMTypeCode) -> i64 {
+        match arg {
+            TVMTypeCode::kDLInt => 0,
+            TVMTypeCode::kDLUInt => 1,
+            TVMTypeCode::kDLFloat => 2,
+            TVMTypeCode::kHandle => 3,
+            TVMTypeCode::kNull => 4,
+            TVMTypeCode::kTVMType => 5,
+            TVMTypeCode::kTVMContext => 6,
+            TVMTypeCode::kArrayHandle => 7,
+            TVMTypeCode::kNodeHandle => 8,
+            TVMTypeCode::kModuleHandle => 9,
+            TVMTypeCode::kFuncHandle => 10,
+            TVMTypeCode::kStr => 11,
+            TVMTypeCode::kBytes => 12,
+            TVMTypeCode::kNDArrayContainer => 13,
+        }
+    }
+}
+
+impl Into<TVMTypeCode> for i64 {
+    fn into(self) -> TVMTypeCode {
+        match self {
+            0 => TVMTypeCode::kDLInt,
+            1 => TVMTypeCode::kDLUInt,
+            2 => TVMTypeCode::kDLFloat,
+            3 => TVMTypeCode::kHandle,
+            4 => TVMTypeCode::kNull,
+            5 => TVMTypeCode::kTVMType,
+            6 => TVMTypeCode::kTVMContext,
+            7 => TVMTypeCode::kArrayHandle,
+            8 => TVMTypeCode::kNodeHandle,
+            9 => TVMTypeCode::kModuleHandle,
+            10 => TVMTypeCode::kFuncHandle,
+            11 => TVMTypeCode::kStr,
+            12 => TVMTypeCode::kBytes,
+            13 => TVMTypeCode::kNDArrayContainer,
+            _ => unreachable!(),
+        }
+    }
+}
+
+impl Display for TVMTypeCode {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        write!(
+            f,
+            "{}",
+            match self {
+                TVMTypeCode::kDLInt => "int",
+                TVMTypeCode::kDLUInt => "uint",
+                TVMTypeCode::kDLFloat => "float",
+                TVMTypeCode::kHandle => "handle",
+                TVMTypeCode::kNull => "null",
+                TVMTypeCode::kTVMType => "TVM type",
+                TVMTypeCode::kTVMContext => "TVM context",
+                TVMTypeCode::kArrayHandle => "Array handle",
+                TVMTypeCode::kNodeHandle => "Node handle",
+                TVMTypeCode::kModuleHandle => "Module handle",
+                TVMTypeCode::kFuncHandle => "Function handle",
+                TVMTypeCode::kStr => "string",
+                TVMTypeCode::kBytes => "bytes",
+                TVMTypeCode::kNDArrayContainer => "ndarray container",
+            }
+        )
+    }
+}
+
+macro_rules! impl_prim_type {
+    ($type:ty, $variant:ident) => {
+        impl<'a> From<&'a $type> for TVMTypeCode {
+            fn from(_arg: &$type) -> Self {
+                TVMTypeCode::$variant
+            }
+        }
+
+        impl<'a> From<&'a mut $type> for TVMTypeCode {
+            fn from(_arg: &mut $type) -> Self {
+                TVMTypeCode::$variant
+            }
+        }
+    };
+}
+
+impl_prim_type!(usize, kDLInt);
+impl_prim_type!(i64, kDLInt);
+impl_prim_type!(i32, kDLInt);
+impl_prim_type!(i16, kDLInt);
+impl_prim_type!(i8, kDLInt);
+
+impl_prim_type!(u64, kDLUInt);
+impl_prim_type!(u32, kDLUInt);
+impl_prim_type!(u16, kDLUInt);
+impl_prim_type!(u8, kDLUInt);
+
+impl_prim_type!(f64, kDLFloat);
+impl_prim_type!(f32, kDLFloat);
+
+impl_prim_type!(str, kStr);
+impl_prim_type!(CStr, kStr);
+impl_prim_type!(String, kStr);
+impl_prim_type!(CString, kStr);
+
+impl_prim_type!([u8], kBytes);
diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs
new file mode 100644 (file)
index 0000000..6da8b27
--- /dev/null
@@ -0,0 +1,559 @@
+//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue`
+//! required for using TVM functions.
+
+use std::{
+    any::Any,
+    convert::TryFrom,
+    ffi::{CStr, CString},
+    fmt::{self, Debug, Formatter},
+    marker::PhantomData,
+    mem,
+    ops::Deref,
+    os::raw::{c_char, c_void},
+};
+
+#[cfg(feature = "runtime")]
+use ffi::runtime::TVMValue as _TVMValue;
+
+#[cfg(feature = "frontend")]
+use ffi::ts::TVMValue as _TVMValue;
+
+use errors::*;
+
+use ty::TVMTypeCode;
+
+/// Wrapped TVMValue type.
+#[derive(Clone, Copy)]
+pub struct TVMValue {
+    pub inner: _TVMValue,
+}
+
+impl TVMValue {
+    /// Creates TVMValue from the raw part.
+    pub fn new(inner: _TVMValue) -> Self {
+        TVMValue { inner }
+    }
+
+    pub(crate) fn into_raw(self) -> _TVMValue {
+        self.inner
+    }
+}
+
+impl Debug for TVMValue {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        unsafe {
+            write!(
+                f,
+                "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\
+                 [v_str: {:?}]",
+                self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str
+            )
+        }
+    }
+}
+
+impl Deref for TVMValue {
+    type Target = _TVMValue;
+    fn deref(&self) -> &Self::Target {
+        &self.inner
+    }
+}
+
+macro_rules! impl_prim_val {
+    ($type:ty, $field:ident, $cast:ty) => {
+        impl From<$type> for TVMValue {
+            fn from(arg: $type) -> Self {
+                let inner = _TVMValue {
+                    $field: arg as $cast,
+                };
+                Self::new(inner)
+            }
+        }
+
+        impl<'a> From<&'a $type> for TVMValue {
+            fn from(arg: &$type) -> Self {
+                let inner = _TVMValue {
+                    $field: *arg as $cast,
+                };
+                Self::new(inner)
+            }
+        }
+
+        impl<'a> From<&'a mut $type> for TVMValue {
+            fn from(arg: &mut $type) -> Self {
+                let inner = _TVMValue {
+                    $field: *arg as $cast,
+                };
+                Self::new(inner)
+            }
+        }
+
+        impl TryFrom<TVMValue> for $type {
+            type Error = Error;
+            fn try_from(val: TVMValue) -> Result<Self> {
+                Ok(unsafe { val.inner.$field as $type })
+            }
+        }
+
+        impl<'a> TryFrom<&'a TVMValue> for $type {
+            type Error = Error;
+            fn try_from(val: &TVMValue) -> Result<Self> {
+                Ok(unsafe { val.into_raw().$field as $type })
+            }
+        }
+
+        impl<'a> TryFrom<&'a mut TVMValue> for $type {
+            type Error = Error;
+            fn try_from(val: &mut TVMValue) -> Result<Self> {
+                Ok(unsafe { val.into_raw().$field as $type })
+            }
+        }
+    };
+}
+
+impl_prim_val!(isize, v_int64, i64);
+impl_prim_val!(i64, v_int64, i64);
+impl_prim_val!(i32, v_int64, i64);
+impl_prim_val!(i16, v_int64, i64);
+impl_prim_val!(i8, v_int64, i64);
+impl_prim_val!(usize, v_int64, i64);
+impl_prim_val!(u64, v_int64, i64);
+impl_prim_val!(u32, v_int64, i64);
+impl_prim_val!(u16, v_int64, i64);
+impl_prim_val!(u8, v_int64, i64);
+
+impl_prim_val!(f64, v_float64, f64);
+impl_prim_val!(f32, v_float64, f64);
+
+impl<'a> From<&'a str> for TVMValue {
+    fn from(arg: &str) -> TVMValue {
+        let arg = CString::new(arg).unwrap();
+        let inner = _TVMValue {
+            v_str: arg.as_ptr() as *const c_char,
+        };
+        mem::forget(arg);
+        Self::new(inner)
+    }
+}
+
+impl<'a> From<&'a String> for TVMValue {
+    fn from(arg: &String) -> TVMValue {
+        let arg = CString::new(arg.as_bytes()).unwrap();
+        let inner = _TVMValue {
+            v_str: arg.as_ptr() as *const c_char,
+        };
+        mem::forget(arg);
+        Self::new(inner)
+    }
+}
+
+impl<'a> From<&'a CString> for TVMValue {
+    fn from(arg: &CString) -> TVMValue {
+        let arg = arg.to_owned();
+        let inner = _TVMValue {
+            v_str: arg.as_ptr() as *const c_char,
+        };
+        mem::forget(arg);
+        Self::new(inner)
+    }
+}
+
+impl<'a> From<&'a [u8]> for TVMValue {
+    fn from(arg: &[u8]) -> TVMValue {
+        let arg = arg.to_owned();
+        let inner = _TVMValue {
+            v_handle: &arg as *const _ as *mut c_void,
+        };
+        mem::forget(arg);
+        Self::new(inner)
+    }
+}
+
+/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function.
+/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`.
+/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions.
+///
+/// ## Example
+///
+/// ```
+/// let s = "hello".to_string();
+/// let arg = TVMArgValue::from(&s);
+/// let tvm: String = arg.try_into().unwrap();
+/// assert_eq!(arg, s);
+/// ```
+#[derive(Debug, Clone, Copy)]
+pub struct TVMArgValue<'a> {
+    /// The wrapped TVMValue
+    pub value: TVMValue,
+    /// The matching type code.
+    pub type_code: TVMTypeCode,
+    /// This is only exposed to runtime and frontend crates and is not meant to be used directly.
+    pub lifetime: PhantomData<&'a ()>,
+}
+
+impl<'a> TVMArgValue<'a> {
+    pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self {
+        TVMArgValue {
+            value: value,
+            type_code: type_code,
+            lifetime: PhantomData,
+        }
+    }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+        if (arg.type_code == TVMTypeCode::kDLInt)
+            | (arg.type_code == TVMTypeCode::kDLUInt)
+            | (arg.type_code == TVMTypeCode::kNull)
+        {
+            Ok(unsafe { arg.value.inner.v_int64 })
+        } else {
+            bail!(ErrorKind::TryFromTVMArgValueError(
+                stringify!(i64).to_string(),
+                arg.type_code.to_string()
+            ))
+        }
+    }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+        if arg.type_code == TVMTypeCode::kDLFloat {
+            Ok(unsafe { arg.value.inner.v_float64 })
+        } else {
+            bail!(ErrorKind::TryFromTVMArgValueError(
+                stringify!(f64).to_string(),
+                arg.type_code.to_string()
+            ))
+        }
+    }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+        if arg.type_code == TVMTypeCode::kStr {
+            let ret_str = unsafe {
+                match CStr::from_ptr(arg.value.inner.v_str).to_str() {
+                    Ok(s) => s,
+                    Err(_) => "Invalid UTF-8 message",
+                }
+            };
+            Ok(ret_str.to_string())
+        } else {
+            bail!(ErrorKind::TryFromTVMArgValueError(
+                stringify!(String).to_string(),
+                arg.type_code.to_string()
+            ))
+        }
+    }
+}
+
+/// Main way to create a TVMArgValue from suported Rust values.
+impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a>
+where
+    TVMValue: From<&'b T>,
+    TVMTypeCode: From<&'b T>,
+{
+    fn from(arg: &'b T) -> Self {
+        TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg))
+    }
+}
+
+/// Creates a conversion to a `TVMArgValue` for an object handle.
+impl<'a, T> From<*const T> for TVMArgValue<'a> {
+    fn from(ptr: *const T) -> Self {
+        let value = TVMValue::new(_TVMValue {
+            v_handle: ptr as *mut T as *mut c_void,
+        });
+
+        TVMArgValue::new(value, TVMTypeCode::kArrayHandle)
+    }
+}
+
+/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
+impl<'a, T> From<*mut T> for TVMArgValue<'a> {
+    fn from(ptr: *mut T) -> Self {
+        let value = TVMValue::new(_TVMValue {
+            v_handle: ptr as *mut c_void,
+        });
+
+        TVMArgValue::new(value, TVMTypeCode::kHandle)
+    }
+}
+
+/// An owned version of TVMPODValue. It can be converted from varieties of
+/// primitive and object types.
+/// It can be downcasted using `try_from` if it contains the desired type.
+///
+/// # Example
+///
+/// ```
+/// let a = 42u32;
+/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
+///
+/// let s = "hello, world!";
+/// let t: TVMRetValue = s.into();
+/// assert_eq!(String::try_from(t).unwrap(), s);
+/// ```
+pub struct TVMRetValue {
+    /// A primitive return value, if any.
+    pub prim_value: usize,
+    /// An object return value, if any.
+    pub box_value: Box<Any>,
+    pub type_code: TVMTypeCode,
+}
+
+impl TVMRetValue {
+    fn new(prim_value: usize, box_value: Box<Any>, type_code: TVMTypeCode) -> Self {
+        Self {
+            prim_value,
+            box_value,
+            type_code,
+        }
+    }
+
+    /// unsafe function to create `TVMRetValue` from `TVMValue` and
+    /// its matching `TVMTypeCode`.
+    pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self {
+        let value = value.into_raw();
+        match type_code {
+            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => {
+                Self::new(value.v_int64 as usize, box (), type_code)
+            }
+            TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code),
+            TVMTypeCode::kHandle
+            | TVMTypeCode::kArrayHandle
+            | TVMTypeCode::kNodeHandle
+            | TVMTypeCode::kModuleHandle
+            | TVMTypeCode::kFuncHandle => {
+                Self::new(value.v_handle as usize, box value.v_handle, type_code)
+            }
+            TVMTypeCode::kStr | TVMTypeCode::kBytes => {
+                Self::new(value.v_str as usize, box (value.v_str), type_code)
+            }
+            _ => Self::new(0usize, box (), type_code),
+        }
+    }
+
+    /// Returns the underlying `TVMValue` and `TVMTypeCode`.
+    pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
+        let val = match self.type_code {
+            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue {
+                v_int64: self.prim_value as i64,
+            }),
+            TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue {
+                v_float64: self.prim_value as f64,
+            }),
+            TVMTypeCode::kHandle
+            | TVMTypeCode::kArrayHandle
+            | TVMTypeCode::kNodeHandle
+            | TVMTypeCode::kModuleHandle
+            | TVMTypeCode::kFuncHandle
+            | TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue {
+                v_handle: self.prim_value as *const c_void as *mut c_void,
+            }),
+            TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue {
+                v_str: self.prim_value as *const c_char,
+            }),
+            _ => unreachable!(),
+        };
+        (val, self.type_code)
+    }
+}
+
+impl Default for TVMRetValue {
+    fn default() -> Self {
+        TVMRetValue {
+            prim_value: 0usize,
+            box_value: box (),
+            type_code: TVMTypeCode::default(),
+        }
+    }
+}
+
+impl Clone for TVMRetValue {
+    fn clone(&self) -> Self {
+        match self.type_code {
+            TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => {
+                Self::new(self.prim_value.clone(), box (), self.type_code.clone())
+            }
+            TVMTypeCode::kHandle
+            | TVMTypeCode::kArrayHandle
+            | TVMTypeCode::kNodeHandle
+            | TVMTypeCode::kModuleHandle
+            | TVMTypeCode::kFuncHandle
+            | TVMTypeCode::kNDArrayContainer => Self::new(
+                self.prim_value.clone(),
+                box (self.prim_value.clone() as *const c_void as *mut c_void),
+                self.type_code.clone(),
+            ),
+            TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new(
+                self.prim_value.clone(),
+                box (self.prim_value.clone() as *const c_char),
+                self.type_code.clone(),
+            ),
+            _ => unreachable!(),
+        }
+    }
+}
+
+impl Debug for TVMRetValue {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        write!(
+            f,
+            "prim_value: {:?}, box_value: {:?}, type_code: {:?}",
+            self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code
+        )
+    }
+}
+
+macro_rules! impl_prim_ret_value {
+    ($type:ty, $code:expr) => {
+        impl From<$type> for TVMRetValue {
+            fn from(val: $type) -> Self {
+                TVMRetValue {
+                    prim_value: val as usize,
+                    box_value: box (),
+                    type_code: $code,
+                }
+            }
+        }
+
+        impl<'a> From<&'a $type> for TVMRetValue {
+            fn from(val: &$type) -> Self {
+                TVMRetValue {
+                    prim_value: *val as usize,
+                    box_value: box (),
+                    type_code: $code,
+                }
+            }
+        }
+
+        impl<'a> From<&'a mut $type> for TVMRetValue {
+            fn from(val: &mut $type) -> Self {
+                TVMRetValue {
+                    prim_value: *val as usize,
+                    box_value: box (),
+                    type_code: $code,
+                }
+            }
+        }
+
+        impl TryFrom<TVMRetValue> for $type {
+            type Error = Error;
+            fn try_from(ret: TVMRetValue) -> Result<$type> {
+                if ret.type_code == $code {
+                    Ok(ret.prim_value as $type)
+                } else {
+                    bail!(ErrorKind::TryFromTVMRetValueError(
+                        stringify!($type).to_string(),
+                        ret.type_code.to_string(),
+                    ))
+                }
+            }
+        }
+    };
+}
+
+impl_prim_ret_value!(i8, TVMTypeCode::kDLInt);
+impl_prim_ret_value!(i16, TVMTypeCode::kDLInt);
+impl_prim_ret_value!(i32, TVMTypeCode::kDLInt);
+impl_prim_ret_value!(i64, TVMTypeCode::kDLInt);
+impl_prim_ret_value!(isize, TVMTypeCode::kDLInt);
+
+impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt);
+impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt);
+impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt);
+impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt);
+impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt);
+
+impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat);
+impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat);
+
+macro_rules! impl_ptr_ret_value {
+    ($type:ty) => {
+        impl From<$type> for TVMRetValue {
+            fn from(ptr: $type) -> Self {
+                TVMRetValue {
+                    prim_value: ptr as usize,
+                    box_value: box (),
+                    type_code: TVMTypeCode::kHandle,
+                }
+            }
+        }
+
+        impl TryFrom<TVMRetValue> for $type {
+            type Error = Error;
+            fn try_from(ret: TVMRetValue) -> Result<$type> {
+                if ret.type_code == TVMTypeCode::kHandle {
+                    Ok(ret.prim_value as $type)
+                } else {
+                    bail!(ErrorKind::TryFromTVMRetValueError(
+                        stringify!($type).to_string(),
+                        ret.type_code.to_string(),
+                    ))
+                }
+            }
+        }
+    };
+}
+
+impl_ptr_ret_value!(*const c_void);
+impl_ptr_ret_value!(*mut c_void);
+
+impl From<String> for TVMRetValue {
+    fn from(val: String) -> Self {
+        let pval = val.as_ptr() as *const c_char as usize;
+        let bval = box (val.as_ptr() as *const c_char);
+        mem::forget(val);
+        TVMRetValue::new(pval, bval, TVMTypeCode::kStr)
+    }
+}
+
+impl TryFrom<TVMRetValue> for String {
+    type Error = Error;
+    fn try_from(ret: TVMRetValue) -> Result<String> {
+        // Note: simple downcast doesn't work for function call return values
+        let ret_str = unsafe {
+            match CStr::from_ptr(ret.prim_value as *const c_char).to_str() {
+                Ok(s) => s,
+                Err(_) => "Invalid UTF-8 message",
+            }
+        };
+
+        Ok(ret_str.to_string())
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use std::convert::TryInto;
+
+    #[test]
+    fn numeric() {
+        macro_rules! arg_ret_tests {
+            ($v:expr; $($ty:ty),+) => {{
+                $(
+                    let v = $v as $ty;
+                    let b = TVMRetValue::from(&v);
+                    let b: $ty = b.try_into().unwrap();
+                    assert_eq!(b, v);
+                )+
+            }};
+        }
+
+        arg_ret_tests!(42; i8, i16, i32, i64, f32, f64);
+    }
+
+    #[test]
+    fn string() {
+        let s = "hello".to_string();
+        let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap();
+        assert_eq!(tvm_arg, s);
+    }
+}
diff --git a/rust/common/tvm-sys/Cargo.toml b/rust/common/tvm-sys/Cargo.toml
new file mode 100644 (file)
index 0000000..117d174
--- /dev/null
@@ -0,0 +1,9 @@
+[package]
+name = "tvm-sys"
+version = "0.1.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+description = "Raw C API"
+
+[build-dependencies]
+bindgen = "0.37.4"
diff --git a/rust/common/tvm-sys/build.rs b/rust/common/tvm-sys/build.rs
new file mode 100644 (file)
index 0000000..f842043
--- /dev/null
@@ -0,0 +1,25 @@
+extern crate bindgen;
+
+use std::path::PathBuf;
+
+fn main() {
+    println!("cargo:rerun-if-env-changed=TVM_HOME");
+    println!("cargo:rustc-link-lib=dylib=tvm_runtime");
+    println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
+    let bindings = bindgen::Builder::default()
+        .header(format!(
+            "{}/include/tvm/runtime/c_runtime_api.h",
+            env!("TVM_HOME")
+        ))
+        .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
+        .blacklist_type("max_align_t") // @see rust-bindgen#550
+        .layout_tests(false)
+        .derive_partialeq(true)
+        .derive_eq(true)
+        .generate()
+        .expect("unable to generate bindings");
+
+    bindings
+        .write_to_file(PathBuf::from("src/bindgen.rs"))
+        .expect("can not write the bindings!");
+}
diff --git a/rust/common/tvm-sys/src/lib.rs b/rust/common/tvm-sys/src/lib.rs
new file mode 100644 (file)
index 0000000..15f1ea3
--- /dev/null
@@ -0,0 +1,9 @@
+#![allow(
+    non_camel_case_types,
+    non_snake_case,
+    non_upper_case_globals,
+    dead_code,
+    improper_ctypes
+)]
+
+include!("bindgen.rs");
diff --git a/rust/frontend/.gitignore b/rust/frontend/.gitignore
new file mode 100644 (file)
index 0000000..2430329
--- /dev/null
@@ -0,0 +1,7 @@
+target
+**/*.rs.bk
+Cargo.lock
+/tests/basics/add_*
+/examples/resnet/deploy_*
+/examples/resnet/*.png
+/examples/resnet/synset.*
diff --git a/rust/frontend/.travis.yml b/rust/frontend/.travis.yml
new file mode 100644 (file)
index 0000000..63a3d02
--- /dev/null
@@ -0,0 +1,5 @@
+language: rust
+rust:
+  - nightly
+matrix:
+  fast_finish: true
diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml
new file mode 100644 (file)
index 0000000..db26155
--- /dev/null
@@ -0,0 +1,25 @@
+[package]
+name = "tvm-frontend"
+version = "0.1.0"
+license = "Apache-2.0"
+description = "Rust frontend support for TVM"
+repository = "https://github.com/dmlc/tvm"
+homepage = "https://github.com/dmlc/tvm"
+readme = "README.md"
+keywords = ["rust", "tvm", "nnvm"]
+categories = ["api-bindings", "science"]
+authors = ["TVM Contributors"]
+
+[lib]
+name = "tvm_frontend"
+crate-type = ["dylib"]
+
+[dependencies]
+error-chain = "0.12.0"
+lazy_static = "1.1.0"
+ndarray = "0.12.1"
+num-traits = "0.2"
+tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] }
+
+[features]
+blas = ["ndarray/blas"]
diff --git a/rust/frontend/README.md b/rust/frontend/README.md
new file mode 100644 (file)
index 0000000..5bd4362
--- /dev/null
@@ -0,0 +1,219 @@
+# TVM Runtime Frontend Support
+
+This crate provides an idiomatic Rust API for [TVM](https://github.com/dmlc/tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly`
+
+## What Does This Crate Offer?
+
+Here is a major workflow
+
+1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/)
+2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators.
+3. Deploy your models using **Rust** :heart:
+
+### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k
+
+Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example.
+
+Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM
+
+```python
+block = get_model('resnet18_v1', pretrained=True)
+    
+sym, params = nnvm.frontend.from_mxnet(block)
+# add the softmax layer for prediction
+net = nnvm.sym.softmax(sym)
+# compile the model
+with nnvm.compiler.build_config(opt_level=opt_level):
+    graph, lib, params = nnvm.compiler.build(
+        net, target, shape={"data": data_shape}, params=params)
+# same the model artifacts
+lib.save(os.path.join(target_dir, "deploy_lib.o"))
+cc.create_shared(os.path.join(target_dir, "deploy_lib.so"),
+                [os.path.join(target_dir, "deploy_lib.o")])
+
+with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo:
+    fo.write(graph.json())
+with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo:
+    fo.write(nnvm.compiler.save_param_dict(params))
+```
+
+Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image
+
+![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true)
+
+as demostrated in the following Rust snippet
+
+```rust
+    let graph = fs::read_to_string("deploy_graph.json")?;
+    // load the built module
+    let lib = Module::load(&Path::new("deploy_lib.so"))?;
+    // get the global TVM graph runtime function
+    let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
+    let runtime_create_fn_ret = call_packed!(
+        runtime_create_fn,
+        &graph,
+        &lib,
+        &ctx.device_type,
+        &ctx.device_id
+    )?;
+    // get graph runtime module
+    let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?;
+    // get the registered `load_params` from runtime module
+    let ref load_param_fn = graph_runtime_module
+        .get_function("load_params", false)
+        .unwrap();
+    // parse parameters and convert to TVMByteArray
+    let params: Vec<u8> = fs::read("deploy_param.params")?;
+    let barr = TVMByteArray::from(&params);
+    // load the parameters
+    call_packed!(load_param_fn, &barr)?;
+    // get the set_input function
+    let ref set_input_fn = graph_runtime_module
+        .get_function("set_input", false)
+        .unwrap();
+
+    call_packed!(set_input_fn, "data", &input)?;
+    // get `run` function from runtime module
+    let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
+    // execute the run function. Note that it has no argument
+    call_packed!(run_fn,)?;
+    // prepare to get the output
+    let output_shape = &mut [1, 1000];
+    let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
+    // get the `get_output` function from runtime module
+    let ref get_output_fn = graph_runtime_module
+        .get_function("get_output", false)
+        .unwrap();
+    // execute the get output function
+    call_packed!(get_output_fn, &0, &output)?;
+    // flatten the output as Vec<f32>
+    let output = output.to_vec::<f32>()?;
+```
+
+and the model correctly predicts the input image as **tiger cat**.
+
+## Installations
+
+Please follow TVM [installations](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
+
+*Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually.
+
+## Supported TVM Functionalities
+
+### Use TVM to Generate Shared Library
+
+One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU.
+
+```python
+import os
+import tvm
+from tvm.contrib import cc
+
+def test_add(target_dir):
+    if not tvm.module.enabled("cuda"):
+        print(f"skip {__file__} because cuda is not enabled...")
+        return
+    n = tvm.var("n")
+    A = tvm.placeholder((n,), name='A')
+    B = tvm.placeholder((n,), name='B')
+    C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
+    s = tvm.create_schedule(C.op)
+    bx, tx = s[C].split(C.op.axis[0], factor=64)
+    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+    fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
+
+    fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
+    fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
+    cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
+            [os.path.join(target_dir, "add_gpu.o")])
+
+
+if __name__ == "__main__":
+    import sys
+    if len(sys.argv) != 2:
+        sys.exit(-1)
+    test_add(sys.argv[1])
+```
+
+### Run the Generated Shared Library
+
+The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust.
+
+```rust
+extern crate tvm_frontend as tvm;
+
+use tvm::*;
+
+fn main() {
+    let shape = &mut [2];
+    let mut data = vec![3f32, 4.0];
+    let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
+    arr.copy_from_buffer(data.as_mut_slice());
+    let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
+    let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap();
+    let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap();
+    assert!(fadd.enabled("gpu"));
+    fadd.import_module(fadd_dep);
+    fadd.entry();
+    function::Builder::from(&mut fadd)
+        .arg(&arr)
+        .arg(&arr)
+        .set_output(&mut ret)?
+        .invoke()
+        .unwrap();
+
+    assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
+}
+```
+
+**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by
+`cargo:rustc-link-search=native=add_gpu`.
+
+See the tests and examples custom `build.rs` for more details.
+
+### Convert and Register a Rust Function as a TVM Packed Function
+
+One can use `register_global_func!` macro to convert and register a Rust
+function of type `fn(&[TVMArgValue]) -> Result<TVMRetValue>` to a global TVM **packed function** as follows
+
+```rust
+#[macro_use]
+extern crate tvm_frontend as tvm;
+use std::convert::TryInto;
+use tvm::*;
+
+fn main() {
+    register_global_func! {
+        fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+            let mut ret = 0f32;
+            let shape = &mut [2];
+            for arg in args.iter() {
+                let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+                let arg: NDArray = arg.try_into()?;
+                let arr = arg.copy_to_ndarray(e).unwrap();
+                let rnd: ArrayD<f32> = ArrayD::try_from(&arr).unwrap();
+                ret += rnd.scalar_sum();
+            }
+            let ret_val = TVMRetValue::from(&ret);
+            Ok(ret_val)
+        }
+    }
+
+    let shape = &mut [2];
+    let mut data = vec![3f32, 4.0];
+    let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+    arr.copy_from_buffer(data.as_mut_slice());
+    let mut registered = function::Builder::default();
+    let ret: f64 = registered
+        .get_function("sum", true)
+        .arg(&arr)
+        .arg(&arr)
+        .invoke()
+        .unwrap()
+        .try_into()
+        .unwrap();
+
+    assert_eq!(ret, 14f64);
+    }
+```
diff --git a/rust/frontend/examples/resnet/Cargo.toml b/rust/frontend/examples/resnet/Cargo.toml
new file mode 100644 (file)
index 0000000..e8a3eb7
--- /dev/null
@@ -0,0 +1,12 @@
+[package]
+name = "resnet"
+version = "0.0.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+build = "build.rs"
+
+[dependencies]
+ndarray = "0.12.1"
+tvm-frontend = { path = "../../" }
+image = "0.20.1"
+csv = "1"
diff --git a/rust/frontend/examples/resnet/README.md b/rust/frontend/examples/resnet/README.md
new file mode 100644 (file)
index 0000000..3d20d55
--- /dev/null
@@ -0,0 +1,15 @@
+## Resnet example
+
+This end-to-end example shows how to:
+* build `Resnet 18` with `tvm` and `nnvm` from Python
+* use the provided Rust frontend API to test for an input image
+
+To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
+and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html).
+
+* **Build the example**: `cargo build`
+
+To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with
+`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details.
+
+* **Run the example**: `cargo run`
diff --git a/rust/frontend/examples/resnet/build.rs b/rust/frontend/examples/resnet/build.rs
new file mode 100644 (file)
index 0000000..f913bf8
--- /dev/null
@@ -0,0 +1,16 @@
+use std::process::Command;
+
+fn main() {
+    let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
+        .output()
+        .expect("Failed to execute command");
+    assert!(
+        std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(),
+        "Could not prepare demo: {}",
+        String::from_utf8(output.stderr).unwrap().trim()
+    );
+    println!(
+        "cargo:rustc-link-search=native={}",
+        env!("CARGO_MANIFEST_DIR")
+    );
+}
diff --git a/rust/frontend/examples/resnet/src/build_resnet.py b/rust/frontend/examples/resnet/src/build_resnet.py
new file mode 100755 (executable)
index 0000000..e5b76aa
--- /dev/null
@@ -0,0 +1,105 @@
+#!/usr/bin/env python3
+
+import argparse
+import csv
+import logging
+from os import path as osp
+import sys
+
+import numpy as np
+
+import mxnet as mx
+from mxnet.gluon.model_zoo.vision import get_model
+from mxnet.gluon.utils import download
+
+import tvm
+from tvm.contrib import graph_runtime, cc
+import nnvm
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+parser = argparse.ArgumentParser(description='Resnet build example')
+aa = parser.add_argument
+aa('--batch-size', type=int, default=1, help='input image batch size')
+aa('--opt-level', type=int, default=3,
+   help='level of optimization. 0 is unoptimized and 3 is the highest level')
+aa('--target', type=str, default='llvm', help='target context for compilation')
+aa('--image-shape', type=str, default='3,224,224', help='input image dimensions')
+aa('--image-name', type=str, default='cat.png', help='name of input image to download')
+args = parser.parse_args()
+
+target_dir = osp.dirname(osp.dirname(osp.realpath(__file__)))
+batch_size = args.batch_size
+opt_level = args.opt_level
+target = tvm.target.create(args.target)
+image_shape = tuple(map(int, args.image_shape.split(",")))
+data_shape = (batch_size,) + image_shape
+
+def build(target_dir):
+    """ Compiles resnet18 with TVM"""
+    deploy_lib = osp.join(target_dir, 'deploy_lib.o')
+    if osp.exists(deploy_lib):
+        return
+    # download the pretrained resnet18 trained on imagenet1k dataset for
+    # image classification task
+    block = get_model('resnet18_v1', pretrained=True)
+
+    sym, params = nnvm.frontend.from_mxnet(block)
+    # add the softmax layer for prediction
+    net = nnvm.sym.softmax(sym)
+    # compile the model
+    with nnvm.compiler.build_config(opt_level=opt_level):
+        graph, lib, params = nnvm.compiler.build(
+            net, target, shape={"data": data_shape}, params=params)
+    # save the model artifacts
+    lib.save(deploy_lib)
+    cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
+                    [osp.join(target_dir, "deploy_lib.o")])
+
+    with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
+        fo.write(graph.json())
+
+    with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
+        fo.write(nnvm.compiler.save_param_dict(params))
+
+def download_img_labels():
+    """ Download an image and imagenet1k class labels for test"""
+    img_name = 'cat.png'
+    synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
+                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
+                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
+                      'imagenet1000_clsid_to_human.txt'])
+    synset_name = 'synset.txt'
+    download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
+    download(synset_url, synset_name)
+
+    with open(synset_name) as fin:
+        synset = eval(fin.read())
+
+    with open("synset.csv", "w") as fout:
+        w = csv.writer(fout)
+        w.writerows(synset.items())
+
+def test_build(target_dir):
+    """ Sanity check with random input"""
+    graph = open(osp.join(target_dir, "deploy_graph.json")).read()
+    lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so"))
+    params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read())
+    input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
+    ctx = tvm.cpu()
+    module = graph_runtime.create(graph, lib, ctx)
+    module.load_params(params)
+    module.run(data=input_data)
+    out = module.get_output(0).asnumpy()
+
+
+if __name__ == '__main__':
+    logger.info("building the model")
+    build(target_dir)
+    logger.info("build was successful")
+    logger.info("test the build artifacts")
+    test_build(target_dir)
+    logger.info("test was successful")
+    download_img_labels()
+    logger.info("image and synset downloads are successful")
diff --git a/rust/frontend/examples/resnet/src/main.rs b/rust/frontend/examples/resnet/src/main.rs
new file mode 100644 (file)
index 0000000..869a35b
--- /dev/null
@@ -0,0 +1,134 @@
+#![feature(try_from)]
+
+extern crate csv;
+extern crate image;
+extern crate ndarray;
+extern crate tvm_frontend as tvm;
+
+use std::{
+    collections::HashMap,
+    convert::TryInto,
+    fs::{self, File},
+    path::Path,
+};
+
+use image::{FilterType, GenericImageView};
+use ndarray::{Array, ArrayD, Axis};
+
+use tvm::*;
+
+fn main() {
+    let ctx = TVMContext::cpu(0);
+    let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap();
+    println!("original image dimensions: {:?}", img.dimensions());
+    // for bigger size images, one needs to first resize to 256x256
+    // with `img.resize_exact` method and then `image.crop` to 224x224
+    let img = img.resize(224, 224, FilterType::Nearest).to_rgb();
+    println!("resized image dimensions: {:?}", img.dimensions());
+    let mut pixels: Vec<f32> = vec![];
+    for pixel in img.pixels() {
+        let tmp = pixel.data;
+        // normalize the RGB channels using mean, std of imagenet1k
+        let tmp = [
+            (tmp[0] as f32 - 123.0) / 58.395, // R
+            (tmp[1] as f32 - 117.0) / 57.12,  // G
+            (tmp[2] as f32 - 104.0) / 57.375, // B
+        ];
+        for e in &tmp {
+            pixels.push(*e);
+        }
+    }
+
+    let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap();
+    let arr: ArrayD<f32> = arr.permuted_axes([2, 0, 1]).into_dyn();
+    // make arr shape as [1, 3, 224, 224] acceptable to resnet
+    let arr = arr.insert_axis(Axis(0));
+    // create input tensor from rust's ndarray
+    let input =
+        NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+    println!(
+        "input size is {:?}",
+        input.shape().expect("cannot get the input shape")
+    );
+    let graph =
+        fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap();
+    // load the built module
+    let lib = Module::load(&Path::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/deploy_lib.so"
+    )))
+    .unwrap();
+    // get the global TVM graph runtime function
+    let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
+    let runtime_create_fn_ret = call_packed!(
+        runtime_create_fn,
+        &graph,
+        &lib,
+        &ctx.device_type,
+        &ctx.device_id
+    )
+    .unwrap();
+    // get graph runtime module
+    let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap();
+    // get the registered `load_params` from runtime module
+    let ref load_param_fn = graph_runtime_module
+        .get_function("load_params", false)
+        .unwrap();
+    // parse parameters and convert to TVMByteArray
+    let params: Vec<u8> =
+        fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap();
+    let barr = TVMByteArray::from(&params);
+    // load the parameters
+    call_packed!(load_param_fn, &barr).unwrap();
+    // get the set_input function
+    let ref set_input_fn = graph_runtime_module
+        .get_function("set_input", false)
+        .unwrap();
+
+    call_packed!(set_input_fn, "data", &input).unwrap();
+    // get `run` function from runtime module
+    let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
+    // execute the run function. Note that it has no argument
+    call_packed!(run_fn,).unwrap();
+    // prepare to get the output
+    let output_shape = &mut [1, 1000];
+    let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
+    // get the `get_output` function from runtime module
+    let ref get_output_fn = graph_runtime_module
+        .get_function("get_output", false)
+        .unwrap();
+    // execute the get output function
+    call_packed!(get_output_fn, &0, &output).unwrap();
+    // flatten the output as Vec<f32>
+    let output = output.to_vec::<f32>().unwrap();
+    // find the maximum entry in the output and its index
+    let mut argmax = -1;
+    let mut max_prob = 0.;
+    for i in 0..output.len() {
+        if output[i] > max_prob {
+            max_prob = output[i];
+            argmax = i as i32;
+        }
+    }
+    // create a hash map of (class id, class name)
+    let mut synset: HashMap<i32, String> = HashMap::new();
+    let file = File::open("synset.csv").unwrap();
+    let mut rdr = csv::ReaderBuilder::new()
+        .has_headers(true)
+        .from_reader(file);
+
+    for result in rdr.records() {
+        let record = result.unwrap();
+        let id: i32 = record[0].parse().unwrap();
+        let cls = record[1].to_string();
+        synset.insert(id, cls);
+    }
+
+    println!(
+        "input image belongs to the class `{}` with probability {}",
+        synset
+            .get(&argmax)
+            .expect("cannot find the class id for argmax"),
+        max_prob
+    );
+}
diff --git a/rust/frontend/src/bytearray.rs b/rust/frontend/src/bytearray.rs
new file mode 100644 (file)
index 0000000..395f34c
--- /dev/null
@@ -0,0 +1,72 @@
+//! Provides [`TVMByteArray`] used for passing the model parameters
+//! (stored as byte-array) to a runtime module.
+//!
+//! For more detail, please see the example `resnet` in `examples` repository.
+
+use std::os::raw::c_char;
+
+use crate::ts;
+
+/// A struct holding TVM byte-array.
+///
+/// ## Example
+///
+/// ```
+/// let v = b"hello".to_vec();
+/// let barr = TVMByteArray::from(&v);
+/// assert_eq!(barr.len(), v.len());
+/// assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
+/// ```
+#[derive(Debug, Clone)]
+pub struct TVMByteArray {
+    pub(crate) inner: ts::TVMByteArray,
+}
+
+impl TVMByteArray {
+    pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray {
+        TVMByteArray { inner: barr }
+    }
+
+    /// Gets the length of the underlying byte-array
+    pub fn len(&self) -> usize {
+        self.inner.size
+    }
+
+    /// Gets the underlying byte-array as `Vec<i8>`
+    pub fn data(&self) -> Vec<i8> {
+        unsafe {
+            let sz = self.len();
+            let mut ret_buf = Vec::with_capacity(sz);
+            ret_buf.set_len(sz);
+            self.inner.data.copy_to(ret_buf.as_mut_ptr(), sz);
+            ret_buf
+        }
+    }
+}
+
+impl<'a> From<&'a Vec<u8>> for TVMByteArray {
+    fn from(arg: &Vec<u8>) -> Self {
+        let barr = ts::TVMByteArray {
+            data: arg.as_ptr() as *const c_char,
+            size: arg.len(),
+        };
+        TVMByteArray::new(barr)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn convert() {
+        let v = vec![1u8, 2, 3];
+        let barr = TVMByteArray::from(&v);
+        assert_eq!(barr.len(), v.len());
+        assert_eq!(barr.data(), vec![1i8, 2, 3]);
+        let v = b"hello".to_vec();
+        let barr = TVMByteArray::from(&v);
+        assert_eq!(barr.len(), v.len());
+        assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
+    }
+}
diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs
new file mode 100644 (file)
index 0000000..65e11d8
--- /dev/null
@@ -0,0 +1,286 @@
+//! Provides [`TVMContext`] and related device specific queries.
+//!
+//! Create a new context by device type (cpu is 1) and device id.
+//!
+//! # Example
+//!
+//! ```
+//! let ctx = TVMContext::new(1, 0);
+//! let cpu0 = TVMContext::cpu(0);
+//! assert_eq!(ctx, cpu0);
+//! ```
+//!
+//! Or from a supported device name.
+//!
+//! ```
+//! let cpu0 = TVMContext::from("cpu");
+//! println!("{}", cpu0);
+//! ```
+
+use std::{
+    fmt::{self, Display, Formatter},
+    os::raw::c_void,
+    ptr,
+};
+
+use crate::{function, ts, Result};
+
+/// Device type can be from a supported device name. See the supported devices
+/// in [TVM](https://github.com/dmlc/tvm).
+///
+/// ## Example
+///
+/// ```
+/// let cpu = TVMDeviceType::from("cpu");
+/// println!("device is: {}", cpu);
+///```
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct TVMDeviceType(pub usize);
+
+impl Default for TVMDeviceType {
+    /// default device is cpu.
+    fn default() -> Self {
+        TVMDeviceType(1)
+    }
+}
+
+impl From<TVMDeviceType> for ts::DLDeviceType {
+    fn from(device_type: TVMDeviceType) -> Self {
+        match device_type.0 {
+            1 => ts::DLDeviceType_kDLCPU,
+            2 => ts::DLDeviceType_kDLGPU,
+            3 => ts::DLDeviceType_kDLCPUPinned,
+            4 => ts::DLDeviceType_kDLOpenCL,
+            7 => ts::DLDeviceType_kDLVulkan,
+            8 => ts::DLDeviceType_kDLMetal,
+            9 => ts::DLDeviceType_kDLVPI,
+            10 => ts::DLDeviceType_kDLROCM,
+            12 => ts::DLDeviceType_kDLExtDev,
+            _ => panic!("device type not found!"),
+        }
+    }
+}
+
+impl From<ts::DLDeviceType> for TVMDeviceType {
+    fn from(device_type: ts::DLDeviceType) -> Self {
+        match device_type {
+            ts::DLDeviceType_kDLCPU => TVMDeviceType(1),
+            ts::DLDeviceType_kDLGPU => TVMDeviceType(2),
+            ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
+            ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
+            ts::DLDeviceType_kDLVulkan => TVMDeviceType(7),
+            ts::DLDeviceType_kDLMetal => TVMDeviceType(8),
+            ts::DLDeviceType_kDLVPI => TVMDeviceType(9),
+            ts::DLDeviceType_kDLROCM => TVMDeviceType(10),
+            ts::DLDeviceType_kDLExtDev => TVMDeviceType(12),
+            _ => panic!("device type not found!"),
+        }
+    }
+}
+
+impl Display for TVMDeviceType {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        write!(
+            f,
+            "{}",
+            match self {
+                TVMDeviceType(1) => "cpu",
+                TVMDeviceType(2) => "gpu",
+                TVMDeviceType(3) => "cpu_pinned",
+                TVMDeviceType(4) => "opencl",
+                TVMDeviceType(8) => "meta",
+                TVMDeviceType(9) => "vpi",
+                TVMDeviceType(10) => "rocm",
+                TVMDeviceType(_) => "rpc",
+            }
+        )
+    }
+}
+
+impl<'a> From<&'a str> for TVMDeviceType {
+    fn from(type_str: &'a str) -> Self {
+        match type_str {
+            "cpu" => TVMDeviceType(1),
+            "llvm" => TVMDeviceType(1),
+            "stackvm" => TVMDeviceType(1),
+            "gpu" => TVMDeviceType(2),
+            "cuda" => TVMDeviceType(2),
+            "nvptx" => TVMDeviceType(2),
+            "cl" => TVMDeviceType(4),
+            "opencl" => TVMDeviceType(4),
+            "metal" => TVMDeviceType(8),
+            "vpi" => TVMDeviceType(9),
+            "rocm" => TVMDeviceType(10),
+            _ => panic!("{:?} not supported!", type_str),
+        }
+    }
+}
+
+/// Represents the underlying device context. Default is cpu.
+///
+/// ## Examples
+///
+/// ```
+/// let ctx = TVMContext::from("gpu");
+/// assert!(ctx.exist());
+///
+/// ```
+///
+/// It is possible to query the underlying context as follows
+///
+/// ```
+/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
+/// println!("compute version: {}", ctx.compute_version());
+/// ```
+#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
+pub struct TVMContext {
+    /// Supported device types
+    pub device_type: TVMDeviceType,
+    /// Device id
+    pub device_id: usize,
+}
+
+impl TVMContext {
+    /// Creates context from device type and id.
+    pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
+        TVMContext {
+            device_type: device_type,
+            device_id: device_id,
+        }
+    }
+}
+
+macro_rules! impl_ctxs {
+    ($(($ctx:ident, $dldevt:expr));+) => {
+        $(
+            impl TVMContext {
+                pub fn $ctx(device_id: usize) -> Self {
+                    Self::new(TVMDeviceType($dldevt), device_id)
+                }
+            }
+        )+
+    };
+}
+
+impl_ctxs!((cpu, 1);
+            (gpu, 2);
+            (nvptx, 2);
+            (cuda, 2);
+            (cpu_pinned, 3);
+            (cl, 4);
+            (opencl, 4);
+            (metal, 8);
+            (vpi, 9);
+            (rocm, 10);
+            (opengl, 11);
+            (ext_dev, 12));
+
+impl<'a> From<&'a str> for TVMContext {
+    fn from(target: &str) -> Self {
+        TVMContext::new(TVMDeviceType::from(target), 0)
+    }
+}
+
+impl TVMContext {
+    /// Checks whether the context exists or not.
+    pub fn exist(&self) -> bool {
+        let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
+            .expect("API function always exists");
+        let dt = self.device_type.0 as usize;
+        // `unwrap` is ok here because if there is any error,
+        // if would occure inside `call_packed!`
+        let ret = call_packed!(func, &dt, &self.device_id, &0)
+            .unwrap()
+            .prim_value;
+        ret != 0
+    }
+
+    /// Synchronize the context stream.
+    pub fn sync(&self) -> Result<()> {
+        check_call!(ts::TVMSynchronize(
+            self.device_type.0 as i32,
+            self.device_id as i32,
+            ptr::null_mut() as *mut c_void
+        ));
+        Ok(())
+    }
+}
+
+macro_rules! impl_device_attrs {
+    ($(($attr_name:ident, $attr_kind:expr));+) => {
+        $(
+            impl TVMContext {
+                pub fn $attr_name(&self) -> usize {
+                    let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
+                        .expect("API function always exists");
+                    let dt = self.device_type.0 as usize;
+                    // `unwrap` is ok here because if there is any error,
+                    // if would occur in function call.
+                    let ret = function::Builder::from(func)
+                        .args(&[dt, self.device_id, $attr_kind])
+                        .invoke()
+                        .unwrap();
+                    ret.prim_value as usize
+                }
+            }
+        )+
+    };
+}
+
+impl_device_attrs!((max_threads_per_block, 1);
+                (warp_size, 2);
+                (max_shared_memory_per_block, 3);
+                (compute_version, 4);
+                (device_name, 5);
+                (max_clock_rate, 6);
+                (multi_processor_count, 7);
+                (max_thread_dimensions, 8));
+
+impl From<ts::DLContext> for TVMContext {
+    fn from(ctx: ts::DLContext) -> Self {
+        TVMContext {
+            device_type: TVMDeviceType::from(ctx.device_type),
+            device_id: ctx.device_id as usize,
+        }
+    }
+}
+
+impl From<TVMContext> for ts::DLContext {
+    fn from(ctx: TVMContext) -> Self {
+        ts::DLContext {
+            device_type: ctx.device_type.into(),
+            device_id: ctx.device_id as i32,
+        }
+    }
+}
+
+impl Display for TVMContext {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        write!(f, "{}({})", self.device_type, self.device_id)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn context() {
+        let ctx = TVMContext::cpu(0);
+        println!("ctx: {}", ctx);
+        let default_ctx = TVMContext::new(TVMDeviceType(1), 0);
+        assert_eq!(ctx.clone(), default_ctx);
+        assert_ne!(ctx, TVMContext::gpu(0));
+
+        let str_ctx = TVMContext::new(TVMDeviceType::from("gpu"), 0);
+        assert_eq!(str_ctx.clone(), str_ctx);
+        assert_ne!(str_ctx, TVMContext::new(TVMDeviceType::from("cpu"), 0));
+    }
+
+    #[test]
+    fn sync() {
+        let ctx = TVMContext::cpu(0);
+        assert!(ctx.sync().is_ok())
+    }
+}
diff --git a/rust/frontend/src/errors.rs b/rust/frontend/src/errors.rs
new file mode 100644 (file)
index 0000000..a10f83c
--- /dev/null
@@ -0,0 +1,51 @@
+//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types.
+
+use std::{ffi, option};
+
+use crate::{common_errors, rust_ndarray};
+
+error_chain! {
+    errors {
+        EmptyArray {
+            description("cannot convert from an empty array")
+        }
+
+        NullHandle(name: String) {
+            description("null handle")
+            display("requested `{}` handle is null", name)
+        }
+
+        FunctionNotFound {
+            description("function not found")
+            display("function was not set in `function::Builder`")
+        }
+
+        TypeMismatch(expected: String, found: String) {
+            description("type mismatch!")
+            display("expected type `{}`, but found `{}`", expected, found)
+        }
+
+        MissingShapeError {
+            description("ndarray `shape()` returns `None`")
+            display("called `Option::unwrap()` on a `None` value")
+        }
+
+        AtMostOneReturn {
+            description("TVM functions accept at most one return value")
+        }
+
+    }
+
+    foreign_links {
+        ShapeError(rust_ndarray::ShapeError);
+        NulError(ffi::NulError);
+        IntoStringError(ffi::IntoStringError);
+        CommonError(common_errors::Error);
+    }
+}
+
+impl From<option::NoneError> for Error {
+    fn from(_err: option::NoneError) -> Self {
+        ErrorKind::MissingShapeError.into()
+    }
+}
diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs
new file mode 100644 (file)
index 0000000..fa6bed1
--- /dev/null
@@ -0,0 +1,512 @@
+//! This module provides an idiomatic Rust API for creating and working with TVM functions.
+//!
+//! For calling an already registered TVM function use [`function::Builder`]
+//! To register a TVM packed function from Rust side either
+//! use [`function::register`] or the macro [`register_global_func`].
+//!
+//! See the tests and examples repository for more examples.
+
+use std::{
+    collections::BTreeMap,
+    ffi::{CStr, CString},
+    mem,
+    os::raw::{c_char, c_int, c_void},
+    ptr, slice, str,
+    sync::Mutex,
+};
+
+use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue};
+
+lazy_static! {
+    static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
+        let mut out_size = 0 as c_int;
+        let name = ptr::null_mut() as *mut c_char;
+        let mut out_array = name as *mut _;
+        check_call!(ts::TVMFuncListGlobalNames(
+            &mut out_size as *mut _,
+            &mut out_array
+        ));
+        let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) };
+        Mutex::new(
+            names_list
+                .into_iter()
+                .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
+                .collect(),
+        )
+    };
+}
+
+/// Wrapper around TVM function handle which includes `is_global`
+/// indicating whether the function is global or not, `is_released`
+/// to hint dropping the function handle and `is_cloned` showing
+/// not to drop a cloned function from Rust side.
+/// The value of these fields can be accessed through their respective methods.
+#[derive(Debug, Hash)]
+pub struct Function {
+    pub(crate) handle: ts::TVMFunctionHandle,
+    // whether the registered function is global or not.
+    is_global: bool,
+    // whether the function has been dropped from frontend or not.
+    is_released: bool,
+    // whether the function has been cloned from frontend or not.
+    is_cloned: bool,
+}
+
+unsafe impl Send for Function {}
+unsafe impl Sync for Function {}
+
+impl Function {
+    pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self {
+        Function {
+            handle: handle,
+            is_global: is_global,
+            is_released: is_released,
+            is_cloned: false,
+        }
+    }
+
+    /// For a given function, it returns a function by name.
+    pub fn get<S: AsRef<str>>(name: S, is_global: bool) -> Option<&'static Function> {
+        let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
+        globals.get_mut(name.as_ref()).and_then(|maybe_func| {
+            if maybe_func.is_none() {
+                let name = CString::new(name.as_ref()).unwrap();
+                let mut handle = ptr::null_mut() as ts::TVMFunctionHandle;
+                check_call!(ts::TVMFuncGetGlobal(
+                    name.as_ptr() as *const c_char,
+                    &mut handle as *mut _
+                ));
+                maybe_func.replace(Function::new(
+                    handle, is_global, false, /* is_released */
+                ));
+            }
+            unsafe {
+                std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
+                    maybe_func.as_ref(),
+                )
+            }
+        })
+    }
+
+    /// Returns the underlying TVM function handle.
+    pub fn handle(&self) -> ts::TVMFunctionHandle {
+        self.handle
+    }
+
+    /// Returns `true` if the underlying TVM function is global and `false` otherwise.
+    pub fn is_global(&self) -> bool {
+        self.is_global
+    }
+
+    /// Returns `true` if the underlying TVM function has been released
+    /// from the frontend and `false` otherwise.
+    pub fn is_released(&self) -> bool {
+        self.is_released
+    }
+
+    /// Returns `true` if the underlying TVM function has been cloned
+    /// from the frontend and `false` otherwise.
+    pub fn is_cloned(&self) -> bool {
+        self.is_cloned
+    }
+}
+
+impl Clone for Function {
+    fn clone(&self) -> Function {
+        if !self.is_released && !self.is_cloned {
+            Self {
+                handle: self.handle,
+                is_global: self.is_global,
+                is_released: self.is_released,
+                is_cloned: true,
+            }
+        } else {
+            Function::new(self.handle, self.is_global, self.is_released)
+        }
+    }
+}
+
+impl Drop for Function {
+    fn drop(&mut self) {
+        if !self.is_released && !self.is_global && !self.is_cloned {
+            check_call!(ts::TVMFuncFree(self.handle));
+            self.is_released = true;
+        }
+    }
+}
+
+/// Function builder in order to create and call functions.
+///
+/// *Note:* Currently TVM functions accept *at most* one return value.
+#[derive(Debug, Clone, Default)]
+pub struct Builder<'a, 'm> {
+    pub func: Option<&'m Function>,
+    pub arg_buf: Option<Box<[TVMArgValue<'a>]>>,
+    pub ret_buf: Option<TVMRetValue>,
+}
+
+impl<'a, 'm> Builder<'a, 'm> {
+    pub fn new(
+        func: Option<&'m Function>,
+        arg_buf: Option<Box<[TVMArgValue<'a>]>>,
+        ret_buf: Option<TVMRetValue>,
+    ) -> Self {
+        Self {
+            func,
+            arg_buf,
+            ret_buf,
+        }
+    }
+
+    pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self {
+        self.func = Function::get(name, is_global);
+        self
+    }
+
+    /// Pushes a [`TVMArgValue`] into the function argument buffer.
+    pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self
+    where
+        TVMValue: From<&'b T>,
+        TVMTypeCode: From<&'b T>,
+    {
+        let tvm_arg = TVMArgValue::from(arg);
+        if self.arg_buf.is_none() {
+            self.arg_buf = Some(Box::new([tvm_arg]));
+        } else {
+            let new_arg_buf = self.arg_buf.take().map(|bbuf| {
+                let mut new_arg_buf = Vec::from(bbuf);
+                new_arg_buf.push(tvm_arg);
+                let new_len = new_arg_buf.len();
+                new_arg_buf.truncate(new_len);
+                new_arg_buf.into_boxed_slice()
+            });
+            self.arg_buf = new_arg_buf;
+        }
+        self
+    }
+
+    /// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
+    pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self
+    where
+        I: IntoIterator<Item = &'b T>,
+        TVMValue: From<&'b T>,
+        TVMTypeCode: From<&'b T>,
+    {
+        for arg in args {
+            self.arg(&arg);
+        }
+        self
+    }
+
+    /// Sets an output for a function that requirs a mutable output to be provided.
+    /// See the `basics` in tests for an example.
+    pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self>
+    where
+        TVMValue: From<&'b T>,
+        TVMTypeCode: From<&'b T>,
+    {
+        if self.ret_buf.is_none() {
+            let tvm_ret =
+                unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) };
+            self.ret_buf = Some(tvm_ret);
+        } else {
+            bail!(ErrorKind::AtMostOneReturn)
+        }
+        Ok(self)
+    }
+
+    /// Calls the function that created from `Builder`.
+    pub fn invoke(&mut self) -> Result<TVMRetValue> {
+        self.clone()(())
+    }
+}
+
+impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
+    type Output = Result<TVMRetValue>;
+    extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output {
+        if self.func.is_none() {
+            bail!("{}", ErrorKind::FunctionNotFound);
+        }
+
+        let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
+        let mut ret_type_code = 0 as c_int;
+        if self.arg_buf.is_some() {
+            let arg_buf = self.arg_buf?;
+            let mut num_args = arg_buf.len();
+            let mut values = arg_buf
+                .iter()
+                .map(|tav| tav.value.inner)
+                .collect::<Vec<ts::TVMValue>>();
+            let mut tcodes = arg_buf
+                .iter()
+                .map(|tav| tav.type_code as c_int)
+                .collect::<Vec<_>>();
+
+            if self.ret_buf.is_some() {
+                num_args = num_args + 1;
+                let ret_buf = self.ret_buf?;
+                let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf);
+                values.append(&mut vec![ret_val.inner]);
+                tcodes.append(&mut vec![ret_type_code as c_int]);
+            }
+
+            values.truncate(num_args);
+            tcodes.truncate(num_args);
+            check_call!(ts::TVMFuncCall(
+                self.func?.handle,
+                values.as_mut_ptr(),
+                tcodes.as_mut_ptr(),
+                num_args as c_int,
+                &mut ret_val as *mut _,
+                &mut ret_type_code as *mut _
+            ));
+        } else {
+            check_call!(ts::TVMFuncCall(
+                self.func?.handle,
+                ptr::null_mut(),
+                ptr::null_mut(),
+                0 as c_int,
+                &mut ret_val as *mut _,
+                &mut ret_type_code as *mut _
+            ));
+        }
+
+        let ret = unsafe {
+            TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into())
+        };
+        Ok(ret)
+    }
+}
+
+/// Converts a [`Function`] to builder. Currently, this is the best way to work with
+/// TVM functions.
+impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
+    fn from(func: &'m Function) -> Self {
+        Builder::new(Some(func), None, None)
+    }
+}
+
+/// Converts a mutable reference of a [`Module`] to [`Builder`].
+impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
+    fn from(module: &'m mut Module) -> Self {
+        Builder::new(module.entry(), None, None)
+    }
+}
+
+unsafe extern "C" fn tvm_callback(
+    args: *mut ts::TVMValue,
+    type_codes: *mut c_int,
+    num_args: c_int,
+    ret: ts::TVMRetValueHandle,
+    fhandle: *mut c_void,
+) -> c_int {
+    // turning off the incorrect linter complaints
+    #![allow(unused_assignments)]
+    let len = num_args as usize;
+    let args_list = slice::from_raw_parts_mut(args, len);
+    let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
+    let mut local_args: Vec<TVMArgValue> = Vec::new();
+    let mut value = mem::uninitialized::<ts::TVMValue>();
+    let mut tcode = mem::uninitialized::<c_int>();
+    let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
+    for i in 0..len {
+        value = args_list[i];
+        tcode = type_codes_list[i];
+        if tcode == ts::TVMTypeCode_kNodeHandle as c_int
+            || tcode == ts::TVMTypeCode_kFuncHandle as c_int
+            || tcode == ts::TVMTypeCode_kModuleHandle as c_int
+        {
+            check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
+        }
+        local_args.push(TVMArgValue::new(
+            TVMValue::new(value),
+            (tcode as i64).into(),
+        ));
+    }
+
+    let rv = match rust_fn(local_args.as_slice()) {
+        Ok(v) => v,
+        Err(msg) => {
+            crate::set_last_error(&msg);
+            return -1;
+        }
+    };
+
+    let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv);
+    let mut ret_val = ret_val.inner;
+    let mut ret_type_code = ret_tcode as c_int;
+    check_call!(ts::TVMCFuncSetReturn(
+        ret,
+        &mut ret_val as *mut _,
+        &mut ret_type_code as *mut _,
+        1 as c_int
+    ));
+    0
+}
+
+unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
+    let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
+    mem::drop(rust_fn);
+}
+
+fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function {
+    let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
+    let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue>;
+    check_call!(ts::TVMFuncCreateFromCFunc(
+        Some(tvm_callback),
+        resource_handle as *mut c_void,
+        Some(tvm_callback_finalizer),
+        &mut fhandle as *mut _
+    ));
+    Function::new(fhandle, false, false)
+}
+
+/// Registers a Rust function with signature
+/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>`
+/// as a **global TVM packed function** from frontend to TVM backend.
+///
+/// Use [`register_global_func`] if overriding an existing global TVM function
+/// is not required.
+///
+/// ## Example
+///
+/// ```
+/// use std::convert::TryInto;
+///
+/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+///     let mut ret = 0i64;
+///     for arg in args.iter() {
+///         let arg: i64 = arg.try_into()?;
+///         ret += arg;
+///     }
+///     let ret_val = TVMRetValue::from(&ret);
+///     Ok(ret_val)
+/// }
+///
+/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
+/// let mut registered = function::Builder::default();
+/// registered.get_function("mysum", true);
+/// assert!(registered.func.is_some());
+/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
+/// assert_eq!(ret, 60);
+/// ```
+pub fn register<S: AsRef<str>>(
+    f: fn(&[TVMArgValue]) -> Result<TVMRetValue>,
+    name: S,
+    override_: bool,
+) -> Result<()> {
+    let func = convert_to_tvm_func(f);
+    let name = CString::new(name.as_ref())?;
+    check_call!(ts::TVMFuncRegisterGlobal(
+        name.as_ref().as_ptr() as *const c_char,
+        func.handle(),
+        override_ as c_int
+    ));
+    mem::forget(name);
+    Ok(())
+}
+
+/// Convenient macro for registering functions from frontend to backend as global
+/// TVM packed functions without overriding. If overriding an existing function is needed
+/// use the [`function::register`] function instead.
+///
+/// ## Example
+///
+/// ```
+/// use std::convert::TryInto;
+///
+/// register_global_func! {
+///     fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+///         let mut ret = 0f64;
+///         for arg in args.iter() {
+///             let arg: f64 = arg.try_into()?;
+///             ret += arg;
+///         }
+///         let ret_val = TVMRetValue::from(&ret);
+///         Ok(ret_val)
+///     }
+/// }
+///
+/// let mut registered = function::Builder::default();
+/// registered.get_function("sum", true);
+/// assert!(registered.func.is_some());
+/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
+/// assert_eq!(ret, 60f64);
+/// ```
+#[macro_export]
+macro_rules! register_global_func {
+    {
+        $(#[$m:meta])*
+        fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue> {
+            $($code:tt)*
+        }
+    } => {{
+        $(#[$m])*
+        fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue> {
+            $($code)*
+        }
+
+        $crate::function::register($fn_name, stringify!($fn_name).to_owned(), false).unwrap();
+    }}
+}
+
+/// Convenient macro for calling TVM packed functions by providing a
+/// function identifier and some arguments. This macro outputs a `Result` type
+/// and let user to perform proper error handling.
+///
+/// **Note**: this macro does *not* expect an outside mutable output. To
+/// set mutable output use [`set_output`] directly in the builder pattern.
+///
+/// [`set_output`]:function/struct.Builder.html#method.set_output
+///
+/// ## Example
+///
+/// Instead of
+///
+/// ```
+/// function::Builder::from(func).arg(&a).arg(&b).invoke();
+/// ```
+///
+/// one can use
+///
+/// ```
+/// call_packed!(func, &a, &b);
+/// ```
+#[macro_export]
+macro_rules! call_packed {
+    ($fn_name:expr, $($arg:expr),*) => {{
+        let mut builder = $crate::function::Builder::from($fn_name);
+        $(
+            builder.arg($arg);
+        )*
+        builder.invoke()
+    }}
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    static CANARY: &str = "module._LoadFromFile";
+
+    #[test]
+    fn list_global_func() {
+        assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
+    }
+
+    #[test]
+    fn get_fn() {
+        assert!(Function::get(CANARY, true).is_some());
+        assert!(Function::get("does not exists!", false).is_none());
+    }
+
+    #[test]
+    fn provide_args() {
+        let mut func = Builder::default();
+        func.get_function("tvm.graph_runtime.remote_create", true)
+            .args(&[10, 20])
+            .arg(&"test".to_owned());
+        assert!(func.arg_buf.is_some());
+        assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3));
+    }
+}
diff --git a/rust/frontend/src/lib.rs b/rust/frontend/src/lib.rs
new file mode 100644 (file)
index 0000000..6e15e4f
--- /dev/null
@@ -0,0 +1,115 @@
+//! [TVM](https://github.com/dmlc/tvm) is a compiler stack for deep learning systems.
+//!
+//! This crate provides an idiomatic Rust API for TVM runtime frontend.
+//!
+//! One particular use case is that given optimized deep learning model artifacts,
+//! (compiled with TVM) which include a shared library
+//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
+//! in Rust idomatically to create a TVM Graph Runtime and
+//! run the model for some inputs and get the
+//! desired predictions *all in Rust*.
+//!
+//! Checkout the `examples` repository for more details.
+
+#![crate_name = "tvm_frontend"]
+#![recursion_limit = "1024"]
+#![allow(non_camel_case_types, unused_unsafe)]
+#![feature(
+    try_from,
+    try_trait,
+    fn_traits,
+    unboxed_closures,
+    box_syntax,
+    option_replace
+)]
+
+#[macro_use]
+extern crate error_chain;
+extern crate tvm_common as common;
+#[macro_use]
+extern crate lazy_static;
+extern crate ndarray as rust_ndarray;
+extern crate num_traits;
+
+use std::{
+    ffi::{CStr, CString},
+    str,
+};
+
+use crate::common::ffi::ts;
+
+// Macro to check the return call to TVM runtime shared library.
+macro_rules! check_call {
+    ($e:expr) => {{
+        if unsafe { $e } != 0 {
+            panic!("{}", $crate::get_last_error());
+        }
+    }};
+}
+
+/// Gets the last error message.
+pub fn get_last_error() -> &'static str {
+    unsafe {
+        match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
+            Ok(s) => s,
+            Err(_) => "Invalid UTF-8 message",
+        }
+    }
+}
+
+pub(crate) fn set_last_error(err: &Error) {
+    let c_string = CString::new(err.to_string()).unwrap();
+    unsafe {
+        ts::TVMAPISetLastError(c_string.as_ptr());
+    }
+}
+
+#[macro_use]
+pub mod function;
+pub mod bytearray;
+pub mod context;
+pub mod errors;
+pub mod module;
+pub mod ndarray;
+pub mod ty;
+pub mod value;
+
+pub use crate::{
+    bytearray::TVMByteArray,
+    common::{
+        errors as common_errors,
+        ty::TVMTypeCode,
+        value::{TVMArgValue, TVMRetValue, TVMValue},
+    },
+    context::{TVMContext, TVMDeviceType},
+    errors::*,
+    function::Function,
+    module::Module,
+    ndarray::NDArray,
+    ty::TVMType,
+};
+
+/// Outputs the current TVM version.
+pub fn version() -> &'static str {
+    match str::from_utf8(ts::TVM_VERSION) {
+        Ok(s) => s,
+        Err(_) => "Invalid UTF-8 string",
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn print_version() {
+        println!("TVM version: {}", version());
+    }
+
+    #[test]
+    fn set_error() {
+        let err = ErrorKind::EmptyArray;
+        set_last_error(&err.into());
+        assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
+    }
+}
diff --git a/rust/frontend/src/module.rs b/rust/frontend/src/module.rs
new file mode 100644 (file)
index 0000000..c12d9d4
--- /dev/null
@@ -0,0 +1,105 @@
+//! Provides the [`Module`] type and methods for working with runtime TVM modules.
+
+use std::{
+    convert::TryInto,
+    ffi::CString,
+    os::raw::{c_char, c_int},
+    path::Path,
+    ptr,
+};
+
+use crate::ts;
+
+use crate::{function::Function, ErrorKind, Result};
+
+const ENTRY_FUNC: &'static str = "__tvm_main__";
+
+/// Wrapper around TVM module handle which contains an entry function.
+/// The entry function can be applied to an imported module through [`entry_func`].
+/// Also [`is_released`] shows whether the module is dropped or not.
+///
+/// [`entry_func`]:struct.Module.html#method.entry_func
+/// [`is_released`]:struct.Module.html#method.is_released
+#[derive(Debug, Clone)]
+pub struct Module {
+    pub(crate) handle: ts::TVMModuleHandle,
+    is_released: bool,
+    entry_func: Option<Function>,
+}
+
+impl Module {
+    pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self {
+        Self {
+            handle,
+            is_released,
+            entry_func: None,
+        }
+    }
+
+    pub fn entry(&mut self) -> Option<&Function> {
+        if self.entry_func.is_none() {
+            self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
+        }
+        self.entry_func.as_ref()
+    }
+
+    /// Gets a function by name from a registered module.
+    pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
+        let name = CString::new(name)?;
+        let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
+        check_call!(ts::TVMModGetFunction(
+            self.handle,
+            name.as_ptr() as *const c_char,
+            query_import as c_int,
+            &mut fhandle as *mut _
+        ));
+        if fhandle.is_null() {
+            bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
+        } else {
+            Ok(Function::new(fhandle, false, false))
+        }
+    }
+
+    /// Imports a dependent module such as `.ptx` for gpu.
+    pub fn import_module(&self, dependent_module: Module) {
+        check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
+    }
+
+    /// Loads a module shared library from path.
+    pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> {
+        let ext = path.as_ref().extension()?.to_str()?;
+        let func = Function::get("module._LoadFromFile", true /* is_global */)
+            .expect("API function always exists");
+        let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?;
+        Ok(ret)
+    }
+
+    /// Checks if a target device is enabled for a module.
+    pub fn enabled(&self, target: &str) -> bool {
+        let func = Function::get("module._Enabled", true /* is_global */)
+            .expect("API function always exists");
+        // `unwrap` is safe here because if there is any error during the
+        // function call, it would occur in `call_packed!`.
+        let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap();
+        ret != 0
+    }
+
+    /// Returns the underlying module handle.
+    pub fn handle(&self) -> ts::TVMModuleHandle {
+        self.handle
+    }
+
+    /// Returns true if the underlying module has been dropped and false otherwise.
+    pub fn is_released(&self) -> bool {
+        self.is_released
+    }
+}
+
+impl Drop for Module {
+    fn drop(&mut self) {
+        if !self.is_released {
+            check_call!(ts::TVMModFree(self.handle));
+            self.is_released = true;
+        }
+    }
+}
diff --git a/rust/frontend/src/ndarray.rs b/rust/frontend/src/ndarray.rs
new file mode 100644 (file)
index 0000000..44dfcca
--- /dev/null
@@ -0,0 +1,363 @@
+//! This module implements the [`NDArray`] type for working with *TVM tensors* or
+//! coverting from a Rust's ndarray to TVM `NDArray`.
+//!
+//! One can create an empty NDArray given the shape, context and dtype using [`empty`].
+//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`].
+//! To copy an NDArray to different context use [`copy_to_ctx`].
+//!
+//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows:
+//!
+//! # Example
+//!
+//! ```
+//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
+//!     .unwrap()
+//!     .into_dyn(); // Rust's ndarray
+//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
+//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
+//! assert!(rnd.all_close(&a, 1e-8f32));
+//! ```
+//!
+//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/
+//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
+//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
+
+use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice};
+
+use crate::rust_ndarray::{Array, ArrayD};
+use num_traits::Num;
+
+use crate::ts;
+
+use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType};
+
+/// See the [`module-level documentation`](../ndarray/index.html) for more details.
+///
+/// Wrapper around TVM array handle.
+#[derive(Debug)]
+pub struct NDArray {
+    pub(crate) handle: ts::TVMArrayHandle,
+    is_view: bool,
+}
+
+impl NDArray {
+    pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self {
+        NDArray {
+            handle: handle,
+            is_view: is_view,
+        }
+    }
+
+    /// Returns the underlying array handle.
+    pub fn handle(&self) -> ts::TVMArrayHandle {
+        self.handle
+    }
+
+    pub fn is_view(&self) -> bool {
+        self.is_view
+    }
+
+    /// Returns the shape of the NDArray.
+    pub fn shape(&self) -> Option<&mut [usize]> {
+        let arr = unsafe { *(self.handle) };
+        if arr.shape.is_null() || arr.data.is_null() {
+            return None;
+        };
+        let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) };
+        Some(slc)
+    }
+
+    /// Returns the total number of entries of the NDArray.
+    pub fn size(&self) -> Option<usize> {
+        self.shape()
+            .map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e))
+    }
+
+    /// Returns the context which the NDArray was defined.
+    pub fn ctx(&self) -> TVMContext {
+        unsafe { (*self.handle).ctx.into() }
+    }
+
+    /// Returns the type of the entries of the NDArray.
+    pub fn dtype(&self) -> TVMType {
+        unsafe { (*self.handle).dtype.into() }
+    }
+
+    /// Returns the number of dimensions of the NDArray.
+    pub fn ndim(&self) -> usize {
+        unsafe { (*self.handle).ndim as usize }
+    }
+
+    /// Returns the strides of the underlying NDArray.
+    pub fn strides(&self) -> Option<&[usize]> {
+        unsafe {
+            let sz = self.ndim() * mem::size_of::<usize>();
+            let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz);
+            Some(slc)
+        }
+    }
+
+    /// Shows whether the underlying ndarray is contiguous in memory or not.
+    pub fn is_contiguous(&self) -> Result<bool> {
+        Ok(match self.strides() {
+            None => true,
+            Some(strides) => {
+                // MissingShapeError in case shape is not determined
+                self.shape()?
+                    .iter()
+                    .zip(strides)
+                    .rfold(
+                        (true, 1),
+                        |(is_contig, expected_stride), (shape, stride)| {
+                            (
+                                is_contig && *stride == expected_stride,
+                                expected_stride * (*shape as usize),
+                            )
+                        },
+                    )
+                    .0
+            }
+        })
+    }
+
+    pub fn byte_offset(&self) -> isize {
+        unsafe { (*self.handle).byte_offset as isize }
+    }
+
+    /// Flattens the NDArray to a `Vec` of the same type in cpu.
+    ///
+    /// ## Example
+    ///
+    /// ```
+    /// let shape = &mut [4];
+    /// let mut data = vec![1i32, 2, 3, 4];
+    /// let ctx = TVMContext::cpu(0);
+    /// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
+    /// ndarray.copy_from_buffer(&mut data);
+    /// assert_eq!(ndarray.shape(), Some(shape));
+    /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
+    /// ```
+    pub fn to_vec<T>(&self) -> Result<Vec<T>> {
+        if self.shape().is_none() {
+            bail!("{}", ErrorKind::EmptyArray);
+        }
+        let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype());
+        let target = self.copy_to_ndarray(earr)?;
+        let arr = unsafe { *(target.handle) };
+        let sz = self.size()? as usize;
+        let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
+        unsafe {
+            v.as_mut_ptr()
+                .copy_from_nonoverlapping(arr.data as *const T, sz);
+            v.set_len(sz);
+        }
+        Ok(v)
+    }
+
+    /// Converts the NDArray to [`TVMByteArray`].
+    pub fn to_bytearray(&self) -> Result<TVMByteArray> {
+        let v = self.to_vec::<u8>()?;
+        Ok(TVMByteArray::from(&v))
+    }
+
+    /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
+    ///
+    /// ## Example
+    ///
+    /// ```
+    /// let shape = &mut [2];
+    /// let mut data = vec![1f32, 2];
+    /// let ctx = TVMContext::gpu(0);
+    /// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
+    /// ndarray.copy_from_buffer(&mut data);
+    /// ```
+    ///
+    /// *Note*: if something goes wrong during the copy, it will panic
+    /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
+    pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
+        check_call!(ts::TVMArrayCopyFromBytes(
+            self.handle,
+            data.as_ptr() as *mut _,
+            data.len() * mem::size_of::<T>()
+        ));
+    }
+
+    /// Copies the NDArray to another target NDArray.
+    pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
+        if self.dtype() != target.dtype() {
+            bail!(
+                "{}",
+                ErrorKind::TypeMismatch(
+                    format!("{}", self.dtype().to_string()),
+                    format!("{}", target.dtype().to_string()),
+                )
+            );
+        }
+        check_call!(ts::TVMArrayCopyFromTo(
+            self.handle,
+            target.handle,
+            ptr::null_mut() as ts::TVMStreamHandle
+        ));
+        Ok(target)
+    }
+
+    /// Copies the NDArray to a target context.
+    pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray> {
+        let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype());
+        let copy = self.copy_to_ndarray(tmp)?;
+        Ok(copy)
+    }
+
+    /// Converts a Rust's ndarray to TVM NDArray.
+    pub fn from_rust_ndarray<T: Num32 + Copy>(
+        rnd: &ArrayD<T>,
+        ctx: TVMContext,
+        dtype: TVMType,
+    ) -> Result<Self> {
+        let mut shape = rnd.shape().to_vec();
+        let mut nd = NDArray::empty(&mut shape, ctx, dtype);
+        let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
+        nd.copy_from_buffer(buf.as_slice_mut()?);
+        Ok(nd)
+    }
+
+    /// Allocates and creates an empty NDArray given the shape, context and dtype.
+    pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray {
+        let mut handle = ptr::null_mut() as ts::TVMArrayHandle;
+        check_call!(ts::TVMArrayAlloc(
+            shape.as_ptr() as *const i64,
+            shape.len() as c_int,
+            dtype.inner.code as c_int,
+            dtype.inner.bits as c_int,
+            dtype.inner.lanes as c_int,
+            ctx.device_type.0 as c_int,
+            ctx.device_id as c_int,
+            &mut handle as *mut _,
+        ));
+        NDArray::new(handle, false)
+    }
+}
+
+macro_rules! impl_from_ndarray_rustndarray {
+    ($type:ty, $type_name:tt) => {
+        impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
+            type Error = Error;
+            fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
+                if nd.shape().is_none() {
+                    bail!("{}", ErrorKind::EmptyArray);
+                }
+                assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
+                Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
+            }
+        }
+
+        impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
+            type Error = Error;
+            fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
+                if nd.shape().is_none() {
+                    bail!("{}", ErrorKind::EmptyArray);
+                }
+                assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
+                Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
+            }
+        }
+    };
+}
+
+impl_from_ndarray_rustndarray!(i32, "int");
+impl_from_ndarray_rustndarray!(u32, "uint");
+impl_from_ndarray_rustndarray!(f32, "float");
+
+impl Drop for NDArray {
+    fn drop(&mut self) {
+        if !self.is_view {
+            check_call!(ts::TVMArrayFree(self.handle));
+        }
+    }
+}
+
+mod sealed {
+    /// Private trait to prevent other traits from being implemeneted in downstream crates.
+    pub trait Sealed {}
+}
+
+/// A trait for the supported 32-bits numerical types in frontend.
+pub trait Num32: Num + sealed::Sealed {
+    const BITS: u8 = 32;
+}
+
+macro_rules! impl_num32 {
+    ($($type:ty),+) => {
+        $(
+            impl sealed::Sealed for $type {}
+            impl Num32 for $type {}
+        )+
+    };
+}
+
+impl_num32!(i32, u32, f32);
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn basics() {
+        let shape = &mut [1, 2, 3];
+        let ctx = TVMContext::cpu(0);
+        let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
+        assert_eq!(ndarray.shape().unwrap(), shape);
+        assert_eq!(
+            ndarray.size().unwrap(),
+            shape.to_vec().into_iter().product()
+        );
+        assert_eq!(ndarray.ndim(), 3);
+        assert!(ndarray.strides().is_none());
+        assert_eq!(ndarray.byte_offset(), 0);
+    }
+
+    #[test]
+    fn copy() {
+        let shape = &mut [4];
+        let mut data = vec![1i32, 2, 3, 4];
+        let ctx = TVMContext::cpu(0);
+        let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
+        assert!(ndarray.to_vec::<i32>().is_ok());
+        ndarray.copy_from_buffer(&mut data);
+        assert_eq!(ndarray.shape().unwrap(), shape);
+        assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
+        assert_eq!(ndarray.ndim(), 1);
+        assert!(ndarray.is_contiguous().is_ok());
+        assert_eq!(ndarray.byte_offset(), 0);
+        let mut shape = vec![4];
+        let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32"));
+        let nd = ndarray.copy_to_ndarray(e);
+        assert!(nd.is_ok());
+        assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
+    }
+
+    #[test]
+    #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
+    fn copy_wrong_dtype() {
+        let mut shape = vec![4];
+        let mut data = vec![1f32, 2., 3., 4.];
+        let ctx = TVMContext::cpu(0);
+        let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32"));
+        nd_float.copy_from_buffer(&mut data);
+        let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32"));
+        nd_float.copy_to_ndarray(empty_int).unwrap();
+    }
+
+    #[test]
+    fn rust_ndarray() {
+        let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
+            .unwrap()
+            .into_dyn();
+        let nd =
+            NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+        assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
+        let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
+        assert!(rnd.all_close(&a, 1e-8f32));
+    }
+}
diff --git a/rust/frontend/src/ty.rs b/rust/frontend/src/ty.rs
new file mode 100644 (file)
index 0000000..7e912a5
--- /dev/null
@@ -0,0 +1,150 @@
+//! This module implements the required conversions from Rust types to TVM types.
+//!
+//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
+//! and 64-bits pointers are supported.
+
+use std::{
+    fmt::{self, Display, Formatter},
+    ops::{Deref, DerefMut},
+};
+
+use crate::ts;
+
+use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
+
+macro_rules! impl_prim_type {
+    ($type:ty, $variant:ident) => {
+        impl From<$type> for TVMTypeCode {
+            fn from(_arg: $type) -> Self {
+                TVMTypeCode::$variant
+            }
+        }
+
+        impl<'a> From<&'a $type> for TVMTypeCode {
+            fn from(_arg: &$type) -> Self {
+                TVMTypeCode::$variant
+            }
+        }
+
+        impl<'a> From<&'a mut $type> for TVMTypeCode {
+            fn from(_arg: &mut $type) -> Self {
+                TVMTypeCode::$variant
+            }
+        }
+    };
+}
+
+impl_prim_type!(TVMDeviceType, kDLInt);
+impl_prim_type!(TVMContext, kTVMContext);
+impl_prim_type!(TVMType, kTVMType);
+impl_prim_type!(Function, kFuncHandle);
+impl_prim_type!(Module, kModuleHandle);
+impl_prim_type!(NDArray, kArrayHandle);
+impl_prim_type!(TVMByteArray, kBytes);
+
+/// See the [module-level documentation](../ty/index.html) for more details.
+///
+/// Wrapper around underlying TVMType
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub struct TVMType {
+    // inner fields are (code: u8, bits: u8, lanes: u16)
+    pub inner: ts::TVMType,
+}
+
+impl TVMType {
+    pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
+        TVMType {
+            inner: ts::TVMType {
+                code: type_code,
+                bits: bits,
+                lanes: lanes,
+            },
+        }
+    }
+}
+
+/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
+/// such as "int32", "float32" or with lane "float32x1".
+impl<'a> From<&'a str> for TVMType {
+    fn from(type_str: &'a str) -> Self {
+        if type_str == "bool" {
+            return TVMType::new(1, 1, 1);
+        }
+
+        let mut type_lanes = type_str.split("x");
+        let typ = type_lanes.next().expect("Missing dtype");
+        let lanes = type_lanes
+            .next()
+            .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
+            .unwrap_or(1);
+        let (type_name, bits) = match typ.find(char::is_numeric) {
+            Some(idx) => {
+                let (name, bits_str) = typ.split_at(idx);
+                (
+                    name,
+                    u8::from_str_radix(bits_str, 10)
+                        .expect(&format!("Bad dtype bits: {}", bits_str)),
+                )
+            }
+            None => (typ, 32),
+        };
+
+        let type_code = match type_name {
+            "int" => 0,
+            "uint" => 1,
+            "float" => 2,
+            "handle" => 3,
+            _ => unimplemented!(),
+        };
+
+        TVMType::new(type_code, bits, lanes)
+    }
+}
+
+impl Display for TVMType {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        let ts::TVMType { code, bits, lanes } = self.inner;
+        if bits == 1 && lanes == 1 {
+            return write!(f, "bool");
+        }
+        let mut tcode_str = match code {
+            0 => "int",
+            1 => "uint",
+            2 => "float",
+            4 => "handle",
+            _ => "Unknown",
+        }
+        .to_string();
+
+        tcode_str += &bits.to_string();
+        if lanes > 1 {
+            tcode_str += &format!("x{}", lanes.to_string());
+        }
+        f.write_str(&tcode_str)
+    }
+}
+
+impl From<TVMType> for ts::DLDataType {
+    fn from(dtype: TVMType) -> Self {
+        dtype.inner
+    }
+}
+
+impl From<ts::DLDataType> for TVMType {
+    fn from(dtype: ts::DLDataType) -> Self {
+        Self::new(dtype.code, dtype.bits, dtype.lanes)
+    }
+}
+
+impl Deref for TVMType {
+    type Target = ts::TVMType;
+    fn deref(&self) -> &Self::Target {
+        &self.inner
+    }
+}
+
+impl DerefMut for TVMType {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        &mut self.inner
+    }
+}
diff --git a/rust/frontend/src/value.rs b/rust/frontend/src/value.rs
new file mode 100644 (file)
index 0000000..9fad7de
--- /dev/null
@@ -0,0 +1,241 @@
+//! This module implements [`TVMArgValue`] and [`TVMRetValue`] types
+//! and their conversions needed for the types used in frontend crate.
+//! `TVMRetValue` is the owned version of `TVMPODValue`.
+
+use std::{convert::TryFrom, mem, os::raw::c_void};
+
+use crate::{
+    common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext,
+    TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue,
+};
+
+macro_rules! impl_tvm_val_from_handle {
+    ($($ty:ty),+) => {
+        $(
+            impl<'a> From<&'a $ty> for TVMValue {
+                fn from(arg: &$ty) -> Self {
+                    let inner = ts::TVMValue {
+                        v_handle: arg.handle as *mut _ as *mut c_void,
+                    };
+                    Self::new(inner)
+                }
+            }
+        )+
+    }
+}
+
+impl_tvm_val_from_handle!(Module, Function, NDArray);
+
+impl<'a> From<&'a TVMType> for TVMValue {
+    fn from(ty: &TVMType) -> Self {
+        let inner = ts::TVMValue { v_type: ty.inner };
+        Self::new(inner)
+    }
+}
+
+impl<'a> From<&'a TVMContext> for TVMValue {
+    fn from(ctx: &TVMContext) -> Self {
+        let inner = ts::TVMValue {
+            v_ctx: ctx.clone().into(),
+        };
+        Self::new(inner)
+    }
+}
+
+impl<'a> From<&'a TVMDeviceType> for TVMValue {
+    fn from(dev: &TVMDeviceType) -> Self {
+        let inner = ts::TVMValue {
+            v_int64: dev.0 as i64,
+        };
+        Self::new(inner)
+    }
+}
+
+impl<'a> From<&'a TVMByteArray> for TVMValue {
+    fn from(barr: &TVMByteArray) -> Self {
+        let inner = ts::TVMValue {
+            v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void,
+        };
+        Self::new(inner)
+    }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+        if arg.type_code == TVMTypeCode::kArrayHandle {
+            let handle = unsafe { arg.value.inner.v_handle };
+            let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) };
+            Ok(Self::new(arr_handle, true))
+        } else {
+            bail!(ErrorKind::TryFromTVMArgValueError(
+                stringify!(NDArray).to_string(),
+                arg.type_code.to_string()
+            ))
+        }
+    }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+        if arg.type_code == TVMTypeCode::kModuleHandle {
+            let handle = unsafe { arg.value.inner.v_handle };
+            Ok(Self::new(handle, false))
+        } else {
+            bail!(ErrorKind::TryFromTVMArgValueError(
+                stringify!(Module).to_string(),
+                arg.type_code.to_string()
+            ))
+        }
+    }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+        if arg.type_code == TVMTypeCode::kBytes {
+            unsafe {
+                let barr_ptr =
+                    mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle);
+                Ok(Self::new(*barr_ptr))
+            }
+        } else {
+            bail!(ErrorKind::TryFromTVMArgValueError(
+                stringify!(TVMByteArray).to_string(),
+                arg.type_code.to_string()
+            ))
+        }
+    }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+        if arg.type_code == TVMTypeCode::kTVMType {
+            let ty = unsafe { arg.value.inner.v_type };
+            Ok(TVMType::from(ty))
+        } else {
+            bail!(ErrorKind::TryFromTVMArgValueError(
+                stringify!(TVMType).to_string(),
+                arg.type_code.to_string()
+            ))
+        }
+    }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext {
+    type Error = Error;
+    fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+        if arg.type_code == TVMTypeCode::kTVMContext {
+            let ty = unsafe { arg.value.inner.v_ctx };
+            Ok(TVMContext::from(ty))
+        } else {
+            bail!(ErrorKind::TryFromTVMArgValueError(
+                stringify!(TVMContext).to_string(),
+                arg.type_code.to_string()
+            ))
+        }
+    }
+}
+
+macro_rules! impl_boxed_ret_value {
+    ($type:ty, $code:expr) => {
+        impl From<$type> for TVMRetValue {
+            fn from(val: $type) -> Self {
+                TVMRetValue {
+                    prim_value: 0,
+                    box_value: box val,
+                    type_code: $code,
+                }
+            }
+        }
+        impl TryFrom<TVMRetValue> for $type {
+            type Error = Error;
+            fn try_from(ret: TVMRetValue) -> Result<$type> {
+                if let Ok(val) = ret.box_value.downcast::<$type>() {
+                    Ok(*val)
+                } else {
+                    bail!(ErrorKind::TryFromTVMRetValueError(
+                        stringify!($type).to_string(),
+                        ret.type_code.to_string()
+                    ))
+                }
+            }
+        }
+    };
+}
+
+impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType);
+impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext);
+impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes);
+
+impl TryFrom<TVMRetValue> for Module {
+    type Error = Error;
+    fn try_from(ret: TVMRetValue) -> Result<Module> {
+        if let Ok(handle) = ret.box_value.downcast::<ts::TVMModuleHandle>() {
+            Ok(Module::new(*handle, false))
+        } else {
+            bail!(ErrorKind::TryFromTVMRetValueError(
+                stringify!(TVMTypeCode::kModuleHandle).to_string(),
+                ret.type_code.to_string()
+            ))
+        }
+    }
+}
+
+impl TryFrom<TVMRetValue> for Function {
+    type Error = Error;
+    fn try_from(ret: TVMRetValue) -> Result<Function> {
+        if let Ok(handle) = ret.box_value.downcast::<ts::TVMFunctionHandle>() {
+            Ok(Function::new(*handle, false, false))
+        } else {
+            bail!(ErrorKind::TryFromTVMRetValueError(
+                stringify!(TVMTypeCode::kFuncHandle).to_string(),
+                ret.type_code.to_string()
+            ))
+        }
+    }
+}
+
+impl TryFrom<TVMRetValue> for NDArray {
+    type Error = Error;
+    fn try_from(ret: TVMRetValue) -> Result<NDArray> {
+        if let Ok(handle) = ret.box_value.downcast::<ts::TVMArrayHandle>() {
+            Ok(NDArray::new(*handle, false))
+        } else {
+            bail!(ErrorKind::TryFromTVMRetValueError(
+                stringify!(TVMTypeCode::kArrayHandle).to_string(),
+                ret.type_code.to_string()
+            ))
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use std::convert::TryInto;
+
+    #[test]
+    fn bytearray() {
+        let w = vec![1u8, 2, 3, 4, 5];
+        let v = TVMByteArray::from(&w);
+        let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
+        assert_eq!(tvm.data(), w.iter().map(|e| *e as i8).collect::<Vec<i8>>());
+    }
+
+    #[test]
+    fn ty() {
+        let t = TVMType::from("int32");
+        let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
+        assert_eq!(tvm, t);
+    }
+
+    #[test]
+    fn ctx() {
+        let c = TVMContext::from("gpu");
+        let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap();
+        assert_eq!(tvm, c);
+    }
+}
diff --git a/rust/frontend/tests/basics/.gitignore b/rust/frontend/tests/basics/.gitignore
new file mode 100644 (file)
index 0000000..10a4b22
--- /dev/null
@@ -0,0 +1,7 @@
+/target
+**/*.rs.bk
+Cargo.lock
+*.o
+*.so
+*.ptx
+*.json
diff --git a/rust/frontend/tests/basics/Cargo.toml b/rust/frontend/tests/basics/Cargo.toml
new file mode 100644 (file)
index 0000000..496c0dd
--- /dev/null
@@ -0,0 +1,15 @@
+[package]
+name = "basics"
+version = "0.0.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+build = "build.rs"
+
+[dependencies]
+ndarray = "0.12.1"
+tvm-frontend = { path = "../../" }
+
+[features]
+default = ["cpu"]
+cpu = []
+gpu = []
diff --git a/rust/frontend/tests/basics/build.rs b/rust/frontend/tests/basics/build.rs
new file mode 100644 (file)
index 0000000..67c21e0
--- /dev/null
@@ -0,0 +1,27 @@
+fn main() {
+    let out_dir = std::env::var("OUT_DIR").unwrap();
+
+    let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py"))
+        .args(&[
+            if cfg!(feature = "cpu") {
+                "llvm"
+            } else {
+                "cuda"
+            },
+            &std::env::var("OUT_DIR").unwrap(),
+        ])
+        .output()
+        .expect("Failed to execute command");
+    assert!(
+        std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(),
+        "Could not build tvm lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    println!("cargo:rustc-link-search=native={}", out_dir);
+}
diff --git a/rust/frontend/tests/basics/src/main.rs b/rust/frontend/tests/basics/src/main.rs
new file mode 100644 (file)
index 0000000..69b948e
--- /dev/null
@@ -0,0 +1,35 @@
+extern crate ndarray as rust_ndarray;
+extern crate tvm_frontend as tvm;
+
+use tvm::*;
+
+fn main() {
+    let shape = &mut [2];
+    let mut data = vec![3f32, 4.0];
+
+    let (ctx, ctx_name) = if cfg!(feature = "cpu") {
+        (TVMContext::cpu(0), "cpu")
+    } else {
+        (TVMContext::gpu(0), "gpu")
+    };
+    let dtype = TVMType::from("float32");
+    let mut arr = NDArray::empty(shape, ctx, dtype);
+    arr.copy_from_buffer(data.as_mut_slice());
+    let mut ret = NDArray::empty(shape, ctx, dtype);
+    let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
+    if !fadd.enabled(ctx_name) {
+        return;
+    }
+    if cfg!(feature = "gpu") {
+        fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap());
+    }
+    function::Builder::from(&mut fadd)
+        .arg(&arr)
+        .arg(&arr)
+        .set_output(&mut ret)
+        .unwrap()
+        .invoke()
+        .unwrap();
+
+    assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
+}
diff --git a/rust/frontend/tests/basics/src/tvm_add.py b/rust/frontend/tests/basics/src/tvm_add.py
new file mode 100755 (executable)
index 0000000..2f3b7c8
--- /dev/null
@@ -0,0 +1,33 @@
+#!/usr/bin/env python3
+
+import os.path as osp
+import sys
+
+import tvm
+from tvm.contrib import cc
+
+
+def main(target, out_dir):
+    n = tvm.var('n')
+    A = tvm.placeholder((n,), name='A')
+    B = tvm.placeholder((n,), name='B')
+    C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
+    s = tvm.create_schedule(C.op)
+
+    if target == 'cuda':
+        bx, tx = s[C].split(C.op.axis[0], factor=64)
+        s[C].bind(bx, tvm.thread_axis('blockIdx.x'))
+        s[C].bind(tx, tvm.thread_axis('threadIdx.x'))
+
+    fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd')
+
+    fadd.save(osp.join(out_dir, 'test_add.o'))
+    if target == 'cuda':
+        fadd.imported_modules[0].save(os.path.join(out_dir, 'test_add.ptx'))
+    cc.create_shared(
+        osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')])
+
+
+if __name__ == '__main__':
+    main(sys.argv[1], sys.argv[2])
+
diff --git a/rust/frontend/tests/callback/Cargo.toml b/rust/frontend/tests/callback/Cargo.toml
new file mode 100644 (file)
index 0000000..1795c57
--- /dev/null
@@ -0,0 +1,8 @@
+[package]
+name = "callback"
+version = "0.0.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray = "0.12.1"
+tvm-frontend = { path = "../../" }
diff --git a/rust/frontend/tests/callback/src/bin/array.rs b/rust/frontend/tests/callback/src/bin/array.rs
new file mode 100644 (file)
index 0000000..81dcadc
--- /dev/null
@@ -0,0 +1,44 @@
+#![feature(extern_crate_item_prelude, try_from)]
+#![allow(unused_imports)]
+
+extern crate ndarray as rust_ndarray;
+#[macro_use]
+extern crate tvm_frontend as tvm;
+
+use rust_ndarray::ArrayD;
+use std::convert::{TryFrom, TryInto};
+
+use tvm::*;
+
+fn main() {
+    register_global_func! {
+        fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+            let mut ret = 0f32;
+            let shape = &mut [2];
+            for arg in args.iter() {
+                let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+                let arg: NDArray = arg.try_into()?;
+                let arr = arg.copy_to_ndarray(e)?;
+                let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
+                ret += rnd.scalar_sum();
+            }
+            Ok(TVMRetValue::from(ret))
+        }
+    }
+
+    let shape = &mut [2];
+    let mut data = vec![3f32, 4.0];
+    let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+    arr.copy_from_buffer(data.as_mut_slice());
+
+    let mut registered = function::Builder::default();
+    let ret: f32 = registered
+        .get_function("sum", true)
+        .arg(&arr)
+        .arg(&arr)
+        .invoke()
+        .unwrap()
+        .try_into()
+        .unwrap();
+    assert_eq!(ret, 14f32);
+}
diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/frontend/tests/callback/src/bin/error.rs
new file mode 100644 (file)
index 0000000..f40f0f1
--- /dev/null
@@ -0,0 +1,43 @@
+#![feature(extern_crate_item_prelude, panic_info_message)]
+#![allow(unused_imports)]
+
+use std::panic;
+
+#[macro_use]
+extern crate tvm_frontend as tvm;
+
+use tvm::*;
+
+fn main() {
+    register_global_func! {
+        fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> {
+            Err(ErrorKind::TypeMismatch(
+                format!("{}", "i64".to_string()),
+                format!("{}", "f64".to_string()),
+            ).into())
+        }
+    }
+
+    let mut registered = function::Builder::default();
+    registered.get_function("error", true);
+    assert!(registered.func.is_some());
+    registered.args(&[10, 20]);
+
+    println!("expected error message is:");
+    panic::set_hook(Box::new(|panic_info| {
+        if let Some(msg) = panic_info.message() {
+            println!("{:?}", msg);
+        }
+        if let Some(location) = panic_info.location() {
+            println!(
+                "panic occurred in file '{}' at line {}",
+                location.file(),
+                location.line()
+            );
+        } else {
+            println!("panic occurred but can't get location information");
+        }
+    }));
+
+    let _result = registered.invoke();
+}
diff --git a/rust/frontend/tests/callback/src/bin/float.rs b/rust/frontend/tests/callback/src/bin/float.rs
new file mode 100644 (file)
index 0000000..3070552
--- /dev/null
@@ -0,0 +1,32 @@
+#![feature(extern_crate_item_prelude, try_from)]
+#![allow(unused_imports)]
+
+#[macro_use]
+extern crate tvm_frontend as tvm;
+
+use std::convert::TryInto;
+use tvm::*;
+
+fn main() {
+    register_global_func! {
+        fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+            let mut ret = 0.0;
+            for arg in args.iter() {
+                let val: f64 = arg.try_into()?;
+                ret += val;
+            }
+            Ok(TVMRetValue::from(&ret))
+        }
+    }
+
+    let mut registered = function::Builder::default();
+    registered.get_function("sum", true);
+    assert!(registered.func.is_some());
+    let ret: f64 = registered
+        .args(&[10.0f64, 20.0, 30.0])
+        .invoke()
+        .unwrap()
+        .try_into()
+        .unwrap();
+    assert_eq!(ret, 60f64);
+}
diff --git a/rust/frontend/tests/callback/src/bin/int.rs b/rust/frontend/tests/callback/src/bin/int.rs
new file mode 100644 (file)
index 0000000..3018822
--- /dev/null
@@ -0,0 +1,31 @@
+#![feature(extern_crate_item_prelude, try_from)]
+#![allow(unused_imports)]
+
+extern crate tvm_frontend as tvm;
+
+use std::convert::TryInto;
+use tvm::*;
+
+fn main() {
+    fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+        let mut ret = 0i64;
+        for arg in args.iter() {
+            let val: i64 = arg.try_into()?;
+            ret += val;
+        }
+        Ok(TVMRetValue::from(&ret))
+    }
+
+    tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
+
+    let mut registered = function::Builder::default();
+    registered.get_function("mysum", true);
+    assert!(registered.func.is_some());
+    let ret: i64 = registered
+        .args(&[10, 20, 30])
+        .invoke()
+        .unwrap()
+        .try_into()
+        .unwrap();
+    assert_eq!(ret, 60);
+}
diff --git a/rust/frontend/tests/callback/src/bin/string.rs b/rust/frontend/tests/callback/src/bin/string.rs
new file mode 100644 (file)
index 0000000..eafee31
--- /dev/null
@@ -0,0 +1,34 @@
+#![feature(extern_crate_item_prelude, try_from)]
+#![allow(unused_imports)]
+
+#[macro_use]
+extern crate tvm_frontend as tvm;
+use std::convert::TryInto;
+use tvm::*;
+
+// FIXME
+fn main() {
+    register_global_func! {
+        fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+            let mut ret = "".to_string();
+            for arg in args.iter() {
+                let val: String = arg.try_into()?;
+                ret += val.as_str();
+            }
+            Ok(TVMRetValue::from(ret))
+        }
+    }
+    let mut registered = function::Builder::default();
+    registered.get_function("concate_str", true);
+    assert!(registered.func.is_some());
+    let a = "a".to_string();
+    let b = "b".to_string();
+    let c = "c".to_string();
+    let ret: String = registered
+        .args(&[a, b, c])
+        .invoke()
+        .unwrap()
+        .try_into()
+        .unwrap();
+    assert_eq!(ret, "abc".to_owned());
+}
diff --git a/rust/runtime/.gitignore b/rust/runtime/.gitignore
new file mode 100644 (file)
index 0000000..230ab66
--- /dev/null
@@ -0,0 +1,3 @@
+Cargo.lock
+target/
+**/*.rs.bk
diff --git a/rust/runtime/.travis.yml b/rust/runtime/.travis.yml
new file mode 100644 (file)
index 0000000..63a3d02
--- /dev/null
@@ -0,0 +1,5 @@
+language: rust
+rust:
+  - nightly
+matrix:
+  fast_finish: true
diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml
new file mode 100644 (file)
index 0000000..d48c0d9
--- /dev/null
@@ -0,0 +1,29 @@
+[package]
+name = "tvm-runtime"
+version = "0.1.0"
+license = "Apache-2.0"
+description = "A static TVM runtime"
+repository = "https://github.com/dmlc/tvm"
+readme = "README.md"
+keywords = ["tvm", "nnvm"]
+categories = ["api-bindings", "science"]
+authors = ["TVM Contributors"]
+
+[features]
+default = ["nom/std"]
+sgx = ["nom/alloc"]
+
+[dependencies]
+bounded-spsc-queue = "0.4.0"
+error-chain = { version = "0.12.0", default-features = false }
+itertools = "0.7.8"
+lazy_static = "1.1.0"
+ndarray = "0.11.2"
+nom = {version = "4.0.0", default-features = false }
+serde = "1.0.59"
+serde_derive = "1.0.79"
+serde_json = "1.0.17"
+tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] }
+
+[target.'cfg(not(target_env = "sgx"))'.dependencies]
+num_cpus = "1.8.0"
diff --git a/rust/runtime/src/allocator.rs b/rust/runtime/src/allocator.rs
new file mode 100644 (file)
index 0000000..5f77037
--- /dev/null
@@ -0,0 +1,52 @@
+#[cfg(target_env = "sgx")]
+use alloc::alloc::{self, Layout};
+#[cfg(not(target_env = "sgx"))]
+use std::alloc::{self, Layout};
+
+use crate::errors::*;
+
+const DEFAULT_ALIGN_BYTES: usize = 4;
+
+#[derive(PartialEq, Eq)]
+pub struct Allocation {
+    layout: Layout,
+    ptr: *mut u8,
+}
+
+impl Allocation {
+    /// Allocates a chunk of memory of `size` bytes with optional alignment.
+    pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
+        let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
+        let layout = Layout::from_size_align(size, alignment)?;
+        let ptr = unsafe { alloc::alloc(layout.clone()) };
+        if ptr.is_null() {
+            alloc::handle_alloc_error(layout);
+        }
+        Ok(Self {
+            ptr: ptr,
+            layout: layout,
+        })
+    }
+
+    pub fn as_mut_ptr(&self) -> *mut u8 {
+        self.ptr
+    }
+
+    /// Returns the size of the Allocation in bytes.
+    pub fn size(&self) -> usize {
+        self.layout.size()
+    }
+
+    /// Returns the byte alignment of the Allocation.
+    pub fn align(&self) -> usize {
+        self.layout.align()
+    }
+}
+
+impl Drop for Allocation {
+    fn drop(&mut self) {
+        unsafe {
+            alloc::dealloc(self.ptr, self.layout.clone());
+        }
+    }
+}
diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs
new file mode 100644 (file)
index 0000000..5c49515
--- /dev/null
@@ -0,0 +1,507 @@
+use std::{
+    any::TypeId,
+    convert::TryFrom,
+    mem,
+    ops::{Deref, DerefMut},
+    os::raw::{c_int, c_void},
+    ptr, slice,
+};
+
+use ndarray;
+
+use crate::{
+    allocator::Allocation,
+    errors::*,
+    ffi::runtime::{
+        DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
+        DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor,
+    },
+};
+
+/// A `Storage` is a container which holds `Tensor` data.
+#[derive(PartialEq)]
+pub enum Storage<'a> {
+    /// A `Storage` which owns its contained bytes.
+    Owned(Allocation),
+
+    /// A view of an existing `Storage`.
+    View(&'a mut [u8], usize), // ptr, align
+}
+
+impl<'a> Storage<'a> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
+        Ok(Storage::Owned(Allocation::new(size, align)?))
+    }
+
+    pub fn as_mut_ptr(&self) -> *mut u8 {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_ptr(),
+            Storage::View(slice, _) => slice.as_ptr() as *mut u8,
+        }
+    }
+
+    pub fn size(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.size(),
+            Storage::View(slice, _) => slice.len(),
+        }
+    }
+
+    pub fn align(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.align(),
+            Storage::View(_, align) => *align,
+        }
+    }
+
+    pub fn as_ptr(&self) -> *const u8 {
+        self.as_mut_ptr() as *const _
+    }
+
+    /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
+    pub fn view(&self) -> Storage<'a> {
+        match self {
+            Storage::Owned(alloc) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
+                self.align(),
+            ),
+            Storage::View(slice, _) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
+                self.align(),
+            ),
+        }
+    }
+
+    pub fn is_owned(&self) -> bool {
+        match self {
+            Storage::Owned(_) => true,
+            _ => false,
+        }
+    }
+
+    /// Returns an owned version of this storage via cloning.
+    pub fn to_owned(&self) -> Storage<'static> {
+        let s = Storage::new(self.size(), Some(self.align())).unwrap();
+        unsafe {
+            s.as_mut_ptr()
+                .copy_from_nonoverlapping(self.as_ptr(), self.size());
+        }
+        s
+    }
+}
+
+impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
+    fn from(data: &'d [T]) -> Self {
+        let data = unsafe {
+            slice::from_raw_parts_mut(
+                data.as_ptr() as *const u8 as *mut u8,
+                data.len() * mem::size_of::<T>() as usize,
+            )
+        };
+        Storage::View(data, mem::align_of::<T>())
+    }
+}
+
+/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
+/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
+/// converted to `ndarray::Array` for non-TVM processing.
+///
+/// # Examples
+///
+/// ```
+/// extern crate ndarray;
+///
+/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// let mut a: Tensor = a_nd.into();
+/// let mut a_dl: DLTensor = (&mut t).into();
+/// call_packed!(tvm_fn, &mut a_dl);
+///
+/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
+/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
+/// ```
+#[derive(PartialEq)]
+pub struct Tensor<'a> {
+    /// The bytes which contain the data this `Tensor` represents.
+    pub(crate) data: Storage<'a>,
+    pub(crate) ctx: TVMContext,
+    pub(crate) dtype: DataType,
+    pub(crate) shape: Vec<i64>,
+    // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
+    /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
+    pub(crate) strides: Option<Vec<usize>>,
+    pub(crate) byte_offset: isize,
+    /// The number of elements in the `Tensor`.
+    pub(crate) size: usize,
+}
+
+unsafe impl<'a> Send for Tensor<'a> {}
+
+impl<'a> Tensor<'a> {
+    pub fn shape(&self) -> Vec<i64> {
+        self.shape.clone()
+    }
+
+    /// Returns the data of this `Tensor` as a `Vec`.
+    ///
+    /// # Panics
+    ///
+    /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
+    pub fn to_vec<T: 'static + std::fmt::Debug + Clone>(&self) -> Vec<T> {
+        assert!(self.is_contiguous());
+        assert!(self.dtype.is_type::<T>());
+        unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() }
+    }
+
+    /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
+    pub fn is_contiguous(&self) -> bool {
+        match self.strides {
+            None => true,
+            Some(ref strides) => {
+                // check that stride for each dimension is the
+                // product of all trailing dimensons' shapes
+                self.shape
+                    .iter()
+                    .zip(strides)
+                    .rfold(
+                        (true, 1),
+                        |(is_contig, expected_stride), (shape, stride)| {
+                            (
+                                is_contig && *stride == expected_stride,
+                                expected_stride * (*shape as usize),
+                            )
+                        },
+                    )
+                    .0
+            }
+        }
+    }
+
+    /// Returns a clone of this `Tensor`.
+    ///
+    /// # Panics
+    ///
+    /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
+    pub fn copy(&mut self, other: &Tensor) {
+        assert!(
+            self.dtype == other.dtype && self.size == other.size,
+            "Tensor shape/dtype mismatch."
+        );
+        assert!(
+      self.is_contiguous() && other.is_contiguous(),
+      "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
+      self.strides,
+      other.strides
+    );
+        unsafe {
+            self.data
+                .as_mut_ptr()
+                .offset(self.byte_offset as isize)
+                .copy_from_nonoverlapping(
+                    other.data.as_mut_ptr().offset(other.byte_offset),
+                    other.size * other.dtype.itemsize(),
+                );
+        }
+    }
+
+    /// Returns an owned version of this `Tensor` via cloning.
+    pub fn to_owned(&self) -> Tensor<'static> {
+        let t = Tensor {
+            data: self.data.to_owned(),
+            ctx: self.ctx.clone(),
+            dtype: self.dtype.clone(),
+            size: self.size.clone(),
+            shape: self.shape.clone(),
+            strides: None,
+            byte_offset: 0,
+        };
+        unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
+    }
+
+    fn from_array_storage<'s, T, D: ndarray::Dimension>(
+        arr: &ndarray::Array<T, D>,
+        storage: Storage<'s>,
+        type_code: usize,
+    ) -> Tensor<'s> {
+        let type_width = mem::size_of::<T>() as usize;
+        Tensor {
+            data: storage,
+            ctx: TVMContext::default(),
+            dtype: DataType {
+                code: type_code,
+                bits: 8 * type_width,
+                lanes: 1,
+            },
+            size: arr.len(),
+            shape: arr.shape().iter().map(|&v| v as i64).collect(),
+            strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
+            byte_offset: 0,
+        }
+    }
+}
+
+/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
+macro_rules! impl_ndarray_try_from_tensor {
+    ($type:ty, $dtype:expr) => {
+        impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
+            type Error = Error;
+            fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
+                ensure!(
+                    tensor.dtype == $dtype,
+                    "Cannot convert Tensor with dtype {:?} to ndarray",
+                    tensor.dtype
+                );
+                Ok(ndarray::Array::from_shape_vec(
+                    tensor
+                        .shape
+                        .iter()
+                        .map(|s| *s as usize)
+                        .collect::<Vec<usize>>(),
+                    tensor.to_vec::<$type>(),
+                )?)
+            }
+        }
+    };
+}
+
+impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
+impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
+impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
+impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
+
+pub struct DLTensor {
+    pub(crate) inner: _DLTensor,
+}
+
+impl Deref for DLTensor {
+    type Target = _DLTensor;
+    fn deref(&self) -> &Self::Target {
+        &self.inner
+    }
+}
+
+impl DerefMut for DLTensor {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        &mut self.inner
+    }
+}
+
+impl DLTensor {
+    pub(crate) fn new(raw: _DLTensor) -> Self {
+        Self { inner: raw }
+    }
+
+    pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
+        assert!(!flatten || tensor.is_contiguous());
+        Self {
+            inner: _DLTensor {
+                data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
+                ctx: DLContext::from(&tensor.ctx),
+                ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
+                dtype: DLDataType::from(&tensor.dtype),
+                shape: if flatten {
+                    &tensor.size as *const _ as *mut i64
+                } else {
+                    tensor.shape.as_ptr()
+                } as *mut i64,
+                strides: if flatten || tensor.is_contiguous() {
+                    ptr::null_mut()
+                } else {
+                    tensor.strides.as_ref().unwrap().as_ptr()
+                } as *mut i64,
+                byte_offset: 0,
+            },
+        }
+    }
+}
+
+impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
+    fn from(tensor: &'a Tensor<'t>) -> Self {
+        DLTensor::from_tensor(tensor, false /* flatten */)
+    }
+}
+
+impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
+    fn from(tensor: &'a mut Tensor<'t>) -> Self {
+        DLTensor::from_tensor(tensor, false /* flatten */)
+    }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct DataType {
+    pub(crate) code: usize,
+    pub(crate) bits: usize,
+    pub(crate) lanes: usize,
+}
+
+impl DataType {
+    /// Returns the number of bytes occupied by an element of this `DataType`.
+    pub fn itemsize(&self) -> usize {
+        (self.bits * self.lanes) >> 3
+    }
+
+    /// Returns whether this `DataType` represents primitive type `T`.
+    pub fn is_type<T: 'static>(&self) -> bool {
+        if self.lanes != 1 {
+            return false;
+        }
+        let typ = TypeId::of::<T>();
+        (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
+            || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
+            || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
+            || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
+            || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
+            || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
+    }
+}
+
+impl<'a> From<&'a DataType> for DLDataType {
+    fn from(dtype: &'a DataType) -> Self {
+        Self {
+            code: dtype.code as u8,
+            bits: dtype.bits as u8,
+            lanes: dtype.lanes as u16,
+        }
+    }
+}
+
+impl From<DLDataType> for DataType {
+    fn from(dtype: DLDataType) -> Self {
+        Self {
+            code: dtype.code as usize,
+            bits: dtype.bits as usize,
+            lanes: dtype.lanes as usize,
+        }
+    }
+}
+
+macro_rules! make_dtype_const {
+    ($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
+        const $name: DataType = DataType {
+            code: $code as usize,
+            bits: $bits,
+            lanes: $lanes,
+        };
+    };
+}
+
+make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
+make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
+// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
+make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
+make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub struct TVMContext {
+    pub(crate) device_type: usize,
+    pub(crate) device_id: usize,
+}
+
+impl<'a> From<&'a TVMContext> for DLContext {
+    fn from(ctx: &'a TVMContext) -> Self {
+        Self {
+            device_type: ctx.device_type as u32,
+            device_id: ctx.device_id as i32,
+        }
+    }
+}
+
+impl Default for TVMContext {
+    fn default() -> Self {
+        Self {
+            device_type: DLDeviceType_kDLCPU as usize,
+            device_id: 0,
+        }
+    }
+}
+
+impl<'a> From<DLTensor> for Tensor<'a> {
+    fn from(dlt: DLTensor) -> Self {
+        unsafe {
+            let dtype = DataType::from(dlt.dtype);
+            let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
+            let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
+            let storage = Storage::from(slice::from_raw_parts(
+                dlt.data as *const u8,
+                dtype.itemsize() * size,
+            ));
+            Self {
+                data: storage,
+                ctx: TVMContext::default(),
+                dtype: dtype,
+                size: size,
+                shape: shape,
+                strides: if dlt.strides == ptr::null_mut() {
+                    None
+                } else {
+                    Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
+                },
+                byte_offset: dlt.byte_offset as isize,
+            }
+        }
+    }
+}
+
+/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
+///
+/// # Panics
+///
+/// Panics if the ndarray is not contiguous.
+macro_rules! impl_tensor_from_ndarray {
+    ($type:ty, $typecode:expr) => {
+        impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
+            fn from(arr: ndarray::Array<$type, D>) -> Self {
+                let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
+                Tensor::from_array_storage(&arr, storage.to_owned(), $typecode as usize)
+            }
+        }
+        impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
+            fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
+                let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
+                Tensor::from_array_storage(arr, storage, $typecode as usize)
+            }
+        }
+    };
+}
+
+/// `From` conversions to `DLTensor` for `ndarray::Array`.
+/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
+macro_rules! impl_dltensor_from_ndarray {
+    ($type:ty, $typecode:expr) => {
+        impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
+            fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
+                DLTensor {
+                    inner: _DLTensor {
+                        data: arr.as_mut_ptr() as *mut c_void,
+                        ctx: DLContext {
+                            device_type: DLDeviceType_kDLCPU,
+                            device_id: 0,
+                        },
+                        ndim: arr.ndim() as c_int,
+                        dtype: DLDataType {
+                            code: $typecode as u8,
+                            bits: 8 * mem::size_of::<$type>() as u8,
+                            lanes: 1,
+                        },
+                        shape: arr.shape().as_ptr() as *const i64 as *mut i64,
+                        strides: arr.strides().as_ptr() as *const isize as *mut i64,
+                        byte_offset: 0,
+                    },
+                }
+            }
+        }
+    };
+}
+
+impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
+impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
+
+impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
+impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
+impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
+impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
+impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
+impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
diff --git a/rust/runtime/src/errors.rs b/rust/runtime/src/errors.rs
new file mode 100644 (file)
index 0000000..cf77230
--- /dev/null
@@ -0,0 +1,36 @@
+#[cfg(target_env = "sgx")]
+use alloc::alloc;
+#[cfg(not(target_env = "sgx"))]
+use std::alloc;
+use std::num;
+
+use crate::common::errors as common_errors;
+use ndarray;
+use serde_json;
+
+error_chain! {
+  errors {
+    GraphFormatError(msg: String) {
+      description("unable to load graph")
+      display("could not load graph json: {}", msg)
+    }
+
+    LoadGraphParamsError(msg: String) {
+      description("unable to load graph params")
+      display("could not load graph params: {}", msg)
+    }
+  }
+  foreign_links {
+    Alloc(alloc::AllocErr);
+    GraphDeserialize(serde_json::Error);
+    ParseInt(num::ParseIntError);
+    ShapeError(ndarray::ShapeError);
+    CommonError(common_errors::Error);
+  }
+}
+
+impl From<alloc::LayoutErr> for Error {
+    fn from(_err: alloc::LayoutErr) -> Error {
+        Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
+    }
+}
diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs
new file mode 100644 (file)
index 0000000..0d5e281
--- /dev/null
@@ -0,0 +1,473 @@
+use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
+
+use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
+use serde;
+use serde_json;
+
+use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor};
+use crate::{
+    common::value::TVMArgValue,
+    errors::{Error, ErrorKind, Result},
+    ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt},
+};
+
+// @see `kTVMNDArrayMagic` in `ndarray.h`
+const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
+// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
+const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7;
+
+/// A TVM computation graph.
+///
+/// # Examples
+///
+/// ```
+/// let graph_json = fs::read_to_string("graph.json")).unwrap();
+/// let graph = Graph::try_from(&graph_json).unwrap();
+/// ```
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Graph {
+    pub nodes: Vec<Node>,
+    pub arg_nodes: Vec<usize>,
+    pub heads: Vec<Entry>,
+    pub node_row_ptr: Option<Vec<usize>>,
+    pub attrs: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Entry {
+    pub id: usize,
+    pub index: usize,
+    pub version: usize,
+}
+
+impl Graph {
+    fn entry_index(&self, entry: &Entry) -> Result<usize> {
+        self.node_row_ptr
+            .as_ref()
+            .map(|nrp| nrp[entry.id] + entry.index)
+            .ok_or("Missing node_row_ptr.".into())
+    }
+
+    /// Attempt to deserialize a JSON attribute to a type `T`.
+    fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
+        Ok(serde_json::from_value::<T>(
+            self.attrs
+                .as_ref()
+                .ok_or(ErrorKind::GraphFormatError(
+                    "Missing graph attrs".to_string(),
+                ))?
+                .get(attr)
+                .ok_or(ErrorKind::GraphFormatError(format!(
+                    "Missing {} attr",
+                    attr
+                )))?
+                .to_owned(),
+        )?)
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Node {
+    pub op: String,
+    pub name: String,
+    pub inputs: Vec<Entry>,
+    pub attrs: Option<HashMap<String, String>>,
+    pub control_deps: Option<Vec<Entry>>,
+}
+
+struct NodeAttrs {
+    func_name: String,
+    num_outputs: usize,
+    flatten_data: bool,
+}
+
+impl Node {
+    fn parse_attrs(&self) -> Result<NodeAttrs> {
+        let attrs = self
+            .attrs
+            .as_ref()
+            .ok_or(format!("Missing node.attrs for `{}`", self.name))?;
+        let func_name = attrs
+            .get("func_name")
+            .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
+            .to_string();
+        let num_outputs = attrs
+            .get("num_outputs")
+            .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
+            .parse::<usize>()?;
+        let flatten_data = attrs
+            .get("flatten_data")
+            .ok_or(format!(
+                "Node `{}` is missing attrs.flatten_data",
+                self.name
+            ))?
+            .parse::<u8>()?
+            == 1;
+        Ok(NodeAttrs {
+            func_name,
+            num_outputs,
+            flatten_data,
+        })
+    }
+}
+
+impl<'a> TryFrom<&'a String> for Graph {
+    type Error = Error;
+    fn try_from(graph_json: &String) -> Result<Self> {
+        let graph = serde_json::from_str(graph_json)?;
+        Ok(graph)
+    }
+}
+
+impl<'a> TryFrom<&'a str> for Graph {
+    type Error = Error;
+    fn try_from(graph_json: &'a str) -> Result<Self> {
+        let graph = serde_json::from_str(graph_json)?;
+        Ok(graph)
+    }
+}
+
+/// A executor for a TVM computation graph.
+///
+/// # Examples
+///
+/// ```
+/// use ndarray::Array;
+///
+/// let syslib = SystemLibModule::default(); // a provider of TVM functions
+///
+/// let mut params_bytes = Vec::new();
+/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
+/// let params = tvm::runtime::load_param_dict(&params_bytes).unwrap();
+///
+/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
+///
+/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
+/// exec.load_params(params);
+///
+/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// exec.set_input("data", x.into());
+/// exec.run();
+/// let output = exec.get_output(0).unwrap();
+///
+/// println!("{:#?}", Array::try_from(output).unwrap());
+/// ```
+pub struct GraphExecutor<'m, 't> {
+    graph: Graph,
+    op_execs: Vec<Box<Fn() + 'm>>,
+    tensors: Vec<Tensor<'t>>,
+}
+
+unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
+
+impl<'m, 't> GraphExecutor<'m, 't> {
+    pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
+        let tensors = Self::setup_storages(&graph)?;
+        Ok(GraphExecutor {
+            op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
+            tensors: tensors,
+            graph: graph,
+        })
+    }
+
+    /// Runs the computation graph.
+    pub fn run(&self) {
+        self.op_execs.iter().for_each(|op_exec| {
+            op_exec();
+        });
+    }
+
+    /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
+    fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
+        let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
+        let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
+        let dtypes = graph
+            .get_attr::<(String, Vec<String>)>("dltype")?
+            .1
+            .iter()
+            .map(|dltype| {
+                if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
+                    Ok(dtype)
+                } else {
+                    Err(ErrorKind::GraphFormatError(
+                        format!("Invalid dltype: {}", dltype).to_string(),
+                    )
+                    .into())
+                }
+            })
+            .collect::<Result<Vec<DataType>>>()?;
+
+        let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
+        let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
+        for (i, &storage_id) in storage_ids.iter().enumerate() {
+            let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
+            let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
+            storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
+        }
+
+        let mut storages: Vec<Storage> = storage_num_bytes
+            .into_iter()
+            .map(|nbytes| Storage::new(nbytes, align))
+            .collect::<Result<Vec<Storage>>>()?;
+
+        let tensors = izip!(storage_ids, shapes, dtypes)
+            .map(|(storage_id, shape, dtype)| {
+                let storage = storages[storage_id].view();
+                Tensor {
+                    data: mem::replace(&mut storages[storage_id], storage),
+                    ctx: TVMContext::default(),
+                    dtype: dtype,
+                    size: shape.iter().product::<i64>() as usize,
+                    shape: shape,
+                    strides: None,
+                    byte_offset: 0,
+                }
+            })
+            .collect();
+
+        Ok(tensors)
+    }
+
+    /// Creates closures which represent the computation performed by this graph.
+    fn setup_op_execs<M: 'm + Module>(
+        graph: &Graph,
+        lib: &'m M,
+        tensors: &Vec<Tensor<'t>>,
+    ) -> Result<Vec<Box<Fn() + 'm>>> {
+        ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
+        let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
+
+        let mut op_execs = Vec::new();
+        for (i, node) in graph.nodes.iter().enumerate() {
+            if node.op == "null" {
+                continue;
+            }
+            ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
+            ensure!(node.attrs.is_some(), "Missing node attrs.");
+
+            let attrs = node.parse_attrs()?;
+
+            if attrs.func_name == "__nop" {
+                continue;
+            }
+
+            let func = lib
+                .get_function(&attrs.func_name)
+                .ok_or(format!("Missing function {}", attrs.func_name))?;
+            let arg_indices = node
+                .inputs
+                .iter()
+                .map(|entry| graph.entry_index(entry))
+                .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi)));
+
+            let dl_tensors = arg_indices
+                .map(|idx| {
+                    let tensor = &tensors[idx?];
+                    Ok(if attrs.flatten_data {
+                        DLTensor::from_tensor(tensor, true /* flatten */)
+                    } else {
+                        DLTensor::from(tensor)
+                    })
+                })
+                .collect::<Result<Vec<DLTensor>>>()
+                .unwrap();
+            let op: Box<Fn()> = box move || {
+                let args = dl_tensors
+                    .iter()
+                    .map(|t| t.into())
+                    .collect::<Vec<TVMArgValue>>();
+                func(args.as_slice());
+            };
+            op_execs.push(op);
+        }
+        Ok(op_execs)
+    }
+
+    pub fn load_params(&mut self, params: HashMap<String, Tensor>) {
+        params.into_iter().for_each(|(name, param)| {
+            self.set_input(name, param);
+        })
+    }
+
+    pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor) {
+        if let Some(idx) = self.get_input_index(name.as_ref()) {
+            // TODO: consider `new_with_params` to avoid ever allocating
+            let ptr = self.tensors[idx].data.as_ptr();
+            let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
+            let mut owner = to_replace.nth(0).unwrap();
+            if value.data.is_owned() {
+                // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
+                // mem::replace(&mut (*owner), value);
+                // to_replace.for_each(|t| {
+                //   panic!("replacing");
+                //   t.data = owner.data.view();
+                // });
+                owner.copy(&value);
+            } else {
+                owner.copy(&value);
+            }
+        } else {
+            println!("Unexpected input `{}`", name.as_ref());
+        }
+    }
+
+    /// Returns the graph input with name `name`, if it exists.
+    pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
+        self.get_input_index(name.as_ref())
+            .and_then(move |idx| Some(&self.tensors[idx]))
+    }
+
+    /// Returns the graph output with index `index`, if it exists.
+    pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
+        let graph = &self.graph;
+        graph.heads.get(idx).and_then(|entry| {
+            graph
+                .entry_index(entry)
+                .map(|idx| self.tensors.get(idx))
+                .unwrap_or(None)
+        })
+    }
+
+    /// Returns the index for graph input with name `name`, if it exists.
+    pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
+        let graph = &self.graph;
+        (0..graph.nodes.len())
+            .skip_while(|&i| graph.nodes[i].name != name.as_ref())
+            .nth(0)
+            .and_then(|i| {
+                if graph.arg_nodes.iter().any(|&id| id == i) {
+                    graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
+                } else {
+                    None
+                }
+            })
+    }
+}
+
+/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
+named!(
+  tvm_str_to_type<CompleteStr, DataType>,
+  do_parse!(
+    type_name: alpha1 >>
+    bits: digit1 >>
+    lanes: opt!(tuple!(tag!("x"), digit1)) >>
+    (DataType {
+      code: match type_name {
+        CompleteStr("int") => DLDataTypeCode_kDLInt,
+        CompleteStr("uint") => DLDataTypeCode_kDLUInt,
+        CompleteStr("float") => DLDataTypeCode_kDLFloat,
+        _ => DLDataTypeCode_kDLFloat,
+      } as usize,
+      bits: bits.parse::<u8>().unwrap() as usize,
+      lanes: match lanes {
+        Some(lanes) => lanes.1.parse::<u16>().unwrap() as usize,
+        None => 1,
+      },
+    })
+  )
+);
+
+/// Converts a bytes to String.
+named!(
+    name<String>,
+    map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
+        b.to_vec()
+    ))
+);
+
+/// Parses a TVMContext
+named!(
+  tvm_ctx<&[u8], TVMContext>,
+  do_parse!(
+    device_type: le_u32 >>
+    device_id: le_i32 >>
+    (TVMContext { device_type: device_type as usize, device_id: device_id as usize })
+  )
+);
+
+/// Parses a DataType
+named!(
+  data_type<&[u8], DataType>,
+  do_parse!(
+    code: le_u8 >>
+    bits: le_u8 >>
+    lanes: le_u16 >>
+    (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize })
+  )
+);
+
+/// Parses a Tensor from a TVM array file.
+named!(
+    tensor<Tensor>,
+    do_parse!(
+        take!(8)
+            >> bits!(tag_bits!(u64, 64, 0))
+            >> ctx: tvm_ctx
+            >> ndim: le_u32
+            >> dtype: data_type
+            >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize)
+            >> length: le_i64
+            >> data: take!(length)
+            >> (Tensor {
+                data: Storage::from(data),
+                ctx: ctx,
+                dtype: dtype,
+                size: shape.iter().product::<i64>() as usize,
+                shape: shape,
+                strides: None,
+                byte_offset: 0,
+            })
+    )
+);
+
+/// Parses a graph params dict from a params binary file.
+named!(
+    parse_param_dict<HashMap<String, Tensor>>,
+    do_parse!(
+        take!(8)
+            >> bits!(tag_bits!(u64, 64, 0))
+            >> names: length_count!(le_u64, name)
+            >> tensors: length_count!(le_u64, tensor)
+            >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter())))
+    )
+);
+
+/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
+pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
+    if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
+        if remaining_bytes.len() > 0 {
+            bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
+        } else {
+            Ok(param_dict)
+        }
+    } else {
+        bail!(ErrorKind::LoadGraphParamsError(
+            "invalid parameters file".to_string()
+        ))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_str_to_type() {
+        assert_eq!(
+            tvm_str_to_type(CompleteStr("float24")).unwrap().1,
+            DataType {
+                code: DLDataTypeCode_kDLFloat as usize,
+                bits: 24,
+                lanes: 1
+            }
+        );
+        assert_eq!(
+            tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1,
+            DataType {
+                code: DLDataTypeCode_kDLUInt as usize,
+                bits: 111,
+                lanes: 44
+            }
+        );
+    }
+}
diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs
new file mode 100644 (file)
index 0000000..da030bc
--- /dev/null
@@ -0,0 +1,74 @@
+//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`.
+//! It's mainly useful for compiling to WebAssembly and SGX,
+//! but also native if you prefer Rust to C++.
+//!
+//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`.
+//! Single-function modules are used via the `packed_func!` macro after obtaining
+//! the function from `runtime::SystemLibModule`
+//!
+//! The main entrypoints to this crate are `GraphExecutor`
+//! For examples of use, please refer to the multi-file tests in the `tests` directory.
+
+#![feature(
+    alloc,
+    allocator_api,
+    box_syntax,
+    fn_traits,
+    try_from,
+    unboxed_closures,
+    vec_remove_item
+)]
+
+#[cfg(target_env = "sgx")]
+extern crate alloc;
+extern crate bounded_spsc_queue;
+#[cfg(target_env = "sgx")]
+extern crate core;
+#[macro_use]
+extern crate error_chain;
+#[macro_use]
+extern crate itertools;
+#[macro_use]
+extern crate lazy_static;
+extern crate ndarray;
+#[macro_use]
+extern crate nom;
+#[cfg(not(target_env = "sgx"))]
+extern crate num_cpus;
+extern crate serde;
+#[macro_use]
+extern crate serde_derive;
+extern crate serde_json;
+extern crate tvm_common as common;
+
+mod allocator;
+mod array;
+pub mod errors;
+mod module;
+#[macro_use]
+mod packed_func;
+mod graph;
+#[cfg(target_env = "sgx")]
+#[macro_use]
+pub mod sgx;
+mod threading;
+mod workspace;
+
+pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue};
+
+pub use self::{
+    array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*,
+};
+
+#[cfg(target_env = "sgx")]
+use self::sgx::ocall_packed_func;
+
+#[no_mangle]
+pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
+    #[cfg(not(target_env = "sgx"))]
+    unsafe {
+        panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
+    }
+    #[cfg(target_env = "sgx")]
+    ocall_packed!("__sgx_set_last_error__", cmsg);
+}
diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs
new file mode 100644 (file)
index 0000000..8e6f7d6
--- /dev/null
@@ -0,0 +1,48 @@
+use std::{
+    collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
+};
+
+use crate::{
+    ffi::runtime::BackendPackedCFunc,
+    packed_func::{wrap_backend_packed_func, PackedFunc},
+};
+
+pub trait Module {
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
+}
+
+pub struct SystemLibModule;
+
+lazy_static! {
+    static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
+        Mutex::new(HashMap::new());
+}
+
+impl Module for SystemLibModule {
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
+        SYSTEM_LIB_FUNCTIONS
+            .lock()
+            .unwrap()
+            .get(name.as_ref())
+            .map(|func| wrap_backend_packed_func(func.to_owned()))
+    }
+}
+
+impl Default for SystemLibModule {
+    fn default() -> Self {
+        SystemLibModule {}
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
+    cname: *const c_char,
+    func: BackendPackedCFunc,
+) -> i32 {
+    let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
+    SYSTEM_LIB_FUNCTIONS
+        .lock()
+        .unwrap()
+        .insert(name.to_string(), func);
+    return 0;
+}
diff --git a/rust/runtime/src/packed_func.rs b/rust/runtime/src/packed_func.rs
new file mode 100644 (file)
index 0000000..2fe0086
--- /dev/null
@@ -0,0 +1,118 @@
+use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
+
+use super::Tensor;
+use crate::ffi::runtime::{
+    BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
+    TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
+};
+
+use super::DLTensor;
+use crate::{
+    common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
+    errors::*,
+};
+
+pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
+
+/// Calls a packed function and returns a `TVMRetValue`.
+///
+/// # Example
+///
+/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
+#[macro_export]
+macro_rules! call_packed {
+  ($fn:expr, $($args:expr),+) => {
+    $fn(&[$($args.into(),)+])
+  };
+  ($fn:expr) => {
+    $fn(&Vec::new())
+  };
+}
+
+impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
+    fn from(arr: &'a DLTensor) -> Self {
+        let raw = _TVMValue {
+            v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
+        };
+        TVMArgValue {
+            value: TVMValue::new(raw),
+            type_code: TVMTypeCode::kArrayHandle,
+            lifetime: PhantomData,
+        }
+    }
+}
+
+impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
+    fn from(arr: &'a mut DLTensor) -> Self {
+        let raw = _TVMValue {
+            v_handle: arr as *mut _ as *mut c_void,
+        };
+        TVMArgValue {
+            value: TVMValue::new(raw),
+            type_code: TVMTypeCode::kArrayHandle,
+            lifetime: PhantomData,
+        }
+    }
+}
+
+impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
+    type Error = Error;
+    fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
+        ensure!(
+            val.type_code == TVMTypeCode::kArrayHandle
+                || val.type_code == TVMTypeCode::kNDArrayContainer,
+            "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
+            TVMTypeCode::kArrayHandle,
+            TVMTypeCode::kNDArrayContainer,
+            val.type_code,
+        );
+
+        let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
+        Ok(DLTensor::new(dlt).into())
+    }
+}
+
+impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
+    fn from(val: &'t Tensor<'a>) -> Self {
+        TVMRetValue {
+            prim_value: 0,
+            box_value: box DLTensor::from(val),
+            type_code: TVMTypeCode::kNDArrayContainer,
+        }
+    }
+}
+
+impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
+    type Error = Error;
+    fn try_from(ret: TVMRetValue) -> Result<Self> {
+        ensure!(
+            ret.type_code == TVMTypeCode::kArrayHandle
+                || ret.type_code == TVMTypeCode::kNDArrayContainer,
+            "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
+            TVMTypeCode_kArrayHandle,
+            TVMTypeCode_kNDArrayContainer,
+            ret.type_code,
+        );
+
+        let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
+        Ok(DLTensor::new(dlt).into())
+    }
+}
+
+// @see `WrapPackedFunc` in `llvm_module.cc`.
+pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
+    box move |args: &[TVMArgValue]| {
+        func(
+            args.iter()
+                .map(|ref arg| arg.value.inner)
+                .collect::<Vec<_TVMValue>>()
+                .as_ptr(),
+            args.iter()
+                .map(|ref arg| arg.type_code as i32)
+                .collect::<Vec<i32>>()
+                .as_ptr() as *const i32,
+            args.len() as i32,
+        );
+        TVMRetValue::default()
+    }
+}
diff --git a/rust/runtime/src/sgx.rs b/rust/runtime/src/sgx.rs
new file mode 100644 (file)
index 0000000..1edf3ef
--- /dev/null
@@ -0,0 +1,80 @@
+use std::{
+    ffi::CString,
+    os::raw::{c_char, c_int},
+};
+
+use errors::Result;
+use ffi::runtime::TVMValue;
+use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
+
+pub use runtime::threading::tvm_run_worker as run_worker;
+
+#[macro_export]
+macro_rules! tvm_ocall {
+    ($func: expr) => {
+        match $func {
+            0 => Ok(()),
+            err => Err(format!("SGX error: {}", err)),
+        }
+    };
+}
+
+pub type SgxStatus = u32;
+
+#[cfg(target_env = "sgx")]
+extern "C" {
+    fn tvm_ocall_packed_func(
+        name: *const c_char,
+        arg_values: *const TVMValue,
+        type_codes: *const c_int,
+        num_args: c_int,
+        ret_val: *mut TVMValue,
+        ret_type_code: *mut c_int,
+    ) -> SgxStatus;
+}
+
+pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
+    let mut ret_val = TVMValue { v_int64: 0 };
+    let ret_type_code = 0i64;
+    unsafe {
+        tvm_ocall!(tvm_ocall_packed_func(
+            CString::new(fn_name.as_ref()).unwrap().as_ptr(),
+            args.iter()
+                .map(|ref arg| arg.value)
+                .collect::<Vec<TVMValue>>()
+                .as_ptr(),
+            args.iter()
+                .map(|ref arg| arg.type_code as i32)
+                .collect::<Vec<i32>>()
+                .as_ptr() as *const i32,
+            args.len() as i32,
+            &mut ret_val as *mut TVMValue,
+            &mut (ret_type_code as i32) as *mut c_int,
+        ))?;
+    }
+    Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
+}
+
+#[macro_export]
+macro_rules! ocall_packed {
+  ($fn_name:expr, $($args:expr),+) => {
+    ocall_packed_func($fn_name, &[$($args.into(),)+])
+      .expect(concat!("Error calling `", $fn_name, "`"))
+  };
+  ($fn_name:expr) => {
+    ocall_packed_func($fn_name, &Vec::new())
+      .expect(concat!("Error calling `", $fn_name, "`"))
+  }
+}
+
+pub fn shutdown() {
+    if env!("TVM_NUM_THREADS") != "0" {
+        sgx_join_threads()
+    }
+}
+
+impl Drop for SystemLibModule {
+    fn drop(&mut self) {
+        shutdown()
+    }
+}
diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs
new file mode 100644 (file)
index 0000000..38f4b7d
--- /dev/null
@@ -0,0 +1,336 @@
+use std::{
+    os::raw::{c_int, c_void},
+    sync::{
+        atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
+        Arc, Barrier,
+    },
+};
+
+#[cfg(not(target_env = "sgx"))]
+use num_cpus;
+#[cfg(not(target_env = "sgx"))]
+use std::{
+    env,
+    thread::{self, JoinHandle},
+};
+
+#[cfg(target_env = "sgx")]
+use std::{collections::VecDeque, ptr, sync::Mutex};
+
+use bounded_spsc_queue::{self, Producer};
+
+use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv};
+
+#[cfg(target_env = "sgx")]
+use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
+
+type FTVMParallelLambda =
+    extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
+
+/// Holds a parallel job request made by a TVM library function.
+struct Job {
+    cb: FTVMParallelLambda,
+    cdata: *const c_void,
+    req_num_tasks: usize,
+    pending: Arc<AtomicUsize>,
+}
+
+impl Job {
+    /// Splits this job into a number of `Task`s which can be scheduled.
+    fn tasks(&self, num_workers: usize) -> Vec<Task> {
+        let num_tasks = if self.req_num_tasks == 0 {
+            num_workers
+        } else {
+            self.req_num_tasks.min(num_workers)
+        };
+        self.pending.store(num_tasks, Ordering::SeqCst);
+
+        let barrier = Arc::new(Barrier::new(num_tasks));
+
+        (0..num_tasks)
+            .map(move |i| Task {
+                id: i,
+                flambda: self.cb,
+                penv: TVMParallelGroupEnv {
+                    sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
+                    num_task: num_tasks as i32,
+                },
+                cdata: self.cdata,
+                pending: Arc::clone(&self.pending),
+            })
+            .collect()
+    }
+
+    /// Waits for all tasks in this `Job` to be completed.
+    fn wait(&self) -> Result<()> {
+        while self.pending.load(Ordering::Acquire) > 0 {
+            #[cfg(not(target_env = "sgx"))]
+            thread::yield_now();
+        }
+        Ok(())
+    }
+}
+
+/// A chunk of work requested by a TVM function.
+struct Task {
+    id: usize,
+    flambda: FTVMParallelLambda,
+    penv: TVMParallelGroupEnv,
+    cdata: *const c_void,
+    pending: Arc<AtomicUsize>,
+}
+unsafe impl Send for Task {}
+unsafe impl Sync for Task {}
+
+impl FnOnce<()> for Task {
+    type Output = i32;
+    extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
+        let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
+        self.pending.fetch_sub(1, Ordering::AcqRel);
+        status
+    }
+}
+
+#[derive(Default)]
+struct Threads {
+    #[allow(unused)]
+    #[cfg(not(target_env = "sgx"))]
+    handles: Vec<JoinHandle<()>>,
+    queues: Vec<Producer<Task>>,
+}
+
+impl<'a> Threads {
+    #[cfg(not(target_env = "sgx"))]
+    fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
+        num_threads: usize,
+        cb: F,
+    ) -> Self {
+        let (handles, queues) = (0..num_threads)
+            .map(|_| {
+                let (p, c) = bounded_spsc_queue::make(2);
+                let handle = thread::spawn(move || cb(c.into()));
+                (handle, p)
+            })
+            .unzip();
+        Threads {
+            handles: handles,
+            queues: queues,
+        }
+    }
+
+    #[cfg(target_env = "sgx")]
+    fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
+        num_threads: usize,
+        _cb: F,
+    ) -> Self {
+        let mut consumer_queues = SGX_QUEUES.lock().unwrap();
+        let queues = (0..num_threads)
+            .map(|_| {
+                let (p, c) = bounded_spsc_queue::make(2);
+                consumer_queues.push_back(c.into());
+                p
+            })
+            .collect();
+        ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
+        Threads { queues: queues }
+    }
+}
+
+struct ThreadPool {
+    num_workers: usize,
+    #[allow(unused)]
+    threads: Threads,
+}
+
+thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
+
+impl ThreadPool {
+    fn new() -> Self {
+        let num_workers = max_concurrency();
+        ThreadPool {
+            num_workers: num_workers,
+            threads: Threads::launch(num_workers, ThreadPool::run_worker),
+        }
+    }
+
+    fn launch(&self, job: Job) {
+        let mut tasks = job.tasks(self.num_workers + 1);
+
+        for (i, task) in tasks.split_off(1).into_iter().enumerate() {
+            self.threads.queues[i].push(task);
+        }
+
+        tasks.pop().unwrap()();
+        job.wait().unwrap();
+    }
+
+    fn run_worker(queue: Consumer<Task>) {
+        loop {
+            let task = queue.pop();
+            let result = task();
+            if result == <i32>::min_value() {
+                break;
+            } else if result != 0 {
+                panic!("Error running task.");
+            }
+        }
+    }
+}
+
+// Send + Sync wrapper for bounded_spsc_queue::Consumer
+struct Consumer<T> {
+    consumer: bounded_spsc_queue::Consumer<T>,
+}
+impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
+    fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
+        Consumer { consumer: c }
+    }
+}
+impl<T> Consumer<T> {
+    fn pop(&self) -> T {
+        self.consumer.pop()
+    }
+}
+unsafe impl<T> Send for Consumer<T> {}
+unsafe impl<T> Sync for Consumer<T> {}
+
+#[cfg(target_env = "sgx")]
+lazy_static! {
+  /// Holds tasks for untrusted threads which re-enter the enclave to execute.
+  static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
+}
+
+#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
+fn max_concurrency() -> usize {
+    if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
+        if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
+            return threads;
+        }
+    }
+    num_cpus::get_physical()
+}
+
+#[cfg(target_env = "sgx")]
+fn max_concurrency() -> usize {
+    usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
+}
+
+#[cfg(target_arch = "wasm32")]
+fn max_concurrency() -> usize {
+    0 // wasm doesn't support threads yet
+}
+
+#[cfg(target_env = "sgx")]
+pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
+    let q = {
+        let mut qs = SGX_QUEUES.lock().unwrap();
+        qs.pop_front()
+        // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
+    };
+    if let Some(q) = q {
+        ThreadPool::run_worker(q);
+    }
+    TVMRetValue::default()
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendParallelLaunch(
+    cb: FTVMParallelLambda,
+    cdata: *const c_void,
+    num_task: usize,
+) -> c_int {
+    if max_concurrency() == 0 {
+        let penv = TVMParallelGroupEnv {
+            sync_handle: 0 as *mut c_void,
+            num_task: 1,
+        };
+        cb(0, &penv as *const _, cdata);
+    } else {
+        THREAD_POOL.with(|pool| {
+            pool.launch(Job {
+                cb: cb,
+                cdata: cdata,
+                req_num_tasks: num_task,
+                pending: Arc::new(ATOMIC_USIZE_INIT),
+            });
+        });
+    }
+    return 0;
+}
+
+#[cfg(target_env = "sgx")]
+pub(crate) fn sgx_join_threads() {
+    extern "C" fn poison_pill(
+        _task_id: usize,
+        _penv: *const TVMParallelGroupEnv,
+        _cdata: *const c_void,
+    ) -> i32 {
+        <i32>::min_value()
+    }
+
+    THREAD_POOL.with(|pool| {
+        pool.launch(Job {
+            cb: poison_pill,
+            cdata: ptr::null(),
+            req_num_tasks: 0,
+            pending: Arc::new(ATOMIC_USIZE_INIT),
+        });
+    });
+    ocall_packed!("__sgx_thread_group_join__", 0);
+}
+
+// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used.
+#[no_mangle]
+pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
+    let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
+    barrier.wait();
+}
+
+#[cfg(test)]
+mod tests {
+    use std::{ptr, thread, time::Duration};
+
+    use super::*;
+
+    #[test]
+    fn test_max_concurrency() {
+        env::set_var("TVM_NUM_THREADS", "42");
+        env::set_var("OMP_NUM_THREADS", "24");
+        assert_eq!(max_concurrency(), 42);
+        env::remove_var("TVM_NUM_THREADS");
+        assert_eq!(max_concurrency(), 24);
+    }
+
+    extern "C" fn flambda(
+        task_id: usize,
+        penv: *const TVMParallelGroupEnv,
+        cdata: *const c_void,
+    ) -> i32 {
+        if cdata == ptr::null() {
+            return 0;
+        }
+        unsafe {
+            let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
+            thread::sleep(Duration::from_millis(50 * task_id as u64));
+            counter.fetch_add(1, Ordering::SeqCst);
+            task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
+            assert_eq!((*penv).num_task, 3);
+        }
+        0
+    }
+
+    #[test]
+    fn test_parallel_launch() {
+        TVMBackendParallelLaunch(flambda, ptr::null(), 6);
+        let counter = ATOMIC_USIZE_INIT;
+        let task_ids_sum = ATOMIC_USIZE_INIT;
+        let cdata = (counter, task_ids_sum);
+        let num_tasks = 3;
+        TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
+        assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
+        assert_eq!(
+            cdata.1.load(Ordering::SeqCst),
+            (0..num_tasks).sum::<usize>()
+        );
+    }
+}
diff --git a/rust/runtime/src/workspace.rs b/rust/runtime/src/workspace.rs
new file mode 100644 (file)
index 0000000..a12a27e
--- /dev/null
@@ -0,0 +1,117 @@
+use std::{
+    cell::RefCell,
+    os::raw::{c_int, c_void},
+    ptr,
+};
+
+use super::allocator::Allocation;
+use crate::errors::*;
+
+const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
+
+struct WorkspacePool {
+    workspaces: Vec<Allocation>,
+    free: Vec<usize>,
+    in_use: Vec<usize>,
+}
+
+impl WorkspacePool {
+    fn new() -> Self {
+        WorkspacePool {
+            workspaces: Vec::new(),
+            free: Vec::new(),
+            in_use: Vec::new(),
+        }
+    }
+
+    fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
+        self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
+        self.in_use.push(self.workspaces.len() - 1);
+        Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
+    }
+
+    fn alloc(&mut self, size: usize) -> Result<*mut u8> {
+        if self.free.len() == 0 {
+            return self.alloc_new(size);
+        }
+        let idx = self
+            .free
+            .iter()
+            .fold(None, |cur_ws_idx: Option<usize>, &idx| {
+                let ws_size = self.workspaces[idx].size();
+                if !ws_size >= size {
+                    return cur_ws_idx;
+                }
+                cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
+                    let cur_size = self.workspaces[cur_idx].size();
+                    Some(match ws_size <= cur_size {
+                        true => idx,
+                        false => cur_idx,
+                    })
+                })
+            });
+        match idx {
+            Some(idx) => {
+                self.free.remove_item(&idx).unwrap();
+                self.in_use.push(idx);
+                Ok(self.workspaces[idx].as_mut_ptr())
+            }
+            None => self.alloc_new(size),
+        }
+    }
+
+    fn free(&mut self, ptr: *mut u8) -> Result<()> {
+        let mut ws_idx = None;
+        for i in 0..self.in_use.len() {
+            let idx = self.in_use[i];
+            if self.workspaces[idx].as_mut_ptr() == ptr {
+                self.in_use.remove(i);
+                ws_idx = Some(idx);
+                break;
+            }
+        }
+        Ok(self
+            .free
+            .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?))
+    }
+}
+
+thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
+
+const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
+
+#[no_mangle]
+pub extern "C" fn TVMBackendAllocWorkspace(
+    _device_type: c_int,
+    _device_id: c_int,
+    size: u64,
+    _dtype_code_hint: c_int,
+    _dtype_bits_hint: c_int,
+) -> *mut c_void {
+    let nbytes = if size == 0 {
+        WORKSPACE_PAGE_SIZE
+    } else {
+        size as usize
+    };
+    WORKSPACE_POOL.with(|pool_cell| {
+        pool_cell
+            .borrow_mut()
+            .alloc(nbytes as usize)
+            .unwrap_or(ptr::null_mut()) as *mut c_void
+    })
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendFreeWorkspace(
+    _device_type: c_int,
+    _device_id: c_int,
+    ptr: *mut c_void,
+) -> c_int {
+    WORKSPACE_POOL.with(|pool_cell| {
+        (match pool_cell.borrow_mut().free(ptr as *mut u8) {
+            Ok(()) => 0,
+            Err(_) => -1,
+        }) as c_int
+    });
+    return 0;
+}
diff --git a/rust/runtime/tests/.gitignore b/rust/runtime/tests/.gitignore
new file mode 100644 (file)
index 0000000..8110767
--- /dev/null
@@ -0,0 +1,3 @@
+*.json
+*.params
+*.o
diff --git a/rust/runtime/tests/build_model.py b/rust/runtime/tests/build_model.py
new file mode 100755 (executable)
index 0000000..ea55ce4
--- /dev/null
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+
+"""Builds a simple NNVM graph for testing."""
+
+from os import path as osp
+
+import nnvm
+from nnvm import sym
+from nnvm.compiler import graph_util
+from nnvm.testing import init
+import numpy as np
+import tvm
+
+CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+
+
+def _get_model(dshape):
+    data = sym.Variable('data', shape=dshape)
+    fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True)
+    left, right = sym.split(fc1, indices_or_sections=2, axis=1)
+    return sym.Group(((left + 1), (right - 1)))
+
+
+def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
+    if isinstance(graph, sym.Symbol):
+        graph = nnvm.graph.create(graph)
+    ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
+    param_shapes = dict(zip(graph.index.input_names, ishapes))
+    np.random.seed(seed)
+    params = {}
+    for param, shape in param_shapes.items():
+        if param in {'data', 'label'} or not shape:
+            continue
+        init_value = np.empty(shape).astype('float32')
+        initializer(param, init_value)
+        params[param] = tvm.nd.array(init_value)
+    return params
+
+def main():
+    dshape = (32, 16)
+    net = _get_model(dshape)
+    ishape_dict = {'data': dshape}
+    params = _init_params(net, ishape_dict)
+    graph, lib, params = nnvm.compiler.build(net, 'llvm',
+                                             shape=ishape_dict,
+                                             params=params,
+                                             dtype='float32')
+
+    with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet:
+        f_resnet.write(graph.json())
+    with open(osp.join(CWD, 'graph.params'), 'wb') as f_params:
+        f_params.write(nnvm.compiler.save_param_dict(params))
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/runtime/tests/test_graph_serde.rs b/rust/runtime/tests/test_graph_serde.rs
new file mode 100644 (file)
index 0000000..18ac19a
--- /dev/null
@@ -0,0 +1,39 @@
+#![feature(try_from)]
+
+extern crate serde;
+extern crate serde_json;
+
+extern crate tvm_runtime;
+
+use std::{convert::TryFrom, fs, io::Read};
+
+use tvm_runtime::Graph;
+
+#[test]
+fn test_load_graph() {
+    let mut params_bytes = Vec::new();
+    fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
+        .expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
+        .read_to_end(&mut params_bytes)
+        .unwrap();
+    let _params = tvm_runtime::load_param_dict(&params_bytes);
+
+    let graph = Graph::try_from(
+        &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
+    )
+    .unwrap();
+
+    assert_eq!(graph.nodes[3].op, "tvm_op");
+    assert_eq!(
+        graph.nodes[3]
+            .attrs
+            .as_ref()
+            .unwrap()
+            .get("func_name")
+            .unwrap(),
+        "fuse_dense"
+    );
+    assert_eq!(graph.nodes[5].inputs[0].index, 0);
+    assert_eq!(graph.nodes[6].inputs[0].index, 1);
+    assert_eq!(graph.heads.len(), 2);
+}
diff --git a/rust/runtime/tests/test_nnvm/Cargo.toml b/rust/runtime/tests/test_nnvm/Cargo.toml
new file mode 100644 (file)
index 0000000..14d0b39
--- /dev/null
@@ -0,0 +1,14 @@
+[package]
+name = "test-nnvm"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray = "0.11.2"
+serde = "1.0.59"
+serde_json = "1.0.17"
+tvm-runtime = { path = "../../" }
+
+[build-dependencies]
+ar = "0.6.0"
diff --git a/rust/runtime/tests/test_nnvm/build.rs b/rust/runtime/tests/test_nnvm/build.rs
new file mode 100644 (file)
index 0000000..3a4fc0a
--- /dev/null
@@ -0,0 +1,33 @@
+extern crate ar;
+
+use std::{env, fs::File, path::Path, process::Command};
+
+use ar::Builder;
+
+fn main() {
+    let out_dir = env::var("OUT_DIR").unwrap();
+
+    let output = Command::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/src/build_test_graph.py"
+    ))
+    .arg(&out_dir)
+    .output()
+    .expect("Failed to execute command");
+    assert!(
+        Path::new(&format!("{}/graph.o", out_dir)).exists(),
+        "Could not build graph lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap());
+    builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
+
+    println!("cargo:rustc-link-lib=static=graph");
+    println!("cargo:rustc-link-search=native={}", out_dir);
+}
diff --git a/rust/runtime/tests/test_nnvm/src/build_test_graph.py b/rust/runtime/tests/test_nnvm/src/build_test_graph.py
new file mode 100755 (executable)
index 0000000..e9f74ec
--- /dev/null
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+
+"""Builds a simple NNVM graph for testing."""
+
+from os import path as osp
+import sys
+
+import nnvm
+from nnvm import sym
+from nnvm.compiler import graph_util
+from nnvm.testing import init
+import numpy as np
+import tvm
+
+
+def _get_model(dshape):
+    data = sym.Variable('data', shape=dshape)
+    fc = sym.dense(data, units=dshape[-1]*2, use_bias=True)
+    left, right = sym.split(fc, indices_or_sections=2, axis=1)
+    return sym.Group(((left + 1), (right - 1), fc))
+
+
+def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
+    if isinstance(graph, sym.Symbol):
+        graph = nnvm.graph.create(graph)
+
+    ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
+    param_shapes = dict(zip(graph.index.input_names, ishapes))
+    np.random.seed(seed)
+    params = {}
+    for param, shape in param_shapes.items():
+        if param in {'data', 'label'} or not shape:
+            continue
+
+        init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32')
+        if param.endswith('_bias'):
+            params[param] = tvm.nd.array(init_value)
+            continue
+
+        init_value = np.empty(shape).astype('float32')
+        initializer(param, init_value)
+        # init_value /= init_value.sum() + 1e-10
+        params[param] = tvm.nd.array(init_value)
+
+    return params
+
+def main():
+    dshape = (4, 8)
+    net = _get_model(dshape)
+    ishape_dict = {'data': dshape}
+    params = _init_params(net, ishape_dict)
+    graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib',
+                                             shape=ishape_dict,
+                                             params=params,
+                                             dtype='float32')
+
+    out_dir = sys.argv[1]
+    lib.save(osp.join(sys.argv[1], 'graph.o'))
+    with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
+        f_resnet.write(graph.json())
+        
+    with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
+        f_params.write(nnvm.compiler.save_param_dict(params))
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/runtime/tests/test_nnvm/src/main.rs b/rust/runtime/tests/test_nnvm/src/main.rs
new file mode 100644 (file)
index 0000000..5017979
--- /dev/null
@@ -0,0 +1,82 @@
+#![feature(try_from)]
+
+#[macro_use]
+extern crate ndarray;
+extern crate serde;
+extern crate serde_json;
+
+extern crate tvm_runtime;
+use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
+
+use ndarray::Array;
+use tvm_runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
+
+const BATCH_SIZE: usize = 4;
+const IN_DIM: usize = 8;
+
+macro_rules! check_sum {
+    ($e:expr, $a:ident, $b:ident) => {
+        let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
+        check_sum!(a, $b);
+    };
+    ($e:expr, $a:expr, $b:ident) => {
+        let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
+        check_sum!(a, $b);
+    };
+    ($a:ident, $b:ident) => {
+        let a_sum: f32 = $a.scalar_sum();
+        let b_sum: f32 = $b.scalar_sum();
+        assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
+    };
+}
+
+fn main() {
+    let syslib = SystemLibModule::default();
+
+    let mut params_bytes = Vec::new();
+    fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
+        .unwrap()
+        .read_to_end(&mut params_bytes)
+        .unwrap();
+    let params = tvm_runtime::load_param_dict(&params_bytes)
+        .unwrap()
+        .into_iter()
+        .map(|(k, v)| (k, v.to_owned()))
+        .collect::<HashMap<String, Tensor<'static>>>();
+
+    let graph =
+        Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap())
+            .unwrap();
+    let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
+
+    let x = Array::from_shape_vec(
+        (BATCH_SIZE, IN_DIM),
+        (0..BATCH_SIZE * IN_DIM)
+            .map(|x| x as f32)
+            .collect::<Vec<f32>>(),
+    )
+    .unwrap();
+    let w = Array::try_from(params.get("dense0_weight").unwrap())
+        .unwrap()
+        .into_shape((IN_DIM * 2, IN_DIM))
+        .unwrap();
+    let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
+    let dense = x.dot(&w.t()) + &b;
+    let left = dense.slice(s![.., 0..IN_DIM]);
+    let right = dense.slice(s![.., IN_DIM..]);
+    let expected_o0 = &left + 1f32;
+    let expected_o1 = &right - 1f32;
+
+    exec.load_params(params);
+    exec.set_input("data", (&x).into());
+
+    check_sum!(exec, data, x);
+    check_sum!(exec, dense0_weight, w);
+    check_sum!(exec, dense0_bias, b);
+
+    exec.run();
+
+    check_sum!(exec, 0, expected_o0);
+    check_sum!(exec, 1, expected_o1);
+    check_sum!(exec, 2, dense);
+}
diff --git a/rust/runtime/tests/test_tvm_basic/Cargo.toml b/rust/runtime/tests/test_tvm_basic/Cargo.toml
new file mode 100644 (file)
index 0000000..2a753b4
--- /dev/null
@@ -0,0 +1,12 @@
+[package]
+name = "test-tvm-basic"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray = "0.11.2"
+tvm-runtime = { path = "../../" }
+
+[build-dependencies]
+ar = "0.6.0"
diff --git a/rust/runtime/tests/test_tvm_basic/build.rs b/rust/runtime/tests/test_tvm_basic/build.rs
new file mode 100644 (file)
index 0000000..d877585
--- /dev/null
@@ -0,0 +1,34 @@
+extern crate ar;
+
+use std::{env, path::Path, process::Command};
+
+use ar::Builder;
+use std::fs::File;
+
+fn main() {
+    let out_dir = env::var("OUT_DIR").unwrap();
+
+    let output = Command::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/src/build_test_lib.py"
+    ))
+    .arg(&out_dir)
+    .output()
+    .expect("Failed to execute command");
+    assert!(
+        Path::new(&format!("{}/test.o", out_dir)).exists(),
+        "Could not build tvm lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap());
+    builder.append_path(format!("{}/test.o", out_dir)).unwrap();
+
+    println!("cargo:rustc-link-lib=static=test");
+    println!("cargo:rustc-link-search=native={}", out_dir);
+}
diff --git a/rust/runtime/tests/test_tvm_basic/src/build_test_lib.py b/rust/runtime/tests/test_tvm_basic/src/build_test_lib.py
new file mode 100755 (executable)
index 0000000..7289a77
--- /dev/null
@@ -0,0 +1,21 @@
+#!/usr/bin/env python3
+
+"""Prepares a simple TVM library for testing."""
+
+from os import path as osp
+import sys
+
+import tvm
+
+def main():
+    n = tvm.var('n')
+    A = tvm.placeholder((n,), name='A')
+    B = tvm.placeholder((n,), name='B')
+    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    s = tvm.create_schedule(C.op)
+    s[C].parallel(s[C].op.axis[0])
+    print(tvm.lower(s, [A, B, C], simple_mode=True))
+    tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs
new file mode 100644 (file)
index 0000000..f14fbec
--- /dev/null
@@ -0,0 +1,22 @@
+extern crate ndarray;
+#[macro_use]
+extern crate tvm_runtime;
+
+use ndarray::Array;
+use tvm_runtime::{DLTensor, Module, SystemLibModule};
+
+fn main() {
+    let syslib = SystemLibModule::default();
+    let add = syslib
+        .get_function("default_function")
+        .expect("main function not found");
+    let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
+    let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
+    let mut c = Array::from_vec(vec![0f32; 4]);
+    let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
+    let mut a_dl: DLTensor = (&mut a).into();
+    let mut b_dl: DLTensor = (&mut b).into();
+    let mut c_dl: DLTensor = (&mut c).into();
+    call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
+    assert!(c.all_close(&e, 1e-8f32));
+}
diff --git a/rust/src/errors.rs b/rust/src/errors.rs
deleted file mode 100644 (file)
index f9da718..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-#[cfg(target_env = "sgx")]
-use alloc::alloc;
-#[cfg(not(target_env = "sgx"))]
-use std::alloc;
-use std::num;
-
-use ndarray;
-use serde_json;
-
-error_chain! {
-  errors {
-    TryFromTVMRetValueError(expected: String, actual: i64) {
-      description("mismatched types while downcasting TVMRetValue")
-      display("invalid downcast: expected `{}` but was `{}`", expected, actual)
-    }
-
-    GraphFormatError(msg: String) {
-      description("unable to load graph")
-      display("could not load graph json: {}", msg)
-    }
-
-    LoadGraphParamsError(msg: String) {
-      description("unable to load graph params")
-      display("could not load graph params: {}", msg)
-    }
-  }
-  foreign_links {
-    Alloc(alloc::AllocErr);
-    GraphDeserialize(serde_json::Error);
-    ParseInt(num::ParseIntError);
-    ShapeError(ndarray::ShapeError);
-  }
-}
-
-impl From<alloc::LayoutErr> for Error {
-  fn from(_err: alloc::LayoutErr) -> Error {
-    Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
-  }
-}
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
deleted file mode 100644 (file)
index e17c669..0000000
+++ /dev/null
@@ -1,67 +0,0 @@
-//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`.
-//! It's mainly useful for compiling to WebAssembly and SGX,
-//! but also native if you prefer Rust to C++.
-//!
-//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`.
-//! Single-function modules are used via the `packed_func!` macro after obtaining
-//! the function from `runtime::SystemLibModule`
-//!
-//! The main entrypoints to this crate are `GraphExecutor`
-//! For examples of use, please refer to the multi-file tests in the `tests` directory.
-
-#![feature(
-  alloc,
-  allocator_api,
-  box_syntax,
-  fn_traits,
-  try_from,
-  unboxed_closures,
-  vec_remove_item
-)]
-
-#[cfg(target_env = "sgx")]
-extern crate alloc;
-extern crate bounded_spsc_queue;
-#[cfg(target_env = "sgx")]
-extern crate core;
-#[macro_use]
-extern crate error_chain;
-#[macro_use]
-extern crate itertools;
-#[macro_use]
-extern crate lazy_static;
-extern crate ndarray;
-#[macro_use]
-extern crate nom;
-#[cfg(not(target_env = "sgx"))]
-extern crate num_cpus;
-extern crate serde;
-#[macro_use]
-extern crate serde_derive;
-extern crate serde_json;
-
-pub mod ffi {
-  #![allow(
-    non_camel_case_types,
-    non_snake_case,
-    non_upper_case_globals,
-    unused
-  )]
-
-  pub mod runtime {
-    use std::os::raw::{c_char, c_int, c_void};
-
-    include!(concat!(
-      env!("CARGO_MANIFEST_DIR"),
-      "/src/runtime/c_runtime_api.rs"
-    ));
-
-    pub type BackendPackedCFunc =
-      extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
-  }
-}
-
-pub mod errors;
-pub mod runtime;
-
-pub use errors::*;
diff --git a/rust/src/runtime/allocator.rs b/rust/src/runtime/allocator.rs
deleted file mode 100644 (file)
index d704336..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-#[cfg(target_env = "sgx")]
-use alloc::alloc::{self, Layout};
-#[cfg(not(target_env = "sgx"))]
-use std::alloc::{self, Layout};
-
-use errors::*;
-
-const DEFAULT_ALIGN_BYTES: usize = 4;
-
-#[derive(PartialEq, Eq)]
-pub struct Allocation {
-  layout: Layout,
-  ptr: *mut u8,
-}
-
-impl Allocation {
-  /// Allocates a chunk of memory of `size` bytes with optional alignment.
-  pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
-    let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
-    let layout = Layout::from_size_align(size, alignment)?;
-    let ptr = unsafe { alloc::alloc(layout.clone()) };
-    if ptr.is_null() {
-      alloc::handle_alloc_error(layout);
-    }
-    Ok(Self {
-      ptr: ptr,
-      layout: layout,
-    })
-  }
-
-  pub fn as_mut_ptr(&self) -> *mut u8 {
-    self.ptr
-  }
-
-  /// Returns the size of the Allocation in bytes.
-  pub fn size(&self) -> usize {
-    self.layout.size()
-  }
-
-  /// Returns the byte alignment of the Allocation.
-  pub fn align(&self) -> usize {
-    self.layout.align()
-  }
-}
-
-impl Drop for Allocation {
-  fn drop(&mut self) {
-    unsafe {
-      alloc::dealloc(self.ptr, self.layout.clone());
-    }
-  }
-}
diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs
deleted file mode 100644 (file)
index 100258d..0000000
+++ /dev/null
@@ -1,500 +0,0 @@
-use std::{
-  any::TypeId,
-  convert::TryFrom,
-  mem,
-  os::raw::{c_int, c_void},
-  ptr, slice,
-};
-
-use ndarray;
-
-use super::allocator::Allocation;
-use errors::*;
-use ffi::runtime::{
-  DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
-  DLDeviceType_kDLCPU, DLTensor,
-};
-
-/// A `Storage` is a container which holds `Tensor` data.
-#[derive(PartialEq)]
-pub enum Storage<'a> {
-  /// A `Storage` which owns its contained bytes.
-  Owned(Allocation),
-
-  /// A view of an existing `Storage`.
-  View(&'a mut [u8], usize), // ptr, align
-}
-
-impl<'a> Storage<'a> {
-  pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
-    Ok(Storage::Owned(Allocation::new(size, align)?))
-  }
-
-  pub fn as_mut_ptr(&self) -> *mut u8 {
-    match self {
-      Storage::Owned(alloc) => alloc.as_mut_ptr(),
-      Storage::View(slice, _) => slice.as_ptr() as *mut u8,
-    }
-  }
-
-  pub fn size(&self) -> usize {
-    match self {
-      Storage::Owned(alloc) => alloc.size(),
-      Storage::View(slice, _) => slice.len(),
-    }
-  }
-
-  pub fn align(&self) -> usize {
-    match self {
-      Storage::Owned(alloc) => alloc.align(),
-      Storage::View(_, align) => *align,
-    }
-  }
-
-  pub fn as_ptr(&self) -> *const u8 {
-    self.as_mut_ptr() as *const _
-  }
-
-  /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
-  pub fn view(&self) -> Storage<'a> {
-    match self {
-      Storage::Owned(alloc) => Storage::View(
-        unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
-        self.align(),
-      ),
-      Storage::View(slice, _) => Storage::View(
-        unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
-        self.align(),
-      ),
-    }
-  }
-
-  pub fn is_owned(&self) -> bool {
-    match self {
-      Storage::Owned(_) => true,
-      _ => false,
-    }
-  }
-
-  /// Returns an owned version of this storage via cloning.
-  pub fn to_owned(&self) -> Storage<'static> {
-    let s = Storage::new(self.size(), Some(self.align())).unwrap();
-    unsafe {
-      s.as_mut_ptr()
-        .copy_from_nonoverlapping(self.as_ptr(), self.size())
-    }
-    s
-  }
-}
-
-impl<'a, T> From<&'a [T]> for Storage<'a> {
-  fn from(data: &'a [T]) -> Self {
-    let data = unsafe {
-      slice::from_raw_parts_mut(
-        data.as_ptr() as *const u8 as *mut u8,
-        data.len() * mem::size_of::<T>() as usize,
-      )
-    };
-    Storage::View(data, mem::align_of::<T>())
-  }
-}
-
-/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
-/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
-/// converted to `ndarray::Array` for non-TVM processing.
-///
-/// # Examples
-///
-/// ```
-/// extern crate ndarray;
-///
-/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
-/// let mut a: Tensor = a_nd.into();
-/// let mut a_dl: DLTensor = (&mut t).into();
-/// call_packed!(tvm_fn, &mut a_dl);
-///
-/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
-/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
-/// ```
-#[derive(PartialEq)]
-pub struct Tensor<'a> {
-  /// The bytes which contain the data this `Tensor` represents.
-  pub(super) data: Storage<'a>,
-  pub(super) ctx: TVMContext,
-  pub(super) dtype: DataType,
-  pub(super) shape: Vec<i64>, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
-  /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
-  pub(super) strides: Option<Vec<usize>>,
-  pub(super) byte_offset: isize,
-  /// The number of elements in the `Tensor`.
-  pub(super) size: usize,
-}
-
-unsafe impl<'a> Send for Tensor<'a> {}
-
-impl<'a> Tensor<'a> {
-  pub fn shape(&self) -> Vec<i64> {
-    self.shape.clone()
-  }
-
-  /// Returns the data of this `Tensor` as a `Vec`.
-  ///
-  /// # Panics
-  ///
-  /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
-  pub fn to_vec<T: 'static>(&self) -> Vec<T> {
-    assert!(self.is_contiguous());
-    assert!(self.dtype.is_type::<T>());
-    let mut vec: Vec<T> = Vec::with_capacity(self.size * self.dtype.itemsize());
-    unsafe {
-      vec.as_mut_ptr().copy_from_nonoverlapping(
-        self.data.as_ptr().offset(self.byte_offset) as *const T,
-        self.size,
-      );
-      vec.set_len(self.size);
-    }
-    vec
-  }
-
-  /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
-  pub fn is_contiguous(&self) -> bool {
-    match self.strides {
-      None => true,
-      Some(ref strides) => {
-        // check that stride for each dimension is the product of all trailing dimensons' shapes
-        self
-          .shape
-          .iter()
-          .zip(strides)
-          .rfold(
-            (true, 1),
-            |(is_contig, expected_stride), (shape, stride)| {
-              (
-                is_contig && *stride == expected_stride,
-                expected_stride * (*shape as usize),
-              )
-            },
-          )
-          .0
-      }
-    }
-  }
-
-  /// Returns a clone of this `Tensor`.
-  ///
-  /// # Panics
-  ///
-  /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
-  pub fn copy(&mut self, other: &Tensor) {
-    assert!(
-      self.dtype == other.dtype && self.size == other.size,
-      "Tensor shape/dtype mismatch."
-    );
-    assert!(
-      self.is_contiguous() && other.is_contiguous(),
-      "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
-      self.strides,
-      other.strides
-    );
-    unsafe {
-      self
-        .data
-        .as_mut_ptr()
-        .offset(self.byte_offset as isize)
-        .copy_from_nonoverlapping(
-          other.data.as_mut_ptr().offset(other.byte_offset),
-          other.size * other.dtype.itemsize(),
-        );
-    }
-  }
-
-  /// Returns an owned version of this `Tensor` via cloning.
-  pub fn to_owned(&self) -> Tensor<'static> {
-    let t = Tensor {
-      data: self.data.to_owned(),
-      ctx: self.ctx.clone(),
-      dtype: self.dtype.clone(),
-      size: self.size.clone(),
-      shape: self.shape.clone(),
-      strides: None,
-      byte_offset: 0,
-    };
-    unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
-  }
-
-  fn from_array_storage<'s, T, D: ndarray::Dimension>(
-    arr: &ndarray::Array<T, D>,
-    storage: Storage<'s>,
-    type_code: usize,
-  ) -> Tensor<'s> {
-    let type_width = mem::size_of::<T>() as usize;
-    Tensor {
-      data: storage,
-      ctx: TVMContext::default(),
-      dtype: DataType {
-        code: type_code,
-        bits: 8 * type_width,
-        lanes: 1,
-      },
-      size: arr.len(),
-      shape: arr.shape().iter().map(|&v| v as i64).collect(),
-      strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
-      byte_offset: 0,
-    }
-  }
-}
-
-/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
-macro_rules! impl_ndarray_try_from_tensor {
-  ($type:ty, $dtype:expr) => {
-    impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
-      type Error = Error;
-      fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
-        ensure!(
-          tensor.dtype == $dtype,
-          "Cannot convert Tensor with dtype {:?} to ndarray",
-          tensor.dtype
-        );
-        Ok(ndarray::Array::from_shape_vec(
-          tensor
-            .shape
-            .iter()
-            .map(|s| *s as usize)
-            .collect::<Vec<usize>>(),
-          tensor.to_vec::<$type>(),
-        )?)
-      }
-    }
-  };
-}
-
-impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
-impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
-impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
-impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
-
-impl DLTensor {
-  pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
-    assert!(!flatten || tensor.is_contiguous());
-    Self {
-      data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
-      ctx: DLContext::from(&tensor.ctx),
-      ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
-      dtype: DLDataType::from(&tensor.dtype),
-      shape: if flatten {
-        &tensor.size as *const _ as *mut i64
-      } else {
-        tensor.shape.as_ptr()
-      } as *mut i64,
-      strides: if flatten || tensor.is_contiguous() {
-        ptr::null_mut()
-      } else {
-        tensor.strides.as_ref().unwrap().as_ptr()
-      } as *mut i64,
-      byte_offset: 0,
-    }
-  }
-}
-
-impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
-  fn from(tensor: &'a Tensor<'t>) -> Self {
-    DLTensor::from_tensor(tensor, false /* flatten */)
-  }
-}
-
-impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
-  fn from(tensor: &'a mut Tensor<'t>) -> Self {
-    DLTensor::from_tensor(tensor, false /* flatten */)
-  }
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq)]
-pub struct DataType {
-  pub(super) code: usize,
-  pub(super) bits: usize,
-  pub(super) lanes: usize,
-}
-
-impl DataType {
-  /// Returns the number of bytes occupied by an element of this `DataType`.
-  pub fn itemsize(&self) -> usize {
-    (self.bits * self.lanes) >> 3
-  }
-
-  /// Returns whether this `DataType` represents primitive type `T`.
-  pub fn is_type<T: 'static>(&self) -> bool {
-    if self.lanes != 1 {
-      return false;
-    }
-    let typ = TypeId::of::<T>();
-    (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
-      || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
-      || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
-      || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
-      || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
-      || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
-  }
-}
-
-impl<'a> From<&'a DataType> for DLDataType {
-  fn from(dtype: &'a DataType) -> Self {
-    Self {
-      code: dtype.code as u8,
-      bits: dtype.bits as u8,
-      lanes: dtype.lanes as u16,
-    }
-  }
-}
-
-impl From<DLDataType> for DataType {
-  fn from(dtype: DLDataType) -> Self {
-    Self {
-      code: dtype.code as usize,
-      bits: dtype.bits as usize,
-      lanes: dtype.lanes as usize,
-    }
-  }
-}
-
-macro_rules! make_dtype_const {
-  ($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
-    const $name: DataType = DataType {
-      code: $code as usize,
-      bits: $bits,
-      lanes: $lanes,
-    };
-  };
-}
-
-make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
-make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
-// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
-make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
-make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
-
-impl Default for DLContext {
-  fn default() -> Self {
-    DLContext {
-      device_type: DLDeviceType_kDLCPU,
-      device_id: 0,
-    }
-  }
-}
-
-#[derive(Debug, Clone, Copy, PartialEq)]
-pub struct TVMContext {
-  pub(super) device_type: usize,
-  pub(super) device_id: usize,
-}
-
-impl<'a> From<&'a TVMContext> for DLContext {
-  fn from(ctx: &'a TVMContext) -> Self {
-    Self {
-      device_type: ctx.device_type as u32,
-      device_id: ctx.device_id as i32,
-    }
-  }
-}
-
-impl Default for TVMContext {
-  fn default() -> Self {
-    Self {
-      device_type: DLDeviceType_kDLCPU as usize,
-      device_id: 0,
-    }
-  }
-}
-
-impl<'a> From<DLTensor> for Tensor<'a> {
-  fn from(dlt: DLTensor) -> Self {
-    unsafe {
-      let dtype = DataType::from(dlt.dtype);
-      let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
-      let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
-      let storage = Storage::from(slice::from_raw_parts(
-        dlt.data as *const u8,
-        dtype.itemsize() * size,
-      ));
-      Self {
-        data: storage,
-        ctx: TVMContext::default(),
-        dtype: dtype,
-        size: size,
-        shape: shape,
-        strides: if dlt.strides == ptr::null_mut() {
-          None
-        } else {
-          Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
-        },
-        byte_offset: dlt.byte_offset as isize,
-      }
-    }
-  }
-}
-
-/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
-///
-/// # Panics
-///
-/// Panics if the ndarray is not contiguous.
-macro_rules! impl_tensor_from_ndarray {
-  ($type:ty, $typecode:expr) => {
-    impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
-      fn from(arr: ndarray::Array<$type, D>) -> Self {
-        assert!(arr.is_standard_layout(), "Array must be contiguous.");
-        let size = arr.len() * mem::size_of::<$type>() as usize;
-        let storage =
-          Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) });
-        Tensor::from_array_storage(&arr, storage, $typecode as usize)
-      }
-    }
-    impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
-      fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
-        assert!(arr.is_standard_layout(), "Array must be contiguous.");
-        Tensor::from_array_storage(
-          arr,
-          Storage::from(arr.as_slice().unwrap()),
-          $typecode as usize,
-        )
-      }
-    }
-  };
-}
-
-/// `From` conversions to `DLTensor` for `ndarray::Array`.
-/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
-macro_rules! impl_dltensor_from_ndarray {
-  ($type:ty, $typecode:expr) => {
-    impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
-      fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
-        DLTensor {
-          data: arr.as_mut_ptr() as *mut c_void,
-          ctx: DLContext::default(),
-          ndim: arr.ndim() as c_int,
-          dtype: DLDataType {
-            code: $typecode as u8,
-            bits: 8 * mem::size_of::<$type>() as u8,
-            lanes: 1,
-          },
-          shape: arr.shape().as_ptr() as *const i64 as *mut i64,
-          strides: arr.strides().as_ptr() as *const isize as *mut i64,
-          byte_offset: 0,
-        }
-      }
-    }
-  };
-}
-
-impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
-impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
-impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
-impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
-impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
-impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
-
-impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
-impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
-impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
-impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
-impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
-impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
diff --git a/rust/src/runtime/c_runtime_api.rs b/rust/src/runtime/c_runtime_api.rs
deleted file mode 100644 (file)
index 6facf9c..0000000
+++ /dev/null
@@ -1,770 +0,0 @@
-/* automatically generated by rust-bindgen for TVM revision 6292c78 */
-
-pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
-pub const DLPACK_VERSION: u32 = 8;
-pub const _STDINT_H: u32 = 1;
-pub const _FEATURES_H: u32 = 1;
-pub const _DEFAULT_SOURCE: u32 = 1;
-pub const __USE_ISOC11: u32 = 1;
-pub const __USE_ISOC99: u32 = 1;
-pub const __USE_ISOC95: u32 = 1;
-pub const __USE_POSIX_IMPLICITLY: u32 = 1;
-pub const _POSIX_SOURCE: u32 = 1;
-pub const _POSIX_C_SOURCE: u32 = 200809;
-pub const __USE_POSIX: u32 = 1;
-pub const __USE_POSIX2: u32 = 1;
-pub const __USE_POSIX199309: u32 = 1;
-pub const __USE_POSIX199506: u32 = 1;
-pub const __USE_XOPEN2K: u32 = 1;
-pub const __USE_XOPEN2K8: u32 = 1;
-pub const _ATFILE_SOURCE: u32 = 1;
-pub const __USE_MISC: u32 = 1;
-pub const __USE_ATFILE: u32 = 1;
-pub const __USE_FORTIFY_LEVEL: u32 = 0;
-pub const _STDC_PREDEF_H: u32 = 1;
-pub const __STDC_IEC_559__: u32 = 1;
-pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
-pub const __STDC_ISO_10646__: u32 = 201505;
-pub const __STDC_NO_THREADS__: u32 = 1;
-pub const __GNU_LIBRARY__: u32 = 6;
-pub const __GLIBC__: u32 = 2;
-pub const __GLIBC_MINOR__: u32 = 23;
-pub const _SYS_CDEFS_H: u32 = 1;
-pub const __WORDSIZE: u32 = 64;
-pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
-pub const __SYSCALL_WORDSIZE: u32 = 64;
-pub const _BITS_WCHAR_H: u32 = 1;
-pub const INT8_MIN: i32 = -128;
-pub const INT16_MIN: i32 = -32768;
-pub const INT32_MIN: i32 = -2147483648;
-pub const INT8_MAX: u32 = 127;
-pub const INT16_MAX: u32 = 32767;
-pub const INT32_MAX: u32 = 2147483647;
-pub const UINT8_MAX: u32 = 255;
-pub const UINT16_MAX: u32 = 65535;
-pub const UINT32_MAX: u32 = 4294967295;
-pub const INT_LEAST8_MIN: i32 = -128;
-pub const INT_LEAST16_MIN: i32 = -32768;
-pub const INT_LEAST32_MIN: i32 = -2147483648;
-pub const INT_LEAST8_MAX: u32 = 127;
-pub const INT_LEAST16_MAX: u32 = 32767;
-pub const INT_LEAST32_MAX: u32 = 2147483647;
-pub const UINT_LEAST8_MAX: u32 = 255;
-pub const UINT_LEAST16_MAX: u32 = 65535;
-pub const UINT_LEAST32_MAX: u32 = 4294967295;
-pub const INT_FAST8_MIN: i32 = -128;
-pub const INT_FAST16_MIN: i64 = -9223372036854775808;
-pub const INT_FAST32_MIN: i64 = -9223372036854775808;
-pub const INT_FAST8_MAX: u32 = 127;
-pub const INT_FAST16_MAX: u64 = 9223372036854775807;
-pub const INT_FAST32_MAX: u64 = 9223372036854775807;
-pub const UINT_FAST8_MAX: u32 = 255;
-pub const UINT_FAST16_MAX: i32 = -1;
-pub const UINT_FAST32_MAX: i32 = -1;
-pub const INTPTR_MIN: i64 = -9223372036854775808;
-pub const INTPTR_MAX: u64 = 9223372036854775807;
-pub const UINTPTR_MAX: i32 = -1;
-pub const PTRDIFF_MIN: i64 = -9223372036854775808;
-pub const PTRDIFF_MAX: u64 = 9223372036854775807;
-pub const SIG_ATOMIC_MIN: i32 = -2147483648;
-pub const SIG_ATOMIC_MAX: u32 = 2147483647;
-pub const SIZE_MAX: i32 = -1;
-pub const WINT_MIN: u32 = 0;
-pub const WINT_MAX: u32 = 4294967295;
-pub type int_least8_t = ::std::os::raw::c_schar;
-pub type int_least16_t = ::std::os::raw::c_short;
-pub type int_least32_t = ::std::os::raw::c_int;
-pub type int_least64_t = ::std::os::raw::c_long;
-pub type uint_least8_t = ::std::os::raw::c_uchar;
-pub type uint_least16_t = ::std::os::raw::c_ushort;
-pub type uint_least32_t = ::std::os::raw::c_uint;
-pub type uint_least64_t = ::std::os::raw::c_ulong;
-pub type int_fast8_t = ::std::os::raw::c_schar;
-pub type int_fast16_t = ::std::os::raw::c_long;
-pub type int_fast32_t = ::std::os::raw::c_long;
-pub type int_fast64_t = ::std::os::raw::c_long;
-pub type uint_fast8_t = ::std::os::raw::c_uchar;
-pub type uint_fast16_t = ::std::os::raw::c_ulong;
-pub type uint_fast32_t = ::std::os::raw::c_ulong;
-pub type uint_fast64_t = ::std::os::raw::c_ulong;
-pub type intmax_t = ::std::os::raw::c_long;
-pub type uintmax_t = ::std::os::raw::c_ulong;
-pub type wchar_t = ::std::os::raw::c_int;
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct max_align_t {
-  pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
-  pub __bindgen_padding_0: u64,
-  pub __clang_max_align_nonce2: f64,
-}
-pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
-pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
-pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
-pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
-pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
-pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
-pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
-/// \brief The device type in DLContext.
-pub type DLDeviceType = u32;
-/// \brief A Device context for Tensor and operator.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLContext {
-  /// \brief The device type used in the device.
-  pub device_type: DLDeviceType,
-  /// \brief The device index
-  pub device_id: ::std::os::raw::c_int,
-}
-pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
-pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
-pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
-/// \brief The type code options DLDataType.
-pub type DLDataTypeCode = u32;
-/// \brief The data type the tensor can hold.
-///
-/// Examples
-/// - float: type_code = 2, bits = 32, lanes=1
-/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
-/// - int8: type_code = 0, bits = 8, lanes=1
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLDataType {
-  /// \brief Type code of base types.
-  /// We keep it uint8_t instead of DLDataTypeCode for minimal memory
-  /// footprint, but the value should be one of DLDataTypeCode enum values.
-  ///
-  pub code: u8,
-  /// \brief Number of bits, common choices are 8, 16, 32.
-  pub bits: u8,
-  /// \brief Number of lanes in the type, used for vector types.
-  pub lanes: u16,
-}
-/// \brief Plain C Tensor object, does not manage memory.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLTensor {
-  /// \brief The opaque data pointer points to the allocated data.
-  /// This will be CUDA device pointer or cl_mem handle in OpenCL.
-  /// This pointer is always aligns to 256 bytes as in CUDA.
-  pub data: *mut ::std::os::raw::c_void,
-  /// \brief The device context of the tensor
-  pub ctx: DLContext,
-  /// \brief Number of dimensions
-  pub ndim: ::std::os::raw::c_int,
-  /// \brief The data type of the pointer
-  pub dtype: DLDataType,
-  /// \brief The shape of the tensor
-  pub shape: *mut i64,
-  /// \brief strides of the tensor,
-  /// can be NULL, indicating tensor is compact.
-  pub strides: *mut i64,
-  /// \brief The offset in bytes to the beginning pointer to data
-  pub byte_offset: u64,
-}
-/// \brief C Tensor object, manage memory of DLTensor. This data structure is
-/// intended to faciliate the borrowing of DLTensor by another framework. It is
-/// not meant to transfer the tensor. When the borrowing framework doesn't need
-/// the tensor, it should call the deleter to notify the host that the resource
-/// is no longer needed.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLManagedTensor {
-  /// \brief DLTensor which is being memory managed
-  pub dl_tensor: DLTensor,
-  /// \brief the context of the original host framework of DLManagedTensor in
-  /// which DLManagedTensor is used in the framework. It can also be NULL.
-  pub manager_ctx: *mut ::std::os::raw::c_void,
-  /// \brief Destructor signature void (*)(void*) - this should be called
-  /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
-  /// if there is no way for the caller to provide a reasonable destructor.
-  pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
-}
-/// \brief type of array index.
-pub type tvm_index_t = i64;
-pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
-pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
-pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
-pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
-pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
-/// \brief Extension device types in TVM
-pub type TVMDeviceExtType = u32;
-pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
-pub const TVMTypeCode_kNull: TVMTypeCode = 4;
-pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
-pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
-pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
-pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
-pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
-pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
-pub const TVMTypeCode_kStr: TVMTypeCode = 11;
-pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
-pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
-pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
-pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
-pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
-pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
-pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
-/// \brief The type code in TVMType
-/// \note TVMType is used in two places.
-pub type TVMTypeCode = u32;
-/// \brief The data type used in TVM Runtime.
-///
-/// Examples
-/// - float: type_code = 2, bits = 32, lanes=1
-/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
-/// - int8: type_code = 0, bits = 8, lanes=1
-///
-/// \note Arguments TVM API function always takes bits=64 and lanes=1
-pub type TVMType = DLDataType;
-/// \brief The Device information, abstract away common device types.
-pub type TVMContext = DLContext;
-/// \brief The tensor array stucture to TVM API.
-pub type TVMArray = DLTensor;
-/// \brief the array handle
-pub type TVMArrayHandle = *mut TVMArray;
-/// \brief Union type of values
-/// being passed through API and function calls.
-#[repr(C)]
-#[derive(Copy, Clone)]
-pub union TVMValue {
-  pub v_int64: i64,
-  pub v_float64: f64,
-  pub v_handle: *mut ::std::os::raw::c_void,
-  pub v_str: *const ::std::os::raw::c_char,
-  pub v_type: TVMType,
-  pub v_ctx: TVMContext,
-  _bindgen_union_align: u64,
-}
-/// \brief Byte array type used to pass in byte array
-/// When kBytes is used as data type.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct TVMByteArray {
-  pub data: *const ::std::os::raw::c_char,
-  pub size: usize,
-}
-/// \brief Handle to TVM runtime modules.
-pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
-/// \brief Handle to packed function handle.
-pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
-/// \brief Handle to hold return value.
-pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
-/// \brief The stream that is specific to device
-/// can be NULL, which indicates the default one.
-pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
-extern "C" {
-  /// \brief Used for implementing C API function.
-  /// Set last error message before return.
-  /// \param msg The error message to be set.
-  pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
-}
-extern "C" {
-  /// \brief return str message of the last error
-  /// all function in this file will return 0 when success
-  /// and -1 when an error occured,
-  /// TVMGetLastError can be called to retrieve the error
-  ///
-  /// this function is threadsafe and can be called by different thread
-  /// \return error info
-  pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
-}
-extern "C" {
-  /// \brief Load module from file.
-  /// \param file_name The file name to load the module from.
-  /// \param format The format of the module.
-  /// \param out The result module
-  ///
-  /// \return 0 when success, -1 when failure happens
-  /// \note The resulting module do not contain import relation.
-  /// It can be reconstructed by TVMModImport.
-  pub fn TVMModLoadFromFile(
-    file_name: *const ::std::os::raw::c_char,
-    format: *const ::std::os::raw::c_char,
-    out: *mut TVMModuleHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Add dep to mod's dependency.
-  /// This allows functions in this module to use modules.
-  ///
-  /// \param mod The module handle.
-  /// \param dep The dependent module to be imported.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Get function from the module.
-  /// \param mod The module handle.
-  /// \param func_name The name of the function.
-  /// \param query_imports Whether to query imported modules
-  /// \param out The result function, can be NULL if it is not available.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMModGetFunction(
-    mod_: TVMModuleHandle,
-    func_name: *const ::std::os::raw::c_char,
-    query_imports: ::std::os::raw::c_int,
-    out: *mut TVMFunctionHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free front-end extension type resource.
-  /// \param handle The extension handle.
-  /// \param type_code The type of of the extension type.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMExtTypeFree(
-    handle: *mut ::std::os::raw::c_void,
-    type_code: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free the Module
-  /// \param mod The module to be freed.
-  ///
-  /// \note This may not free up the module's resources.
-  /// If there is active TVMFunctionHandle uses the module
-  /// Or if this module is imported by another active module.
-  ///
-  /// The all functions remains valid until TVMFuncFree is called.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free the function when it is no longer needed.
-  /// \param func The function handle
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Call a Packed TVM Function.
-  ///
-  /// \param func node handle of the function.
-  /// \param arg_values The arguments
-  /// \param type_codes The type codes of the arguments
-  /// \param num_args Number of arguments.
-  ///
-  /// \param ret_val The return value.
-  /// \param ret_type_code the type code of return value.
-  ///
-  /// \return 0 when success, -1 when failure happens
-  /// \note TVM calls always exchanges with type bits=64, lanes=1
-  ///
-  /// \note API calls always exchanges with type bits=64, lanes=1
-  /// If API call returns container handles (e.g. FunctionHandle)
-  /// these handles should be managed by the front-end.
-  /// The front-end need to call free function (e.g. TVMFuncFree)
-  /// to free these handles.
-  pub fn TVMFuncCall(
-    func: TVMFunctionHandle,
-    arg_values: *mut TVMValue,
-    type_codes: *mut ::std::os::raw::c_int,
-    num_args: ::std::os::raw::c_int,
-    ret_val: *mut TVMValue,
-    ret_type_code: *mut ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Set the return value of TVMPackedCFunc.
-  ///
-  /// This function is called by TVMPackedCFunc to set the return value.
-  /// When this function is not called, the function returns null by default.
-  ///
-  /// \param ret The return value handle, pass by ret in TVMPackedCFunc
-  /// \param value The value to be returned.
-  /// \param type_code The type of the value to be returned.
-  /// \param num_ret Number of return values, for now only 1 is supported.
-  pub fn TVMCFuncSetReturn(
-    ret: TVMRetValueHandle,
-    value: *mut TVMValue,
-    type_code: *mut ::std::os::raw::c_int,
-    num_ret: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Inplace translate callback argument value to return value.
-  /// This is only needed for non-POD arguments.
-  ///
-  /// \param value The value to be translated.
-  /// \param code The type code to be translated.
-  /// \note This function will do a shallow copy when necessary.
-  ///
-  /// \return 0 when success, -1 when failure happens.
-  pub fn TVMCbArgToReturn(
-    value: *mut TVMValue,
-    code: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-/// \brief C type of packed function.
-///
-/// \param args The arguments
-/// \param type_codes The type codes of the arguments
-/// \param num_args Number of arguments.
-/// \param ret The return value handle.
-/// \param resource_handle The handle additional resouce handle from fron-end.
-/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
-/// \sa TVMCFuncSetReturn
-pub type TVMPackedCFunc = ::std::option::Option<
-  unsafe extern "C" fn(
-    args: *mut TVMValue,
-    type_codes: *mut ::std::os::raw::c_int,
-    num_args: ::std::os::raw::c_int,
-    ret: TVMRetValueHandle,
-    resource_handle: *mut ::std::os::raw::c_void,
-  ) -> ::std::os::raw::c_int,
->;
-/// \brief C callback to free the resource handle in C packed function.
-/// \param resource_handle The handle additional resouce handle from fron-end.
-pub type TVMPackedCFuncFinalizer =
-  ::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
-/// \brief Signature for extension function declarer.
-///
-/// TVM call this function to get the extension functions
-/// The declarer will call register_func to register function and their name.
-///
-/// \param register_func_handle The register function
-/// \return 0 if success, -1 if failure happens
-pub type TVMExtensionFuncDeclarer = ::std::option::Option<
-  unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
->;
-extern "C" {
-  /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
-  ///
-  /// The resource_handle will be managed by TVM API, until the function is no longer used.
-  ///
-  /// \param func The packed C function.
-  /// \param resource_handle The resource handle from front-end, can be NULL.
-  /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
-  /// \param out the result function handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMFuncCreateFromCFunc(
-    func: TVMPackedCFunc,
-    resource_handle: *mut ::std::os::raw::c_void,
-    fin: TVMPackedCFuncFinalizer,
-    out: *mut TVMFunctionHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Register the function to runtime's global table.
-  ///
-  /// The registered function then can be pulled by the backend by the name.
-  ///
-  /// \param name The name of the function.
-  /// \param f The function to be registered.
-  /// \param override Whether allow override already registered function.
-  pub fn TVMFuncRegisterGlobal(
-    name: *const ::std::os::raw::c_char,
-    f: TVMFunctionHandle,
-    override_: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Get a global function.
-  ///
-  /// \param name The name of the function.
-  /// \param out the result function pointer, NULL if it does not exist.
-  ///
-  /// \note The function handle of global function is managed by TVM runtime,
-  /// So TVMFuncFree is should not be called when it get deleted.
-  pub fn TVMFuncGetGlobal(
-    name: *const ::std::os::raw::c_char,
-    out: *mut TVMFunctionHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief List all the globally registered function name
-  /// \param out_size The number of functions
-  /// \param out_array The array of function names.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMFuncListGlobalNames(
-    out_size: *mut ::std::os::raw::c_int,
-    out_array: *mut *mut *const ::std::os::raw::c_char,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Allocate a nd-array's memory,
-  /// including space of shape, of given spec.
-  ///
-  /// \param shape The shape of the array, the data content will be copied to out
-  /// \param ndim The number of dimension of the array.
-  /// \param dtype_code The type code of the dtype
-  /// \param dtype_bits The number of bits of dtype
-  /// \param dtype_lanes The number of lanes in the dtype.
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context.
-  /// \param out The output handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayAlloc(
-    shape: *const tvm_index_t,
-    ndim: ::std::os::raw::c_int,
-    dtype_code: ::std::os::raw::c_int,
-    dtype_bits: ::std::os::raw::c_int,
-    dtype_lanes: ::std::os::raw::c_int,
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    out: *mut TVMArrayHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free the TVM Array.
-  /// \param handle The array handle to be freed.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Copy array data from CPU byte array.
-  /// \param handle The array handle.
-  /// \param data the data pointer
-  /// \param nbytes The number of bytes to copy.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayCopyFromBytes(
-    handle: TVMArrayHandle,
-    data: *mut ::std::os::raw::c_void,
-    nbytes: usize,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Copy array data to CPU byte array.
-  /// \param handle The array handle.
-  /// \param data the data pointer
-  /// \param nbytes The number of bytes to copy.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayCopyToBytes(
-    handle: TVMArrayHandle,
-    data: *mut ::std::os::raw::c_void,
-    nbytes: usize,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Copy the array, both from and to must be valid during the copy.
-  /// \param from The array to be copied from.
-  /// \param to The target space.
-  /// \param stream The stream where the copy happens, can be NULL.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayCopyFromTo(
-    from: TVMArrayHandle,
-    to: TVMArrayHandle,
-    stream: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Produce an array from the DLManagedTensor that shares data memory
-  /// with the DLManagedTensor.
-  /// \param from The source DLManagedTensor.
-  /// \param out The output array handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayFromDLPack(
-    from: *mut DLManagedTensor,
-    out: *mut TVMArrayHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Produce a DLMangedTensor from the array that shares data memory with
-  /// the array.
-  /// \param from The source array.
-  /// \param out The DLManagedTensor handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMArrayToDLPack(
-    from: TVMArrayHandle,
-    out: *mut *mut DLManagedTensor,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Delete (free) a DLManagedTensor's data.
-  /// \param dltensor Pointer to the DLManagedTensor.
-  pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
-}
-extern "C" {
-  /// \brief Create a new runtime stream.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context
-  /// \param out The new stream handle
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMStreamCreate(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    out: *mut TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Free a created stream handle.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context
-  /// \param stream The stream to be freed
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMStreamFree(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    stream: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Set the runtime stream of current thread to be stream.
-  /// The subsequent calls to the same device_type
-  /// will use the setted stream handle.
-  /// The specific type of stream is runtime device dependent.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context.
-  /// \param handle The stream handle.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMSetStream(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    handle: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Wait until all computations on stream completes.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context.
-  /// \param stream The stream to be synchronized.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMSynchronize(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    stream: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Synchronize two streams of execution.
-  ///
-  /// \param device_type The device type of context
-  /// \param device_id The device id of context
-  /// \param src The source stream to synchronize.
-  /// \param dst The destination stream to synchronize.
-  /// \return 0 when success, -1 when failure happens
-  pub fn TVMStreamStreamSynchronize(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    src: TVMStreamHandle,
-    dst: TVMStreamHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Backend function for modules to get function
-  /// from its environment mod_node (its imports and global function).
-  /// The user do should not call TVMFuncFree on func.
-  ///
-  /// \param mod_node The module handle.
-  /// \param func_name The name of the function.
-  /// \param out The result function.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendGetFuncFromEnv(
-    mod_node: *mut ::std::os::raw::c_void,
-    func_name: *const ::std::os::raw::c_char,
-    out: *mut TVMFunctionHandle,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Backend function to register system-wide library symbol.
-  ///
-  /// \param name The name of the symbol
-  /// \param ptr The symbol address.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendRegisterSystemLibSymbol(
-    name: *const ::std::os::raw::c_char,
-    ptr: *mut ::std::os::raw::c_void,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Backend function to allocate temporal workspace.
-  ///
-  /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
-  ///
-  /// \param nbytes The size of the space requested.
-  /// \param device_type The device type which the space will be allocated.
-  /// \param device_id The device id which the space will be allocated.
-  /// \param dtype_code_hint The type code of the array elements. Only used in
-  /// certain backends such as OpenGL.
-  /// \param dtype_bits_hint The type bits of the array elements. Only used in
-  /// certain backends such as OpenGL.
-  /// \return nullptr when error is thrown, a valid ptr if success
-  pub fn TVMBackendAllocWorkspace(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    nbytes: u64,
-    dtype_code_hint: ::std::os::raw::c_int,
-    dtype_bits_hint: ::std::os::raw::c_int,
-  ) -> *mut ::std::os::raw::c_void;
-}
-extern "C" {
-  /// \brief Backend function to free temporal workspace.
-  ///
-  /// \param ptr The result allocated space pointer.
-  /// \param device_type The device type which the space will be allocated.
-  /// \param device_id The device id which the space will be allocated.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  ///
-  /// \sa TVMBackendAllocWorkspace
-  pub fn TVMBackendFreeWorkspace(
-    device_type: ::std::os::raw::c_int,
-    device_id: ::std::os::raw::c_int,
-    ptr: *mut ::std::os::raw::c_void,
-  ) -> ::std::os::raw::c_int;
-}
-/// \brief Environment for TVM parallel task.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct TVMParallelGroupEnv {
-  /// \brief Auxiliary used for synchronization
-  pub sync_handle: *mut ::std::os::raw::c_void,
-  /// \brief total amount of task
-  pub num_task: i32,
-}
-/// \brief The callback function to execute a parallel lambda
-/// \param task_id the task id of the function.
-/// \param penv The parallel environment backs the execution.
-/// \param cdata The supporting closure data.
-pub type FTVMParallelLambda = ::std::option::Option<
-  unsafe extern "C" fn(
-    task_id: ::std::os::raw::c_int,
-    penv: *mut TVMParallelGroupEnv,
-    cdata: *mut ::std::os::raw::c_void,
-  ) -> ::std::os::raw::c_int,
->;
-extern "C" {
-  /// \brief Backend function for running parallel jobs.
-  ///
-  /// \param flambda The parallel function to be launched.
-  /// \param cdata The closure data.
-  /// \param num_task Number of tasks to launch, can be 0, means launch
-  /// with all available threads.
-  ///
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendParallelLaunch(
-    flambda: FTVMParallelLambda,
-    cdata: *mut ::std::os::raw::c_void,
-    num_task: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief BSP barrrier between parallel threads
-  /// \param task_id the task id of the function.
-  /// \param penv The parallel environment backs the execution.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendParallelBarrier(
-    task_id: ::std::os::raw::c_int,
-    penv: *mut TVMParallelGroupEnv,
-  ) -> ::std::os::raw::c_int;
-}
-extern "C" {
-  /// \brief Simple static initialization function.
-  /// Run f once and set handle to be not null.
-  /// This function is mainly used for test purpose.
-  ///
-  /// \param handle An global address to indicate f
-  /// \param f The function to be ran
-  /// \param cdata The closure data to pass to the function.
-  /// \param nbytes Number of bytes in the closure data.
-  /// \return 0 when no error is thrown, -1 when failure happens
-  pub fn TVMBackendRunOnce(
-    handle: *mut *mut ::std::os::raw::c_void,
-    f: ::std::option::Option<
-      unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
-    >,
-    cdata: *mut ::std::os::raw::c_void,
-    nbytes: ::std::os::raw::c_int,
-  ) -> ::std::os::raw::c_int;
-}
diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs
deleted file mode 100644 (file)
index 08fbd59..0000000
+++ /dev/null
@@ -1,472 +0,0 @@
-use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
-
-use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
-use serde;
-use serde_json;
-
-use super::{DataType, Module, Storage, TVMArgValue, TVMContext, Tensor};
-use errors::{Error, ErrorKind, Result};
-use ffi::runtime::{
-  DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor,
-};
-
-// Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h`
-const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
-// Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h`
-const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7;
-
-/// A TVM computation graph.
-///
-/// # Examples
-///
-/// ```
-/// let graph_json = fs::read_to_string("graph.json")).unwrap();
-/// let graph = Graph::try_from(&graph_json).unwrap();
-/// ```
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Graph {
-  pub nodes: Vec<Node>,
-  pub arg_nodes: Vec<usize>,
-  pub heads: Vec<Entry>,
-  pub node_row_ptr: Option<Vec<usize>>,
-  pub attrs: Option<HashMap<String, serde_json::Value>>,
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Entry {
-  pub id: usize,
-  pub index: usize,
-  pub version: usize,
-}
-
-impl Graph {
-  fn entry_index(&self, entry: &Entry) -> Result<usize> {
-    self
-      .node_row_ptr
-      .as_ref()
-      .map(|nrp| nrp[entry.id] + entry.index)
-      .ok_or("Missing node_row_ptr.".into())
-  }
-
-  /// Attempt to deserialize a JSON attribute to a type `T`.
-  fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
-    Ok(serde_json::from_value::<T>(
-      self
-        .attrs
-        .as_ref()
-        .ok_or(ErrorKind::GraphFormatError(
-          "Missing graph attrs".to_string(),
-        ))?
-        .get(attr)
-        .ok_or(ErrorKind::GraphFormatError(format!(
-          "Missing {} attr",
-          attr
-        )))?
-        .to_owned(),
-    )?)
-  }
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Node {
-  pub op: String,
-  pub name: String,
-  pub inputs: Vec<Entry>,
-  pub attrs: Option<HashMap<String, String>>,
-  pub control_deps: Option<Vec<Entry>>,
-}
-
-struct NodeAttrs {
-  func_name: String,
-  num_outputs: usize,
-  flatten_data: bool,
-}
-
-impl Node {
-  fn parse_attrs(&self) -> Result<NodeAttrs> {
-    let attrs = self
-      .attrs
-      .as_ref()
-      .ok_or(format!("Missing node.attrs for `{}`", self.name))?;
-    let func_name = attrs
-      .get("func_name")
-      .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
-      .to_string();
-    let num_outputs = attrs
-      .get("num_outputs")
-      .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
-      .parse::<usize>()?;
-    let flatten_data = attrs
-      .get("flatten_data")
-      .ok_or(format!(
-        "Node `{}` is missing attrs.flatten_data",
-        self.name
-      ))?
-      .parse::<u8>()?
-      == 1;
-    Ok(NodeAttrs {
-      func_name,
-      num_outputs,
-      flatten_data,
-    })
-  }
-}
-
-impl<'a> TryFrom<&'a String> for Graph {
-  type Error = Error;
-  fn try_from(graph_json: &String) -> Result<Self> {
-    let graph = serde_json::from_str(graph_json)?;
-    Ok(graph)
-  }
-}
-
-impl<'a> TryFrom<&'a str> for Graph {
-  type Error = Error;
-  fn try_from(graph_json: &'a str) -> Result<Self> {
-    let graph = serde_json::from_str(graph_json)?;
-    Ok(graph)
-  }
-}
-
-/// A executor for a TVM computation graph.
-///
-/// # Examples
-///
-/// ```
-/// use ndarray::Array;
-///
-/// let syslib = SystemLibModule::default(); // a provider of TVM functions
-///
-/// let mut params_bytes = Vec::new();
-/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
-/// let params = tvm::runtime::load_param_dict(&params_bytes).unwrap();
-///
-/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
-///
-/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
-/// exec.load_params(params);
-///
-/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
-/// exec.set_input("data", x.into());
-/// exec.run();
-/// let output = exec.get_output(0).unwrap();
-///
-/// println!("{:#?}", Array::try_from(output).unwrap());
-/// ```
-pub struct GraphExecutor<'m, 't> {
-  graph: Graph,
-  op_execs: Vec<Box<Fn() + 'm>>,
-  tensors: Vec<Tensor<'t>>,
-}
-
-unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
-
-impl<'m, 't> GraphExecutor<'m, 't> {
-  pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
-    let tensors = Self::setup_storages(&graph)?;
-    Ok(GraphExecutor {
-      op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
-      tensors: tensors,
-      graph: graph,
-    })
-  }
-
-  /// Runs the computation graph.
-  pub fn run(&self) {
-    self.op_execs.iter().for_each(|op_exec| {
-      op_exec();
-    });
-  }
-
-  /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
-  fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
-    let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
-    let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
-    let dtypes = graph
-      .get_attr::<(String, Vec<String>)>("dltype")?
-      .1
-      .iter()
-      .map(|dltype| {
-        if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
-          Ok(dtype)
-        } else {
-          Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into())
-        }
-      })
-      .collect::<Result<Vec<DataType>>>()?;
-
-    let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
-    let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
-    for (i, &storage_id) in storage_ids.iter().enumerate() {
-      let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
-      let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
-      storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
-    }
-
-    let mut storages: Vec<Storage> = storage_num_bytes
-      .into_iter()
-      .map(|nbytes| Storage::new(nbytes, align))
-      .collect::<Result<Vec<Storage>>>()?;
-
-    let tensors = izip!(storage_ids, shapes, dtypes)
-      .map(|(storage_id, shape, dtype)| {
-        let storage = storages[storage_id].view();
-        Tensor {
-          data: mem::replace(&mut storages[storage_id], storage),
-          ctx: TVMContext::default(),
-          dtype: dtype,
-          size: shape.iter().product::<i64>() as usize,
-          shape: shape,
-          strides: None,
-          byte_offset: 0,
-        }
-      })
-      .collect();
-
-    Ok(tensors)
-  }
-
-  /// Creates closures which represent the computation performed by this graph.
-  fn setup_op_execs<M: 'm + Module>(
-    graph: &Graph,
-    lib: &'m M,
-    tensors: &Vec<Tensor<'t>>,
-  ) -> Result<Vec<Box<Fn() + 'm>>> {
-    ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
-    let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
-
-    let mut op_execs = Vec::new();
-    for (i, node) in graph.nodes.iter().enumerate() {
-      if node.op == "null" {
-        continue;
-      }
-      ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
-      ensure!(node.attrs.is_some(), "Missing node attrs.");
-
-      let attrs = node.parse_attrs()?;
-
-      if attrs.func_name == "__nop" {
-        continue;
-      }
-
-      let func = lib
-        .get_function(&attrs.func_name)
-        .ok_or(format!("Missing function {}", attrs.func_name))?;
-      let arg_indices = node
-        .inputs
-        .iter()
-        .map(|entry| graph.entry_index(entry))
-        .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi)));
-
-      let dl_tensors = arg_indices
-        .map(|idx| {
-          let tensor = &tensors[idx?];
-          Ok(if attrs.flatten_data {
-            DLTensor::from_tensor(tensor, true /* flatten */)
-          } else {
-            DLTensor::from(tensor)
-          })
-        })
-        .collect::<Result<Vec<DLTensor>>>()
-        .unwrap();
-      let op: Box<Fn()> = box move || {
-        let args = dl_tensors
-          .iter()
-          .map(|t| t.into())
-          .collect::<Vec<TVMArgValue>>();
-        func(args.as_slice());
-      };
-      op_execs.push(op);
-    }
-    Ok(op_execs)
-  }
-
-  pub fn load_params(&mut self, params: HashMap<String, Tensor<'t>>) {
-    params.into_iter().for_each(|(name, param)| {
-      self.set_input(name, param);
-    })
-  }
-
-  pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor<'t>) {
-    if let Some(idx) = self.get_input_index(name.as_ref()) {
-      // TODO: consider `new_with_params` to avoid ever allocating
-      let ptr = self.tensors[idx].data.as_ptr();
-      let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
-      let mut owner = to_replace.nth(0).unwrap();
-      if value.data.is_owned() {
-        // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
-        // mem::replace(&mut (*owner), value);
-        // to_replace.for_each(|t| {
-        //   panic!("replacing");
-        //   t.data = owner.data.view();
-        // });
-        owner.copy(&value);
-      } else {
-        owner.copy(&value);
-      }
-    } else {
-      println!("Unexpected input `{}`", name.as_ref());
-    }
-  }
-
-  /// Returns the graph input with name `name`, if it exists.
-  pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
-    self
-      .get_input_index(name.as_ref())
-      .and_then(move |idx| Some(&self.tensors[idx]))
-  }
-
-  /// Returns the graph output with index `index`, if it exists.
-  pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
-    let graph = &self.graph;
-    graph.heads.get(idx).and_then(|entry| {
-      graph
-        .entry_index(entry)
-        .map(|idx| self.tensors.get(idx))
-        .unwrap_or(None)
-    })
-  }
-
-  /// Returns the index for graph input with name `name`, if it exists.
-  pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
-    let graph = &self.graph;
-    (0..graph.nodes.len())
-      .skip_while(|&i| graph.nodes[i].name != name.as_ref())
-      .nth(0)
-      .and_then(|i| {
-        if graph.arg_nodes.iter().any(|&id| id == i) {
-          graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
-        } else {
-          None
-        }
-      })
-  }
-}
-
-/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
-named!(
-  tvm_str_to_type<CompleteStr, DataType>,
-  do_parse!(
-    type_name: alpha1 >>
-    bits: digit1 >>
-    lanes: opt!(tuple!(tag!("x"), digit1)) >>
-    (DataType {
-      code: match type_name {
-        CompleteStr("int") => DLDataTypeCode_kDLInt,
-        CompleteStr("uint") => DLDataTypeCode_kDLUInt,
-        CompleteStr("float") => DLDataTypeCode_kDLFloat,
-        _ => DLDataTypeCode_kDLFloat,
-      } as usize,
-      bits: bits.parse::<u8>().unwrap() as usize,
-      lanes: match lanes {
-        Some(lanes) => lanes.1.parse::<u16>().unwrap() as usize,
-        None => 1,
-      },
-    })
-  )
-);
-
-/// Converts a bytes to String.
-named!(
-  name<String>,
-  map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
-    b.to_vec()
-  ))
-);
-
-/// Parses a TVMContext
-named!(
-  tvm_ctx<&[u8], TVMContext>,
-  do_parse!(
-    device_type: le_u32 >>
-    device_id: le_i32 >>
-    (TVMContext { device_type: device_type as usize, device_id: device_id as usize })
-  )
-);
-
-/// Parses a DataType
-named!(
-  data_type<&[u8], DataType>,
-  do_parse!(
-    code: le_u8 >>
-    bits: le_u8 >>
-    lanes: le_u16 >>
-    (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize })
-  )
-);
-
-/// Parses a Tensor from a TVM array file.
-named!(
-  tensor<Tensor>,
-  do_parse!(
-    take!(8)
-      >> bits!(tag_bits!(u64, 64, 0))
-      >> ctx: tvm_ctx
-      >> ndim: le_u32
-      >> dtype: data_type
-      >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize)
-      >> length: le_i64
-      >> data: take!(length)
-      >> (Tensor {
-        data: Storage::from(data),
-        ctx: ctx,
-        dtype: dtype,
-        size: shape.iter().product::<i64>() as usize,
-        shape: shape,
-        strides: None,
-        byte_offset: 0,
-      })
-  )
-);
-
-/// Parses a graph params dict from a params binary file.
-named!(
-  parse_param_dict<HashMap<String, Tensor>>,
-  do_parse!(
-    take!(8)
-      >> bits!(tag_bits!(u64, 64, 0))
-      >> names: length_count!(le_u64, name)
-      >> tensors: length_count!(le_u64, tensor)
-      >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter())))
-  )
-);
-
-/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
-pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
-  if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
-    if remaining_bytes.len() > 0 {
-      bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
-    } else {
-      Ok(param_dict)
-    }
-  } else {
-    bail!(ErrorKind::LoadGraphParamsError(
-      "invalid parameters file".to_string()
-    ))
-  }
-}
-
-#[cfg(test)]
-mod tests {
-  use super::*;
-
-  #[test]
-  fn test_str_to_type() {
-    assert_eq!(
-      tvm_str_to_type(CompleteStr("float24")).unwrap().1,
-      DataType {
-        code: DLDataTypeCode_kDLFloat as usize,
-        bits: 24,
-        lanes: 1
-      }
-    );
-    assert_eq!(
-      tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1,
-      DataType {
-        code: DLDataTypeCode_kDLUInt as usize,
-        bits: 111,
-        lanes: 44
-      }
-    );
-  }
-}
diff --git a/rust/src/runtime/mod.rs b/rust/src/runtime/mod.rs
deleted file mode 100644 (file)
index 1a9c5ba..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-mod allocator;
-mod array;
-mod module;
-#[macro_use]
-mod packed_func;
-mod graph;
-#[cfg(target_env = "sgx")]
-#[macro_use]
-pub mod sgx;
-mod threading;
-mod workspace;
-
-use std::os::raw::c_char;
-
-pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*};
-
-#[cfg(target_env = "sgx")]
-use self::sgx::ocall_packed_func;
-
-#[no_mangle]
-pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) {
-  #[cfg(not(target_env = "sgx"))]
-  unsafe {
-    panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
-  }
-  #[cfg(target_env = "sgx")]
-  ocall_packed!("__sgx_set_last_error__", cmsg);
-}
diff --git a/rust/src/runtime/module.rs b/rust/src/runtime/module.rs
deleted file mode 100644 (file)
index 2594756..0000000
+++ /dev/null
@@ -1,46 +0,0 @@
-use std::{
-  collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
-};
-
-use ffi::runtime::BackendPackedCFunc;
-use runtime::packed_func::{wrap_backend_packed_func, PackedFunc};
-
-pub trait Module {
-  fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
-}
-
-pub struct SystemLibModule;
-
-lazy_static! {
-  static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
-    Mutex::new(HashMap::new());
-}
-
-impl Module for SystemLibModule {
-  fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
-    SYSTEM_LIB_FUNCTIONS
-      .lock()
-      .unwrap()
-      .get(name.as_ref())
-      .map(|func| wrap_backend_packed_func(func.to_owned()))
-  }
-}
-
-impl Default for SystemLibModule {
-  fn default() -> Self {
-    SystemLibModule {}
-  }
-}
-
-#[no_mangle]
-pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
-  cname: *const c_char,
-  func: BackendPackedCFunc,
-) -> i32 {
-  let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
-  SYSTEM_LIB_FUNCTIONS
-    .lock()
-    .unwrap()
-    .insert(name.to_string(), func);
-  return 0;
-}
diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs
deleted file mode 100644 (file)
index a6ad7fc..0000000
+++ /dev/null
@@ -1,342 +0,0 @@
-use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
-
-use super::Tensor;
-use ffi::runtime::{
-  BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor,
-  TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMTypeCode_kNDArrayContainer, TVMValue,
-};
-
-use errors::*;
-
-pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
-
-/// Calls a packed function and returns a `TVMRetValue`.
-///
-/// # Example
-///
-/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
-#[macro_export]
-macro_rules! call_packed {
-  ($fn:expr, $($args:expr),+) => {
-    $fn(&[$($args.into(),)+])
-  };
-  ($fn:expr) => {
-    $fn(&Vec::new())
-  };
-}
-
-/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
-/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
-#[derive(Clone, Copy)]
-pub struct TVMArgValue<'a> {
-  _lifetime: PhantomData<&'a ()>,
-  pub(crate) value: TVMValue,
-  pub(crate) type_code: i64,
-}
-
-impl<'a> TVMArgValue<'a> {
-  pub fn new(value: TVMValue, type_code: i64) -> Self {
-    TVMArgValue {
-      _lifetime: PhantomData,
-      value: value,
-      type_code: type_code,
-    }
-  }
-}
-
-/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
-macro_rules! impl_prim_tvm_arg {
-  ($type:ty, $field:ident, $code:expr, $as:ty) => {
-    impl<'a> From<$type> for TVMArgValue<'a> {
-      fn from(val: $type) -> Self {
-        TVMArgValue {
-          value: TVMValue { $field: val as $as },
-          type_code: $code as i64,
-          _lifetime: PhantomData,
-        }
-      }
-    }
-    impl<'a> TryFrom<TVMArgValue<'a>> for $type {
-      type Error = Error;
-      fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
-        ensure!(
-          val.type_code == $code as i64,
-          "Could not downcast arg. Expected `{}`, got `{}`",
-          $code,
-          val.type_code
-        );
-        Ok(unsafe { val.value.$field as $type })
-      }
-    }
-  };
-  ($type:ty, $field:ident, $code:expr) => {
-    impl_prim_tvm_arg!($type, $field, $code, $type);
-  };
-  ($type:ty,v_int64) => {
-    impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64);
-  };
-  ($type:ty,v_float64) => {
-    impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64);
-  };
-}
-
-impl_prim_tvm_arg!(f32, v_float64);
-impl_prim_tvm_arg!(f64, v_float64);
-impl_prim_tvm_arg!(i8, v_int64);
-impl_prim_tvm_arg!(u8, v_int64);
-impl_prim_tvm_arg!(i32, v_int64);
-impl_prim_tvm_arg!(u32, v_int64);
-impl_prim_tvm_arg!(i64, v_int64);
-impl_prim_tvm_arg!(u64, v_int64);
-
-/// Creates a conversion to a `TVMArgValue` for an object handle.
-impl<'a, T> From<*const T> for TVMArgValue<'a> {
-  fn from(ptr: *const T) -> Self {
-    TVMArgValue {
-      value: TVMValue {
-        v_handle: ptr as *mut T as *mut c_void,
-      },
-      type_code: TVMTypeCode_kArrayHandle as i64,
-      _lifetime: PhantomData,
-    }
-  }
-}
-
-/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
-impl<'a, T> From<*mut T> for TVMArgValue<'a> {
-  fn from(ptr: *mut T) -> Self {
-    TVMArgValue {
-      value: TVMValue {
-        v_handle: ptr as *mut c_void,
-      },
-      type_code: TVMTypeCode_kHandle as i64,
-      _lifetime: PhantomData,
-    }
-  }
-}
-
-impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
-  fn from(arr: &'a mut DLTensor) -> Self {
-    TVMArgValue {
-      value: TVMValue {
-        v_handle: arr as *mut _ as *mut c_void,
-      },
-      type_code: TVMTypeCode_kArrayHandle as i64,
-      _lifetime: PhantomData,
-    }
-  }
-}
-
-impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
-  fn from(arr: &'a DLTensor) -> Self {
-    TVMArgValue {
-      value: TVMValue {
-        v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
-      },
-      type_code: TVMTypeCode_kArrayHandle as i64,
-      _lifetime: PhantomData,
-    }
-  }
-}
-
-impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
-  type Error = Error;
-  fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
-    ensure!(
-      val.type_code == TVMTypeCode_kArrayHandle as i64
-        || val.type_code == TVMTypeCode_kNDArrayContainer as i64,
-      "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
-      TVMTypeCode_kArrayHandle,
-      TVMTypeCode_kNDArrayContainer,
-      val.type_code,
-    );
-
-    let dlt = unsafe { *(val.value.v_handle as *mut DLTensor as *const DLTensor) };
-    Ok(dlt.into())
-  }
-}
-
-/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
-/// Can be downcasted using `try_from` if it contains the desired type.
-///
-/// # Example
-///
-/// ```
-/// let a = 42u32;
-/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
-///
-/// let s = "hello, world!";
-/// let t: TVMRetValue = s.into();
-/// assert_eq!(String::try_from(t).unwrap(), s);
-/// ```
-pub struct TVMRetValue {
-  /// A primitive return value, if any.
-  prim_value: u64,
-  /// An object return value, if any.
-  box_value: Box<Any>,
-  /// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use.
-  type_code: i64,
-}
-
-#[cfg(target_env = "sgx")]
-impl TVMRetValue {
-  pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
-    unsafe {
-      Self {
-        prim_value: match type_code {
-          0 | 1 => value.v_int64 as u64,
-          2 => value.v_float64 as u64,
-          3 | 7 | 8 | 9 | 10 => value.v_handle as u64,
-          11 | 12 => value.v_str as u64,
-          _ => 0,
-        } as u64,
-        box_value: box (),
-        type_code: type_code,
-      }
-    }
-  }
-
-  pub fn into_tvm_value(self) -> (TVMValue, i64) {
-    let val = match self.type_code {
-      0 | 1 => TVMValue {
-        v_int64: self.prim_value.clone() as i64,
-      },
-      2 => TVMValue {
-        v_float64: self.prim_value.clone() as f64,
-      },
-      3 | 7 | 8 | 9 | 10 | 13 => TVMValue {
-        v_handle: Box::into_raw(self.box_value) as *mut c_void,
-      },
-      11 | 12 => TVMValue {
-        v_str: Box::into_raw(self.box_value) as *const _,
-      },
-      _ => unreachable!(),
-    };
-    (val, self.type_code)
-  }
-}
-
-impl Default for TVMRetValue {
-  fn default() -> Self {
-    TVMRetValue {
-      prim_value: 0,
-      box_value: box (),
-      type_code: 0,
-    }
-  }
-}
-
-macro_rules! impl_prim_ret_value {
-  ($type:ty, $code:expr) => {
-    impl From<$type> for TVMRetValue {
-      fn from(val: $type) -> Self {
-        TVMRetValue {
-          prim_value: val as u64,
-          box_value: box (),
-          type_code: $code,
-        }
-      }
-    }
-    impl TryFrom<TVMRetValue> for $type {
-      type Error = Error;
-      fn try_from(ret: TVMRetValue) -> Result<$type> {
-        if ret.type_code == $code {
-          Ok(ret.prim_value as $type)
-        } else {
-          bail!(ErrorKind::TryFromTVMRetValueError(
-            stringify!($type).to_string(),
-            ret.type_code
-          ))
-        }
-      }
-    }
-  };
-}
-
-macro_rules! impl_boxed_ret_value {
-  ($type:ty, $code:expr) => {
-    impl From<$type> for TVMRetValue {
-      fn from(val: $type) -> Self {
-        TVMRetValue {
-          prim_value: 0,
-          box_value: box val,
-          type_code: $code,
-        }
-      }
-    }
-    impl TryFrom<TVMRetValue> for $type {
-      type Error = Error;
-      fn try_from(ret: TVMRetValue) -> Result<$type> {
-        if let Ok(val) = ret.box_value.downcast::<$type>() {
-          Ok(*val)
-        } else {
-          bail!(ErrorKind::TryFromTVMRetValueError(
-            stringify!($type).to_string(),
-            ret.type_code
-          ))
-        }
-      }
-    }
-  };
-}
-
-impl_prim_ret_value!(i8, 0);
-impl_prim_ret_value!(u8, 1);
-impl_prim_ret_value!(i16, 0);
-impl_prim_ret_value!(u16, 1);
-impl_prim_ret_value!(i32, 0);
-impl_prim_ret_value!(u32, 1);
-impl_prim_ret_value!(f32, 2);
-impl_prim_ret_value!(i64, 0);
-impl_prim_ret_value!(u64, 1);
-impl_prim_ret_value!(f64, 2);
-impl_prim_ret_value!(isize, 0);
-impl_prim_ret_value!(usize, 1);
-impl_boxed_ret_value!(String, 11);
-
-impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
-  fn from(val: &'t Tensor<'a>) -> Self {
-    TVMRetValue {
-      prim_value: 0,
-      box_value: box DLTensor::from(val),
-      type_code: TVMTypeCode_kNDArrayContainer as i64,
-    }
-  }
-}
-
-impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
-  type Error = Error;
-  fn try_from(ret: TVMRetValue) -> Result<Self> {
-    ensure!(
-      ret.type_code == TVMTypeCode_kArrayHandle as i64
-        || ret.type_code == TVMTypeCode_kNDArrayContainer as i64,
-      "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
-      TVMTypeCode_kArrayHandle,
-      TVMTypeCode_kNDArrayContainer,
-      ret.type_code,
-    );
-
-    let dlt = unsafe { *(ret.prim_value as *mut DLTensor as *const DLTensor) };
-    Ok(dlt.into())
-  }
-}
-
-// @see `WrapPackedFunc` in `llvm_module.cc`.
-pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
-  box move |args: &[TVMArgValue]| {
-    func(
-      args
-        .iter()
-        .map(|ref arg| arg.value)
-        .collect::<Vec<TVMValue>>()
-        .as_ptr(),
-      args
-        .iter()
-        .map(|ref arg| arg.type_code as i32)
-        .collect::<Vec<i32>>()
-        .as_ptr() as *const i32,
-      args.len() as i32,
-    );
-    TVMRetValue::default()
-  }
-}
diff --git a/rust/src/runtime/sgx.rs b/rust/src/runtime/sgx.rs
deleted file mode 100644 (file)
index 00be3ee..0000000
+++ /dev/null
@@ -1,82 +0,0 @@
-use std::{
-  ffi::CString,
-  os::raw::{c_char, c_int},
-};
-
-use errors::Result;
-use ffi::runtime::TVMValue;
-use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
-
-pub use runtime::threading::tvm_run_worker as run_worker;
-
-#[macro_export]
-macro_rules! tvm_ocall {
-  ($func: expr) => {
-    match $func {
-      0 => Ok(()),
-      err => Err(format!("SGX error: {}", err)),
-    }
-  };
-}
-
-pub type SgxStatus = u32;
-
-#[cfg(target_env = "sgx")]
-extern "C" {
-  fn tvm_ocall_packed_func(
-    name: *const c_char,
-    arg_values: *const TVMValue,
-    type_codes: *const c_int,
-    num_args: c_int,
-    ret_val: *mut TVMValue,
-    ret_type_code: *mut c_int,
-  ) -> SgxStatus;
-}
-
-pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
-  let mut ret_val = TVMValue { v_int64: 0 };
-  let ret_type_code = 0i64;
-  unsafe {
-    tvm_ocall!(tvm_ocall_packed_func(
-      CString::new(fn_name.as_ref()).unwrap().as_ptr(),
-      args
-        .iter()
-        .map(|ref arg| arg.value)
-        .collect::<Vec<TVMValue>>()
-        .as_ptr(),
-      args
-        .iter()
-        .map(|ref arg| arg.type_code as i32)
-        .collect::<Vec<i32>>()
-        .as_ptr() as *const i32,
-      args.len() as i32,
-      &mut ret_val as *mut TVMValue,
-      &mut (ret_type_code as i32) as *mut c_int,
-    ))?;
-  }
-  Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
-}
-
-#[macro_export]
-macro_rules! ocall_packed {
-  ($fn_name:expr, $($args:expr),+) => {
-    ocall_packed_func($fn_name, &[$($args.into(),)+])
-      .expect(concat!("Error calling `", $fn_name, "`"))
-  };
-  ($fn_name:expr) => {
-    ocall_packed_func($fn_name, &Vec::new())
-      .expect(concat!("Error calling `", $fn_name, "`"))
-  }
-}
-
-pub fn shutdown() {
-  if env!("TVM_NUM_THREADS") != "0" {
-    sgx_join_threads()
-  }
-}
-
-impl Drop for SystemLibModule {
-  fn drop(&mut self) {
-    shutdown()
-  }
-}
diff --git a/rust/src/runtime/threading.rs b/rust/src/runtime/threading.rs
deleted file mode 100644 (file)
index 1d6d7fc..0000000
+++ /dev/null
@@ -1,337 +0,0 @@
-use std::{
-  os::raw::{c_int, c_void},
-  sync::{
-    atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
-    Arc, Barrier,
-  },
-};
-
-#[cfg(not(target_env = "sgx"))]
-use num_cpus;
-#[cfg(not(target_env = "sgx"))]
-use std::{
-  env,
-  thread::{self, JoinHandle},
-};
-
-#[cfg(target_env = "sgx")]
-use std::{collections::VecDeque, ptr, sync::Mutex};
-
-use bounded_spsc_queue::{self, Producer};
-
-use super::super::errors::*;
-use ffi::runtime::TVMParallelGroupEnv;
-
-#[cfg(target_env = "sgx")]
-use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
-
-type FTVMParallelLambda =
-  extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
-
-/// Holds a parallel job request made by a TVM library function.
-struct Job {
-  cb: FTVMParallelLambda,
-  cdata: *const c_void,
-  req_num_tasks: usize,
-  pending: Arc<AtomicUsize>,
-}
-
-impl Job {
-  /// Splits this job into a number of `Task`s which can be scheduled.
-  fn tasks(&self, num_workers: usize) -> Vec<Task> {
-    let num_tasks = if self.req_num_tasks == 0 {
-      num_workers
-    } else {
-      self.req_num_tasks.min(num_workers)
-    };
-    self.pending.store(num_tasks, Ordering::SeqCst);
-
-    let barrier = Arc::new(Barrier::new(num_tasks));
-
-    (0..num_tasks)
-      .map(move |i| Task {
-        id: i,
-        flambda: self.cb,
-        penv: TVMParallelGroupEnv {
-          sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
-          num_task: num_tasks as i32,
-        },
-        cdata: self.cdata,
-        pending: Arc::clone(&self.pending),
-      })
-      .collect()
-  }
-
-  /// Waits for all tasks in this `Job` to be completed.
-  fn wait(&self) -> Result<()> {
-    while self.pending.load(Ordering::Acquire) > 0 {
-      #[cfg(not(target_env = "sgx"))]
-      thread::yield_now();
-    }
-    Ok(())
-  }
-}
-
-/// A chunk of work requested by a TVM function.
-struct Task {
-  id: usize,
-  flambda: FTVMParallelLambda,
-  penv: TVMParallelGroupEnv,
-  cdata: *const c_void,
-  pending: Arc<AtomicUsize>,
-}
-unsafe impl Send for Task {}
-unsafe impl Sync for Task {}
-
-impl FnOnce<()> for Task {
-  type Output = i32;
-  extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
-    let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
-    self.pending.fetch_sub(1, Ordering::AcqRel);
-    status
-  }
-}
-
-#[derive(Default)]
-struct Threads {
-  #[allow(unused)]
-  #[cfg(not(target_env = "sgx"))]
-  handles: Vec<JoinHandle<()>>,
-  queues: Vec<Producer<Task>>,
-}
-
-impl<'a> Threads {
-  #[cfg(not(target_env = "sgx"))]
-  fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
-    num_threads: usize,
-    cb: F,
-  ) -> Self {
-    let (handles, queues) = (0..num_threads)
-      .map(|_| {
-        let (p, c) = bounded_spsc_queue::make(2);
-        let handle = thread::spawn(move || cb(c.into()));
-        (handle, p)
-      })
-      .unzip();
-    Threads {
-      handles: handles,
-      queues: queues,
-    }
-  }
-
-  #[cfg(target_env = "sgx")]
-  fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
-    num_threads: usize,
-    _cb: F,
-  ) -> Self {
-    let mut consumer_queues = SGX_QUEUES.lock().unwrap();
-    let queues = (0..num_threads)
-      .map(|_| {
-        let (p, c) = bounded_spsc_queue::make(2);
-        consumer_queues.push_back(c.into());
-        p
-      })
-      .collect();
-    ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
-    Threads { queues: queues }
-  }
-}
-
-struct ThreadPool {
-  num_workers: usize,
-  #[allow(unused)]
-  threads: Threads,
-}
-
-thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
-
-impl ThreadPool {
-  fn new() -> Self {
-    let num_workers = max_concurrency();
-    ThreadPool {
-      num_workers: num_workers,
-      threads: Threads::launch(num_workers, ThreadPool::run_worker),
-    }
-  }
-
-  fn launch(&self, job: Job) {
-    let mut tasks = job.tasks(self.num_workers + 1);
-
-    for (i, task) in tasks.split_off(1).into_iter().enumerate() {
-      self.threads.queues[i].push(task);
-    }
-
-    tasks.pop().unwrap()();
-    job.wait().unwrap();
-  }
-
-  fn run_worker(queue: Consumer<Task>) {
-    loop {
-      let task = queue.pop();
-      let result = task();
-      if result == <i32>::min_value() {
-        break;
-      } else if result != 0 {
-        panic!("Error running task.");
-      }
-    }
-  }
-}
-
-// Send + Sync wrapper for bounded_spsc_queue::Consumer
-struct Consumer<T> {
-  consumer: bounded_spsc_queue::Consumer<T>,
-}
-impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
-  fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
-    Consumer { consumer: c }
-  }
-}
-impl<T> Consumer<T> {
-  fn pop(&self) -> T {
-    self.consumer.pop()
-  }
-}
-unsafe impl<T> Send for Consumer<T> {}
-unsafe impl<T> Sync for Consumer<T> {}
-
-#[cfg(target_env = "sgx")]
-lazy_static! {
-  /// Holds tasks for untrusted threads which re-enter the enclave to execute.
-  static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
-}
-
-#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
-fn max_concurrency() -> usize {
-  if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
-    if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
-      return threads;
-    }
-  }
-  num_cpus::get_physical()
-}
-
-#[cfg(target_env = "sgx")]
-fn max_concurrency() -> usize {
-  usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
-}
-
-#[cfg(target_arch = "wasm32")]
-fn max_concurrency() -> usize {
-  0 // wasm doesn't support threads yet
-}
-
-#[cfg(target_env = "sgx")]
-pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
-  let q = {
-    let mut qs = SGX_QUEUES.lock().unwrap();
-    qs.pop_front()
-    // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
-  };
-  if let Some(q) = q {
-    ThreadPool::run_worker(q);
-  }
-  TVMRetValue::default()
-}
-
-#[no_mangle]
-pub extern "C" fn TVMBackendParallelLaunch(
-  cb: FTVMParallelLambda,
-  cdata: *const c_void,
-  num_task: usize,
-) -> c_int {
-  if max_concurrency() == 0 {
-    let penv = TVMParallelGroupEnv {
-      sync_handle: 0 as *mut c_void,
-      num_task: 1,
-    };
-    cb(0, &penv as *const _, cdata);
-  } else {
-    THREAD_POOL.with(|pool| {
-      pool.launch(Job {
-        cb: cb,
-        cdata: cdata,
-        req_num_tasks: num_task,
-        pending: Arc::new(ATOMIC_USIZE_INIT),
-      });
-    });
-  }
-  return 0;
-}
-
-#[cfg(target_env = "sgx")]
-pub(crate) fn sgx_join_threads() {
-  extern "C" fn poison_pill(
-    _task_id: usize,
-    _penv: *const TVMParallelGroupEnv,
-    _cdata: *const c_void,
-  ) -> i32 {
-    <i32>::min_value()
-  }
-
-  THREAD_POOL.with(|pool| {
-    pool.launch(Job {
-      cb: poison_pill,
-      cdata: ptr::null(),
-      req_num_tasks: 0,
-      pending: Arc::new(ATOMIC_USIZE_INIT),
-    });
-  });
-  ocall_packed!("__sgx_thread_group_join__", 0);
-}
-
-// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used.
-#[no_mangle]
-pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
-  let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
-  barrier.wait();
-}
-
-#[cfg(test)]
-mod tests {
-  use std::{ptr, thread, time::Duration};
-
-  use super::*;
-
-  #[test]
-  fn test_max_concurrency() {
-    env::set_var("TVM_NUM_THREADS", "42");
-    env::set_var("OMP_NUM_THREADS", "24");
-    assert_eq!(max_concurrency(), 42);
-    env::remove_var("TVM_NUM_THREADS");
-    assert_eq!(max_concurrency(), 24);
-  }
-
-  extern "C" fn flambda(
-    task_id: usize,
-    penv: *const TVMParallelGroupEnv,
-    cdata: *const c_void,
-  ) -> i32 {
-    if cdata == ptr::null() {
-      return 0;
-    }
-    unsafe {
-      let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
-      thread::sleep(Duration::from_millis(50 * task_id as u64));
-      counter.fetch_add(1, Ordering::SeqCst);
-      task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
-      assert_eq!((*penv).num_task, 3);
-    }
-    0
-  }
-
-  #[test]
-  fn test_parallel_launch() {
-    TVMBackendParallelLaunch(flambda, ptr::null(), 6);
-    let counter = ATOMIC_USIZE_INIT;
-    let task_ids_sum = ATOMIC_USIZE_INIT;
-    let cdata = (counter, task_ids_sum);
-    let num_tasks = 3;
-    TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
-    assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
-    assert_eq!(
-      cdata.1.load(Ordering::SeqCst),
-      (0..num_tasks).sum::<usize>()
-    );
-  }
-}
diff --git a/rust/src/runtime/workspace.rs b/rust/src/runtime/workspace.rs
deleted file mode 100644 (file)
index d0e6d8c..0000000
+++ /dev/null
@@ -1,119 +0,0 @@
-use std::{
-  cell::RefCell,
-  os::raw::{c_int, c_void},
-  ptr,
-};
-
-use super::allocator::Allocation;
-use errors::*;
-
-const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
-
-struct WorkspacePool {
-  workspaces: Vec<Allocation>,
-  free: Vec<usize>,
-  in_use: Vec<usize>,
-}
-
-impl WorkspacePool {
-  fn new() -> Self {
-    WorkspacePool {
-      workspaces: Vec::new(),
-      free: Vec::new(),
-      in_use: Vec::new(),
-    }
-  }
-
-  fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
-    self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
-    self.in_use.push(self.workspaces.len() - 1);
-    Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
-  }
-
-  fn alloc(&mut self, size: usize) -> Result<*mut u8> {
-    if self.free.len() == 0 {
-      return self.alloc_new(size);
-    }
-    let idx = self
-      .free
-      .iter()
-      .fold(None, |cur_ws_idx: Option<usize>, &idx| {
-        let ws_size = self.workspaces[idx].size();
-        if !ws_size >= size {
-          return cur_ws_idx;
-        }
-        cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
-          let cur_size = self.workspaces[cur_idx].size();
-          Some(match ws_size <= cur_size {
-            true => idx,
-            false => cur_idx,
-          })
-        })
-      });
-    match idx {
-      Some(idx) => {
-        self.free.remove_item(&idx).unwrap();
-        self.in_use.push(idx);
-        Ok(self.workspaces[idx].as_mut_ptr())
-      }
-      None => self.alloc_new(size),
-    }
-  }
-
-  fn free(&mut self, ptr: *mut u8) -> Result<()> {
-    let mut ws_idx = None;
-    for i in 0..self.in_use.len() {
-      let idx = self.in_use[i];
-      if self.workspaces[idx].as_mut_ptr() == ptr {
-        self.in_use.remove(i);
-        ws_idx = Some(idx);
-        break;
-      }
-    }
-    Ok(
-      self
-        .free
-        .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?),
-    )
-  }
-}
-
-thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
-
-const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
-
-#[no_mangle]
-pub extern "C" fn TVMBackendAllocWorkspace(
-  _device_type: c_int,
-  _device_id: c_int,
-  size: u64,
-  _dtype_code_hint: c_int,
-  _dtype_bits_hint: c_int,
-) -> *mut c_void {
-  let nbytes = if size == 0 {
-    WORKSPACE_PAGE_SIZE
-  } else {
-    size as usize
-  };
-  WORKSPACE_POOL.with(|pool_cell| {
-    pool_cell
-      .borrow_mut()
-      .alloc(nbytes as usize)
-      .unwrap_or(ptr::null_mut()) as *mut c_void
-  })
-}
-
-#[no_mangle]
-pub extern "C" fn TVMBackendFreeWorkspace(
-  _device_type: c_int,
-  _device_id: c_int,
-  ptr: *mut c_void,
-) -> c_int {
-  WORKSPACE_POOL.with(|pool_cell| {
-    (match pool_cell.borrow_mut().free(ptr as *mut u8) {
-      Ok(()) => 0,
-      Err(_) => -1,
-    }) as c_int
-  });
-  return 0;
-}
diff --git a/rust/tests/.gitignore b/rust/tests/.gitignore
deleted file mode 100644 (file)
index 8110767..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-*.json
-*.params
-*.o
diff --git a/rust/tests/build_model.py b/rust/tests/build_model.py
deleted file mode 100644 (file)
index e0b9049..0000000
+++ /dev/null
@@ -1,53 +0,0 @@
-"""Builds a simple NNVM graph for testing."""
-
-from os import path as osp
-
-import nnvm
-from nnvm import sym
-from nnvm.compiler import graph_util
-from nnvm.testing import init
-import numpy as np
-import tvm
-
-CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
-
-
-def _get_model(dshape):
-    data = sym.Variable('data', shape=dshape)
-    fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True)
-    left, right = sym.split(fc1, indices_or_sections=2, axis=1)
-    return sym.Group(((left + 1), (right - 1)))
-
-
-def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
-    if isinstance(graph, sym.Symbol):
-        graph = nnvm.graph.create(graph)
-    ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
-    param_shapes = dict(zip(graph.index.input_names, ishapes))
-    np.random.seed(seed)
-    params = {}
-    for param, shape in param_shapes.items():
-        if param in {'data', 'label'} or not shape:
-            continue
-        init_value = np.empty(shape).astype('float32')
-        initializer(param, init_value)
-        params[param] = tvm.nd.array(init_value)
-    return params
-
-def main():
-    dshape = (32, 16)
-    net = _get_model(dshape)
-    ishape_dict = {'data': dshape}
-    params = _init_params(net, ishape_dict)
-    graph, lib, params = nnvm.compiler.build(net, 'llvm',
-                                             shape=ishape_dict,
-                                             params=params,
-                                             dtype='float32')
-
-    with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet:
-        f_resnet.write(graph.json())
-    with open(osp.join(CWD, 'graph.params'), 'wb') as f_params:
-        f_params.write(nnvm.compiler.save_param_dict(params))
-
-if __name__ == '__main__':
-    main()
diff --git a/rust/tests/test_graph_serde.rs b/rust/tests/test_graph_serde.rs
deleted file mode 100644 (file)
index b02c128..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-#![feature(try_from)]
-
-extern crate serde;
-extern crate serde_json;
-
-extern crate tvm;
-
-use std::{convert::TryFrom, fs, io::Read};
-
-use tvm::runtime::Graph;
-
-#[test]
-fn test_load_graph() {
-  let mut params_bytes = Vec::new();
-  fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
-    .expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
-    .read_to_end(&mut params_bytes)
-    .unwrap();
-  let _params = tvm::runtime::load_param_dict(&params_bytes);
-
-  let graph = Graph::try_from(
-    &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
-  )
-  .unwrap();
-
-  assert_eq!(graph.nodes[3].op, "tvm_op");
-  assert_eq!(
-    graph.nodes[3]
-      .attrs
-      .as_ref()
-      .unwrap()
-      .get("func_name")
-      .unwrap(),
-    "fuse_dense"
-  );
-  assert_eq!(graph.nodes[5].inputs[0].index, 0);
-  assert_eq!(graph.nodes[6].inputs[0].index, 1);
-  assert_eq!(graph.heads.len(), 2);
-}
diff --git a/rust/tests/test_nnvm/Cargo.toml b/rust/tests/test_nnvm/Cargo.toml
deleted file mode 100644 (file)
index 7e6ce5f..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-[package]
-name = "test-nnvm"
-version = "0.0.0"
-license = "Apache-2.0"
-authors = ["Nick Hynes <nhynes@berkeley.edu>"]
-
-[dependencies]
-ndarray = "0.11.2"
-tvm = { path = "../../" }
-serde = "1.0.59"
-serde_json = "1.0.17"
-
-[build-dependencies]
-ar = "0.6.0"
diff --git a/rust/tests/test_nnvm/build.rs b/rust/tests/test_nnvm/build.rs
deleted file mode 100644 (file)
index 4d9cd30..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-extern crate ar;
-
-use std::{
-  env,
-  fs::File,
-  path::{Path, PathBuf},
-  process::Command,
-};
-
-use ar::Builder;
-
-fn main() {
-  let out_dir = env::var("OUT_DIR").unwrap();
-
-  let output = Command::new(concat!(
-    env!("CARGO_MANIFEST_DIR"),
-    "/src/build_test_graph.py"
-  ))
-  .arg(&out_dir)
-  .output()
-  .expect("Failed to execute command");
-  assert!(
-    Path::new(&format!("{}/graph.o", out_dir)).exists(),
-    "Could not build graph lib: {}",
-    String::from_utf8(output.stderr)
-      .unwrap()
-      .trim()
-      .split("\n")
-      .last()
-      .unwrap_or("")
-  );
-
-  let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect();
-  let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect();
-  let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
-  builder.append_path(in_path.to_str().unwrap()).unwrap();
-
-  println!("cargo:rustc-link-lib=static=graph");
-  println!("cargo:rustc-link-search=native={}", out_dir);
-}
diff --git a/rust/tests/test_nnvm/src/build_test_graph.py b/rust/tests/test_nnvm/src/build_test_graph.py
deleted file mode 100755 (executable)
index 429cc21..0000000
+++ /dev/null
@@ -1,63 +0,0 @@
-#!/usr/bin/env python3
-
-"""Builds a simple NNVM graph for testing."""
-
-from os import path as osp
-import sys
-
-import nnvm
-from nnvm import sym
-from nnvm.compiler import graph_util
-from nnvm.testing import init
-import numpy as np
-import tvm
-
-
-def _get_model(dshape):
-    data = sym.Variable('data', shape=dshape)
-    fc = sym.dense(data, units=dshape[-1]*2, use_bias=True)
-    left, right = sym.split(fc, indices_or_sections=2, axis=1)
-    return sym.Group(((left + 1), (right - 1), fc))
-
-
-def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
-    if isinstance(graph, sym.Symbol):
-        graph = nnvm.graph.create(graph)
-    ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
-    param_shapes = dict(zip(graph.index.input_names, ishapes))
-    np.random.seed(seed)
-    params = {}
-    for param, shape in param_shapes.items():
-        if param in {'data', 'label'} or not shape:
-            continue
-
-        init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32')
-        if param.endswith('_bias'):
-            params[param] = tvm.nd.array(init_value)
-            continue
-
-        init_value = np.empty(shape).astype('float32')
-        initializer(param, init_value)
-        # init_value /= init_value.sum() + 1e-10
-        params[param] = tvm.nd.array(init_value)
-    return params
-
-def main():
-    dshape = (4, 8)
-    net = _get_model(dshape)
-    ishape_dict = {'data': dshape}
-    params = _init_params(net, ishape_dict)
-    graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib',
-                                             shape=ishape_dict,
-                                             params=params,
-                                             dtype='float32')
-
-    out_dir = sys.argv[1]
-    lib.save(osp.join(sys.argv[1], 'graph.o'))
-    with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
-        f_resnet.write(graph.json())
-    with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
-        f_params.write(nnvm.compiler.save_param_dict(params))
-
-if __name__ == '__main__':
-    main()
diff --git a/rust/tests/test_nnvm/src/main.rs b/rust/tests/test_nnvm/src/main.rs
deleted file mode 100644 (file)
index 0953ce2..0000000
+++ /dev/null
@@ -1,80 +0,0 @@
-#![feature(try_from)]
-
-#[macro_use]
-extern crate ndarray;
-extern crate serde;
-extern crate serde_json;
-
-extern crate tvm;
-use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
-
-use ndarray::Array;
-use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
-
-const BATCH_SIZE: usize = 4;
-const IN_DIM: usize = 8;
-
-macro_rules! check_sum {
-  ($e:expr, $a:ident, $b:ident) => {
-    let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
-    check_sum!(a, $b);
-  };
-  ($e:expr, $a:expr, $b:ident) => {
-    let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
-    check_sum!(a, $b);
-  };
-  ($a:ident, $b:ident) => {
-    let a_sum: f32 = $a.scalar_sum();
-    let b_sum: f32 = $b.scalar_sum();
-    assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
-  };
-}
-
-fn main() {
-  let syslib = SystemLibModule::default();
-
-  let mut params_bytes = Vec::new();
-  fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
-    .unwrap()
-    .read_to_end(&mut params_bytes)
-    .unwrap();
-  let params = tvm::runtime::load_param_dict(&params_bytes)
-    .unwrap()
-    .into_iter()
-    .map(|(k, v)| (k, v.to_owned()))
-    .collect::<HashMap<String, Tensor<'static>>>();
-
-  let graph =
-    Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap();
-  let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
-
-  let x = Array::from_shape_vec(
-    (BATCH_SIZE, IN_DIM),
-    (0..BATCH_SIZE * IN_DIM)
-      .map(|x| x as f32)
-      .collect::<Vec<f32>>(),
-  ).unwrap();
-  let w = Array::try_from(params.get("dense0_weight").unwrap())
-    .unwrap()
-    .into_shape((IN_DIM * 2, IN_DIM))
-    .unwrap();
-  let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
-  let dense = x.dot(&w.t()) + &b;
-  let left = dense.slice(s![.., 0..IN_DIM]);
-  let right = dense.slice(s![.., IN_DIM..]);
-  let expected_o0 = &left + 1f32;
-  let expected_o1 = &right - 1f32;
-
-  exec.load_params(params);
-  exec.set_input("data", x.clone().into());
-
-  check_sum!(exec, data, x);
-  check_sum!(exec, dense0_weight, w);
-  check_sum!(exec, dense0_bias, b);
-
-  exec.run();
-
-  check_sum!(exec, 0, expected_o0);
-  check_sum!(exec, 1, expected_o1);
-  check_sum!(exec, 2, dense);
-}
diff --git a/rust/tests/test_tvm_basic/Cargo.toml b/rust/tests/test_tvm_basic/Cargo.toml
deleted file mode 100644 (file)
index bd4193b..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-[package]
-name = "test-tvm-basic"
-version = "0.0.0"
-license = "Apache-2.0"
-authors = ["Nick Hynes <nhynes@berkeley.edu>"]
-
-[dependencies]
-ndarray = "0.11.2"
-tvm = { path = "../../" }
-
-[build-dependencies]
-ar = "0.6.0"
diff --git a/rust/tests/test_tvm_basic/build.rs b/rust/tests/test_tvm_basic/build.rs
deleted file mode 100644 (file)
index 778dd1c..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-extern crate ar;
-
-use std::{env, path::PathBuf, process::Command};
-
-use ar::Builder;
-use std::fs::File;
-
-fn main() {
-  let out_dir = env::var("OUT_DIR").unwrap();
-
-  let output = Command::new(concat!(
-    env!("CARGO_MANIFEST_DIR"),
-    "/src/build_test_lib.py"
-  )).arg(&out_dir)
-    .output()
-    .expect("Failed to execute command");
-  if output.stderr.len() > 0 {
-    panic!(String::from_utf8(output.stderr).unwrap());
-  }
-
-  let in_path: PathBuf = [&out_dir, "test.o"].iter().collect();
-  let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect();
-  let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
-  builder.append_path(in_path.to_str().unwrap()).unwrap();
-
-  println!("cargo:rustc-link-lib=static=test");
-  println!("cargo:rustc-link-search=native={}", out_dir);
-}
diff --git a/rust/tests/test_tvm_basic/src/build_test_lib.py b/rust/tests/test_tvm_basic/src/build_test_lib.py
deleted file mode 100755 (executable)
index 7289a77..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/usr/bin/env python3
-
-"""Prepares a simple TVM library for testing."""
-
-from os import path as osp
-import sys
-
-import tvm
-
-def main():
-    n = tvm.var('n')
-    A = tvm.placeholder((n,), name='A')
-    B = tvm.placeholder((n,), name='B')
-    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
-    s = tvm.create_schedule(C.op)
-    s[C].parallel(s[C].op.axis[0])
-    print(tvm.lower(s, [A, B, C], simple_mode=True))
-    tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
-
-if __name__ == '__main__':
-    main()
diff --git a/rust/tests/test_tvm_basic/src/main.rs b/rust/tests/test_tvm_basic/src/main.rs
deleted file mode 100644 (file)
index b6c1145..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-extern crate ndarray;
-#[macro_use]
-extern crate tvm;
-
-use ndarray::Array;
-use tvm::{
-  ffi::runtime::DLTensor,
-  runtime::{Module, SystemLibModule},
-};
-
-fn main() {
-  let syslib = SystemLibModule::default();
-  let add = syslib
-    .get_function("default_function")
-    .expect("main function not found");
-  let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
-  let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
-  let mut c = Array::from_vec(vec![0f32; 4]);
-  let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
-  let mut a_dl: DLTensor = (&mut a).into();
-  let mut b_dl: DLTensor = (&mut b).into();
-  let mut c_dl: DLTensor = (&mut c).into();
-  call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
-  assert!(c.all_close(&e, 1e-8f32));
-}
index 8e66d1098946a331130832d5ea8570a42d460361..5d8c242f44dff8ec3402f941c3978a8ca9bf802a 100755 (executable)
@@ -2,24 +2,60 @@
 
 set -e
 
-export LD_LIBRARY_PATH=lib:$LD_LIBRARY_PATH
+export TVM_HOME="$(git rev-parse --show-toplevel)"
 
-tvm_root="$(git rev-parse --show-toplevel)"
-export PYTHONPATH="$tvm_root/python":"$tvm_root/nnvm/python":"$tvm_root/topi/python"
+export LD_LIBRARY_PATH="$TVM_HOME/lib":"$TVM_HOME/build":"$TVM_HOME/nnvm":$LD_LIBRARY_PATH
+export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/nnvm/python":"$TVM_HOME/topi/python"
+export RUST_DIR="$TVM_HOME/rust"
 
-#cd rust
-#cargo fmt -- --check
+cd $RUST_DIR
+cargo fmt -- --check
+
+# test common
+cd $RUST_DIR/common
+cargo build --features runtime
+cargo test --features runtime --tests
+
+cargo build --features frontend
+cargo test --features frontend --tests
+
+# test runtime
+cd $RUST_DIR/runtime
 
 # run basic tests
-#python3 tests/build_model.py
-#cargo test --tests
+python3 tests/build_model.py
+cargo test --tests
 
 # run TVM module test
-#cd tests/test_tvm_basic
-#cargo run
-#cd -
+cd tests/test_tvm_basic
+cargo run
+cd -
 
 # run NNVM graph test
-#cd tests/test_nnvm
-#cargo run
-#cd -
+cd tests/test_nnvm
+cargo run
+cd -
+
+# test frontend
+cd $RUST_DIR/frontend
+
+cargo test --tests -- --test-threads=1
+
+# run basic tests on cpu
+cd tests/basics
+cargo build --features cpu
+cargo run --features cpu
+# uncomment when have more CI resources
+# cargo build --features gpu
+# cargo run --features gpu
+# fi
+cd -
+
+# run callback tests separately: https://discuss.tvm.ai/t/are-global-functions-need-to-be-accessed-in-separate-processes/1075
+cd tests/callback
+cargo build
+cargo run --bin int
+cargo run --bin float
+cargo run --bin array
+cargo run --bin string
+cd -