+++ /dev/null
-Cargo.lock
-target/
-**/*.rs.bk
max_width = 100
hard_tabs = false
-tab_spaces = 2
+tab_spaces = 4
newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
-edition = "2015"
+edition = "2018"
merge_derives = true
use_try_shorthand = true
use_field_init_shorthand = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
-error_on_line_overflow = false
-error_on_unformatted = false
+error_on_line_overflow = true
+error_on_unformatted = true
report_todo = "Never"
report_fixme = "Never"
ignore = []
+++ /dev/null
-language: rust
-rust:
- - nightly
-matrix:
- fast_finish: true
-[package]
-name = "tvm"
-version = "0.1.0"
-license = "Apache-2.0"
-description = "TVM Rust runtime"
-repository = "https://github.com/dmlc/tvm"
-readme = "README.md"
-keywords = ["tvm", "nnvm"]
-categories = ["api-bindings", "science"]
-authors = ["TVM Contributors"]
-
-[features]
-default = ["nom/std"]
-sgx = ["nom/alloc"]
-
-[dependencies]
-bounded-spsc-queue = "0.4.0"
-error-chain = { version = "0.12.0", default-features = false }
-itertools = "0.7.8"
-lazy_static = "1.1.0"
-ndarray = "0.11.2"
-nom = {version = "4.0.0", default-features = false }
-serde = "1.0.59"
-serde_derive = "1.0.79"
-serde_json = "1.0.17"
-
-[target.'cfg(not(target_env = "sgx"))'.dependencies]
-num_cpus = "1.8.0"
+[workspace]
+members = [
+ "common",
+ "runtime",
+ "runtime/tests/test_tvm_basic",
+ "runtime/tests/test_nnvm",
+ "frontend",
+ "frontend/tests/basics",
+ "frontend/tests/callback",
+ "frontend/examples/resnet"
+]
--- /dev/null
+target
+**/*.rs.bk
+Cargo.lock
+/tvm-sys/src/bindgen.rs
--- /dev/null
+[package]
+name = "tvm-common"
+version = "0.1.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+
+[features]
+runtime = []
+frontend = ["tvm-sys"]
+
+[dependencies]
+error-chain = { version = "0.12.0", default-features = false }
+tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
--- /dev/null
+/* automatically generated by rust-bindgen for TVM revision 6292c78 */
+
+pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
+pub const DLPACK_VERSION: u32 = 8;
+pub const _STDINT_H: u32 = 1;
+pub const _FEATURES_H: u32 = 1;
+pub const _DEFAULT_SOURCE: u32 = 1;
+pub const __USE_ISOC11: u32 = 1;
+pub const __USE_ISOC99: u32 = 1;
+pub const __USE_ISOC95: u32 = 1;
+pub const __USE_POSIX_IMPLICITLY: u32 = 1;
+pub const _POSIX_SOURCE: u32 = 1;
+pub const _POSIX_C_SOURCE: u32 = 200809;
+pub const __USE_POSIX: u32 = 1;
+pub const __USE_POSIX2: u32 = 1;
+pub const __USE_POSIX199309: u32 = 1;
+pub const __USE_POSIX199506: u32 = 1;
+pub const __USE_XOPEN2K: u32 = 1;
+pub const __USE_XOPEN2K8: u32 = 1;
+pub const _ATFILE_SOURCE: u32 = 1;
+pub const __USE_MISC: u32 = 1;
+pub const __USE_ATFILE: u32 = 1;
+pub const __USE_FORTIFY_LEVEL: u32 = 0;
+pub const _STDC_PREDEF_H: u32 = 1;
+pub const __STDC_IEC_559__: u32 = 1;
+pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
+pub const __STDC_ISO_10646__: u32 = 201505;
+pub const __STDC_NO_THREADS__: u32 = 1;
+pub const __GNU_LIBRARY__: u32 = 6;
+pub const __GLIBC__: u32 = 2;
+pub const __GLIBC_MINOR__: u32 = 23;
+pub const _SYS_CDEFS_H: u32 = 1;
+pub const __WORDSIZE: u32 = 64;
+pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
+pub const __SYSCALL_WORDSIZE: u32 = 64;
+pub const _BITS_WCHAR_H: u32 = 1;
+pub const INT8_MIN: i32 = -128;
+pub const INT16_MIN: i32 = -32768;
+pub const INT32_MIN: i32 = -2147483648;
+pub const INT8_MAX: u32 = 127;
+pub const INT16_MAX: u32 = 32767;
+pub const INT32_MAX: u32 = 2147483647;
+pub const UINT8_MAX: u32 = 255;
+pub const UINT16_MAX: u32 = 65535;
+pub const UINT32_MAX: u32 = 4294967295;
+pub const INT_LEAST8_MIN: i32 = -128;
+pub const INT_LEAST16_MIN: i32 = -32768;
+pub const INT_LEAST32_MIN: i32 = -2147483648;
+pub const INT_LEAST8_MAX: u32 = 127;
+pub const INT_LEAST16_MAX: u32 = 32767;
+pub const INT_LEAST32_MAX: u32 = 2147483647;
+pub const UINT_LEAST8_MAX: u32 = 255;
+pub const UINT_LEAST16_MAX: u32 = 65535;
+pub const UINT_LEAST32_MAX: u32 = 4294967295;
+pub const INT_FAST8_MIN: i32 = -128;
+pub const INT_FAST16_MIN: i64 = -9223372036854775808;
+pub const INT_FAST32_MIN: i64 = -9223372036854775808;
+pub const INT_FAST8_MAX: u32 = 127;
+pub const INT_FAST16_MAX: u64 = 9223372036854775807;
+pub const INT_FAST32_MAX: u64 = 9223372036854775807;
+pub const UINT_FAST8_MAX: u32 = 255;
+pub const UINT_FAST16_MAX: i32 = -1;
+pub const UINT_FAST32_MAX: i32 = -1;
+pub const INTPTR_MIN: i64 = -9223372036854775808;
+pub const INTPTR_MAX: u64 = 9223372036854775807;
+pub const UINTPTR_MAX: i32 = -1;
+pub const PTRDIFF_MIN: i64 = -9223372036854775808;
+pub const PTRDIFF_MAX: u64 = 9223372036854775807;
+pub const SIG_ATOMIC_MIN: i32 = -2147483648;
+pub const SIG_ATOMIC_MAX: u32 = 2147483647;
+pub const SIZE_MAX: i32 = -1;
+pub const WINT_MIN: u32 = 0;
+pub const WINT_MAX: u32 = 4294967295;
+pub type int_least8_t = ::std::os::raw::c_schar;
+pub type int_least16_t = ::std::os::raw::c_short;
+pub type int_least32_t = ::std::os::raw::c_int;
+pub type int_least64_t = ::std::os::raw::c_long;
+pub type uint_least8_t = ::std::os::raw::c_uchar;
+pub type uint_least16_t = ::std::os::raw::c_ushort;
+pub type uint_least32_t = ::std::os::raw::c_uint;
+pub type uint_least64_t = ::std::os::raw::c_ulong;
+pub type int_fast8_t = ::std::os::raw::c_schar;
+pub type int_fast16_t = ::std::os::raw::c_long;
+pub type int_fast32_t = ::std::os::raw::c_long;
+pub type int_fast64_t = ::std::os::raw::c_long;
+pub type uint_fast8_t = ::std::os::raw::c_uchar;
+pub type uint_fast16_t = ::std::os::raw::c_ulong;
+pub type uint_fast32_t = ::std::os::raw::c_ulong;
+pub type uint_fast64_t = ::std::os::raw::c_ulong;
+pub type intmax_t = ::std::os::raw::c_long;
+pub type uintmax_t = ::std::os::raw::c_ulong;
+pub type wchar_t = ::std::os::raw::c_int;
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct max_align_t {
+ pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
+ pub __bindgen_padding_0: u64,
+ pub __clang_max_align_nonce2: f64,
+}
+pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
+pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
+pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
+pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
+pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
+pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
+pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
+/// \brief The device type in DLContext.
+pub type DLDeviceType = u32;
+/// \brief A Device context for Tensor and operator.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct DLContext {
+ /// \brief The device type used in the device.
+ pub device_type: DLDeviceType,
+ /// \brief The device index
+ pub device_id: ::std::os::raw::c_int,
+}
+pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
+pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
+pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
+/// \brief The type code options DLDataType.
+pub type DLDataTypeCode = u32;
+/// \brief The data type the tensor can hold.
+///
+/// Examples
+/// - float: type_code = 2, bits = 32, lanes=1
+/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
+/// - int8: type_code = 0, bits = 8, lanes=1
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct DLDataType {
+ /// \brief Type code of base types.
+ /// We keep it uint8_t instead of DLDataTypeCode for minimal memory
+ /// footprint, but the value should be one of DLDataTypeCode enum values.
+ ///
+ pub code: u8,
+ /// \brief Number of bits, common choices are 8, 16, 32.
+ pub bits: u8,
+ /// \brief Number of lanes in the type, used for vector types.
+ pub lanes: u16,
+}
+/// \brief Plain C Tensor object, does not manage memory.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct DLTensor {
+ /// \brief The opaque data pointer points to the allocated data.
+ /// This will be CUDA device pointer or cl_mem handle in OpenCL.
+ /// This pointer is always aligns to 256 bytes as in CUDA.
+ pub data: *mut ::std::os::raw::c_void,
+ /// \brief The device context of the tensor
+ pub ctx: DLContext,
+ /// \brief Number of dimensions
+ pub ndim: ::std::os::raw::c_int,
+ /// \brief The data type of the pointer
+ pub dtype: DLDataType,
+ /// \brief The shape of the tensor
+ pub shape: *mut i64,
+ /// \brief strides of the tensor,
+ /// can be NULL, indicating tensor is compact.
+ pub strides: *mut i64,
+ /// \brief The offset in bytes to the beginning pointer to data
+ pub byte_offset: u64,
+}
+/// \brief C Tensor object, manage memory of DLTensor. This data structure is
+/// intended to faciliate the borrowing of DLTensor by another framework. It is
+/// not meant to transfer the tensor. When the borrowing framework doesn't need
+/// the tensor, it should call the deleter to notify the host that the resource
+/// is no longer needed.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct DLManagedTensor {
+ /// \brief DLTensor which is being memory managed
+ pub dl_tensor: DLTensor,
+ /// \brief the context of the original host framework of DLManagedTensor in
+ /// which DLManagedTensor is used in the framework. It can also be NULL.
+ pub manager_ctx: *mut ::std::os::raw::c_void,
+ /// \brief Destructor signature void (*)(void*) - this should be called
+ /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
+ /// if there is no way for the caller to provide a reasonable destructor.
+ pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
+}
+/// \brief type of array index.
+pub type tvm_index_t = i64;
+pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
+pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
+pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
+pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
+pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
+/// \brief Extension device types in TVM
+pub type TVMDeviceExtType = u32;
+pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
+pub const TVMTypeCode_kNull: TVMTypeCode = 4;
+pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
+pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
+pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
+pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
+pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
+pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
+pub const TVMTypeCode_kStr: TVMTypeCode = 11;
+pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
+pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
+pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
+pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
+pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
+pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
+pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
+/// \brief The type code in TVMType
+/// \note TVMType is used in two places.
+pub type TVMTypeCode = u32;
+/// \brief The data type used in TVM Runtime.
+///
+/// Examples
+/// - float: type_code = 2, bits = 32, lanes=1
+/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
+/// - int8: type_code = 0, bits = 8, lanes=1
+///
+/// \note Arguments TVM API function always takes bits=64 and lanes=1
+pub type TVMType = DLDataType;
+/// \brief The Device information, abstract away common device types.
+pub type TVMContext = DLContext;
+/// \brief The tensor array stucture to TVM API.
+pub type TVMArray = DLTensor;
+/// \brief the array handle
+pub type TVMArrayHandle = *mut TVMArray;
+/// \brief Union type of values
+/// being passed through API and function calls.
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub union TVMValue {
+ pub v_int64: i64,
+ pub v_float64: f64,
+ pub v_handle: *mut ::std::os::raw::c_void,
+ pub v_str: *const ::std::os::raw::c_char,
+ pub v_type: TVMType,
+ pub v_ctx: TVMContext,
+ _bindgen_union_align: u64,
+}
+/// \brief Byte array type used to pass in byte array
+/// When kBytes is used as data type.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct TVMByteArray {
+ pub data: *const ::std::os::raw::c_char,
+ pub size: usize,
+}
+/// \brief Handle to TVM runtime modules.
+pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
+/// \brief Handle to packed function handle.
+pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
+/// \brief Handle to hold return value.
+pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
+/// \brief The stream that is specific to device
+/// can be NULL, which indicates the default one.
+pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
+extern "C" {
+ /// \brief Used for implementing C API function.
+ /// Set last error message before return.
+ /// \param msg The error message to be set.
+ pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
+}
+extern "C" {
+ /// \brief return str message of the last error
+ /// all function in this file will return 0 when success
+ /// and -1 when an error occured,
+ /// TVMGetLastError can be called to retrieve the error
+ ///
+ /// this function is threadsafe and can be called by different thread
+ /// \return error info
+ pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
+}
+extern "C" {
+ /// \brief Load module from file.
+ /// \param file_name The file name to load the module from.
+ /// \param format The format of the module.
+ /// \param out The result module
+ ///
+ /// \return 0 when success, -1 when failure happens
+ /// \note The resulting module do not contain import relation.
+ /// It can be reconstructed by TVMModImport.
+ pub fn TVMModLoadFromFile(
+ file_name: *const ::std::os::raw::c_char,
+ format: *const ::std::os::raw::c_char,
+ out: *mut TVMModuleHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Add dep to mod's dependency.
+ /// This allows functions in this module to use modules.
+ ///
+ /// \param mod The module handle.
+ /// \param dep The dependent module to be imported.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Get function from the module.
+ /// \param mod The module handle.
+ /// \param func_name The name of the function.
+ /// \param query_imports Whether to query imported modules
+ /// \param out The result function, can be NULL if it is not available.
+ /// \return 0 when no error is thrown, -1 when failure happens
+ pub fn TVMModGetFunction(
+ mod_: TVMModuleHandle,
+ func_name: *const ::std::os::raw::c_char,
+ query_imports: ::std::os::raw::c_int,
+ out: *mut TVMFunctionHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Free front-end extension type resource.
+ /// \param handle The extension handle.
+ /// \param type_code The type of of the extension type.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMExtTypeFree(
+ handle: *mut ::std::os::raw::c_void,
+ type_code: ::std::os::raw::c_int,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Free the Module
+ /// \param mod The module to be freed.
+ ///
+ /// \note This may not free up the module's resources.
+ /// If there is active TVMFunctionHandle uses the module
+ /// Or if this module is imported by another active module.
+ ///
+ /// The all functions remains valid until TVMFuncFree is called.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Free the function when it is no longer needed.
+ /// \param func The function handle
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Call a Packed TVM Function.
+ ///
+ /// \param func node handle of the function.
+ /// \param arg_values The arguments
+ /// \param type_codes The type codes of the arguments
+ /// \param num_args Number of arguments.
+ ///
+ /// \param ret_val The return value.
+ /// \param ret_type_code the type code of return value.
+ ///
+ /// \return 0 when success, -1 when failure happens
+ /// \note TVM calls always exchanges with type bits=64, lanes=1
+ ///
+ /// \note API calls always exchanges with type bits=64, lanes=1
+ /// If API call returns container handles (e.g. FunctionHandle)
+ /// these handles should be managed by the front-end.
+ /// The front-end need to call free function (e.g. TVMFuncFree)
+ /// to free these handles.
+ pub fn TVMFuncCall(
+ func: TVMFunctionHandle,
+ arg_values: *mut TVMValue,
+ type_codes: *mut ::std::os::raw::c_int,
+ num_args: ::std::os::raw::c_int,
+ ret_val: *mut TVMValue,
+ ret_type_code: *mut ::std::os::raw::c_int,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Set the return value of TVMPackedCFunc.
+ ///
+ /// This function is called by TVMPackedCFunc to set the return value.
+ /// When this function is not called, the function returns null by default.
+ ///
+ /// \param ret The return value handle, pass by ret in TVMPackedCFunc
+ /// \param value The value to be returned.
+ /// \param type_code The type of the value to be returned.
+ /// \param num_ret Number of return values, for now only 1 is supported.
+ pub fn TVMCFuncSetReturn(
+ ret: TVMRetValueHandle,
+ value: *mut TVMValue,
+ type_code: *mut ::std::os::raw::c_int,
+ num_ret: ::std::os::raw::c_int,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Inplace translate callback argument value to return value.
+ /// This is only needed for non-POD arguments.
+ ///
+ /// \param value The value to be translated.
+ /// \param code The type code to be translated.
+ /// \note This function will do a shallow copy when necessary.
+ ///
+ /// \return 0 when success, -1 when failure happens.
+ pub fn TVMCbArgToReturn(
+ value: *mut TVMValue,
+ code: ::std::os::raw::c_int,
+ ) -> ::std::os::raw::c_int;
+}
+/// \brief C type of packed function.
+///
+/// \param args The arguments
+/// \param type_codes The type codes of the arguments
+/// \param num_args Number of arguments.
+/// \param ret The return value handle.
+/// \param resource_handle The handle additional resouce handle from fron-end.
+/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
+/// \sa TVMCFuncSetReturn
+pub type TVMPackedCFunc = ::std::option::Option<
+ unsafe extern "C" fn(
+ args: *mut TVMValue,
+ type_codes: *mut ::std::os::raw::c_int,
+ num_args: ::std::os::raw::c_int,
+ ret: TVMRetValueHandle,
+ resource_handle: *mut ::std::os::raw::c_void,
+ ) -> ::std::os::raw::c_int,
+>;
+/// \brief C callback to free the resource handle in C packed function.
+/// \param resource_handle The handle additional resouce handle from fron-end.
+pub type TVMPackedCFuncFinalizer =
+ ::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
+/// \brief Signature for extension function declarer.
+///
+/// TVM call this function to get the extension functions
+/// The declarer will call register_func to register function and their name.
+///
+/// \param register_func_handle The register function
+/// \return 0 if success, -1 if failure happens
+pub type TVMExtensionFuncDeclarer = ::std::option::Option<
+ unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
+>;
+extern "C" {
+ /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
+ ///
+ /// The resource_handle will be managed by TVM API, until the function is no longer used.
+ ///
+ /// \param func The packed C function.
+ /// \param resource_handle The resource handle from front-end, can be NULL.
+ /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
+ /// \param out the result function handle.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMFuncCreateFromCFunc(
+ func: TVMPackedCFunc,
+ resource_handle: *mut ::std::os::raw::c_void,
+ fin: TVMPackedCFuncFinalizer,
+ out: *mut TVMFunctionHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Register the function to runtime's global table.
+ ///
+ /// The registered function then can be pulled by the backend by the name.
+ ///
+ /// \param name The name of the function.
+ /// \param f The function to be registered.
+ /// \param override Whether allow override already registered function.
+ pub fn TVMFuncRegisterGlobal(
+ name: *const ::std::os::raw::c_char,
+ f: TVMFunctionHandle,
+ override_: ::std::os::raw::c_int,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Get a global function.
+ ///
+ /// \param name The name of the function.
+ /// \param out the result function pointer, NULL if it does not exist.
+ ///
+ /// \note The function handle of global function is managed by TVM runtime,
+ /// So TVMFuncFree is should not be called when it get deleted.
+ pub fn TVMFuncGetGlobal(
+ name: *const ::std::os::raw::c_char,
+ out: *mut TVMFunctionHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief List all the globally registered function name
+ /// \param out_size The number of functions
+ /// \param out_array The array of function names.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMFuncListGlobalNames(
+ out_size: *mut ::std::os::raw::c_int,
+ out_array: *mut *mut *const ::std::os::raw::c_char,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Allocate a nd-array's memory,
+ /// including space of shape, of given spec.
+ ///
+ /// \param shape The shape of the array, the data content will be copied to out
+ /// \param ndim The number of dimension of the array.
+ /// \param dtype_code The type code of the dtype
+ /// \param dtype_bits The number of bits of dtype
+ /// \param dtype_lanes The number of lanes in the dtype.
+ /// \param device_type The device type of context
+ /// \param device_id The device id of context.
+ /// \param out The output handle.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMArrayAlloc(
+ shape: *const tvm_index_t,
+ ndim: ::std::os::raw::c_int,
+ dtype_code: ::std::os::raw::c_int,
+ dtype_bits: ::std::os::raw::c_int,
+ dtype_lanes: ::std::os::raw::c_int,
+ device_type: ::std::os::raw::c_int,
+ device_id: ::std::os::raw::c_int,
+ out: *mut TVMArrayHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Free the TVM Array.
+ /// \param handle The array handle to be freed.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Copy array data from CPU byte array.
+ /// \param handle The array handle.
+ /// \param data the data pointer
+ /// \param nbytes The number of bytes to copy.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMArrayCopyFromBytes(
+ handle: TVMArrayHandle,
+ data: *mut ::std::os::raw::c_void,
+ nbytes: usize,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Copy array data to CPU byte array.
+ /// \param handle The array handle.
+ /// \param data the data pointer
+ /// \param nbytes The number of bytes to copy.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMArrayCopyToBytes(
+ handle: TVMArrayHandle,
+ data: *mut ::std::os::raw::c_void,
+ nbytes: usize,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Copy the array, both from and to must be valid during the copy.
+ /// \param from The array to be copied from.
+ /// \param to The target space.
+ /// \param stream The stream where the copy happens, can be NULL.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMArrayCopyFromTo(
+ from: TVMArrayHandle,
+ to: TVMArrayHandle,
+ stream: TVMStreamHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Produce an array from the DLManagedTensor that shares data memory
+ /// with the DLManagedTensor.
+ /// \param from The source DLManagedTensor.
+ /// \param out The output array handle.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMArrayFromDLPack(
+ from: *mut DLManagedTensor,
+ out: *mut TVMArrayHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Produce a DLMangedTensor from the array that shares data memory with
+ /// the array.
+ /// \param from The source array.
+ /// \param out The DLManagedTensor handle.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMArrayToDLPack(
+ from: TVMArrayHandle,
+ out: *mut *mut DLManagedTensor,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Delete (free) a DLManagedTensor's data.
+ /// \param dltensor Pointer to the DLManagedTensor.
+ pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
+}
+extern "C" {
+ /// \brief Create a new runtime stream.
+ ///
+ /// \param device_type The device type of context
+ /// \param device_id The device id of context
+ /// \param out The new stream handle
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMStreamCreate(
+ device_type: ::std::os::raw::c_int,
+ device_id: ::std::os::raw::c_int,
+ out: *mut TVMStreamHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Free a created stream handle.
+ ///
+ /// \param device_type The device type of context
+ /// \param device_id The device id of context
+ /// \param stream The stream to be freed
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMStreamFree(
+ device_type: ::std::os::raw::c_int,
+ device_id: ::std::os::raw::c_int,
+ stream: TVMStreamHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Set the runtime stream of current thread to be stream.
+ /// The subsequent calls to the same device_type
+ /// will use the setted stream handle.
+ /// The specific type of stream is runtime device dependent.
+ ///
+ /// \param device_type The device type of context
+ /// \param device_id The device id of context.
+ /// \param handle The stream handle.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMSetStream(
+ device_type: ::std::os::raw::c_int,
+ device_id: ::std::os::raw::c_int,
+ handle: TVMStreamHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Wait until all computations on stream completes.
+ ///
+ /// \param device_type The device type of context
+ /// \param device_id The device id of context.
+ /// \param stream The stream to be synchronized.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMSynchronize(
+ device_type: ::std::os::raw::c_int,
+ device_id: ::std::os::raw::c_int,
+ stream: TVMStreamHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Synchronize two streams of execution.
+ ///
+ /// \param device_type The device type of context
+ /// \param device_id The device id of context
+ /// \param src The source stream to synchronize.
+ /// \param dst The destination stream to synchronize.
+ /// \return 0 when success, -1 when failure happens
+ pub fn TVMStreamStreamSynchronize(
+ device_type: ::std::os::raw::c_int,
+ device_id: ::std::os::raw::c_int,
+ src: TVMStreamHandle,
+ dst: TVMStreamHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Backend function for modules to get function
+ /// from its environment mod_node (its imports and global function).
+ /// The user do should not call TVMFuncFree on func.
+ ///
+ /// \param mod_node The module handle.
+ /// \param func_name The name of the function.
+ /// \param out The result function.
+ /// \return 0 when no error is thrown, -1 when failure happens
+ pub fn TVMBackendGetFuncFromEnv(
+ mod_node: *mut ::std::os::raw::c_void,
+ func_name: *const ::std::os::raw::c_char,
+ out: *mut TVMFunctionHandle,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Backend function to register system-wide library symbol.
+ ///
+ /// \param name The name of the symbol
+ /// \param ptr The symbol address.
+ /// \return 0 when no error is thrown, -1 when failure happens
+ pub fn TVMBackendRegisterSystemLibSymbol(
+ name: *const ::std::os::raw::c_char,
+ ptr: *mut ::std::os::raw::c_void,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Backend function to allocate temporal workspace.
+ ///
+ /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
+ ///
+ /// \param nbytes The size of the space requested.
+ /// \param device_type The device type which the space will be allocated.
+ /// \param device_id The device id which the space will be allocated.
+ /// \param dtype_code_hint The type code of the array elements. Only used in
+ /// certain backends such as OpenGL.
+ /// \param dtype_bits_hint The type bits of the array elements. Only used in
+ /// certain backends such as OpenGL.
+ /// \return nullptr when error is thrown, a valid ptr if success
+ pub fn TVMBackendAllocWorkspace(
+ device_type: ::std::os::raw::c_int,
+ device_id: ::std::os::raw::c_int,
+ nbytes: u64,
+ dtype_code_hint: ::std::os::raw::c_int,
+ dtype_bits_hint: ::std::os::raw::c_int,
+ ) -> *mut ::std::os::raw::c_void;
+}
+extern "C" {
+ /// \brief Backend function to free temporal workspace.
+ ///
+ /// \param ptr The result allocated space pointer.
+ /// \param device_type The device type which the space will be allocated.
+ /// \param device_id The device id which the space will be allocated.
+ /// \return 0 when no error is thrown, -1 when failure happens
+ ///
+ /// \sa TVMBackendAllocWorkspace
+ pub fn TVMBackendFreeWorkspace(
+ device_type: ::std::os::raw::c_int,
+ device_id: ::std::os::raw::c_int,
+ ptr: *mut ::std::os::raw::c_void,
+ ) -> ::std::os::raw::c_int;
+}
+/// \brief Environment for TVM parallel task.
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct TVMParallelGroupEnv {
+ /// \brief Auxiliary used for synchronization
+ pub sync_handle: *mut ::std::os::raw::c_void,
+ /// \brief total amount of task
+ pub num_task: i32,
+}
+/// \brief The callback function to execute a parallel lambda
+/// \param task_id the task id of the function.
+/// \param penv The parallel environment backs the execution.
+/// \param cdata The supporting closure data.
+pub type FTVMParallelLambda = ::std::option::Option<
+ unsafe extern "C" fn(
+ task_id: ::std::os::raw::c_int,
+ penv: *mut TVMParallelGroupEnv,
+ cdata: *mut ::std::os::raw::c_void,
+ ) -> ::std::os::raw::c_int,
+>;
+extern "C" {
+ /// \brief Backend function for running parallel jobs.
+ ///
+ /// \param flambda The parallel function to be launched.
+ /// \param cdata The closure data.
+ /// \param num_task Number of tasks to launch, can be 0, means launch
+ /// with all available threads.
+ ///
+ /// \return 0 when no error is thrown, -1 when failure happens
+ pub fn TVMBackendParallelLaunch(
+ flambda: FTVMParallelLambda,
+ cdata: *mut ::std::os::raw::c_void,
+ num_task: ::std::os::raw::c_int,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief BSP barrrier between parallel threads
+ /// \param task_id the task id of the function.
+ /// \param penv The parallel environment backs the execution.
+ /// \return 0 when no error is thrown, -1 when failure happens
+ pub fn TVMBackendParallelBarrier(
+ task_id: ::std::os::raw::c_int,
+ penv: *mut TVMParallelGroupEnv,
+ ) -> ::std::os::raw::c_int;
+}
+extern "C" {
+ /// \brief Simple static initialization function.
+ /// Run f once and set handle to be not null.
+ /// This function is mainly used for test purpose.
+ ///
+ /// \param handle An global address to indicate f
+ /// \param f The function to be ran
+ /// \param cdata The closure data to pass to the function.
+ /// \param nbytes Number of bytes in the closure data.
+ /// \return 0 when no error is thrown, -1 when failure happens
+ pub fn TVMBackendRunOnce(
+ handle: *mut *mut ::std::os::raw::c_void,
+ f: ::std::option::Option<
+ unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
+ >,
+ cdata: *mut ::std::os::raw::c_void,
+ nbytes: ::std::os::raw::c_int,
+ ) -> ::std::os::raw::c_int;
+}
--- /dev/null
+//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
+
+error_chain! {
+ errors {
+ TryFromTVMArgValueError(expected: String, actual: String) {
+ description("mismatched types while converting from TVMArgValue")
+ display("expected `{}` but given `{}`", expected, actual)
+ }
+
+ TryFromTVMRetValueError(expected: String, actual: String) {
+ description("mismatched types while downcasting TVMRetValue")
+ display("invalid downcast: expected `{}` but given `{}`", expected, actual)
+ }
+ }
+}
--- /dev/null
+//! This crate contains the refactored basic components required
+//! for `runtime` and `frontend` TVM crates.
+
+#![crate_name = "tvm_common"]
+#![recursion_limit = "1024"]
+#![allow(non_camel_case_types, unused_imports)]
+#![feature(box_syntax, try_from)]
+
+#[macro_use]
+extern crate error_chain;
+
+/// Unified ffi module for both runtime and frontend crates.
+pub mod ffi {
+ #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
+
+ #[cfg(feature = "frontend")]
+ pub extern crate tvm_sys as ts;
+
+ #[cfg(feature = "runtime")]
+ pub mod runtime {
+ use std::os::raw::{c_char, c_int, c_void};
+
+ include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
+
+ pub type BackendPackedCFunc = extern "C" fn(
+ args: *const TVMValue,
+ type_codes: *const c_int,
+ num_args: c_int,
+ ) -> c_int;
+ }
+}
+
+pub mod errors;
+pub mod ty;
+pub mod value;
+
+pub use errors::*;
+pub use ty::TVMTypeCode;
+pub use value::{TVMArgValue, TVMRetValue, TVMValue};
--- /dev/null
+//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
+//!
+//! # Example
+//!
+//! ```
+//! let dtype = TVMType::from("float");
+//! println!("dtype is: {}", dtype);
+//! ```
+
+use std::{
+ ffi::{CStr, CString},
+ fmt::{self, Display, Formatter},
+};
+
+/// TVM type codes.
+#[repr(u32)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
+pub enum TVMTypeCode {
+ kDLInt = 0,
+ kDLUInt = 1,
+ kDLFloat = 2,
+ kHandle = 3,
+ kNull = 4,
+ kTVMType = 5,
+ kTVMContext = 6,
+ kArrayHandle = 7,
+ kNodeHandle = 8,
+ kModuleHandle = 9,
+ kFuncHandle = 10,
+ kStr = 11,
+ kBytes = 12,
+ kNDArrayContainer = 13,
+}
+
+impl Default for TVMTypeCode {
+ fn default() -> Self {
+ TVMTypeCode::kDLInt
+ }
+}
+
+impl From<TVMTypeCode> for i64 {
+ fn from(arg: TVMTypeCode) -> i64 {
+ match arg {
+ TVMTypeCode::kDLInt => 0,
+ TVMTypeCode::kDLUInt => 1,
+ TVMTypeCode::kDLFloat => 2,
+ TVMTypeCode::kHandle => 3,
+ TVMTypeCode::kNull => 4,
+ TVMTypeCode::kTVMType => 5,
+ TVMTypeCode::kTVMContext => 6,
+ TVMTypeCode::kArrayHandle => 7,
+ TVMTypeCode::kNodeHandle => 8,
+ TVMTypeCode::kModuleHandle => 9,
+ TVMTypeCode::kFuncHandle => 10,
+ TVMTypeCode::kStr => 11,
+ TVMTypeCode::kBytes => 12,
+ TVMTypeCode::kNDArrayContainer => 13,
+ }
+ }
+}
+
+impl Into<TVMTypeCode> for i64 {
+ fn into(self) -> TVMTypeCode {
+ match self {
+ 0 => TVMTypeCode::kDLInt,
+ 1 => TVMTypeCode::kDLUInt,
+ 2 => TVMTypeCode::kDLFloat,
+ 3 => TVMTypeCode::kHandle,
+ 4 => TVMTypeCode::kNull,
+ 5 => TVMTypeCode::kTVMType,
+ 6 => TVMTypeCode::kTVMContext,
+ 7 => TVMTypeCode::kArrayHandle,
+ 8 => TVMTypeCode::kNodeHandle,
+ 9 => TVMTypeCode::kModuleHandle,
+ 10 => TVMTypeCode::kFuncHandle,
+ 11 => TVMTypeCode::kStr,
+ 12 => TVMTypeCode::kBytes,
+ 13 => TVMTypeCode::kNDArrayContainer,
+ _ => unreachable!(),
+ }
+ }
+}
+
+impl Display for TVMTypeCode {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{}",
+ match self {
+ TVMTypeCode::kDLInt => "int",
+ TVMTypeCode::kDLUInt => "uint",
+ TVMTypeCode::kDLFloat => "float",
+ TVMTypeCode::kHandle => "handle",
+ TVMTypeCode::kNull => "null",
+ TVMTypeCode::kTVMType => "TVM type",
+ TVMTypeCode::kTVMContext => "TVM context",
+ TVMTypeCode::kArrayHandle => "Array handle",
+ TVMTypeCode::kNodeHandle => "Node handle",
+ TVMTypeCode::kModuleHandle => "Module handle",
+ TVMTypeCode::kFuncHandle => "Function handle",
+ TVMTypeCode::kStr => "string",
+ TVMTypeCode::kBytes => "bytes",
+ TVMTypeCode::kNDArrayContainer => "ndarray container",
+ }
+ )
+ }
+}
+
+macro_rules! impl_prim_type {
+ ($type:ty, $variant:ident) => {
+ impl<'a> From<&'a $type> for TVMTypeCode {
+ fn from(_arg: &$type) -> Self {
+ TVMTypeCode::$variant
+ }
+ }
+
+ impl<'a> From<&'a mut $type> for TVMTypeCode {
+ fn from(_arg: &mut $type) -> Self {
+ TVMTypeCode::$variant
+ }
+ }
+ };
+}
+
+impl_prim_type!(usize, kDLInt);
+impl_prim_type!(i64, kDLInt);
+impl_prim_type!(i32, kDLInt);
+impl_prim_type!(i16, kDLInt);
+impl_prim_type!(i8, kDLInt);
+
+impl_prim_type!(u64, kDLUInt);
+impl_prim_type!(u32, kDLUInt);
+impl_prim_type!(u16, kDLUInt);
+impl_prim_type!(u8, kDLUInt);
+
+impl_prim_type!(f64, kDLFloat);
+impl_prim_type!(f32, kDLFloat);
+
+impl_prim_type!(str, kStr);
+impl_prim_type!(CStr, kStr);
+impl_prim_type!(String, kStr);
+impl_prim_type!(CString, kStr);
+
+impl_prim_type!([u8], kBytes);
--- /dev/null
+//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue`
+//! required for using TVM functions.
+
+use std::{
+ any::Any,
+ convert::TryFrom,
+ ffi::{CStr, CString},
+ fmt::{self, Debug, Formatter},
+ marker::PhantomData,
+ mem,
+ ops::Deref,
+ os::raw::{c_char, c_void},
+};
+
+#[cfg(feature = "runtime")]
+use ffi::runtime::TVMValue as _TVMValue;
+
+#[cfg(feature = "frontend")]
+use ffi::ts::TVMValue as _TVMValue;
+
+use errors::*;
+
+use ty::TVMTypeCode;
+
+/// Wrapped TVMValue type.
+#[derive(Clone, Copy)]
+pub struct TVMValue {
+ pub inner: _TVMValue,
+}
+
+impl TVMValue {
+ /// Creates TVMValue from the raw part.
+ pub fn new(inner: _TVMValue) -> Self {
+ TVMValue { inner }
+ }
+
+ pub(crate) fn into_raw(self) -> _TVMValue {
+ self.inner
+ }
+}
+
+impl Debug for TVMValue {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ unsafe {
+ write!(
+ f,
+ "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\
+ [v_str: {:?}]",
+ self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str
+ )
+ }
+ }
+}
+
+impl Deref for TVMValue {
+ type Target = _TVMValue;
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+macro_rules! impl_prim_val {
+ ($type:ty, $field:ident, $cast:ty) => {
+ impl From<$type> for TVMValue {
+ fn from(arg: $type) -> Self {
+ let inner = _TVMValue {
+ $field: arg as $cast,
+ };
+ Self::new(inner)
+ }
+ }
+
+ impl<'a> From<&'a $type> for TVMValue {
+ fn from(arg: &$type) -> Self {
+ let inner = _TVMValue {
+ $field: *arg as $cast,
+ };
+ Self::new(inner)
+ }
+ }
+
+ impl<'a> From<&'a mut $type> for TVMValue {
+ fn from(arg: &mut $type) -> Self {
+ let inner = _TVMValue {
+ $field: *arg as $cast,
+ };
+ Self::new(inner)
+ }
+ }
+
+ impl TryFrom<TVMValue> for $type {
+ type Error = Error;
+ fn try_from(val: TVMValue) -> Result<Self> {
+ Ok(unsafe { val.inner.$field as $type })
+ }
+ }
+
+ impl<'a> TryFrom<&'a TVMValue> for $type {
+ type Error = Error;
+ fn try_from(val: &TVMValue) -> Result<Self> {
+ Ok(unsafe { val.into_raw().$field as $type })
+ }
+ }
+
+ impl<'a> TryFrom<&'a mut TVMValue> for $type {
+ type Error = Error;
+ fn try_from(val: &mut TVMValue) -> Result<Self> {
+ Ok(unsafe { val.into_raw().$field as $type })
+ }
+ }
+ };
+}
+
+impl_prim_val!(isize, v_int64, i64);
+impl_prim_val!(i64, v_int64, i64);
+impl_prim_val!(i32, v_int64, i64);
+impl_prim_val!(i16, v_int64, i64);
+impl_prim_val!(i8, v_int64, i64);
+impl_prim_val!(usize, v_int64, i64);
+impl_prim_val!(u64, v_int64, i64);
+impl_prim_val!(u32, v_int64, i64);
+impl_prim_val!(u16, v_int64, i64);
+impl_prim_val!(u8, v_int64, i64);
+
+impl_prim_val!(f64, v_float64, f64);
+impl_prim_val!(f32, v_float64, f64);
+
+impl<'a> From<&'a str> for TVMValue {
+ fn from(arg: &str) -> TVMValue {
+ let arg = CString::new(arg).unwrap();
+ let inner = _TVMValue {
+ v_str: arg.as_ptr() as *const c_char,
+ };
+ mem::forget(arg);
+ Self::new(inner)
+ }
+}
+
+impl<'a> From<&'a String> for TVMValue {
+ fn from(arg: &String) -> TVMValue {
+ let arg = CString::new(arg.as_bytes()).unwrap();
+ let inner = _TVMValue {
+ v_str: arg.as_ptr() as *const c_char,
+ };
+ mem::forget(arg);
+ Self::new(inner)
+ }
+}
+
+impl<'a> From<&'a CString> for TVMValue {
+ fn from(arg: &CString) -> TVMValue {
+ let arg = arg.to_owned();
+ let inner = _TVMValue {
+ v_str: arg.as_ptr() as *const c_char,
+ };
+ mem::forget(arg);
+ Self::new(inner)
+ }
+}
+
+impl<'a> From<&'a [u8]> for TVMValue {
+ fn from(arg: &[u8]) -> TVMValue {
+ let arg = arg.to_owned();
+ let inner = _TVMValue {
+ v_handle: &arg as *const _ as *mut c_void,
+ };
+ mem::forget(arg);
+ Self::new(inner)
+ }
+}
+
+/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function.
+/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`.
+/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions.
+///
+/// ## Example
+///
+/// ```
+/// let s = "hello".to_string();
+/// let arg = TVMArgValue::from(&s);
+/// let tvm: String = arg.try_into().unwrap();
+/// assert_eq!(arg, s);
+/// ```
+#[derive(Debug, Clone, Copy)]
+pub struct TVMArgValue<'a> {
+ /// The wrapped TVMValue
+ pub value: TVMValue,
+ /// The matching type code.
+ pub type_code: TVMTypeCode,
+ /// This is only exposed to runtime and frontend crates and is not meant to be used directly.
+ pub lifetime: PhantomData<&'a ()>,
+}
+
+impl<'a> TVMArgValue<'a> {
+ pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self {
+ TVMArgValue {
+ value: value,
+ type_code: type_code,
+ lifetime: PhantomData,
+ }
+ }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+ if (arg.type_code == TVMTypeCode::kDLInt)
+ | (arg.type_code == TVMTypeCode::kDLUInt)
+ | (arg.type_code == TVMTypeCode::kNull)
+ {
+ Ok(unsafe { arg.value.inner.v_int64 })
+ } else {
+ bail!(ErrorKind::TryFromTVMArgValueError(
+ stringify!(i64).to_string(),
+ arg.type_code.to_string()
+ ))
+ }
+ }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+ if arg.type_code == TVMTypeCode::kDLFloat {
+ Ok(unsafe { arg.value.inner.v_float64 })
+ } else {
+ bail!(ErrorKind::TryFromTVMArgValueError(
+ stringify!(f64).to_string(),
+ arg.type_code.to_string()
+ ))
+ }
+ }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+ if arg.type_code == TVMTypeCode::kStr {
+ let ret_str = unsafe {
+ match CStr::from_ptr(arg.value.inner.v_str).to_str() {
+ Ok(s) => s,
+ Err(_) => "Invalid UTF-8 message",
+ }
+ };
+ Ok(ret_str.to_string())
+ } else {
+ bail!(ErrorKind::TryFromTVMArgValueError(
+ stringify!(String).to_string(),
+ arg.type_code.to_string()
+ ))
+ }
+ }
+}
+
+/// Main way to create a TVMArgValue from suported Rust values.
+impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a>
+where
+ TVMValue: From<&'b T>,
+ TVMTypeCode: From<&'b T>,
+{
+ fn from(arg: &'b T) -> Self {
+ TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg))
+ }
+}
+
+/// Creates a conversion to a `TVMArgValue` for an object handle.
+impl<'a, T> From<*const T> for TVMArgValue<'a> {
+ fn from(ptr: *const T) -> Self {
+ let value = TVMValue::new(_TVMValue {
+ v_handle: ptr as *mut T as *mut c_void,
+ });
+
+ TVMArgValue::new(value, TVMTypeCode::kArrayHandle)
+ }
+}
+
+/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
+impl<'a, T> From<*mut T> for TVMArgValue<'a> {
+ fn from(ptr: *mut T) -> Self {
+ let value = TVMValue::new(_TVMValue {
+ v_handle: ptr as *mut c_void,
+ });
+
+ TVMArgValue::new(value, TVMTypeCode::kHandle)
+ }
+}
+
+/// An owned version of TVMPODValue. It can be converted from varieties of
+/// primitive and object types.
+/// It can be downcasted using `try_from` if it contains the desired type.
+///
+/// # Example
+///
+/// ```
+/// let a = 42u32;
+/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
+///
+/// let s = "hello, world!";
+/// let t: TVMRetValue = s.into();
+/// assert_eq!(String::try_from(t).unwrap(), s);
+/// ```
+pub struct TVMRetValue {
+ /// A primitive return value, if any.
+ pub prim_value: usize,
+ /// An object return value, if any.
+ pub box_value: Box<Any>,
+ pub type_code: TVMTypeCode,
+}
+
+impl TVMRetValue {
+ fn new(prim_value: usize, box_value: Box<Any>, type_code: TVMTypeCode) -> Self {
+ Self {
+ prim_value,
+ box_value,
+ type_code,
+ }
+ }
+
+ /// unsafe function to create `TVMRetValue` from `TVMValue` and
+ /// its matching `TVMTypeCode`.
+ pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self {
+ let value = value.into_raw();
+ match type_code {
+ TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => {
+ Self::new(value.v_int64 as usize, box (), type_code)
+ }
+ TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code),
+ TVMTypeCode::kHandle
+ | TVMTypeCode::kArrayHandle
+ | TVMTypeCode::kNodeHandle
+ | TVMTypeCode::kModuleHandle
+ | TVMTypeCode::kFuncHandle => {
+ Self::new(value.v_handle as usize, box value.v_handle, type_code)
+ }
+ TVMTypeCode::kStr | TVMTypeCode::kBytes => {
+ Self::new(value.v_str as usize, box (value.v_str), type_code)
+ }
+ _ => Self::new(0usize, box (), type_code),
+ }
+ }
+
+ /// Returns the underlying `TVMValue` and `TVMTypeCode`.
+ pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
+ let val = match self.type_code {
+ TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue {
+ v_int64: self.prim_value as i64,
+ }),
+ TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue {
+ v_float64: self.prim_value as f64,
+ }),
+ TVMTypeCode::kHandle
+ | TVMTypeCode::kArrayHandle
+ | TVMTypeCode::kNodeHandle
+ | TVMTypeCode::kModuleHandle
+ | TVMTypeCode::kFuncHandle
+ | TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue {
+ v_handle: self.prim_value as *const c_void as *mut c_void,
+ }),
+ TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue {
+ v_str: self.prim_value as *const c_char,
+ }),
+ _ => unreachable!(),
+ };
+ (val, self.type_code)
+ }
+}
+
+impl Default for TVMRetValue {
+ fn default() -> Self {
+ TVMRetValue {
+ prim_value: 0usize,
+ box_value: box (),
+ type_code: TVMTypeCode::default(),
+ }
+ }
+}
+
+impl Clone for TVMRetValue {
+ fn clone(&self) -> Self {
+ match self.type_code {
+ TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => {
+ Self::new(self.prim_value.clone(), box (), self.type_code.clone())
+ }
+ TVMTypeCode::kHandle
+ | TVMTypeCode::kArrayHandle
+ | TVMTypeCode::kNodeHandle
+ | TVMTypeCode::kModuleHandle
+ | TVMTypeCode::kFuncHandle
+ | TVMTypeCode::kNDArrayContainer => Self::new(
+ self.prim_value.clone(),
+ box (self.prim_value.clone() as *const c_void as *mut c_void),
+ self.type_code.clone(),
+ ),
+ TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new(
+ self.prim_value.clone(),
+ box (self.prim_value.clone() as *const c_char),
+ self.type_code.clone(),
+ ),
+ _ => unreachable!(),
+ }
+ }
+}
+
+impl Debug for TVMRetValue {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(
+ f,
+ "prim_value: {:?}, box_value: {:?}, type_code: {:?}",
+ self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code
+ )
+ }
+}
+
+macro_rules! impl_prim_ret_value {
+ ($type:ty, $code:expr) => {
+ impl From<$type> for TVMRetValue {
+ fn from(val: $type) -> Self {
+ TVMRetValue {
+ prim_value: val as usize,
+ box_value: box (),
+ type_code: $code,
+ }
+ }
+ }
+
+ impl<'a> From<&'a $type> for TVMRetValue {
+ fn from(val: &$type) -> Self {
+ TVMRetValue {
+ prim_value: *val as usize,
+ box_value: box (),
+ type_code: $code,
+ }
+ }
+ }
+
+ impl<'a> From<&'a mut $type> for TVMRetValue {
+ fn from(val: &mut $type) -> Self {
+ TVMRetValue {
+ prim_value: *val as usize,
+ box_value: box (),
+ type_code: $code,
+ }
+ }
+ }
+
+ impl TryFrom<TVMRetValue> for $type {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<$type> {
+ if ret.type_code == $code {
+ Ok(ret.prim_value as $type)
+ } else {
+ bail!(ErrorKind::TryFromTVMRetValueError(
+ stringify!($type).to_string(),
+ ret.type_code.to_string(),
+ ))
+ }
+ }
+ }
+ };
+}
+
+impl_prim_ret_value!(i8, TVMTypeCode::kDLInt);
+impl_prim_ret_value!(i16, TVMTypeCode::kDLInt);
+impl_prim_ret_value!(i32, TVMTypeCode::kDLInt);
+impl_prim_ret_value!(i64, TVMTypeCode::kDLInt);
+impl_prim_ret_value!(isize, TVMTypeCode::kDLInt);
+
+impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt);
+impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt);
+impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt);
+impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt);
+impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt);
+
+impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat);
+impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat);
+
+macro_rules! impl_ptr_ret_value {
+ ($type:ty) => {
+ impl From<$type> for TVMRetValue {
+ fn from(ptr: $type) -> Self {
+ TVMRetValue {
+ prim_value: ptr as usize,
+ box_value: box (),
+ type_code: TVMTypeCode::kHandle,
+ }
+ }
+ }
+
+ impl TryFrom<TVMRetValue> for $type {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<$type> {
+ if ret.type_code == TVMTypeCode::kHandle {
+ Ok(ret.prim_value as $type)
+ } else {
+ bail!(ErrorKind::TryFromTVMRetValueError(
+ stringify!($type).to_string(),
+ ret.type_code.to_string(),
+ ))
+ }
+ }
+ }
+ };
+}
+
+impl_ptr_ret_value!(*const c_void);
+impl_ptr_ret_value!(*mut c_void);
+
+impl From<String> for TVMRetValue {
+ fn from(val: String) -> Self {
+ let pval = val.as_ptr() as *const c_char as usize;
+ let bval = box (val.as_ptr() as *const c_char);
+ mem::forget(val);
+ TVMRetValue::new(pval, bval, TVMTypeCode::kStr)
+ }
+}
+
+impl TryFrom<TVMRetValue> for String {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<String> {
+ // Note: simple downcast doesn't work for function call return values
+ let ret_str = unsafe {
+ match CStr::from_ptr(ret.prim_value as *const c_char).to_str() {
+ Ok(s) => s,
+ Err(_) => "Invalid UTF-8 message",
+ }
+ };
+
+ Ok(ret_str.to_string())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::convert::TryInto;
+
+ #[test]
+ fn numeric() {
+ macro_rules! arg_ret_tests {
+ ($v:expr; $($ty:ty),+) => {{
+ $(
+ let v = $v as $ty;
+ let b = TVMRetValue::from(&v);
+ let b: $ty = b.try_into().unwrap();
+ assert_eq!(b, v);
+ )+
+ }};
+ }
+
+ arg_ret_tests!(42; i8, i16, i32, i64, f32, f64);
+ }
+
+ #[test]
+ fn string() {
+ let s = "hello".to_string();
+ let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap();
+ assert_eq!(tvm_arg, s);
+ }
+}
--- /dev/null
+[package]
+name = "tvm-sys"
+version = "0.1.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+description = "Raw C API"
+
+[build-dependencies]
+bindgen = "0.37.4"
--- /dev/null
+extern crate bindgen;
+
+use std::path::PathBuf;
+
+fn main() {
+ println!("cargo:rerun-if-env-changed=TVM_HOME");
+ println!("cargo:rustc-link-lib=dylib=tvm_runtime");
+ println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
+ let bindings = bindgen::Builder::default()
+ .header(format!(
+ "{}/include/tvm/runtime/c_runtime_api.h",
+ env!("TVM_HOME")
+ ))
+ .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
+ .blacklist_type("max_align_t") // @see rust-bindgen#550
+ .layout_tests(false)
+ .derive_partialeq(true)
+ .derive_eq(true)
+ .generate()
+ .expect("unable to generate bindings");
+
+ bindings
+ .write_to_file(PathBuf::from("src/bindgen.rs"))
+ .expect("can not write the bindings!");
+}
--- /dev/null
+#![allow(
+ non_camel_case_types,
+ non_snake_case,
+ non_upper_case_globals,
+ dead_code,
+ improper_ctypes
+)]
+
+include!("bindgen.rs");
--- /dev/null
+target
+**/*.rs.bk
+Cargo.lock
+/tests/basics/add_*
+/examples/resnet/deploy_*
+/examples/resnet/*.png
+/examples/resnet/synset.*
--- /dev/null
+language: rust
+rust:
+ - nightly
+matrix:
+ fast_finish: true
--- /dev/null
+[package]
+name = "tvm-frontend"
+version = "0.1.0"
+license = "Apache-2.0"
+description = "Rust frontend support for TVM"
+repository = "https://github.com/dmlc/tvm"
+homepage = "https://github.com/dmlc/tvm"
+readme = "README.md"
+keywords = ["rust", "tvm", "nnvm"]
+categories = ["api-bindings", "science"]
+authors = ["TVM Contributors"]
+
+[lib]
+name = "tvm_frontend"
+crate-type = ["dylib"]
+
+[dependencies]
+error-chain = "0.12.0"
+lazy_static = "1.1.0"
+ndarray = "0.12.1"
+num-traits = "0.2"
+tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] }
+
+[features]
+blas = ["ndarray/blas"]
--- /dev/null
+# TVM Runtime Frontend Support
+
+This crate provides an idiomatic Rust API for [TVM](https://github.com/dmlc/tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly`
+
+## What Does This Crate Offer?
+
+Here is a major workflow
+
+1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/)
+2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators.
+3. Deploy your models using **Rust** :heart:
+
+### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k
+
+Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example.
+
+Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM
+
+```python
+block = get_model('resnet18_v1', pretrained=True)
+
+sym, params = nnvm.frontend.from_mxnet(block)
+# add the softmax layer for prediction
+net = nnvm.sym.softmax(sym)
+# compile the model
+with nnvm.compiler.build_config(opt_level=opt_level):
+ graph, lib, params = nnvm.compiler.build(
+ net, target, shape={"data": data_shape}, params=params)
+# same the model artifacts
+lib.save(os.path.join(target_dir, "deploy_lib.o"))
+cc.create_shared(os.path.join(target_dir, "deploy_lib.so"),
+ [os.path.join(target_dir, "deploy_lib.o")])
+
+with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo:
+ fo.write(graph.json())
+with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo:
+ fo.write(nnvm.compiler.save_param_dict(params))
+```
+
+Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image
+
+
+
+as demostrated in the following Rust snippet
+
+```rust
+ let graph = fs::read_to_string("deploy_graph.json")?;
+ // load the built module
+ let lib = Module::load(&Path::new("deploy_lib.so"))?;
+ // get the global TVM graph runtime function
+ let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
+ let runtime_create_fn_ret = call_packed!(
+ runtime_create_fn,
+ &graph,
+ &lib,
+ &ctx.device_type,
+ &ctx.device_id
+ )?;
+ // get graph runtime module
+ let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?;
+ // get the registered `load_params` from runtime module
+ let ref load_param_fn = graph_runtime_module
+ .get_function("load_params", false)
+ .unwrap();
+ // parse parameters and convert to TVMByteArray
+ let params: Vec<u8> = fs::read("deploy_param.params")?;
+ let barr = TVMByteArray::from(¶ms);
+ // load the parameters
+ call_packed!(load_param_fn, &barr)?;
+ // get the set_input function
+ let ref set_input_fn = graph_runtime_module
+ .get_function("set_input", false)
+ .unwrap();
+
+ call_packed!(set_input_fn, "data", &input)?;
+ // get `run` function from runtime module
+ let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
+ // execute the run function. Note that it has no argument
+ call_packed!(run_fn,)?;
+ // prepare to get the output
+ let output_shape = &mut [1, 1000];
+ let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
+ // get the `get_output` function from runtime module
+ let ref get_output_fn = graph_runtime_module
+ .get_function("get_output", false)
+ .unwrap();
+ // execute the get output function
+ call_packed!(get_output_fn, &0, &output)?;
+ // flatten the output as Vec<f32>
+ let output = output.to_vec::<f32>()?;
+```
+
+and the model correctly predicts the input image as **tiger cat**.
+
+## Installations
+
+Please follow TVM [installations](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
+
+*Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually.
+
+## Supported TVM Functionalities
+
+### Use TVM to Generate Shared Library
+
+One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU.
+
+```python
+import os
+import tvm
+from tvm.contrib import cc
+
+def test_add(target_dir):
+ if not tvm.module.enabled("cuda"):
+ print(f"skip {__file__} because cuda is not enabled...")
+ return
+ n = tvm.var("n")
+ A = tvm.placeholder((n,), name='A')
+ B = tvm.placeholder((n,), name='B')
+ C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
+ s = tvm.create_schedule(C.op)
+ bx, tx = s[C].split(C.op.axis[0], factor=64)
+ s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+ s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+ fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
+
+ fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
+ fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
+ cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
+ [os.path.join(target_dir, "add_gpu.o")])
+
+
+if __name__ == "__main__":
+ import sys
+ if len(sys.argv) != 2:
+ sys.exit(-1)
+ test_add(sys.argv[1])
+```
+
+### Run the Generated Shared Library
+
+The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust.
+
+```rust
+extern crate tvm_frontend as tvm;
+
+use tvm::*;
+
+fn main() {
+ let shape = &mut [2];
+ let mut data = vec![3f32, 4.0];
+ let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
+ arr.copy_from_buffer(data.as_mut_slice());
+ let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
+ let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap();
+ let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap();
+ assert!(fadd.enabled("gpu"));
+ fadd.import_module(fadd_dep);
+ fadd.entry();
+ function::Builder::from(&mut fadd)
+ .arg(&arr)
+ .arg(&arr)
+ .set_output(&mut ret)?
+ .invoke()
+ .unwrap();
+
+ assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
+}
+```
+
+**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by
+`cargo:rustc-link-search=native=add_gpu`.
+
+See the tests and examples custom `build.rs` for more details.
+
+### Convert and Register a Rust Function as a TVM Packed Function
+
+One can use `register_global_func!` macro to convert and register a Rust
+function of type `fn(&[TVMArgValue]) -> Result<TVMRetValue>` to a global TVM **packed function** as follows
+
+```rust
+#[macro_use]
+extern crate tvm_frontend as tvm;
+use std::convert::TryInto;
+use tvm::*;
+
+fn main() {
+ register_global_func! {
+ fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ let mut ret = 0f32;
+ let shape = &mut [2];
+ for arg in args.iter() {
+ let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+ let arg: NDArray = arg.try_into()?;
+ let arr = arg.copy_to_ndarray(e).unwrap();
+ let rnd: ArrayD<f32> = ArrayD::try_from(&arr).unwrap();
+ ret += rnd.scalar_sum();
+ }
+ let ret_val = TVMRetValue::from(&ret);
+ Ok(ret_val)
+ }
+ }
+
+ let shape = &mut [2];
+ let mut data = vec![3f32, 4.0];
+ let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+ arr.copy_from_buffer(data.as_mut_slice());
+ let mut registered = function::Builder::default();
+ let ret: f64 = registered
+ .get_function("sum", true)
+ .arg(&arr)
+ .arg(&arr)
+ .invoke()
+ .unwrap()
+ .try_into()
+ .unwrap();
+
+ assert_eq!(ret, 14f64);
+ }
+```
--- /dev/null
+[package]
+name = "resnet"
+version = "0.0.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+build = "build.rs"
+
+[dependencies]
+ndarray = "0.12.1"
+tvm-frontend = { path = "../../" }
+image = "0.20.1"
+csv = "1"
--- /dev/null
+## Resnet example
+
+This end-to-end example shows how to:
+* build `Resnet 18` with `tvm` and `nnvm` from Python
+* use the provided Rust frontend API to test for an input image
+
+To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
+and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html).
+
+* **Build the example**: `cargo build`
+
+To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with
+`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details.
+
+* **Run the example**: `cargo run`
--- /dev/null
+use std::process::Command;
+
+fn main() {
+ let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
+ .output()
+ .expect("Failed to execute command");
+ assert!(
+ std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(),
+ "Could not prepare demo: {}",
+ String::from_utf8(output.stderr).unwrap().trim()
+ );
+ println!(
+ "cargo:rustc-link-search=native={}",
+ env!("CARGO_MANIFEST_DIR")
+ );
+}
--- /dev/null
+#!/usr/bin/env python3
+
+import argparse
+import csv
+import logging
+from os import path as osp
+import sys
+
+import numpy as np
+
+import mxnet as mx
+from mxnet.gluon.model_zoo.vision import get_model
+from mxnet.gluon.utils import download
+
+import tvm
+from tvm.contrib import graph_runtime, cc
+import nnvm
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+parser = argparse.ArgumentParser(description='Resnet build example')
+aa = parser.add_argument
+aa('--batch-size', type=int, default=1, help='input image batch size')
+aa('--opt-level', type=int, default=3,
+ help='level of optimization. 0 is unoptimized and 3 is the highest level')
+aa('--target', type=str, default='llvm', help='target context for compilation')
+aa('--image-shape', type=str, default='3,224,224', help='input image dimensions')
+aa('--image-name', type=str, default='cat.png', help='name of input image to download')
+args = parser.parse_args()
+
+target_dir = osp.dirname(osp.dirname(osp.realpath(__file__)))
+batch_size = args.batch_size
+opt_level = args.opt_level
+target = tvm.target.create(args.target)
+image_shape = tuple(map(int, args.image_shape.split(",")))
+data_shape = (batch_size,) + image_shape
+
+def build(target_dir):
+ """ Compiles resnet18 with TVM"""
+ deploy_lib = osp.join(target_dir, 'deploy_lib.o')
+ if osp.exists(deploy_lib):
+ return
+ # download the pretrained resnet18 trained on imagenet1k dataset for
+ # image classification task
+ block = get_model('resnet18_v1', pretrained=True)
+
+ sym, params = nnvm.frontend.from_mxnet(block)
+ # add the softmax layer for prediction
+ net = nnvm.sym.softmax(sym)
+ # compile the model
+ with nnvm.compiler.build_config(opt_level=opt_level):
+ graph, lib, params = nnvm.compiler.build(
+ net, target, shape={"data": data_shape}, params=params)
+ # save the model artifacts
+ lib.save(deploy_lib)
+ cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
+ [osp.join(target_dir, "deploy_lib.o")])
+
+ with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
+ fo.write(graph.json())
+
+ with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
+ fo.write(nnvm.compiler.save_param_dict(params))
+
+def download_img_labels():
+ """ Download an image and imagenet1k class labels for test"""
+ img_name = 'cat.png'
+ synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
+ '4d0b62f3d01426887599d4f7ede23ee5/raw/',
+ '596b27d23537e5a1b5751d2b0481ef172f58b539/',
+ 'imagenet1000_clsid_to_human.txt'])
+ synset_name = 'synset.txt'
+ download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
+ download(synset_url, synset_name)
+
+ with open(synset_name) as fin:
+ synset = eval(fin.read())
+
+ with open("synset.csv", "w") as fout:
+ w = csv.writer(fout)
+ w.writerows(synset.items())
+
+def test_build(target_dir):
+ """ Sanity check with random input"""
+ graph = open(osp.join(target_dir, "deploy_graph.json")).read()
+ lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so"))
+ params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read())
+ input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
+ ctx = tvm.cpu()
+ module = graph_runtime.create(graph, lib, ctx)
+ module.load_params(params)
+ module.run(data=input_data)
+ out = module.get_output(0).asnumpy()
+
+
+if __name__ == '__main__':
+ logger.info("building the model")
+ build(target_dir)
+ logger.info("build was successful")
+ logger.info("test the build artifacts")
+ test_build(target_dir)
+ logger.info("test was successful")
+ download_img_labels()
+ logger.info("image and synset downloads are successful")
--- /dev/null
+#![feature(try_from)]
+
+extern crate csv;
+extern crate image;
+extern crate ndarray;
+extern crate tvm_frontend as tvm;
+
+use std::{
+ collections::HashMap,
+ convert::TryInto,
+ fs::{self, File},
+ path::Path,
+};
+
+use image::{FilterType, GenericImageView};
+use ndarray::{Array, ArrayD, Axis};
+
+use tvm::*;
+
+fn main() {
+ let ctx = TVMContext::cpu(0);
+ let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap();
+ println!("original image dimensions: {:?}", img.dimensions());
+ // for bigger size images, one needs to first resize to 256x256
+ // with `img.resize_exact` method and then `image.crop` to 224x224
+ let img = img.resize(224, 224, FilterType::Nearest).to_rgb();
+ println!("resized image dimensions: {:?}", img.dimensions());
+ let mut pixels: Vec<f32> = vec![];
+ for pixel in img.pixels() {
+ let tmp = pixel.data;
+ // normalize the RGB channels using mean, std of imagenet1k
+ let tmp = [
+ (tmp[0] as f32 - 123.0) / 58.395, // R
+ (tmp[1] as f32 - 117.0) / 57.12, // G
+ (tmp[2] as f32 - 104.0) / 57.375, // B
+ ];
+ for e in &tmp {
+ pixels.push(*e);
+ }
+ }
+
+ let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap();
+ let arr: ArrayD<f32> = arr.permuted_axes([2, 0, 1]).into_dyn();
+ // make arr shape as [1, 3, 224, 224] acceptable to resnet
+ let arr = arr.insert_axis(Axis(0));
+ // create input tensor from rust's ndarray
+ let input =
+ NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+ println!(
+ "input size is {:?}",
+ input.shape().expect("cannot get the input shape")
+ );
+ let graph =
+ fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap();
+ // load the built module
+ let lib = Module::load(&Path::new(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/deploy_lib.so"
+ )))
+ .unwrap();
+ // get the global TVM graph runtime function
+ let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
+ let runtime_create_fn_ret = call_packed!(
+ runtime_create_fn,
+ &graph,
+ &lib,
+ &ctx.device_type,
+ &ctx.device_id
+ )
+ .unwrap();
+ // get graph runtime module
+ let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap();
+ // get the registered `load_params` from runtime module
+ let ref load_param_fn = graph_runtime_module
+ .get_function("load_params", false)
+ .unwrap();
+ // parse parameters and convert to TVMByteArray
+ let params: Vec<u8> =
+ fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap();
+ let barr = TVMByteArray::from(¶ms);
+ // load the parameters
+ call_packed!(load_param_fn, &barr).unwrap();
+ // get the set_input function
+ let ref set_input_fn = graph_runtime_module
+ .get_function("set_input", false)
+ .unwrap();
+
+ call_packed!(set_input_fn, "data", &input).unwrap();
+ // get `run` function from runtime module
+ let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
+ // execute the run function. Note that it has no argument
+ call_packed!(run_fn,).unwrap();
+ // prepare to get the output
+ let output_shape = &mut [1, 1000];
+ let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
+ // get the `get_output` function from runtime module
+ let ref get_output_fn = graph_runtime_module
+ .get_function("get_output", false)
+ .unwrap();
+ // execute the get output function
+ call_packed!(get_output_fn, &0, &output).unwrap();
+ // flatten the output as Vec<f32>
+ let output = output.to_vec::<f32>().unwrap();
+ // find the maximum entry in the output and its index
+ let mut argmax = -1;
+ let mut max_prob = 0.;
+ for i in 0..output.len() {
+ if output[i] > max_prob {
+ max_prob = output[i];
+ argmax = i as i32;
+ }
+ }
+ // create a hash map of (class id, class name)
+ let mut synset: HashMap<i32, String> = HashMap::new();
+ let file = File::open("synset.csv").unwrap();
+ let mut rdr = csv::ReaderBuilder::new()
+ .has_headers(true)
+ .from_reader(file);
+
+ for result in rdr.records() {
+ let record = result.unwrap();
+ let id: i32 = record[0].parse().unwrap();
+ let cls = record[1].to_string();
+ synset.insert(id, cls);
+ }
+
+ println!(
+ "input image belongs to the class `{}` with probability {}",
+ synset
+ .get(&argmax)
+ .expect("cannot find the class id for argmax"),
+ max_prob
+ );
+}
--- /dev/null
+//! Provides [`TVMByteArray`] used for passing the model parameters
+//! (stored as byte-array) to a runtime module.
+//!
+//! For more detail, please see the example `resnet` in `examples` repository.
+
+use std::os::raw::c_char;
+
+use crate::ts;
+
+/// A struct holding TVM byte-array.
+///
+/// ## Example
+///
+/// ```
+/// let v = b"hello".to_vec();
+/// let barr = TVMByteArray::from(&v);
+/// assert_eq!(barr.len(), v.len());
+/// assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
+/// ```
+#[derive(Debug, Clone)]
+pub struct TVMByteArray {
+ pub(crate) inner: ts::TVMByteArray,
+}
+
+impl TVMByteArray {
+ pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray {
+ TVMByteArray { inner: barr }
+ }
+
+ /// Gets the length of the underlying byte-array
+ pub fn len(&self) -> usize {
+ self.inner.size
+ }
+
+ /// Gets the underlying byte-array as `Vec<i8>`
+ pub fn data(&self) -> Vec<i8> {
+ unsafe {
+ let sz = self.len();
+ let mut ret_buf = Vec::with_capacity(sz);
+ ret_buf.set_len(sz);
+ self.inner.data.copy_to(ret_buf.as_mut_ptr(), sz);
+ ret_buf
+ }
+ }
+}
+
+impl<'a> From<&'a Vec<u8>> for TVMByteArray {
+ fn from(arg: &Vec<u8>) -> Self {
+ let barr = ts::TVMByteArray {
+ data: arg.as_ptr() as *const c_char,
+ size: arg.len(),
+ };
+ TVMByteArray::new(barr)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn convert() {
+ let v = vec![1u8, 2, 3];
+ let barr = TVMByteArray::from(&v);
+ assert_eq!(barr.len(), v.len());
+ assert_eq!(barr.data(), vec![1i8, 2, 3]);
+ let v = b"hello".to_vec();
+ let barr = TVMByteArray::from(&v);
+ assert_eq!(barr.len(), v.len());
+ assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
+ }
+}
--- /dev/null
+//! Provides [`TVMContext`] and related device specific queries.
+//!
+//! Create a new context by device type (cpu is 1) and device id.
+//!
+//! # Example
+//!
+//! ```
+//! let ctx = TVMContext::new(1, 0);
+//! let cpu0 = TVMContext::cpu(0);
+//! assert_eq!(ctx, cpu0);
+//! ```
+//!
+//! Or from a supported device name.
+//!
+//! ```
+//! let cpu0 = TVMContext::from("cpu");
+//! println!("{}", cpu0);
+//! ```
+
+use std::{
+ fmt::{self, Display, Formatter},
+ os::raw::c_void,
+ ptr,
+};
+
+use crate::{function, ts, Result};
+
+/// Device type can be from a supported device name. See the supported devices
+/// in [TVM](https://github.com/dmlc/tvm).
+///
+/// ## Example
+///
+/// ```
+/// let cpu = TVMDeviceType::from("cpu");
+/// println!("device is: {}", cpu);
+///```
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct TVMDeviceType(pub usize);
+
+impl Default for TVMDeviceType {
+ /// default device is cpu.
+ fn default() -> Self {
+ TVMDeviceType(1)
+ }
+}
+
+impl From<TVMDeviceType> for ts::DLDeviceType {
+ fn from(device_type: TVMDeviceType) -> Self {
+ match device_type.0 {
+ 1 => ts::DLDeviceType_kDLCPU,
+ 2 => ts::DLDeviceType_kDLGPU,
+ 3 => ts::DLDeviceType_kDLCPUPinned,
+ 4 => ts::DLDeviceType_kDLOpenCL,
+ 7 => ts::DLDeviceType_kDLVulkan,
+ 8 => ts::DLDeviceType_kDLMetal,
+ 9 => ts::DLDeviceType_kDLVPI,
+ 10 => ts::DLDeviceType_kDLROCM,
+ 12 => ts::DLDeviceType_kDLExtDev,
+ _ => panic!("device type not found!"),
+ }
+ }
+}
+
+impl From<ts::DLDeviceType> for TVMDeviceType {
+ fn from(device_type: ts::DLDeviceType) -> Self {
+ match device_type {
+ ts::DLDeviceType_kDLCPU => TVMDeviceType(1),
+ ts::DLDeviceType_kDLGPU => TVMDeviceType(2),
+ ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
+ ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
+ ts::DLDeviceType_kDLVulkan => TVMDeviceType(7),
+ ts::DLDeviceType_kDLMetal => TVMDeviceType(8),
+ ts::DLDeviceType_kDLVPI => TVMDeviceType(9),
+ ts::DLDeviceType_kDLROCM => TVMDeviceType(10),
+ ts::DLDeviceType_kDLExtDev => TVMDeviceType(12),
+ _ => panic!("device type not found!"),
+ }
+ }
+}
+
+impl Display for TVMDeviceType {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{}",
+ match self {
+ TVMDeviceType(1) => "cpu",
+ TVMDeviceType(2) => "gpu",
+ TVMDeviceType(3) => "cpu_pinned",
+ TVMDeviceType(4) => "opencl",
+ TVMDeviceType(8) => "meta",
+ TVMDeviceType(9) => "vpi",
+ TVMDeviceType(10) => "rocm",
+ TVMDeviceType(_) => "rpc",
+ }
+ )
+ }
+}
+
+impl<'a> From<&'a str> for TVMDeviceType {
+ fn from(type_str: &'a str) -> Self {
+ match type_str {
+ "cpu" => TVMDeviceType(1),
+ "llvm" => TVMDeviceType(1),
+ "stackvm" => TVMDeviceType(1),
+ "gpu" => TVMDeviceType(2),
+ "cuda" => TVMDeviceType(2),
+ "nvptx" => TVMDeviceType(2),
+ "cl" => TVMDeviceType(4),
+ "opencl" => TVMDeviceType(4),
+ "metal" => TVMDeviceType(8),
+ "vpi" => TVMDeviceType(9),
+ "rocm" => TVMDeviceType(10),
+ _ => panic!("{:?} not supported!", type_str),
+ }
+ }
+}
+
+/// Represents the underlying device context. Default is cpu.
+///
+/// ## Examples
+///
+/// ```
+/// let ctx = TVMContext::from("gpu");
+/// assert!(ctx.exist());
+///
+/// ```
+///
+/// It is possible to query the underlying context as follows
+///
+/// ```
+/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
+/// println!("compute version: {}", ctx.compute_version());
+/// ```
+#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
+pub struct TVMContext {
+ /// Supported device types
+ pub device_type: TVMDeviceType,
+ /// Device id
+ pub device_id: usize,
+}
+
+impl TVMContext {
+ /// Creates context from device type and id.
+ pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
+ TVMContext {
+ device_type: device_type,
+ device_id: device_id,
+ }
+ }
+}
+
+macro_rules! impl_ctxs {
+ ($(($ctx:ident, $dldevt:expr));+) => {
+ $(
+ impl TVMContext {
+ pub fn $ctx(device_id: usize) -> Self {
+ Self::new(TVMDeviceType($dldevt), device_id)
+ }
+ }
+ )+
+ };
+}
+
+impl_ctxs!((cpu, 1);
+ (gpu, 2);
+ (nvptx, 2);
+ (cuda, 2);
+ (cpu_pinned, 3);
+ (cl, 4);
+ (opencl, 4);
+ (metal, 8);
+ (vpi, 9);
+ (rocm, 10);
+ (opengl, 11);
+ (ext_dev, 12));
+
+impl<'a> From<&'a str> for TVMContext {
+ fn from(target: &str) -> Self {
+ TVMContext::new(TVMDeviceType::from(target), 0)
+ }
+}
+
+impl TVMContext {
+ /// Checks whether the context exists or not.
+ pub fn exist(&self) -> bool {
+ let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
+ .expect("API function always exists");
+ let dt = self.device_type.0 as usize;
+ // `unwrap` is ok here because if there is any error,
+ // if would occure inside `call_packed!`
+ let ret = call_packed!(func, &dt, &self.device_id, &0)
+ .unwrap()
+ .prim_value;
+ ret != 0
+ }
+
+ /// Synchronize the context stream.
+ pub fn sync(&self) -> Result<()> {
+ check_call!(ts::TVMSynchronize(
+ self.device_type.0 as i32,
+ self.device_id as i32,
+ ptr::null_mut() as *mut c_void
+ ));
+ Ok(())
+ }
+}
+
+macro_rules! impl_device_attrs {
+ ($(($attr_name:ident, $attr_kind:expr));+) => {
+ $(
+ impl TVMContext {
+ pub fn $attr_name(&self) -> usize {
+ let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
+ .expect("API function always exists");
+ let dt = self.device_type.0 as usize;
+ // `unwrap` is ok here because if there is any error,
+ // if would occur in function call.
+ let ret = function::Builder::from(func)
+ .args(&[dt, self.device_id, $attr_kind])
+ .invoke()
+ .unwrap();
+ ret.prim_value as usize
+ }
+ }
+ )+
+ };
+}
+
+impl_device_attrs!((max_threads_per_block, 1);
+ (warp_size, 2);
+ (max_shared_memory_per_block, 3);
+ (compute_version, 4);
+ (device_name, 5);
+ (max_clock_rate, 6);
+ (multi_processor_count, 7);
+ (max_thread_dimensions, 8));
+
+impl From<ts::DLContext> for TVMContext {
+ fn from(ctx: ts::DLContext) -> Self {
+ TVMContext {
+ device_type: TVMDeviceType::from(ctx.device_type),
+ device_id: ctx.device_id as usize,
+ }
+ }
+}
+
+impl From<TVMContext> for ts::DLContext {
+ fn from(ctx: TVMContext) -> Self {
+ ts::DLContext {
+ device_type: ctx.device_type.into(),
+ device_id: ctx.device_id as i32,
+ }
+ }
+}
+
+impl Display for TVMContext {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(f, "{}({})", self.device_type, self.device_id)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn context() {
+ let ctx = TVMContext::cpu(0);
+ println!("ctx: {}", ctx);
+ let default_ctx = TVMContext::new(TVMDeviceType(1), 0);
+ assert_eq!(ctx.clone(), default_ctx);
+ assert_ne!(ctx, TVMContext::gpu(0));
+
+ let str_ctx = TVMContext::new(TVMDeviceType::from("gpu"), 0);
+ assert_eq!(str_ctx.clone(), str_ctx);
+ assert_ne!(str_ctx, TVMContext::new(TVMDeviceType::from("cpu"), 0));
+ }
+
+ #[test]
+ fn sync() {
+ let ctx = TVMContext::cpu(0);
+ assert!(ctx.sync().is_ok())
+ }
+}
--- /dev/null
+//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types.
+
+use std::{ffi, option};
+
+use crate::{common_errors, rust_ndarray};
+
+error_chain! {
+ errors {
+ EmptyArray {
+ description("cannot convert from an empty array")
+ }
+
+ NullHandle(name: String) {
+ description("null handle")
+ display("requested `{}` handle is null", name)
+ }
+
+ FunctionNotFound {
+ description("function not found")
+ display("function was not set in `function::Builder`")
+ }
+
+ TypeMismatch(expected: String, found: String) {
+ description("type mismatch!")
+ display("expected type `{}`, but found `{}`", expected, found)
+ }
+
+ MissingShapeError {
+ description("ndarray `shape()` returns `None`")
+ display("called `Option::unwrap()` on a `None` value")
+ }
+
+ AtMostOneReturn {
+ description("TVM functions accept at most one return value")
+ }
+
+ }
+
+ foreign_links {
+ ShapeError(rust_ndarray::ShapeError);
+ NulError(ffi::NulError);
+ IntoStringError(ffi::IntoStringError);
+ CommonError(common_errors::Error);
+ }
+}
+
+impl From<option::NoneError> for Error {
+ fn from(_err: option::NoneError) -> Self {
+ ErrorKind::MissingShapeError.into()
+ }
+}
--- /dev/null
+//! This module provides an idiomatic Rust API for creating and working with TVM functions.
+//!
+//! For calling an already registered TVM function use [`function::Builder`]
+//! To register a TVM packed function from Rust side either
+//! use [`function::register`] or the macro [`register_global_func`].
+//!
+//! See the tests and examples repository for more examples.
+
+use std::{
+ collections::BTreeMap,
+ ffi::{CStr, CString},
+ mem,
+ os::raw::{c_char, c_int, c_void},
+ ptr, slice, str,
+ sync::Mutex,
+};
+
+use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue};
+
+lazy_static! {
+ static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
+ let mut out_size = 0 as c_int;
+ let name = ptr::null_mut() as *mut c_char;
+ let mut out_array = name as *mut _;
+ check_call!(ts::TVMFuncListGlobalNames(
+ &mut out_size as *mut _,
+ &mut out_array
+ ));
+ let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) };
+ Mutex::new(
+ names_list
+ .into_iter()
+ .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
+ .collect(),
+ )
+ };
+}
+
+/// Wrapper around TVM function handle which includes `is_global`
+/// indicating whether the function is global or not, `is_released`
+/// to hint dropping the function handle and `is_cloned` showing
+/// not to drop a cloned function from Rust side.
+/// The value of these fields can be accessed through their respective methods.
+#[derive(Debug, Hash)]
+pub struct Function {
+ pub(crate) handle: ts::TVMFunctionHandle,
+ // whether the registered function is global or not.
+ is_global: bool,
+ // whether the function has been dropped from frontend or not.
+ is_released: bool,
+ // whether the function has been cloned from frontend or not.
+ is_cloned: bool,
+}
+
+unsafe impl Send for Function {}
+unsafe impl Sync for Function {}
+
+impl Function {
+ pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self {
+ Function {
+ handle: handle,
+ is_global: is_global,
+ is_released: is_released,
+ is_cloned: false,
+ }
+ }
+
+ /// For a given function, it returns a function by name.
+ pub fn get<S: AsRef<str>>(name: S, is_global: bool) -> Option<&'static Function> {
+ let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
+ globals.get_mut(name.as_ref()).and_then(|maybe_func| {
+ if maybe_func.is_none() {
+ let name = CString::new(name.as_ref()).unwrap();
+ let mut handle = ptr::null_mut() as ts::TVMFunctionHandle;
+ check_call!(ts::TVMFuncGetGlobal(
+ name.as_ptr() as *const c_char,
+ &mut handle as *mut _
+ ));
+ maybe_func.replace(Function::new(
+ handle, is_global, false, /* is_released */
+ ));
+ }
+ unsafe {
+ std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
+ maybe_func.as_ref(),
+ )
+ }
+ })
+ }
+
+ /// Returns the underlying TVM function handle.
+ pub fn handle(&self) -> ts::TVMFunctionHandle {
+ self.handle
+ }
+
+ /// Returns `true` if the underlying TVM function is global and `false` otherwise.
+ pub fn is_global(&self) -> bool {
+ self.is_global
+ }
+
+ /// Returns `true` if the underlying TVM function has been released
+ /// from the frontend and `false` otherwise.
+ pub fn is_released(&self) -> bool {
+ self.is_released
+ }
+
+ /// Returns `true` if the underlying TVM function has been cloned
+ /// from the frontend and `false` otherwise.
+ pub fn is_cloned(&self) -> bool {
+ self.is_cloned
+ }
+}
+
+impl Clone for Function {
+ fn clone(&self) -> Function {
+ if !self.is_released && !self.is_cloned {
+ Self {
+ handle: self.handle,
+ is_global: self.is_global,
+ is_released: self.is_released,
+ is_cloned: true,
+ }
+ } else {
+ Function::new(self.handle, self.is_global, self.is_released)
+ }
+ }
+}
+
+impl Drop for Function {
+ fn drop(&mut self) {
+ if !self.is_released && !self.is_global && !self.is_cloned {
+ check_call!(ts::TVMFuncFree(self.handle));
+ self.is_released = true;
+ }
+ }
+}
+
+/// Function builder in order to create and call functions.
+///
+/// *Note:* Currently TVM functions accept *at most* one return value.
+#[derive(Debug, Clone, Default)]
+pub struct Builder<'a, 'm> {
+ pub func: Option<&'m Function>,
+ pub arg_buf: Option<Box<[TVMArgValue<'a>]>>,
+ pub ret_buf: Option<TVMRetValue>,
+}
+
+impl<'a, 'm> Builder<'a, 'm> {
+ pub fn new(
+ func: Option<&'m Function>,
+ arg_buf: Option<Box<[TVMArgValue<'a>]>>,
+ ret_buf: Option<TVMRetValue>,
+ ) -> Self {
+ Self {
+ func,
+ arg_buf,
+ ret_buf,
+ }
+ }
+
+ pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self {
+ self.func = Function::get(name, is_global);
+ self
+ }
+
+ /// Pushes a [`TVMArgValue`] into the function argument buffer.
+ pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self
+ where
+ TVMValue: From<&'b T>,
+ TVMTypeCode: From<&'b T>,
+ {
+ let tvm_arg = TVMArgValue::from(arg);
+ if self.arg_buf.is_none() {
+ self.arg_buf = Some(Box::new([tvm_arg]));
+ } else {
+ let new_arg_buf = self.arg_buf.take().map(|bbuf| {
+ let mut new_arg_buf = Vec::from(bbuf);
+ new_arg_buf.push(tvm_arg);
+ let new_len = new_arg_buf.len();
+ new_arg_buf.truncate(new_len);
+ new_arg_buf.into_boxed_slice()
+ });
+ self.arg_buf = new_arg_buf;
+ }
+ self
+ }
+
+ /// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
+ pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self
+ where
+ I: IntoIterator<Item = &'b T>,
+ TVMValue: From<&'b T>,
+ TVMTypeCode: From<&'b T>,
+ {
+ for arg in args {
+ self.arg(&arg);
+ }
+ self
+ }
+
+ /// Sets an output for a function that requirs a mutable output to be provided.
+ /// See the `basics` in tests for an example.
+ pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self>
+ where
+ TVMValue: From<&'b T>,
+ TVMTypeCode: From<&'b T>,
+ {
+ if self.ret_buf.is_none() {
+ let tvm_ret =
+ unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) };
+ self.ret_buf = Some(tvm_ret);
+ } else {
+ bail!(ErrorKind::AtMostOneReturn)
+ }
+ Ok(self)
+ }
+
+ /// Calls the function that created from `Builder`.
+ pub fn invoke(&mut self) -> Result<TVMRetValue> {
+ self.clone()(())
+ }
+}
+
+impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
+ type Output = Result<TVMRetValue>;
+ extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output {
+ if self.func.is_none() {
+ bail!("{}", ErrorKind::FunctionNotFound);
+ }
+
+ let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
+ let mut ret_type_code = 0 as c_int;
+ if self.arg_buf.is_some() {
+ let arg_buf = self.arg_buf?;
+ let mut num_args = arg_buf.len();
+ let mut values = arg_buf
+ .iter()
+ .map(|tav| tav.value.inner)
+ .collect::<Vec<ts::TVMValue>>();
+ let mut tcodes = arg_buf
+ .iter()
+ .map(|tav| tav.type_code as c_int)
+ .collect::<Vec<_>>();
+
+ if self.ret_buf.is_some() {
+ num_args = num_args + 1;
+ let ret_buf = self.ret_buf?;
+ let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf);
+ values.append(&mut vec![ret_val.inner]);
+ tcodes.append(&mut vec![ret_type_code as c_int]);
+ }
+
+ values.truncate(num_args);
+ tcodes.truncate(num_args);
+ check_call!(ts::TVMFuncCall(
+ self.func?.handle,
+ values.as_mut_ptr(),
+ tcodes.as_mut_ptr(),
+ num_args as c_int,
+ &mut ret_val as *mut _,
+ &mut ret_type_code as *mut _
+ ));
+ } else {
+ check_call!(ts::TVMFuncCall(
+ self.func?.handle,
+ ptr::null_mut(),
+ ptr::null_mut(),
+ 0 as c_int,
+ &mut ret_val as *mut _,
+ &mut ret_type_code as *mut _
+ ));
+ }
+
+ let ret = unsafe {
+ TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into())
+ };
+ Ok(ret)
+ }
+}
+
+/// Converts a [`Function`] to builder. Currently, this is the best way to work with
+/// TVM functions.
+impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
+ fn from(func: &'m Function) -> Self {
+ Builder::new(Some(func), None, None)
+ }
+}
+
+/// Converts a mutable reference of a [`Module`] to [`Builder`].
+impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
+ fn from(module: &'m mut Module) -> Self {
+ Builder::new(module.entry(), None, None)
+ }
+}
+
+unsafe extern "C" fn tvm_callback(
+ args: *mut ts::TVMValue,
+ type_codes: *mut c_int,
+ num_args: c_int,
+ ret: ts::TVMRetValueHandle,
+ fhandle: *mut c_void,
+) -> c_int {
+ // turning off the incorrect linter complaints
+ #![allow(unused_assignments)]
+ let len = num_args as usize;
+ let args_list = slice::from_raw_parts_mut(args, len);
+ let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
+ let mut local_args: Vec<TVMArgValue> = Vec::new();
+ let mut value = mem::uninitialized::<ts::TVMValue>();
+ let mut tcode = mem::uninitialized::<c_int>();
+ let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
+ for i in 0..len {
+ value = args_list[i];
+ tcode = type_codes_list[i];
+ if tcode == ts::TVMTypeCode_kNodeHandle as c_int
+ || tcode == ts::TVMTypeCode_kFuncHandle as c_int
+ || tcode == ts::TVMTypeCode_kModuleHandle as c_int
+ {
+ check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
+ }
+ local_args.push(TVMArgValue::new(
+ TVMValue::new(value),
+ (tcode as i64).into(),
+ ));
+ }
+
+ let rv = match rust_fn(local_args.as_slice()) {
+ Ok(v) => v,
+ Err(msg) => {
+ crate::set_last_error(&msg);
+ return -1;
+ }
+ };
+
+ let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv);
+ let mut ret_val = ret_val.inner;
+ let mut ret_type_code = ret_tcode as c_int;
+ check_call!(ts::TVMCFuncSetReturn(
+ ret,
+ &mut ret_val as *mut _,
+ &mut ret_type_code as *mut _,
+ 1 as c_int
+ ));
+ 0
+}
+
+unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
+ let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
+ mem::drop(rust_fn);
+}
+
+fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function {
+ let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
+ let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue>;
+ check_call!(ts::TVMFuncCreateFromCFunc(
+ Some(tvm_callback),
+ resource_handle as *mut c_void,
+ Some(tvm_callback_finalizer),
+ &mut fhandle as *mut _
+ ));
+ Function::new(fhandle, false, false)
+}
+
+/// Registers a Rust function with signature
+/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>`
+/// as a **global TVM packed function** from frontend to TVM backend.
+///
+/// Use [`register_global_func`] if overriding an existing global TVM function
+/// is not required.
+///
+/// ## Example
+///
+/// ```
+/// use std::convert::TryInto;
+///
+/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+/// let mut ret = 0i64;
+/// for arg in args.iter() {
+/// let arg: i64 = arg.try_into()?;
+/// ret += arg;
+/// }
+/// let ret_val = TVMRetValue::from(&ret);
+/// Ok(ret_val)
+/// }
+///
+/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
+/// let mut registered = function::Builder::default();
+/// registered.get_function("mysum", true);
+/// assert!(registered.func.is_some());
+/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
+/// assert_eq!(ret, 60);
+/// ```
+pub fn register<S: AsRef<str>>(
+ f: fn(&[TVMArgValue]) -> Result<TVMRetValue>,
+ name: S,
+ override_: bool,
+) -> Result<()> {
+ let func = convert_to_tvm_func(f);
+ let name = CString::new(name.as_ref())?;
+ check_call!(ts::TVMFuncRegisterGlobal(
+ name.as_ref().as_ptr() as *const c_char,
+ func.handle(),
+ override_ as c_int
+ ));
+ mem::forget(name);
+ Ok(())
+}
+
+/// Convenient macro for registering functions from frontend to backend as global
+/// TVM packed functions without overriding. If overriding an existing function is needed
+/// use the [`function::register`] function instead.
+///
+/// ## Example
+///
+/// ```
+/// use std::convert::TryInto;
+///
+/// register_global_func! {
+/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+/// let mut ret = 0f64;
+/// for arg in args.iter() {
+/// let arg: f64 = arg.try_into()?;
+/// ret += arg;
+/// }
+/// let ret_val = TVMRetValue::from(&ret);
+/// Ok(ret_val)
+/// }
+/// }
+///
+/// let mut registered = function::Builder::default();
+/// registered.get_function("sum", true);
+/// assert!(registered.func.is_some());
+/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
+/// assert_eq!(ret, 60f64);
+/// ```
+#[macro_export]
+macro_rules! register_global_func {
+ {
+ $(#[$m:meta])*
+ fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue> {
+ $($code:tt)*
+ }
+ } => {{
+ $(#[$m])*
+ fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ $($code)*
+ }
+
+ $crate::function::register($fn_name, stringify!($fn_name).to_owned(), false).unwrap();
+ }}
+}
+
+/// Convenient macro for calling TVM packed functions by providing a
+/// function identifier and some arguments. This macro outputs a `Result` type
+/// and let user to perform proper error handling.
+///
+/// **Note**: this macro does *not* expect an outside mutable output. To
+/// set mutable output use [`set_output`] directly in the builder pattern.
+///
+/// [`set_output`]:function/struct.Builder.html#method.set_output
+///
+/// ## Example
+///
+/// Instead of
+///
+/// ```
+/// function::Builder::from(func).arg(&a).arg(&b).invoke();
+/// ```
+///
+/// one can use
+///
+/// ```
+/// call_packed!(func, &a, &b);
+/// ```
+#[macro_export]
+macro_rules! call_packed {
+ ($fn_name:expr, $($arg:expr),*) => {{
+ let mut builder = $crate::function::Builder::from($fn_name);
+ $(
+ builder.arg($arg);
+ )*
+ builder.invoke()
+ }}
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ static CANARY: &str = "module._LoadFromFile";
+
+ #[test]
+ fn list_global_func() {
+ assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
+ }
+
+ #[test]
+ fn get_fn() {
+ assert!(Function::get(CANARY, true).is_some());
+ assert!(Function::get("does not exists!", false).is_none());
+ }
+
+ #[test]
+ fn provide_args() {
+ let mut func = Builder::default();
+ func.get_function("tvm.graph_runtime.remote_create", true)
+ .args(&[10, 20])
+ .arg(&"test".to_owned());
+ assert!(func.arg_buf.is_some());
+ assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3));
+ }
+}
--- /dev/null
+//! [TVM](https://github.com/dmlc/tvm) is a compiler stack for deep learning systems.
+//!
+//! This crate provides an idiomatic Rust API for TVM runtime frontend.
+//!
+//! One particular use case is that given optimized deep learning model artifacts,
+//! (compiled with TVM) which include a shared library
+//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
+//! in Rust idomatically to create a TVM Graph Runtime and
+//! run the model for some inputs and get the
+//! desired predictions *all in Rust*.
+//!
+//! Checkout the `examples` repository for more details.
+
+#![crate_name = "tvm_frontend"]
+#![recursion_limit = "1024"]
+#![allow(non_camel_case_types, unused_unsafe)]
+#![feature(
+ try_from,
+ try_trait,
+ fn_traits,
+ unboxed_closures,
+ box_syntax,
+ option_replace
+)]
+
+#[macro_use]
+extern crate error_chain;
+extern crate tvm_common as common;
+#[macro_use]
+extern crate lazy_static;
+extern crate ndarray as rust_ndarray;
+extern crate num_traits;
+
+use std::{
+ ffi::{CStr, CString},
+ str,
+};
+
+use crate::common::ffi::ts;
+
+// Macro to check the return call to TVM runtime shared library.
+macro_rules! check_call {
+ ($e:expr) => {{
+ if unsafe { $e } != 0 {
+ panic!("{}", $crate::get_last_error());
+ }
+ }};
+}
+
+/// Gets the last error message.
+pub fn get_last_error() -> &'static str {
+ unsafe {
+ match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
+ Ok(s) => s,
+ Err(_) => "Invalid UTF-8 message",
+ }
+ }
+}
+
+pub(crate) fn set_last_error(err: &Error) {
+ let c_string = CString::new(err.to_string()).unwrap();
+ unsafe {
+ ts::TVMAPISetLastError(c_string.as_ptr());
+ }
+}
+
+#[macro_use]
+pub mod function;
+pub mod bytearray;
+pub mod context;
+pub mod errors;
+pub mod module;
+pub mod ndarray;
+pub mod ty;
+pub mod value;
+
+pub use crate::{
+ bytearray::TVMByteArray,
+ common::{
+ errors as common_errors,
+ ty::TVMTypeCode,
+ value::{TVMArgValue, TVMRetValue, TVMValue},
+ },
+ context::{TVMContext, TVMDeviceType},
+ errors::*,
+ function::Function,
+ module::Module,
+ ndarray::NDArray,
+ ty::TVMType,
+};
+
+/// Outputs the current TVM version.
+pub fn version() -> &'static str {
+ match str::from_utf8(ts::TVM_VERSION) {
+ Ok(s) => s,
+ Err(_) => "Invalid UTF-8 string",
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn print_version() {
+ println!("TVM version: {}", version());
+ }
+
+ #[test]
+ fn set_error() {
+ let err = ErrorKind::EmptyArray;
+ set_last_error(&err.into());
+ assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
+ }
+}
--- /dev/null
+//! Provides the [`Module`] type and methods for working with runtime TVM modules.
+
+use std::{
+ convert::TryInto,
+ ffi::CString,
+ os::raw::{c_char, c_int},
+ path::Path,
+ ptr,
+};
+
+use crate::ts;
+
+use crate::{function::Function, ErrorKind, Result};
+
+const ENTRY_FUNC: &'static str = "__tvm_main__";
+
+/// Wrapper around TVM module handle which contains an entry function.
+/// The entry function can be applied to an imported module through [`entry_func`].
+/// Also [`is_released`] shows whether the module is dropped or not.
+///
+/// [`entry_func`]:struct.Module.html#method.entry_func
+/// [`is_released`]:struct.Module.html#method.is_released
+#[derive(Debug, Clone)]
+pub struct Module {
+ pub(crate) handle: ts::TVMModuleHandle,
+ is_released: bool,
+ entry_func: Option<Function>,
+}
+
+impl Module {
+ pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self {
+ Self {
+ handle,
+ is_released,
+ entry_func: None,
+ }
+ }
+
+ pub fn entry(&mut self) -> Option<&Function> {
+ if self.entry_func.is_none() {
+ self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
+ }
+ self.entry_func.as_ref()
+ }
+
+ /// Gets a function by name from a registered module.
+ pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
+ let name = CString::new(name)?;
+ let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
+ check_call!(ts::TVMModGetFunction(
+ self.handle,
+ name.as_ptr() as *const c_char,
+ query_import as c_int,
+ &mut fhandle as *mut _
+ ));
+ if fhandle.is_null() {
+ bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
+ } else {
+ Ok(Function::new(fhandle, false, false))
+ }
+ }
+
+ /// Imports a dependent module such as `.ptx` for gpu.
+ pub fn import_module(&self, dependent_module: Module) {
+ check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
+ }
+
+ /// Loads a module shared library from path.
+ pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> {
+ let ext = path.as_ref().extension()?.to_str()?;
+ let func = Function::get("module._LoadFromFile", true /* is_global */)
+ .expect("API function always exists");
+ let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?;
+ Ok(ret)
+ }
+
+ /// Checks if a target device is enabled for a module.
+ pub fn enabled(&self, target: &str) -> bool {
+ let func = Function::get("module._Enabled", true /* is_global */)
+ .expect("API function always exists");
+ // `unwrap` is safe here because if there is any error during the
+ // function call, it would occur in `call_packed!`.
+ let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap();
+ ret != 0
+ }
+
+ /// Returns the underlying module handle.
+ pub fn handle(&self) -> ts::TVMModuleHandle {
+ self.handle
+ }
+
+ /// Returns true if the underlying module has been dropped and false otherwise.
+ pub fn is_released(&self) -> bool {
+ self.is_released
+ }
+}
+
+impl Drop for Module {
+ fn drop(&mut self) {
+ if !self.is_released {
+ check_call!(ts::TVMModFree(self.handle));
+ self.is_released = true;
+ }
+ }
+}
--- /dev/null
+//! This module implements the [`NDArray`] type for working with *TVM tensors* or
+//! coverting from a Rust's ndarray to TVM `NDArray`.
+//!
+//! One can create an empty NDArray given the shape, context and dtype using [`empty`].
+//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`].
+//! To copy an NDArray to different context use [`copy_to_ctx`].
+//!
+//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows:
+//!
+//! # Example
+//!
+//! ```
+//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
+//! .unwrap()
+//! .into_dyn(); // Rust's ndarray
+//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
+//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
+//! assert!(rnd.all_close(&a, 1e-8f32));
+//! ```
+//!
+//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/
+//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
+//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
+
+use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice};
+
+use crate::rust_ndarray::{Array, ArrayD};
+use num_traits::Num;
+
+use crate::ts;
+
+use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType};
+
+/// See the [`module-level documentation`](../ndarray/index.html) for more details.
+///
+/// Wrapper around TVM array handle.
+#[derive(Debug)]
+pub struct NDArray {
+ pub(crate) handle: ts::TVMArrayHandle,
+ is_view: bool,
+}
+
+impl NDArray {
+ pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self {
+ NDArray {
+ handle: handle,
+ is_view: is_view,
+ }
+ }
+
+ /// Returns the underlying array handle.
+ pub fn handle(&self) -> ts::TVMArrayHandle {
+ self.handle
+ }
+
+ pub fn is_view(&self) -> bool {
+ self.is_view
+ }
+
+ /// Returns the shape of the NDArray.
+ pub fn shape(&self) -> Option<&mut [usize]> {
+ let arr = unsafe { *(self.handle) };
+ if arr.shape.is_null() || arr.data.is_null() {
+ return None;
+ };
+ let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) };
+ Some(slc)
+ }
+
+ /// Returns the total number of entries of the NDArray.
+ pub fn size(&self) -> Option<usize> {
+ self.shape()
+ .map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e))
+ }
+
+ /// Returns the context which the NDArray was defined.
+ pub fn ctx(&self) -> TVMContext {
+ unsafe { (*self.handle).ctx.into() }
+ }
+
+ /// Returns the type of the entries of the NDArray.
+ pub fn dtype(&self) -> TVMType {
+ unsafe { (*self.handle).dtype.into() }
+ }
+
+ /// Returns the number of dimensions of the NDArray.
+ pub fn ndim(&self) -> usize {
+ unsafe { (*self.handle).ndim as usize }
+ }
+
+ /// Returns the strides of the underlying NDArray.
+ pub fn strides(&self) -> Option<&[usize]> {
+ unsafe {
+ let sz = self.ndim() * mem::size_of::<usize>();
+ let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz);
+ Some(slc)
+ }
+ }
+
+ /// Shows whether the underlying ndarray is contiguous in memory or not.
+ pub fn is_contiguous(&self) -> Result<bool> {
+ Ok(match self.strides() {
+ None => true,
+ Some(strides) => {
+ // MissingShapeError in case shape is not determined
+ self.shape()?
+ .iter()
+ .zip(strides)
+ .rfold(
+ (true, 1),
+ |(is_contig, expected_stride), (shape, stride)| {
+ (
+ is_contig && *stride == expected_stride,
+ expected_stride * (*shape as usize),
+ )
+ },
+ )
+ .0
+ }
+ })
+ }
+
+ pub fn byte_offset(&self) -> isize {
+ unsafe { (*self.handle).byte_offset as isize }
+ }
+
+ /// Flattens the NDArray to a `Vec` of the same type in cpu.
+ ///
+ /// ## Example
+ ///
+ /// ```
+ /// let shape = &mut [4];
+ /// let mut data = vec![1i32, 2, 3, 4];
+ /// let ctx = TVMContext::cpu(0);
+ /// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
+ /// ndarray.copy_from_buffer(&mut data);
+ /// assert_eq!(ndarray.shape(), Some(shape));
+ /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
+ /// ```
+ pub fn to_vec<T>(&self) -> Result<Vec<T>> {
+ if self.shape().is_none() {
+ bail!("{}", ErrorKind::EmptyArray);
+ }
+ let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype());
+ let target = self.copy_to_ndarray(earr)?;
+ let arr = unsafe { *(target.handle) };
+ let sz = self.size()? as usize;
+ let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
+ unsafe {
+ v.as_mut_ptr()
+ .copy_from_nonoverlapping(arr.data as *const T, sz);
+ v.set_len(sz);
+ }
+ Ok(v)
+ }
+
+ /// Converts the NDArray to [`TVMByteArray`].
+ pub fn to_bytearray(&self) -> Result<TVMByteArray> {
+ let v = self.to_vec::<u8>()?;
+ Ok(TVMByteArray::from(&v))
+ }
+
+ /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
+ ///
+ /// ## Example
+ ///
+ /// ```
+ /// let shape = &mut [2];
+ /// let mut data = vec![1f32, 2];
+ /// let ctx = TVMContext::gpu(0);
+ /// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
+ /// ndarray.copy_from_buffer(&mut data);
+ /// ```
+ ///
+ /// *Note*: if something goes wrong during the copy, it will panic
+ /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
+ pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
+ check_call!(ts::TVMArrayCopyFromBytes(
+ self.handle,
+ data.as_ptr() as *mut _,
+ data.len() * mem::size_of::<T>()
+ ));
+ }
+
+ /// Copies the NDArray to another target NDArray.
+ pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
+ if self.dtype() != target.dtype() {
+ bail!(
+ "{}",
+ ErrorKind::TypeMismatch(
+ format!("{}", self.dtype().to_string()),
+ format!("{}", target.dtype().to_string()),
+ )
+ );
+ }
+ check_call!(ts::TVMArrayCopyFromTo(
+ self.handle,
+ target.handle,
+ ptr::null_mut() as ts::TVMStreamHandle
+ ));
+ Ok(target)
+ }
+
+ /// Copies the NDArray to a target context.
+ pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray> {
+ let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype());
+ let copy = self.copy_to_ndarray(tmp)?;
+ Ok(copy)
+ }
+
+ /// Converts a Rust's ndarray to TVM NDArray.
+ pub fn from_rust_ndarray<T: Num32 + Copy>(
+ rnd: &ArrayD<T>,
+ ctx: TVMContext,
+ dtype: TVMType,
+ ) -> Result<Self> {
+ let mut shape = rnd.shape().to_vec();
+ let mut nd = NDArray::empty(&mut shape, ctx, dtype);
+ let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
+ nd.copy_from_buffer(buf.as_slice_mut()?);
+ Ok(nd)
+ }
+
+ /// Allocates and creates an empty NDArray given the shape, context and dtype.
+ pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray {
+ let mut handle = ptr::null_mut() as ts::TVMArrayHandle;
+ check_call!(ts::TVMArrayAlloc(
+ shape.as_ptr() as *const i64,
+ shape.len() as c_int,
+ dtype.inner.code as c_int,
+ dtype.inner.bits as c_int,
+ dtype.inner.lanes as c_int,
+ ctx.device_type.0 as c_int,
+ ctx.device_id as c_int,
+ &mut handle as *mut _,
+ ));
+ NDArray::new(handle, false)
+ }
+}
+
+macro_rules! impl_from_ndarray_rustndarray {
+ ($type:ty, $type_name:tt) => {
+ impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
+ type Error = Error;
+ fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
+ if nd.shape().is_none() {
+ bail!("{}", ErrorKind::EmptyArray);
+ }
+ assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
+ Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
+ }
+ }
+
+ impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
+ type Error = Error;
+ fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
+ if nd.shape().is_none() {
+ bail!("{}", ErrorKind::EmptyArray);
+ }
+ assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
+ Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
+ }
+ }
+ };
+}
+
+impl_from_ndarray_rustndarray!(i32, "int");
+impl_from_ndarray_rustndarray!(u32, "uint");
+impl_from_ndarray_rustndarray!(f32, "float");
+
+impl Drop for NDArray {
+ fn drop(&mut self) {
+ if !self.is_view {
+ check_call!(ts::TVMArrayFree(self.handle));
+ }
+ }
+}
+
+mod sealed {
+ /// Private trait to prevent other traits from being implemeneted in downstream crates.
+ pub trait Sealed {}
+}
+
+/// A trait for the supported 32-bits numerical types in frontend.
+pub trait Num32: Num + sealed::Sealed {
+ const BITS: u8 = 32;
+}
+
+macro_rules! impl_num32 {
+ ($($type:ty),+) => {
+ $(
+ impl sealed::Sealed for $type {}
+ impl Num32 for $type {}
+ )+
+ };
+}
+
+impl_num32!(i32, u32, f32);
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn basics() {
+ let shape = &mut [1, 2, 3];
+ let ctx = TVMContext::cpu(0);
+ let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
+ assert_eq!(ndarray.shape().unwrap(), shape);
+ assert_eq!(
+ ndarray.size().unwrap(),
+ shape.to_vec().into_iter().product()
+ );
+ assert_eq!(ndarray.ndim(), 3);
+ assert!(ndarray.strides().is_none());
+ assert_eq!(ndarray.byte_offset(), 0);
+ }
+
+ #[test]
+ fn copy() {
+ let shape = &mut [4];
+ let mut data = vec![1i32, 2, 3, 4];
+ let ctx = TVMContext::cpu(0);
+ let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
+ assert!(ndarray.to_vec::<i32>().is_ok());
+ ndarray.copy_from_buffer(&mut data);
+ assert_eq!(ndarray.shape().unwrap(), shape);
+ assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
+ assert_eq!(ndarray.ndim(), 1);
+ assert!(ndarray.is_contiguous().is_ok());
+ assert_eq!(ndarray.byte_offset(), 0);
+ let mut shape = vec![4];
+ let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32"));
+ let nd = ndarray.copy_to_ndarray(e);
+ assert!(nd.is_ok());
+ assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
+ }
+
+ #[test]
+ #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
+ fn copy_wrong_dtype() {
+ let mut shape = vec![4];
+ let mut data = vec![1f32, 2., 3., 4.];
+ let ctx = TVMContext::cpu(0);
+ let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32"));
+ nd_float.copy_from_buffer(&mut data);
+ let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32"));
+ nd_float.copy_to_ndarray(empty_int).unwrap();
+ }
+
+ #[test]
+ fn rust_ndarray() {
+ let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
+ .unwrap()
+ .into_dyn();
+ let nd =
+ NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+ assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
+ let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
+ assert!(rnd.all_close(&a, 1e-8f32));
+ }
+}
--- /dev/null
+//! This module implements the required conversions from Rust types to TVM types.
+//!
+//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
+//! and 64-bits pointers are supported.
+
+use std::{
+ fmt::{self, Display, Formatter},
+ ops::{Deref, DerefMut},
+};
+
+use crate::ts;
+
+use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
+
+macro_rules! impl_prim_type {
+ ($type:ty, $variant:ident) => {
+ impl From<$type> for TVMTypeCode {
+ fn from(_arg: $type) -> Self {
+ TVMTypeCode::$variant
+ }
+ }
+
+ impl<'a> From<&'a $type> for TVMTypeCode {
+ fn from(_arg: &$type) -> Self {
+ TVMTypeCode::$variant
+ }
+ }
+
+ impl<'a> From<&'a mut $type> for TVMTypeCode {
+ fn from(_arg: &mut $type) -> Self {
+ TVMTypeCode::$variant
+ }
+ }
+ };
+}
+
+impl_prim_type!(TVMDeviceType, kDLInt);
+impl_prim_type!(TVMContext, kTVMContext);
+impl_prim_type!(TVMType, kTVMType);
+impl_prim_type!(Function, kFuncHandle);
+impl_prim_type!(Module, kModuleHandle);
+impl_prim_type!(NDArray, kArrayHandle);
+impl_prim_type!(TVMByteArray, kBytes);
+
+/// See the [module-level documentation](../ty/index.html) for more details.
+///
+/// Wrapper around underlying TVMType
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub struct TVMType {
+ // inner fields are (code: u8, bits: u8, lanes: u16)
+ pub inner: ts::TVMType,
+}
+
+impl TVMType {
+ pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
+ TVMType {
+ inner: ts::TVMType {
+ code: type_code,
+ bits: bits,
+ lanes: lanes,
+ },
+ }
+ }
+}
+
+/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
+/// such as "int32", "float32" or with lane "float32x1".
+impl<'a> From<&'a str> for TVMType {
+ fn from(type_str: &'a str) -> Self {
+ if type_str == "bool" {
+ return TVMType::new(1, 1, 1);
+ }
+
+ let mut type_lanes = type_str.split("x");
+ let typ = type_lanes.next().expect("Missing dtype");
+ let lanes = type_lanes
+ .next()
+ .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
+ .unwrap_or(1);
+ let (type_name, bits) = match typ.find(char::is_numeric) {
+ Some(idx) => {
+ let (name, bits_str) = typ.split_at(idx);
+ (
+ name,
+ u8::from_str_radix(bits_str, 10)
+ .expect(&format!("Bad dtype bits: {}", bits_str)),
+ )
+ }
+ None => (typ, 32),
+ };
+
+ let type_code = match type_name {
+ "int" => 0,
+ "uint" => 1,
+ "float" => 2,
+ "handle" => 3,
+ _ => unimplemented!(),
+ };
+
+ TVMType::new(type_code, bits, lanes)
+ }
+}
+
+impl Display for TVMType {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ let ts::TVMType { code, bits, lanes } = self.inner;
+ if bits == 1 && lanes == 1 {
+ return write!(f, "bool");
+ }
+ let mut tcode_str = match code {
+ 0 => "int",
+ 1 => "uint",
+ 2 => "float",
+ 4 => "handle",
+ _ => "Unknown",
+ }
+ .to_string();
+
+ tcode_str += &bits.to_string();
+ if lanes > 1 {
+ tcode_str += &format!("x{}", lanes.to_string());
+ }
+ f.write_str(&tcode_str)
+ }
+}
+
+impl From<TVMType> for ts::DLDataType {
+ fn from(dtype: TVMType) -> Self {
+ dtype.inner
+ }
+}
+
+impl From<ts::DLDataType> for TVMType {
+ fn from(dtype: ts::DLDataType) -> Self {
+ Self::new(dtype.code, dtype.bits, dtype.lanes)
+ }
+}
+
+impl Deref for TVMType {
+ type Target = ts::TVMType;
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+impl DerefMut for TVMType {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.inner
+ }
+}
--- /dev/null
+//! This module implements [`TVMArgValue`] and [`TVMRetValue`] types
+//! and their conversions needed for the types used in frontend crate.
+//! `TVMRetValue` is the owned version of `TVMPODValue`.
+
+use std::{convert::TryFrom, mem, os::raw::c_void};
+
+use crate::{
+ common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext,
+ TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue,
+};
+
+macro_rules! impl_tvm_val_from_handle {
+ ($($ty:ty),+) => {
+ $(
+ impl<'a> From<&'a $ty> for TVMValue {
+ fn from(arg: &$ty) -> Self {
+ let inner = ts::TVMValue {
+ v_handle: arg.handle as *mut _ as *mut c_void,
+ };
+ Self::new(inner)
+ }
+ }
+ )+
+ }
+}
+
+impl_tvm_val_from_handle!(Module, Function, NDArray);
+
+impl<'a> From<&'a TVMType> for TVMValue {
+ fn from(ty: &TVMType) -> Self {
+ let inner = ts::TVMValue { v_type: ty.inner };
+ Self::new(inner)
+ }
+}
+
+impl<'a> From<&'a TVMContext> for TVMValue {
+ fn from(ctx: &TVMContext) -> Self {
+ let inner = ts::TVMValue {
+ v_ctx: ctx.clone().into(),
+ };
+ Self::new(inner)
+ }
+}
+
+impl<'a> From<&'a TVMDeviceType> for TVMValue {
+ fn from(dev: &TVMDeviceType) -> Self {
+ let inner = ts::TVMValue {
+ v_int64: dev.0 as i64,
+ };
+ Self::new(inner)
+ }
+}
+
+impl<'a> From<&'a TVMByteArray> for TVMValue {
+ fn from(barr: &TVMByteArray) -> Self {
+ let inner = ts::TVMValue {
+ v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void,
+ };
+ Self::new(inner)
+ }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+ if arg.type_code == TVMTypeCode::kArrayHandle {
+ let handle = unsafe { arg.value.inner.v_handle };
+ let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) };
+ Ok(Self::new(arr_handle, true))
+ } else {
+ bail!(ErrorKind::TryFromTVMArgValueError(
+ stringify!(NDArray).to_string(),
+ arg.type_code.to_string()
+ ))
+ }
+ }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+ if arg.type_code == TVMTypeCode::kModuleHandle {
+ let handle = unsafe { arg.value.inner.v_handle };
+ Ok(Self::new(handle, false))
+ } else {
+ bail!(ErrorKind::TryFromTVMArgValueError(
+ stringify!(Module).to_string(),
+ arg.type_code.to_string()
+ ))
+ }
+ }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+ if arg.type_code == TVMTypeCode::kBytes {
+ unsafe {
+ let barr_ptr =
+ mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle);
+ Ok(Self::new(*barr_ptr))
+ }
+ } else {
+ bail!(ErrorKind::TryFromTVMArgValueError(
+ stringify!(TVMByteArray).to_string(),
+ arg.type_code.to_string()
+ ))
+ }
+ }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+ if arg.type_code == TVMTypeCode::kTVMType {
+ let ty = unsafe { arg.value.inner.v_type };
+ Ok(TVMType::from(ty))
+ } else {
+ bail!(ErrorKind::TryFromTVMArgValueError(
+ stringify!(TVMType).to_string(),
+ arg.type_code.to_string()
+ ))
+ }
+ }
+}
+
+impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
+ if arg.type_code == TVMTypeCode::kTVMContext {
+ let ty = unsafe { arg.value.inner.v_ctx };
+ Ok(TVMContext::from(ty))
+ } else {
+ bail!(ErrorKind::TryFromTVMArgValueError(
+ stringify!(TVMContext).to_string(),
+ arg.type_code.to_string()
+ ))
+ }
+ }
+}
+
+macro_rules! impl_boxed_ret_value {
+ ($type:ty, $code:expr) => {
+ impl From<$type> for TVMRetValue {
+ fn from(val: $type) -> Self {
+ TVMRetValue {
+ prim_value: 0,
+ box_value: box val,
+ type_code: $code,
+ }
+ }
+ }
+ impl TryFrom<TVMRetValue> for $type {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<$type> {
+ if let Ok(val) = ret.box_value.downcast::<$type>() {
+ Ok(*val)
+ } else {
+ bail!(ErrorKind::TryFromTVMRetValueError(
+ stringify!($type).to_string(),
+ ret.type_code.to_string()
+ ))
+ }
+ }
+ }
+ };
+}
+
+impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType);
+impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext);
+impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes);
+
+impl TryFrom<TVMRetValue> for Module {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<Module> {
+ if let Ok(handle) = ret.box_value.downcast::<ts::TVMModuleHandle>() {
+ Ok(Module::new(*handle, false))
+ } else {
+ bail!(ErrorKind::TryFromTVMRetValueError(
+ stringify!(TVMTypeCode::kModuleHandle).to_string(),
+ ret.type_code.to_string()
+ ))
+ }
+ }
+}
+
+impl TryFrom<TVMRetValue> for Function {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<Function> {
+ if let Ok(handle) = ret.box_value.downcast::<ts::TVMFunctionHandle>() {
+ Ok(Function::new(*handle, false, false))
+ } else {
+ bail!(ErrorKind::TryFromTVMRetValueError(
+ stringify!(TVMTypeCode::kFuncHandle).to_string(),
+ ret.type_code.to_string()
+ ))
+ }
+ }
+}
+
+impl TryFrom<TVMRetValue> for NDArray {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<NDArray> {
+ if let Ok(handle) = ret.box_value.downcast::<ts::TVMArrayHandle>() {
+ Ok(NDArray::new(*handle, false))
+ } else {
+ bail!(ErrorKind::TryFromTVMRetValueError(
+ stringify!(TVMTypeCode::kArrayHandle).to_string(),
+ ret.type_code.to_string()
+ ))
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::convert::TryInto;
+
+ #[test]
+ fn bytearray() {
+ let w = vec![1u8, 2, 3, 4, 5];
+ let v = TVMByteArray::from(&w);
+ let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
+ assert_eq!(tvm.data(), w.iter().map(|e| *e as i8).collect::<Vec<i8>>());
+ }
+
+ #[test]
+ fn ty() {
+ let t = TVMType::from("int32");
+ let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
+ assert_eq!(tvm, t);
+ }
+
+ #[test]
+ fn ctx() {
+ let c = TVMContext::from("gpu");
+ let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap();
+ assert_eq!(tvm, c);
+ }
+}
--- /dev/null
+/target
+**/*.rs.bk
+Cargo.lock
+*.o
+*.so
+*.ptx
+*.json
--- /dev/null
+[package]
+name = "basics"
+version = "0.0.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+build = "build.rs"
+
+[dependencies]
+ndarray = "0.12.1"
+tvm-frontend = { path = "../../" }
+
+[features]
+default = ["cpu"]
+cpu = []
+gpu = []
--- /dev/null
+fn main() {
+ let out_dir = std::env::var("OUT_DIR").unwrap();
+
+ let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py"))
+ .args(&[
+ if cfg!(feature = "cpu") {
+ "llvm"
+ } else {
+ "cuda"
+ },
+ &std::env::var("OUT_DIR").unwrap(),
+ ])
+ .output()
+ .expect("Failed to execute command");
+ assert!(
+ std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(),
+ "Could not build tvm lib: {}",
+ String::from_utf8(output.stderr)
+ .unwrap()
+ .trim()
+ .split("\n")
+ .last()
+ .unwrap_or("")
+ );
+
+ println!("cargo:rustc-link-search=native={}", out_dir);
+}
--- /dev/null
+extern crate ndarray as rust_ndarray;
+extern crate tvm_frontend as tvm;
+
+use tvm::*;
+
+fn main() {
+ let shape = &mut [2];
+ let mut data = vec![3f32, 4.0];
+
+ let (ctx, ctx_name) = if cfg!(feature = "cpu") {
+ (TVMContext::cpu(0), "cpu")
+ } else {
+ (TVMContext::gpu(0), "gpu")
+ };
+ let dtype = TVMType::from("float32");
+ let mut arr = NDArray::empty(shape, ctx, dtype);
+ arr.copy_from_buffer(data.as_mut_slice());
+ let mut ret = NDArray::empty(shape, ctx, dtype);
+ let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
+ if !fadd.enabled(ctx_name) {
+ return;
+ }
+ if cfg!(feature = "gpu") {
+ fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap());
+ }
+ function::Builder::from(&mut fadd)
+ .arg(&arr)
+ .arg(&arr)
+ .set_output(&mut ret)
+ .unwrap()
+ .invoke()
+ .unwrap();
+
+ assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
+}
--- /dev/null
+#!/usr/bin/env python3
+
+import os.path as osp
+import sys
+
+import tvm
+from tvm.contrib import cc
+
+
+def main(target, out_dir):
+ n = tvm.var('n')
+ A = tvm.placeholder((n,), name='A')
+ B = tvm.placeholder((n,), name='B')
+ C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
+ s = tvm.create_schedule(C.op)
+
+ if target == 'cuda':
+ bx, tx = s[C].split(C.op.axis[0], factor=64)
+ s[C].bind(bx, tvm.thread_axis('blockIdx.x'))
+ s[C].bind(tx, tvm.thread_axis('threadIdx.x'))
+
+ fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd')
+
+ fadd.save(osp.join(out_dir, 'test_add.o'))
+ if target == 'cuda':
+ fadd.imported_modules[0].save(os.path.join(out_dir, 'test_add.ptx'))
+ cc.create_shared(
+ osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')])
+
+
+if __name__ == '__main__':
+ main(sys.argv[1], sys.argv[2])
+
--- /dev/null
+[package]
+name = "callback"
+version = "0.0.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray = "0.12.1"
+tvm-frontend = { path = "../../" }
--- /dev/null
+#![feature(extern_crate_item_prelude, try_from)]
+#![allow(unused_imports)]
+
+extern crate ndarray as rust_ndarray;
+#[macro_use]
+extern crate tvm_frontend as tvm;
+
+use rust_ndarray::ArrayD;
+use std::convert::{TryFrom, TryInto};
+
+use tvm::*;
+
+fn main() {
+ register_global_func! {
+ fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ let mut ret = 0f32;
+ let shape = &mut [2];
+ for arg in args.iter() {
+ let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+ let arg: NDArray = arg.try_into()?;
+ let arr = arg.copy_to_ndarray(e)?;
+ let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
+ ret += rnd.scalar_sum();
+ }
+ Ok(TVMRetValue::from(ret))
+ }
+ }
+
+ let shape = &mut [2];
+ let mut data = vec![3f32, 4.0];
+ let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+ arr.copy_from_buffer(data.as_mut_slice());
+
+ let mut registered = function::Builder::default();
+ let ret: f32 = registered
+ .get_function("sum", true)
+ .arg(&arr)
+ .arg(&arr)
+ .invoke()
+ .unwrap()
+ .try_into()
+ .unwrap();
+ assert_eq!(ret, 14f32);
+}
--- /dev/null
+#![feature(extern_crate_item_prelude, panic_info_message)]
+#![allow(unused_imports)]
+
+use std::panic;
+
+#[macro_use]
+extern crate tvm_frontend as tvm;
+
+use tvm::*;
+
+fn main() {
+ register_global_func! {
+ fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ Err(ErrorKind::TypeMismatch(
+ format!("{}", "i64".to_string()),
+ format!("{}", "f64".to_string()),
+ ).into())
+ }
+ }
+
+ let mut registered = function::Builder::default();
+ registered.get_function("error", true);
+ assert!(registered.func.is_some());
+ registered.args(&[10, 20]);
+
+ println!("expected error message is:");
+ panic::set_hook(Box::new(|panic_info| {
+ if let Some(msg) = panic_info.message() {
+ println!("{:?}", msg);
+ }
+ if let Some(location) = panic_info.location() {
+ println!(
+ "panic occurred in file '{}' at line {}",
+ location.file(),
+ location.line()
+ );
+ } else {
+ println!("panic occurred but can't get location information");
+ }
+ }));
+
+ let _result = registered.invoke();
+}
--- /dev/null
+#![feature(extern_crate_item_prelude, try_from)]
+#![allow(unused_imports)]
+
+#[macro_use]
+extern crate tvm_frontend as tvm;
+
+use std::convert::TryInto;
+use tvm::*;
+
+fn main() {
+ register_global_func! {
+ fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ let mut ret = 0.0;
+ for arg in args.iter() {
+ let val: f64 = arg.try_into()?;
+ ret += val;
+ }
+ Ok(TVMRetValue::from(&ret))
+ }
+ }
+
+ let mut registered = function::Builder::default();
+ registered.get_function("sum", true);
+ assert!(registered.func.is_some());
+ let ret: f64 = registered
+ .args(&[10.0f64, 20.0, 30.0])
+ .invoke()
+ .unwrap()
+ .try_into()
+ .unwrap();
+ assert_eq!(ret, 60f64);
+}
--- /dev/null
+#![feature(extern_crate_item_prelude, try_from)]
+#![allow(unused_imports)]
+
+extern crate tvm_frontend as tvm;
+
+use std::convert::TryInto;
+use tvm::*;
+
+fn main() {
+ fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ let mut ret = 0i64;
+ for arg in args.iter() {
+ let val: i64 = arg.try_into()?;
+ ret += val;
+ }
+ Ok(TVMRetValue::from(&ret))
+ }
+
+ tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
+
+ let mut registered = function::Builder::default();
+ registered.get_function("mysum", true);
+ assert!(registered.func.is_some());
+ let ret: i64 = registered
+ .args(&[10, 20, 30])
+ .invoke()
+ .unwrap()
+ .try_into()
+ .unwrap();
+ assert_eq!(ret, 60);
+}
--- /dev/null
+#![feature(extern_crate_item_prelude, try_from)]
+#![allow(unused_imports)]
+
+#[macro_use]
+extern crate tvm_frontend as tvm;
+use std::convert::TryInto;
+use tvm::*;
+
+// FIXME
+fn main() {
+ register_global_func! {
+ fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ let mut ret = "".to_string();
+ for arg in args.iter() {
+ let val: String = arg.try_into()?;
+ ret += val.as_str();
+ }
+ Ok(TVMRetValue::from(ret))
+ }
+ }
+ let mut registered = function::Builder::default();
+ registered.get_function("concate_str", true);
+ assert!(registered.func.is_some());
+ let a = "a".to_string();
+ let b = "b".to_string();
+ let c = "c".to_string();
+ let ret: String = registered
+ .args(&[a, b, c])
+ .invoke()
+ .unwrap()
+ .try_into()
+ .unwrap();
+ assert_eq!(ret, "abc".to_owned());
+}
--- /dev/null
+Cargo.lock
+target/
+**/*.rs.bk
--- /dev/null
+language: rust
+rust:
+ - nightly
+matrix:
+ fast_finish: true
--- /dev/null
+[package]
+name = "tvm-runtime"
+version = "0.1.0"
+license = "Apache-2.0"
+description = "A static TVM runtime"
+repository = "https://github.com/dmlc/tvm"
+readme = "README.md"
+keywords = ["tvm", "nnvm"]
+categories = ["api-bindings", "science"]
+authors = ["TVM Contributors"]
+
+[features]
+default = ["nom/std"]
+sgx = ["nom/alloc"]
+
+[dependencies]
+bounded-spsc-queue = "0.4.0"
+error-chain = { version = "0.12.0", default-features = false }
+itertools = "0.7.8"
+lazy_static = "1.1.0"
+ndarray = "0.11.2"
+nom = {version = "4.0.0", default-features = false }
+serde = "1.0.59"
+serde_derive = "1.0.79"
+serde_json = "1.0.17"
+tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] }
+
+[target.'cfg(not(target_env = "sgx"))'.dependencies]
+num_cpus = "1.8.0"
--- /dev/null
+#[cfg(target_env = "sgx")]
+use alloc::alloc::{self, Layout};
+#[cfg(not(target_env = "sgx"))]
+use std::alloc::{self, Layout};
+
+use crate::errors::*;
+
+const DEFAULT_ALIGN_BYTES: usize = 4;
+
+#[derive(PartialEq, Eq)]
+pub struct Allocation {
+ layout: Layout,
+ ptr: *mut u8,
+}
+
+impl Allocation {
+ /// Allocates a chunk of memory of `size` bytes with optional alignment.
+ pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
+ let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
+ let layout = Layout::from_size_align(size, alignment)?;
+ let ptr = unsafe { alloc::alloc(layout.clone()) };
+ if ptr.is_null() {
+ alloc::handle_alloc_error(layout);
+ }
+ Ok(Self {
+ ptr: ptr,
+ layout: layout,
+ })
+ }
+
+ pub fn as_mut_ptr(&self) -> *mut u8 {
+ self.ptr
+ }
+
+ /// Returns the size of the Allocation in bytes.
+ pub fn size(&self) -> usize {
+ self.layout.size()
+ }
+
+ /// Returns the byte alignment of the Allocation.
+ pub fn align(&self) -> usize {
+ self.layout.align()
+ }
+}
+
+impl Drop for Allocation {
+ fn drop(&mut self) {
+ unsafe {
+ alloc::dealloc(self.ptr, self.layout.clone());
+ }
+ }
+}
--- /dev/null
+use std::{
+ any::TypeId,
+ convert::TryFrom,
+ mem,
+ ops::{Deref, DerefMut},
+ os::raw::{c_int, c_void},
+ ptr, slice,
+};
+
+use ndarray;
+
+use crate::{
+ allocator::Allocation,
+ errors::*,
+ ffi::runtime::{
+ DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
+ DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor,
+ },
+};
+
+/// A `Storage` is a container which holds `Tensor` data.
+#[derive(PartialEq)]
+pub enum Storage<'a> {
+ /// A `Storage` which owns its contained bytes.
+ Owned(Allocation),
+
+ /// A view of an existing `Storage`.
+ View(&'a mut [u8], usize), // ptr, align
+}
+
+impl<'a> Storage<'a> {
+ pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
+ Ok(Storage::Owned(Allocation::new(size, align)?))
+ }
+
+ pub fn as_mut_ptr(&self) -> *mut u8 {
+ match self {
+ Storage::Owned(alloc) => alloc.as_mut_ptr(),
+ Storage::View(slice, _) => slice.as_ptr() as *mut u8,
+ }
+ }
+
+ pub fn size(&self) -> usize {
+ match self {
+ Storage::Owned(alloc) => alloc.size(),
+ Storage::View(slice, _) => slice.len(),
+ }
+ }
+
+ pub fn align(&self) -> usize {
+ match self {
+ Storage::Owned(alloc) => alloc.align(),
+ Storage::View(_, align) => *align,
+ }
+ }
+
+ pub fn as_ptr(&self) -> *const u8 {
+ self.as_mut_ptr() as *const _
+ }
+
+ /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
+ pub fn view(&self) -> Storage<'a> {
+ match self {
+ Storage::Owned(alloc) => Storage::View(
+ unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
+ self.align(),
+ ),
+ Storage::View(slice, _) => Storage::View(
+ unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
+ self.align(),
+ ),
+ }
+ }
+
+ pub fn is_owned(&self) -> bool {
+ match self {
+ Storage::Owned(_) => true,
+ _ => false,
+ }
+ }
+
+ /// Returns an owned version of this storage via cloning.
+ pub fn to_owned(&self) -> Storage<'static> {
+ let s = Storage::new(self.size(), Some(self.align())).unwrap();
+ unsafe {
+ s.as_mut_ptr()
+ .copy_from_nonoverlapping(self.as_ptr(), self.size());
+ }
+ s
+ }
+}
+
+impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
+ fn from(data: &'d [T]) -> Self {
+ let data = unsafe {
+ slice::from_raw_parts_mut(
+ data.as_ptr() as *const u8 as *mut u8,
+ data.len() * mem::size_of::<T>() as usize,
+ )
+ };
+ Storage::View(data, mem::align_of::<T>())
+ }
+}
+
+/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
+/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
+/// converted to `ndarray::Array` for non-TVM processing.
+///
+/// # Examples
+///
+/// ```
+/// extern crate ndarray;
+///
+/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// let mut a: Tensor = a_nd.into();
+/// let mut a_dl: DLTensor = (&mut t).into();
+/// call_packed!(tvm_fn, &mut a_dl);
+///
+/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
+/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
+/// ```
+#[derive(PartialEq)]
+pub struct Tensor<'a> {
+ /// The bytes which contain the data this `Tensor` represents.
+ pub(crate) data: Storage<'a>,
+ pub(crate) ctx: TVMContext,
+ pub(crate) dtype: DataType,
+ pub(crate) shape: Vec<i64>,
+ // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
+ /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
+ pub(crate) strides: Option<Vec<usize>>,
+ pub(crate) byte_offset: isize,
+ /// The number of elements in the `Tensor`.
+ pub(crate) size: usize,
+}
+
+unsafe impl<'a> Send for Tensor<'a> {}
+
+impl<'a> Tensor<'a> {
+ pub fn shape(&self) -> Vec<i64> {
+ self.shape.clone()
+ }
+
+ /// Returns the data of this `Tensor` as a `Vec`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
+ pub fn to_vec<T: 'static + std::fmt::Debug + Clone>(&self) -> Vec<T> {
+ assert!(self.is_contiguous());
+ assert!(self.dtype.is_type::<T>());
+ unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() }
+ }
+
+ /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
+ pub fn is_contiguous(&self) -> bool {
+ match self.strides {
+ None => true,
+ Some(ref strides) => {
+ // check that stride for each dimension is the
+ // product of all trailing dimensons' shapes
+ self.shape
+ .iter()
+ .zip(strides)
+ .rfold(
+ (true, 1),
+ |(is_contig, expected_stride), (shape, stride)| {
+ (
+ is_contig && *stride == expected_stride,
+ expected_stride * (*shape as usize),
+ )
+ },
+ )
+ .0
+ }
+ }
+ }
+
+ /// Returns a clone of this `Tensor`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
+ pub fn copy(&mut self, other: &Tensor) {
+ assert!(
+ self.dtype == other.dtype && self.size == other.size,
+ "Tensor shape/dtype mismatch."
+ );
+ assert!(
+ self.is_contiguous() && other.is_contiguous(),
+ "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
+ self.strides,
+ other.strides
+ );
+ unsafe {
+ self.data
+ .as_mut_ptr()
+ .offset(self.byte_offset as isize)
+ .copy_from_nonoverlapping(
+ other.data.as_mut_ptr().offset(other.byte_offset),
+ other.size * other.dtype.itemsize(),
+ );
+ }
+ }
+
+ /// Returns an owned version of this `Tensor` via cloning.
+ pub fn to_owned(&self) -> Tensor<'static> {
+ let t = Tensor {
+ data: self.data.to_owned(),
+ ctx: self.ctx.clone(),
+ dtype: self.dtype.clone(),
+ size: self.size.clone(),
+ shape: self.shape.clone(),
+ strides: None,
+ byte_offset: 0,
+ };
+ unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
+ }
+
+ fn from_array_storage<'s, T, D: ndarray::Dimension>(
+ arr: &ndarray::Array<T, D>,
+ storage: Storage<'s>,
+ type_code: usize,
+ ) -> Tensor<'s> {
+ let type_width = mem::size_of::<T>() as usize;
+ Tensor {
+ data: storage,
+ ctx: TVMContext::default(),
+ dtype: DataType {
+ code: type_code,
+ bits: 8 * type_width,
+ lanes: 1,
+ },
+ size: arr.len(),
+ shape: arr.shape().iter().map(|&v| v as i64).collect(),
+ strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
+ byte_offset: 0,
+ }
+ }
+}
+
+/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
+macro_rules! impl_ndarray_try_from_tensor {
+ ($type:ty, $dtype:expr) => {
+ impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
+ type Error = Error;
+ fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
+ ensure!(
+ tensor.dtype == $dtype,
+ "Cannot convert Tensor with dtype {:?} to ndarray",
+ tensor.dtype
+ );
+ Ok(ndarray::Array::from_shape_vec(
+ tensor
+ .shape
+ .iter()
+ .map(|s| *s as usize)
+ .collect::<Vec<usize>>(),
+ tensor.to_vec::<$type>(),
+ )?)
+ }
+ }
+ };
+}
+
+impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
+impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
+impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
+impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
+
+pub struct DLTensor {
+ pub(crate) inner: _DLTensor,
+}
+
+impl Deref for DLTensor {
+ type Target = _DLTensor;
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+impl DerefMut for DLTensor {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.inner
+ }
+}
+
+impl DLTensor {
+ pub(crate) fn new(raw: _DLTensor) -> Self {
+ Self { inner: raw }
+ }
+
+ pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
+ assert!(!flatten || tensor.is_contiguous());
+ Self {
+ inner: _DLTensor {
+ data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
+ ctx: DLContext::from(&tensor.ctx),
+ ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
+ dtype: DLDataType::from(&tensor.dtype),
+ shape: if flatten {
+ &tensor.size as *const _ as *mut i64
+ } else {
+ tensor.shape.as_ptr()
+ } as *mut i64,
+ strides: if flatten || tensor.is_contiguous() {
+ ptr::null_mut()
+ } else {
+ tensor.strides.as_ref().unwrap().as_ptr()
+ } as *mut i64,
+ byte_offset: 0,
+ },
+ }
+ }
+}
+
+impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
+ fn from(tensor: &'a Tensor<'t>) -> Self {
+ DLTensor::from_tensor(tensor, false /* flatten */)
+ }
+}
+
+impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
+ fn from(tensor: &'a mut Tensor<'t>) -> Self {
+ DLTensor::from_tensor(tensor, false /* flatten */)
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct DataType {
+ pub(crate) code: usize,
+ pub(crate) bits: usize,
+ pub(crate) lanes: usize,
+}
+
+impl DataType {
+ /// Returns the number of bytes occupied by an element of this `DataType`.
+ pub fn itemsize(&self) -> usize {
+ (self.bits * self.lanes) >> 3
+ }
+
+ /// Returns whether this `DataType` represents primitive type `T`.
+ pub fn is_type<T: 'static>(&self) -> bool {
+ if self.lanes != 1 {
+ return false;
+ }
+ let typ = TypeId::of::<T>();
+ (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
+ || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
+ || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
+ || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
+ || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
+ || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
+ }
+}
+
+impl<'a> From<&'a DataType> for DLDataType {
+ fn from(dtype: &'a DataType) -> Self {
+ Self {
+ code: dtype.code as u8,
+ bits: dtype.bits as u8,
+ lanes: dtype.lanes as u16,
+ }
+ }
+}
+
+impl From<DLDataType> for DataType {
+ fn from(dtype: DLDataType) -> Self {
+ Self {
+ code: dtype.code as usize,
+ bits: dtype.bits as usize,
+ lanes: dtype.lanes as usize,
+ }
+ }
+}
+
+macro_rules! make_dtype_const {
+ ($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
+ const $name: DataType = DataType {
+ code: $code as usize,
+ bits: $bits,
+ lanes: $lanes,
+ };
+ };
+}
+
+make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
+make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
+// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
+make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
+make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub struct TVMContext {
+ pub(crate) device_type: usize,
+ pub(crate) device_id: usize,
+}
+
+impl<'a> From<&'a TVMContext> for DLContext {
+ fn from(ctx: &'a TVMContext) -> Self {
+ Self {
+ device_type: ctx.device_type as u32,
+ device_id: ctx.device_id as i32,
+ }
+ }
+}
+
+impl Default for TVMContext {
+ fn default() -> Self {
+ Self {
+ device_type: DLDeviceType_kDLCPU as usize,
+ device_id: 0,
+ }
+ }
+}
+
+impl<'a> From<DLTensor> for Tensor<'a> {
+ fn from(dlt: DLTensor) -> Self {
+ unsafe {
+ let dtype = DataType::from(dlt.dtype);
+ let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
+ let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
+ let storage = Storage::from(slice::from_raw_parts(
+ dlt.data as *const u8,
+ dtype.itemsize() * size,
+ ));
+ Self {
+ data: storage,
+ ctx: TVMContext::default(),
+ dtype: dtype,
+ size: size,
+ shape: shape,
+ strides: if dlt.strides == ptr::null_mut() {
+ None
+ } else {
+ Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
+ },
+ byte_offset: dlt.byte_offset as isize,
+ }
+ }
+ }
+}
+
+/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
+///
+/// # Panics
+///
+/// Panics if the ndarray is not contiguous.
+macro_rules! impl_tensor_from_ndarray {
+ ($type:ty, $typecode:expr) => {
+ impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
+ fn from(arr: ndarray::Array<$type, D>) -> Self {
+ let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
+ Tensor::from_array_storage(&arr, storage.to_owned(), $typecode as usize)
+ }
+ }
+ impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
+ fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
+ let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
+ Tensor::from_array_storage(arr, storage, $typecode as usize)
+ }
+ }
+ };
+}
+
+/// `From` conversions to `DLTensor` for `ndarray::Array`.
+/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
+macro_rules! impl_dltensor_from_ndarray {
+ ($type:ty, $typecode:expr) => {
+ impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
+ fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
+ DLTensor {
+ inner: _DLTensor {
+ data: arr.as_mut_ptr() as *mut c_void,
+ ctx: DLContext {
+ device_type: DLDeviceType_kDLCPU,
+ device_id: 0,
+ },
+ ndim: arr.ndim() as c_int,
+ dtype: DLDataType {
+ code: $typecode as u8,
+ bits: 8 * mem::size_of::<$type>() as u8,
+ lanes: 1,
+ },
+ shape: arr.shape().as_ptr() as *const i64 as *mut i64,
+ strides: arr.strides().as_ptr() as *const isize as *mut i64,
+ byte_offset: 0,
+ },
+ }
+ }
+ }
+ };
+}
+
+impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
+impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
+
+impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
+impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
+impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
+impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
+impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
+impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
--- /dev/null
+#[cfg(target_env = "sgx")]
+use alloc::alloc;
+#[cfg(not(target_env = "sgx"))]
+use std::alloc;
+use std::num;
+
+use crate::common::errors as common_errors;
+use ndarray;
+use serde_json;
+
+error_chain! {
+ errors {
+ GraphFormatError(msg: String) {
+ description("unable to load graph")
+ display("could not load graph json: {}", msg)
+ }
+
+ LoadGraphParamsError(msg: String) {
+ description("unable to load graph params")
+ display("could not load graph params: {}", msg)
+ }
+ }
+ foreign_links {
+ Alloc(alloc::AllocErr);
+ GraphDeserialize(serde_json::Error);
+ ParseInt(num::ParseIntError);
+ ShapeError(ndarray::ShapeError);
+ CommonError(common_errors::Error);
+ }
+}
+
+impl From<alloc::LayoutErr> for Error {
+ fn from(_err: alloc::LayoutErr) -> Error {
+ Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
+ }
+}
--- /dev/null
+use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
+
+use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
+use serde;
+use serde_json;
+
+use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor};
+use crate::{
+ common::value::TVMArgValue,
+ errors::{Error, ErrorKind, Result},
+ ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt},
+};
+
+// @see `kTVMNDArrayMagic` in `ndarray.h`
+const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
+// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
+const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7;
+
+/// A TVM computation graph.
+///
+/// # Examples
+///
+/// ```
+/// let graph_json = fs::read_to_string("graph.json")).unwrap();
+/// let graph = Graph::try_from(&graph_json).unwrap();
+/// ```
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Graph {
+ pub nodes: Vec<Node>,
+ pub arg_nodes: Vec<usize>,
+ pub heads: Vec<Entry>,
+ pub node_row_ptr: Option<Vec<usize>>,
+ pub attrs: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Entry {
+ pub id: usize,
+ pub index: usize,
+ pub version: usize,
+}
+
+impl Graph {
+ fn entry_index(&self, entry: &Entry) -> Result<usize> {
+ self.node_row_ptr
+ .as_ref()
+ .map(|nrp| nrp[entry.id] + entry.index)
+ .ok_or("Missing node_row_ptr.".into())
+ }
+
+ /// Attempt to deserialize a JSON attribute to a type `T`.
+ fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
+ Ok(serde_json::from_value::<T>(
+ self.attrs
+ .as_ref()
+ .ok_or(ErrorKind::GraphFormatError(
+ "Missing graph attrs".to_string(),
+ ))?
+ .get(attr)
+ .ok_or(ErrorKind::GraphFormatError(format!(
+ "Missing {} attr",
+ attr
+ )))?
+ .to_owned(),
+ )?)
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Node {
+ pub op: String,
+ pub name: String,
+ pub inputs: Vec<Entry>,
+ pub attrs: Option<HashMap<String, String>>,
+ pub control_deps: Option<Vec<Entry>>,
+}
+
+struct NodeAttrs {
+ func_name: String,
+ num_outputs: usize,
+ flatten_data: bool,
+}
+
+impl Node {
+ fn parse_attrs(&self) -> Result<NodeAttrs> {
+ let attrs = self
+ .attrs
+ .as_ref()
+ .ok_or(format!("Missing node.attrs for `{}`", self.name))?;
+ let func_name = attrs
+ .get("func_name")
+ .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
+ .to_string();
+ let num_outputs = attrs
+ .get("num_outputs")
+ .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
+ .parse::<usize>()?;
+ let flatten_data = attrs
+ .get("flatten_data")
+ .ok_or(format!(
+ "Node `{}` is missing attrs.flatten_data",
+ self.name
+ ))?
+ .parse::<u8>()?
+ == 1;
+ Ok(NodeAttrs {
+ func_name,
+ num_outputs,
+ flatten_data,
+ })
+ }
+}
+
+impl<'a> TryFrom<&'a String> for Graph {
+ type Error = Error;
+ fn try_from(graph_json: &String) -> Result<Self> {
+ let graph = serde_json::from_str(graph_json)?;
+ Ok(graph)
+ }
+}
+
+impl<'a> TryFrom<&'a str> for Graph {
+ type Error = Error;
+ fn try_from(graph_json: &'a str) -> Result<Self> {
+ let graph = serde_json::from_str(graph_json)?;
+ Ok(graph)
+ }
+}
+
+/// A executor for a TVM computation graph.
+///
+/// # Examples
+///
+/// ```
+/// use ndarray::Array;
+///
+/// let syslib = SystemLibModule::default(); // a provider of TVM functions
+///
+/// let mut params_bytes = Vec::new();
+/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
+/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap();
+///
+/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
+///
+/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
+/// exec.load_params(params);
+///
+/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// exec.set_input("data", x.into());
+/// exec.run();
+/// let output = exec.get_output(0).unwrap();
+///
+/// println!("{:#?}", Array::try_from(output).unwrap());
+/// ```
+pub struct GraphExecutor<'m, 't> {
+ graph: Graph,
+ op_execs: Vec<Box<Fn() + 'm>>,
+ tensors: Vec<Tensor<'t>>,
+}
+
+unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
+
+impl<'m, 't> GraphExecutor<'m, 't> {
+ pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
+ let tensors = Self::setup_storages(&graph)?;
+ Ok(GraphExecutor {
+ op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
+ tensors: tensors,
+ graph: graph,
+ })
+ }
+
+ /// Runs the computation graph.
+ pub fn run(&self) {
+ self.op_execs.iter().for_each(|op_exec| {
+ op_exec();
+ });
+ }
+
+ /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
+ fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
+ let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
+ let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
+ let dtypes = graph
+ .get_attr::<(String, Vec<String>)>("dltype")?
+ .1
+ .iter()
+ .map(|dltype| {
+ if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
+ Ok(dtype)
+ } else {
+ Err(ErrorKind::GraphFormatError(
+ format!("Invalid dltype: {}", dltype).to_string(),
+ )
+ .into())
+ }
+ })
+ .collect::<Result<Vec<DataType>>>()?;
+
+ let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
+ let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
+ for (i, &storage_id) in storage_ids.iter().enumerate() {
+ let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
+ let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
+ storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
+ }
+
+ let mut storages: Vec<Storage> = storage_num_bytes
+ .into_iter()
+ .map(|nbytes| Storage::new(nbytes, align))
+ .collect::<Result<Vec<Storage>>>()?;
+
+ let tensors = izip!(storage_ids, shapes, dtypes)
+ .map(|(storage_id, shape, dtype)| {
+ let storage = storages[storage_id].view();
+ Tensor {
+ data: mem::replace(&mut storages[storage_id], storage),
+ ctx: TVMContext::default(),
+ dtype: dtype,
+ size: shape.iter().product::<i64>() as usize,
+ shape: shape,
+ strides: None,
+ byte_offset: 0,
+ }
+ })
+ .collect();
+
+ Ok(tensors)
+ }
+
+ /// Creates closures which represent the computation performed by this graph.
+ fn setup_op_execs<M: 'm + Module>(
+ graph: &Graph,
+ lib: &'m M,
+ tensors: &Vec<Tensor<'t>>,
+ ) -> Result<Vec<Box<Fn() + 'm>>> {
+ ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
+ let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
+
+ let mut op_execs = Vec::new();
+ for (i, node) in graph.nodes.iter().enumerate() {
+ if node.op == "null" {
+ continue;
+ }
+ ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
+ ensure!(node.attrs.is_some(), "Missing node attrs.");
+
+ let attrs = node.parse_attrs()?;
+
+ if attrs.func_name == "__nop" {
+ continue;
+ }
+
+ let func = lib
+ .get_function(&attrs.func_name)
+ .ok_or(format!("Missing function {}", attrs.func_name))?;
+ let arg_indices = node
+ .inputs
+ .iter()
+ .map(|entry| graph.entry_index(entry))
+ .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi)));
+
+ let dl_tensors = arg_indices
+ .map(|idx| {
+ let tensor = &tensors[idx?];
+ Ok(if attrs.flatten_data {
+ DLTensor::from_tensor(tensor, true /* flatten */)
+ } else {
+ DLTensor::from(tensor)
+ })
+ })
+ .collect::<Result<Vec<DLTensor>>>()
+ .unwrap();
+ let op: Box<Fn()> = box move || {
+ let args = dl_tensors
+ .iter()
+ .map(|t| t.into())
+ .collect::<Vec<TVMArgValue>>();
+ func(args.as_slice());
+ };
+ op_execs.push(op);
+ }
+ Ok(op_execs)
+ }
+
+ pub fn load_params(&mut self, params: HashMap<String, Tensor>) {
+ params.into_iter().for_each(|(name, param)| {
+ self.set_input(name, param);
+ })
+ }
+
+ pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor) {
+ if let Some(idx) = self.get_input_index(name.as_ref()) {
+ // TODO: consider `new_with_params` to avoid ever allocating
+ let ptr = self.tensors[idx].data.as_ptr();
+ let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
+ let mut owner = to_replace.nth(0).unwrap();
+ if value.data.is_owned() {
+ // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
+ // mem::replace(&mut (*owner), value);
+ // to_replace.for_each(|t| {
+ // panic!("replacing");
+ // t.data = owner.data.view();
+ // });
+ owner.copy(&value);
+ } else {
+ owner.copy(&value);
+ }
+ } else {
+ println!("Unexpected input `{}`", name.as_ref());
+ }
+ }
+
+ /// Returns the graph input with name `name`, if it exists.
+ pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
+ self.get_input_index(name.as_ref())
+ .and_then(move |idx| Some(&self.tensors[idx]))
+ }
+
+ /// Returns the graph output with index `index`, if it exists.
+ pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
+ let graph = &self.graph;
+ graph.heads.get(idx).and_then(|entry| {
+ graph
+ .entry_index(entry)
+ .map(|idx| self.tensors.get(idx))
+ .unwrap_or(None)
+ })
+ }
+
+ /// Returns the index for graph input with name `name`, if it exists.
+ pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
+ let graph = &self.graph;
+ (0..graph.nodes.len())
+ .skip_while(|&i| graph.nodes[i].name != name.as_ref())
+ .nth(0)
+ .and_then(|i| {
+ if graph.arg_nodes.iter().any(|&id| id == i) {
+ graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
+ } else {
+ None
+ }
+ })
+ }
+}
+
+/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
+named!(
+ tvm_str_to_type<CompleteStr, DataType>,
+ do_parse!(
+ type_name: alpha1 >>
+ bits: digit1 >>
+ lanes: opt!(tuple!(tag!("x"), digit1)) >>
+ (DataType {
+ code: match type_name {
+ CompleteStr("int") => DLDataTypeCode_kDLInt,
+ CompleteStr("uint") => DLDataTypeCode_kDLUInt,
+ CompleteStr("float") => DLDataTypeCode_kDLFloat,
+ _ => DLDataTypeCode_kDLFloat,
+ } as usize,
+ bits: bits.parse::<u8>().unwrap() as usize,
+ lanes: match lanes {
+ Some(lanes) => lanes.1.parse::<u16>().unwrap() as usize,
+ None => 1,
+ },
+ })
+ )
+);
+
+/// Converts a bytes to String.
+named!(
+ name<String>,
+ map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
+ b.to_vec()
+ ))
+);
+
+/// Parses a TVMContext
+named!(
+ tvm_ctx<&[u8], TVMContext>,
+ do_parse!(
+ device_type: le_u32 >>
+ device_id: le_i32 >>
+ (TVMContext { device_type: device_type as usize, device_id: device_id as usize })
+ )
+);
+
+/// Parses a DataType
+named!(
+ data_type<&[u8], DataType>,
+ do_parse!(
+ code: le_u8 >>
+ bits: le_u8 >>
+ lanes: le_u16 >>
+ (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize })
+ )
+);
+
+/// Parses a Tensor from a TVM array file.
+named!(
+ tensor<Tensor>,
+ do_parse!(
+ take!(8)
+ >> bits!(tag_bits!(u64, 64, 0))
+ >> ctx: tvm_ctx
+ >> ndim: le_u32
+ >> dtype: data_type
+ >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize)
+ >> length: le_i64
+ >> data: take!(length)
+ >> (Tensor {
+ data: Storage::from(data),
+ ctx: ctx,
+ dtype: dtype,
+ size: shape.iter().product::<i64>() as usize,
+ shape: shape,
+ strides: None,
+ byte_offset: 0,
+ })
+ )
+);
+
+/// Parses a graph params dict from a params binary file.
+named!(
+ parse_param_dict<HashMap<String, Tensor>>,
+ do_parse!(
+ take!(8)
+ >> bits!(tag_bits!(u64, 64, 0))
+ >> names: length_count!(le_u64, name)
+ >> tensors: length_count!(le_u64, tensor)
+ >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter())))
+ )
+);
+
+/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
+pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
+ if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
+ if remaining_bytes.len() > 0 {
+ bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
+ } else {
+ Ok(param_dict)
+ }
+ } else {
+ bail!(ErrorKind::LoadGraphParamsError(
+ "invalid parameters file".to_string()
+ ))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_str_to_type() {
+ assert_eq!(
+ tvm_str_to_type(CompleteStr("float24")).unwrap().1,
+ DataType {
+ code: DLDataTypeCode_kDLFloat as usize,
+ bits: 24,
+ lanes: 1
+ }
+ );
+ assert_eq!(
+ tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1,
+ DataType {
+ code: DLDataTypeCode_kDLUInt as usize,
+ bits: 111,
+ lanes: 44
+ }
+ );
+ }
+}
--- /dev/null
+//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`.
+//! It's mainly useful for compiling to WebAssembly and SGX,
+//! but also native if you prefer Rust to C++.
+//!
+//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`.
+//! Single-function modules are used via the `packed_func!` macro after obtaining
+//! the function from `runtime::SystemLibModule`
+//!
+//! The main entrypoints to this crate are `GraphExecutor`
+//! For examples of use, please refer to the multi-file tests in the `tests` directory.
+
+#![feature(
+ alloc,
+ allocator_api,
+ box_syntax,
+ fn_traits,
+ try_from,
+ unboxed_closures,
+ vec_remove_item
+)]
+
+#[cfg(target_env = "sgx")]
+extern crate alloc;
+extern crate bounded_spsc_queue;
+#[cfg(target_env = "sgx")]
+extern crate core;
+#[macro_use]
+extern crate error_chain;
+#[macro_use]
+extern crate itertools;
+#[macro_use]
+extern crate lazy_static;
+extern crate ndarray;
+#[macro_use]
+extern crate nom;
+#[cfg(not(target_env = "sgx"))]
+extern crate num_cpus;
+extern crate serde;
+#[macro_use]
+extern crate serde_derive;
+extern crate serde_json;
+extern crate tvm_common as common;
+
+mod allocator;
+mod array;
+pub mod errors;
+mod module;
+#[macro_use]
+mod packed_func;
+mod graph;
+#[cfg(target_env = "sgx")]
+#[macro_use]
+pub mod sgx;
+mod threading;
+mod workspace;
+
+pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue};
+
+pub use self::{
+ array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*,
+};
+
+#[cfg(target_env = "sgx")]
+use self::sgx::ocall_packed_func;
+
+#[no_mangle]
+pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
+ #[cfg(not(target_env = "sgx"))]
+ unsafe {
+ panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
+ }
+ #[cfg(target_env = "sgx")]
+ ocall_packed!("__sgx_set_last_error__", cmsg);
+}
--- /dev/null
+use std::{
+ collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
+};
+
+use crate::{
+ ffi::runtime::BackendPackedCFunc,
+ packed_func::{wrap_backend_packed_func, PackedFunc},
+};
+
+pub trait Module {
+ fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
+}
+
+pub struct SystemLibModule;
+
+lazy_static! {
+ static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
+ Mutex::new(HashMap::new());
+}
+
+impl Module for SystemLibModule {
+ fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
+ SYSTEM_LIB_FUNCTIONS
+ .lock()
+ .unwrap()
+ .get(name.as_ref())
+ .map(|func| wrap_backend_packed_func(func.to_owned()))
+ }
+}
+
+impl Default for SystemLibModule {
+ fn default() -> Self {
+ SystemLibModule {}
+ }
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
+ cname: *const c_char,
+ func: BackendPackedCFunc,
+) -> i32 {
+ let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
+ SYSTEM_LIB_FUNCTIONS
+ .lock()
+ .unwrap()
+ .insert(name.to_string(), func);
+ return 0;
+}
--- /dev/null
+use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
+
+use super::Tensor;
+use crate::ffi::runtime::{
+ BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
+ TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
+};
+
+use super::DLTensor;
+use crate::{
+ common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
+ errors::*,
+};
+
+pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
+
+/// Calls a packed function and returns a `TVMRetValue`.
+///
+/// # Example
+///
+/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
+#[macro_export]
+macro_rules! call_packed {
+ ($fn:expr, $($args:expr),+) => {
+ $fn(&[$($args.into(),)+])
+ };
+ ($fn:expr) => {
+ $fn(&Vec::new())
+ };
+}
+
+impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
+ fn from(arr: &'a DLTensor) -> Self {
+ let raw = _TVMValue {
+ v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
+ };
+ TVMArgValue {
+ value: TVMValue::new(raw),
+ type_code: TVMTypeCode::kArrayHandle,
+ lifetime: PhantomData,
+ }
+ }
+}
+
+impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
+ fn from(arr: &'a mut DLTensor) -> Self {
+ let raw = _TVMValue {
+ v_handle: arr as *mut _ as *mut c_void,
+ };
+ TVMArgValue {
+ value: TVMValue::new(raw),
+ type_code: TVMTypeCode::kArrayHandle,
+ lifetime: PhantomData,
+ }
+ }
+}
+
+impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
+ type Error = Error;
+ fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
+ ensure!(
+ val.type_code == TVMTypeCode::kArrayHandle
+ || val.type_code == TVMTypeCode::kNDArrayContainer,
+ "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
+ TVMTypeCode::kArrayHandle,
+ TVMTypeCode::kNDArrayContainer,
+ val.type_code,
+ );
+
+ let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
+ Ok(DLTensor::new(dlt).into())
+ }
+}
+
+impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
+ fn from(val: &'t Tensor<'a>) -> Self {
+ TVMRetValue {
+ prim_value: 0,
+ box_value: box DLTensor::from(val),
+ type_code: TVMTypeCode::kNDArrayContainer,
+ }
+ }
+}
+
+impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<Self> {
+ ensure!(
+ ret.type_code == TVMTypeCode::kArrayHandle
+ || ret.type_code == TVMTypeCode::kNDArrayContainer,
+ "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
+ TVMTypeCode_kArrayHandle,
+ TVMTypeCode_kNDArrayContainer,
+ ret.type_code,
+ );
+
+ let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
+ Ok(DLTensor::new(dlt).into())
+ }
+}
+
+// @see `WrapPackedFunc` in `llvm_module.cc`.
+pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
+ box move |args: &[TVMArgValue]| {
+ func(
+ args.iter()
+ .map(|ref arg| arg.value.inner)
+ .collect::<Vec<_TVMValue>>()
+ .as_ptr(),
+ args.iter()
+ .map(|ref arg| arg.type_code as i32)
+ .collect::<Vec<i32>>()
+ .as_ptr() as *const i32,
+ args.len() as i32,
+ );
+ TVMRetValue::default()
+ }
+}
--- /dev/null
+use std::{
+ ffi::CString,
+ os::raw::{c_char, c_int},
+};
+
+use errors::Result;
+use ffi::runtime::TVMValue;
+use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
+
+pub use runtime::threading::tvm_run_worker as run_worker;
+
+#[macro_export]
+macro_rules! tvm_ocall {
+ ($func: expr) => {
+ match $func {
+ 0 => Ok(()),
+ err => Err(format!("SGX error: {}", err)),
+ }
+ };
+}
+
+pub type SgxStatus = u32;
+
+#[cfg(target_env = "sgx")]
+extern "C" {
+ fn tvm_ocall_packed_func(
+ name: *const c_char,
+ arg_values: *const TVMValue,
+ type_codes: *const c_int,
+ num_args: c_int,
+ ret_val: *mut TVMValue,
+ ret_type_code: *mut c_int,
+ ) -> SgxStatus;
+}
+
+pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ let mut ret_val = TVMValue { v_int64: 0 };
+ let ret_type_code = 0i64;
+ unsafe {
+ tvm_ocall!(tvm_ocall_packed_func(
+ CString::new(fn_name.as_ref()).unwrap().as_ptr(),
+ args.iter()
+ .map(|ref arg| arg.value)
+ .collect::<Vec<TVMValue>>()
+ .as_ptr(),
+ args.iter()
+ .map(|ref arg| arg.type_code as i32)
+ .collect::<Vec<i32>>()
+ .as_ptr() as *const i32,
+ args.len() as i32,
+ &mut ret_val as *mut TVMValue,
+ &mut (ret_type_code as i32) as *mut c_int,
+ ))?;
+ }
+ Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
+}
+
+#[macro_export]
+macro_rules! ocall_packed {
+ ($fn_name:expr, $($args:expr),+) => {
+ ocall_packed_func($fn_name, &[$($args.into(),)+])
+ .expect(concat!("Error calling `", $fn_name, "`"))
+ };
+ ($fn_name:expr) => {
+ ocall_packed_func($fn_name, &Vec::new())
+ .expect(concat!("Error calling `", $fn_name, "`"))
+ }
+}
+
+pub fn shutdown() {
+ if env!("TVM_NUM_THREADS") != "0" {
+ sgx_join_threads()
+ }
+}
+
+impl Drop for SystemLibModule {
+ fn drop(&mut self) {
+ shutdown()
+ }
+}
--- /dev/null
+use std::{
+ os::raw::{c_int, c_void},
+ sync::{
+ atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
+ Arc, Barrier,
+ },
+};
+
+#[cfg(not(target_env = "sgx"))]
+use num_cpus;
+#[cfg(not(target_env = "sgx"))]
+use std::{
+ env,
+ thread::{self, JoinHandle},
+};
+
+#[cfg(target_env = "sgx")]
+use std::{collections::VecDeque, ptr, sync::Mutex};
+
+use bounded_spsc_queue::{self, Producer};
+
+use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv};
+
+#[cfg(target_env = "sgx")]
+use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
+
+type FTVMParallelLambda =
+ extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
+
+/// Holds a parallel job request made by a TVM library function.
+struct Job {
+ cb: FTVMParallelLambda,
+ cdata: *const c_void,
+ req_num_tasks: usize,
+ pending: Arc<AtomicUsize>,
+}
+
+impl Job {
+ /// Splits this job into a number of `Task`s which can be scheduled.
+ fn tasks(&self, num_workers: usize) -> Vec<Task> {
+ let num_tasks = if self.req_num_tasks == 0 {
+ num_workers
+ } else {
+ self.req_num_tasks.min(num_workers)
+ };
+ self.pending.store(num_tasks, Ordering::SeqCst);
+
+ let barrier = Arc::new(Barrier::new(num_tasks));
+
+ (0..num_tasks)
+ .map(move |i| Task {
+ id: i,
+ flambda: self.cb,
+ penv: TVMParallelGroupEnv {
+ sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
+ num_task: num_tasks as i32,
+ },
+ cdata: self.cdata,
+ pending: Arc::clone(&self.pending),
+ })
+ .collect()
+ }
+
+ /// Waits for all tasks in this `Job` to be completed.
+ fn wait(&self) -> Result<()> {
+ while self.pending.load(Ordering::Acquire) > 0 {
+ #[cfg(not(target_env = "sgx"))]
+ thread::yield_now();
+ }
+ Ok(())
+ }
+}
+
+/// A chunk of work requested by a TVM function.
+struct Task {
+ id: usize,
+ flambda: FTVMParallelLambda,
+ penv: TVMParallelGroupEnv,
+ cdata: *const c_void,
+ pending: Arc<AtomicUsize>,
+}
+unsafe impl Send for Task {}
+unsafe impl Sync for Task {}
+
+impl FnOnce<()> for Task {
+ type Output = i32;
+ extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
+ let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
+ self.pending.fetch_sub(1, Ordering::AcqRel);
+ status
+ }
+}
+
+#[derive(Default)]
+struct Threads {
+ #[allow(unused)]
+ #[cfg(not(target_env = "sgx"))]
+ handles: Vec<JoinHandle<()>>,
+ queues: Vec<Producer<Task>>,
+}
+
+impl<'a> Threads {
+ #[cfg(not(target_env = "sgx"))]
+ fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
+ num_threads: usize,
+ cb: F,
+ ) -> Self {
+ let (handles, queues) = (0..num_threads)
+ .map(|_| {
+ let (p, c) = bounded_spsc_queue::make(2);
+ let handle = thread::spawn(move || cb(c.into()));
+ (handle, p)
+ })
+ .unzip();
+ Threads {
+ handles: handles,
+ queues: queues,
+ }
+ }
+
+ #[cfg(target_env = "sgx")]
+ fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
+ num_threads: usize,
+ _cb: F,
+ ) -> Self {
+ let mut consumer_queues = SGX_QUEUES.lock().unwrap();
+ let queues = (0..num_threads)
+ .map(|_| {
+ let (p, c) = bounded_spsc_queue::make(2);
+ consumer_queues.push_back(c.into());
+ p
+ })
+ .collect();
+ ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
+ Threads { queues: queues }
+ }
+}
+
+struct ThreadPool {
+ num_workers: usize,
+ #[allow(unused)]
+ threads: Threads,
+}
+
+thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
+
+impl ThreadPool {
+ fn new() -> Self {
+ let num_workers = max_concurrency();
+ ThreadPool {
+ num_workers: num_workers,
+ threads: Threads::launch(num_workers, ThreadPool::run_worker),
+ }
+ }
+
+ fn launch(&self, job: Job) {
+ let mut tasks = job.tasks(self.num_workers + 1);
+
+ for (i, task) in tasks.split_off(1).into_iter().enumerate() {
+ self.threads.queues[i].push(task);
+ }
+
+ tasks.pop().unwrap()();
+ job.wait().unwrap();
+ }
+
+ fn run_worker(queue: Consumer<Task>) {
+ loop {
+ let task = queue.pop();
+ let result = task();
+ if result == <i32>::min_value() {
+ break;
+ } else if result != 0 {
+ panic!("Error running task.");
+ }
+ }
+ }
+}
+
+// Send + Sync wrapper for bounded_spsc_queue::Consumer
+struct Consumer<T> {
+ consumer: bounded_spsc_queue::Consumer<T>,
+}
+impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
+ fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
+ Consumer { consumer: c }
+ }
+}
+impl<T> Consumer<T> {
+ fn pop(&self) -> T {
+ self.consumer.pop()
+ }
+}
+unsafe impl<T> Send for Consumer<T> {}
+unsafe impl<T> Sync for Consumer<T> {}
+
+#[cfg(target_env = "sgx")]
+lazy_static! {
+ /// Holds tasks for untrusted threads which re-enter the enclave to execute.
+ static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
+}
+
+#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
+fn max_concurrency() -> usize {
+ if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
+ if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
+ return threads;
+ }
+ }
+ num_cpus::get_physical()
+}
+
+#[cfg(target_env = "sgx")]
+fn max_concurrency() -> usize {
+ usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
+}
+
+#[cfg(target_arch = "wasm32")]
+fn max_concurrency() -> usize {
+ 0 // wasm doesn't support threads yet
+}
+
+#[cfg(target_env = "sgx")]
+pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
+ let q = {
+ let mut qs = SGX_QUEUES.lock().unwrap();
+ qs.pop_front()
+ // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
+ };
+ if let Some(q) = q {
+ ThreadPool::run_worker(q);
+ }
+ TVMRetValue::default()
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendParallelLaunch(
+ cb: FTVMParallelLambda,
+ cdata: *const c_void,
+ num_task: usize,
+) -> c_int {
+ if max_concurrency() == 0 {
+ let penv = TVMParallelGroupEnv {
+ sync_handle: 0 as *mut c_void,
+ num_task: 1,
+ };
+ cb(0, &penv as *const _, cdata);
+ } else {
+ THREAD_POOL.with(|pool| {
+ pool.launch(Job {
+ cb: cb,
+ cdata: cdata,
+ req_num_tasks: num_task,
+ pending: Arc::new(ATOMIC_USIZE_INIT),
+ });
+ });
+ }
+ return 0;
+}
+
+#[cfg(target_env = "sgx")]
+pub(crate) fn sgx_join_threads() {
+ extern "C" fn poison_pill(
+ _task_id: usize,
+ _penv: *const TVMParallelGroupEnv,
+ _cdata: *const c_void,
+ ) -> i32 {
+ <i32>::min_value()
+ }
+
+ THREAD_POOL.with(|pool| {
+ pool.launch(Job {
+ cb: poison_pill,
+ cdata: ptr::null(),
+ req_num_tasks: 0,
+ pending: Arc::new(ATOMIC_USIZE_INIT),
+ });
+ });
+ ocall_packed!("__sgx_thread_group_join__", 0);
+}
+
+// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used.
+#[no_mangle]
+pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
+ let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
+ barrier.wait();
+}
+
+#[cfg(test)]
+mod tests {
+ use std::{ptr, thread, time::Duration};
+
+ use super::*;
+
+ #[test]
+ fn test_max_concurrency() {
+ env::set_var("TVM_NUM_THREADS", "42");
+ env::set_var("OMP_NUM_THREADS", "24");
+ assert_eq!(max_concurrency(), 42);
+ env::remove_var("TVM_NUM_THREADS");
+ assert_eq!(max_concurrency(), 24);
+ }
+
+ extern "C" fn flambda(
+ task_id: usize,
+ penv: *const TVMParallelGroupEnv,
+ cdata: *const c_void,
+ ) -> i32 {
+ if cdata == ptr::null() {
+ return 0;
+ }
+ unsafe {
+ let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
+ thread::sleep(Duration::from_millis(50 * task_id as u64));
+ counter.fetch_add(1, Ordering::SeqCst);
+ task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
+ assert_eq!((*penv).num_task, 3);
+ }
+ 0
+ }
+
+ #[test]
+ fn test_parallel_launch() {
+ TVMBackendParallelLaunch(flambda, ptr::null(), 6);
+ let counter = ATOMIC_USIZE_INIT;
+ let task_ids_sum = ATOMIC_USIZE_INIT;
+ let cdata = (counter, task_ids_sum);
+ let num_tasks = 3;
+ TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
+ assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
+ assert_eq!(
+ cdata.1.load(Ordering::SeqCst),
+ (0..num_tasks).sum::<usize>()
+ );
+ }
+}
--- /dev/null
+use std::{
+ cell::RefCell,
+ os::raw::{c_int, c_void},
+ ptr,
+};
+
+use super::allocator::Allocation;
+use crate::errors::*;
+
+const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
+
+struct WorkspacePool {
+ workspaces: Vec<Allocation>,
+ free: Vec<usize>,
+ in_use: Vec<usize>,
+}
+
+impl WorkspacePool {
+ fn new() -> Self {
+ WorkspacePool {
+ workspaces: Vec::new(),
+ free: Vec::new(),
+ in_use: Vec::new(),
+ }
+ }
+
+ fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
+ self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
+ self.in_use.push(self.workspaces.len() - 1);
+ Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
+ }
+
+ fn alloc(&mut self, size: usize) -> Result<*mut u8> {
+ if self.free.len() == 0 {
+ return self.alloc_new(size);
+ }
+ let idx = self
+ .free
+ .iter()
+ .fold(None, |cur_ws_idx: Option<usize>, &idx| {
+ let ws_size = self.workspaces[idx].size();
+ if !ws_size >= size {
+ return cur_ws_idx;
+ }
+ cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
+ let cur_size = self.workspaces[cur_idx].size();
+ Some(match ws_size <= cur_size {
+ true => idx,
+ false => cur_idx,
+ })
+ })
+ });
+ match idx {
+ Some(idx) => {
+ self.free.remove_item(&idx).unwrap();
+ self.in_use.push(idx);
+ Ok(self.workspaces[idx].as_mut_ptr())
+ }
+ None => self.alloc_new(size),
+ }
+ }
+
+ fn free(&mut self, ptr: *mut u8) -> Result<()> {
+ let mut ws_idx = None;
+ for i in 0..self.in_use.len() {
+ let idx = self.in_use[i];
+ if self.workspaces[idx].as_mut_ptr() == ptr {
+ self.in_use.remove(i);
+ ws_idx = Some(idx);
+ break;
+ }
+ }
+ Ok(self
+ .free
+ .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?))
+ }
+}
+
+thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
+
+const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
+
+#[no_mangle]
+pub extern "C" fn TVMBackendAllocWorkspace(
+ _device_type: c_int,
+ _device_id: c_int,
+ size: u64,
+ _dtype_code_hint: c_int,
+ _dtype_bits_hint: c_int,
+) -> *mut c_void {
+ let nbytes = if size == 0 {
+ WORKSPACE_PAGE_SIZE
+ } else {
+ size as usize
+ };
+ WORKSPACE_POOL.with(|pool_cell| {
+ pool_cell
+ .borrow_mut()
+ .alloc(nbytes as usize)
+ .unwrap_or(ptr::null_mut()) as *mut c_void
+ })
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendFreeWorkspace(
+ _device_type: c_int,
+ _device_id: c_int,
+ ptr: *mut c_void,
+) -> c_int {
+ WORKSPACE_POOL.with(|pool_cell| {
+ (match pool_cell.borrow_mut().free(ptr as *mut u8) {
+ Ok(()) => 0,
+ Err(_) => -1,
+ }) as c_int
+ });
+ return 0;
+}
--- /dev/null
+*.json
+*.params
+*.o
--- /dev/null
+#!/usr/bin/env python3
+
+"""Builds a simple NNVM graph for testing."""
+
+from os import path as osp
+
+import nnvm
+from nnvm import sym
+from nnvm.compiler import graph_util
+from nnvm.testing import init
+import numpy as np
+import tvm
+
+CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+
+
+def _get_model(dshape):
+ data = sym.Variable('data', shape=dshape)
+ fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True)
+ left, right = sym.split(fc1, indices_or_sections=2, axis=1)
+ return sym.Group(((left + 1), (right - 1)))
+
+
+def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
+ if isinstance(graph, sym.Symbol):
+ graph = nnvm.graph.create(graph)
+ ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
+ param_shapes = dict(zip(graph.index.input_names, ishapes))
+ np.random.seed(seed)
+ params = {}
+ for param, shape in param_shapes.items():
+ if param in {'data', 'label'} or not shape:
+ continue
+ init_value = np.empty(shape).astype('float32')
+ initializer(param, init_value)
+ params[param] = tvm.nd.array(init_value)
+ return params
+
+def main():
+ dshape = (32, 16)
+ net = _get_model(dshape)
+ ishape_dict = {'data': dshape}
+ params = _init_params(net, ishape_dict)
+ graph, lib, params = nnvm.compiler.build(net, 'llvm',
+ shape=ishape_dict,
+ params=params,
+ dtype='float32')
+
+ with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet:
+ f_resnet.write(graph.json())
+ with open(osp.join(CWD, 'graph.params'), 'wb') as f_params:
+ f_params.write(nnvm.compiler.save_param_dict(params))
+
+if __name__ == '__main__':
+ main()
--- /dev/null
+#![feature(try_from)]
+
+extern crate serde;
+extern crate serde_json;
+
+extern crate tvm_runtime;
+
+use std::{convert::TryFrom, fs, io::Read};
+
+use tvm_runtime::Graph;
+
+#[test]
+fn test_load_graph() {
+ let mut params_bytes = Vec::new();
+ fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
+ .expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
+ .read_to_end(&mut params_bytes)
+ .unwrap();
+ let _params = tvm_runtime::load_param_dict(¶ms_bytes);
+
+ let graph = Graph::try_from(
+ &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
+ )
+ .unwrap();
+
+ assert_eq!(graph.nodes[3].op, "tvm_op");
+ assert_eq!(
+ graph.nodes[3]
+ .attrs
+ .as_ref()
+ .unwrap()
+ .get("func_name")
+ .unwrap(),
+ "fuse_dense"
+ );
+ assert_eq!(graph.nodes[5].inputs[0].index, 0);
+ assert_eq!(graph.nodes[6].inputs[0].index, 1);
+ assert_eq!(graph.heads.len(), 2);
+}
--- /dev/null
+[package]
+name = "test-nnvm"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray = "0.11.2"
+serde = "1.0.59"
+serde_json = "1.0.17"
+tvm-runtime = { path = "../../" }
+
+[build-dependencies]
+ar = "0.6.0"
--- /dev/null
+extern crate ar;
+
+use std::{env, fs::File, path::Path, process::Command};
+
+use ar::Builder;
+
+fn main() {
+ let out_dir = env::var("OUT_DIR").unwrap();
+
+ let output = Command::new(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/src/build_test_graph.py"
+ ))
+ .arg(&out_dir)
+ .output()
+ .expect("Failed to execute command");
+ assert!(
+ Path::new(&format!("{}/graph.o", out_dir)).exists(),
+ "Could not build graph lib: {}",
+ String::from_utf8(output.stderr)
+ .unwrap()
+ .trim()
+ .split("\n")
+ .last()
+ .unwrap_or("")
+ );
+
+ let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap());
+ builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
+
+ println!("cargo:rustc-link-lib=static=graph");
+ println!("cargo:rustc-link-search=native={}", out_dir);
+}
--- /dev/null
+#!/usr/bin/env python3
+
+"""Builds a simple NNVM graph for testing."""
+
+from os import path as osp
+import sys
+
+import nnvm
+from nnvm import sym
+from nnvm.compiler import graph_util
+from nnvm.testing import init
+import numpy as np
+import tvm
+
+
+def _get_model(dshape):
+ data = sym.Variable('data', shape=dshape)
+ fc = sym.dense(data, units=dshape[-1]*2, use_bias=True)
+ left, right = sym.split(fc, indices_or_sections=2, axis=1)
+ return sym.Group(((left + 1), (right - 1), fc))
+
+
+def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
+ if isinstance(graph, sym.Symbol):
+ graph = nnvm.graph.create(graph)
+
+ ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
+ param_shapes = dict(zip(graph.index.input_names, ishapes))
+ np.random.seed(seed)
+ params = {}
+ for param, shape in param_shapes.items():
+ if param in {'data', 'label'} or not shape:
+ continue
+
+ init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32')
+ if param.endswith('_bias'):
+ params[param] = tvm.nd.array(init_value)
+ continue
+
+ init_value = np.empty(shape).astype('float32')
+ initializer(param, init_value)
+ # init_value /= init_value.sum() + 1e-10
+ params[param] = tvm.nd.array(init_value)
+
+ return params
+
+def main():
+ dshape = (4, 8)
+ net = _get_model(dshape)
+ ishape_dict = {'data': dshape}
+ params = _init_params(net, ishape_dict)
+ graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib',
+ shape=ishape_dict,
+ params=params,
+ dtype='float32')
+
+ out_dir = sys.argv[1]
+ lib.save(osp.join(sys.argv[1], 'graph.o'))
+ with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
+ f_resnet.write(graph.json())
+
+ with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
+ f_params.write(nnvm.compiler.save_param_dict(params))
+
+if __name__ == '__main__':
+ main()
--- /dev/null
+#![feature(try_from)]
+
+#[macro_use]
+extern crate ndarray;
+extern crate serde;
+extern crate serde_json;
+
+extern crate tvm_runtime;
+use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
+
+use ndarray::Array;
+use tvm_runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
+
+const BATCH_SIZE: usize = 4;
+const IN_DIM: usize = 8;
+
+macro_rules! check_sum {
+ ($e:expr, $a:ident, $b:ident) => {
+ let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
+ check_sum!(a, $b);
+ };
+ ($e:expr, $a:expr, $b:ident) => {
+ let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
+ check_sum!(a, $b);
+ };
+ ($a:ident, $b:ident) => {
+ let a_sum: f32 = $a.scalar_sum();
+ let b_sum: f32 = $b.scalar_sum();
+ assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
+ };
+}
+
+fn main() {
+ let syslib = SystemLibModule::default();
+
+ let mut params_bytes = Vec::new();
+ fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
+ .unwrap()
+ .read_to_end(&mut params_bytes)
+ .unwrap();
+ let params = tvm_runtime::load_param_dict(¶ms_bytes)
+ .unwrap()
+ .into_iter()
+ .map(|(k, v)| (k, v.to_owned()))
+ .collect::<HashMap<String, Tensor<'static>>>();
+
+ let graph =
+ Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap())
+ .unwrap();
+ let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
+
+ let x = Array::from_shape_vec(
+ (BATCH_SIZE, IN_DIM),
+ (0..BATCH_SIZE * IN_DIM)
+ .map(|x| x as f32)
+ .collect::<Vec<f32>>(),
+ )
+ .unwrap();
+ let w = Array::try_from(params.get("dense0_weight").unwrap())
+ .unwrap()
+ .into_shape((IN_DIM * 2, IN_DIM))
+ .unwrap();
+ let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
+ let dense = x.dot(&w.t()) + &b;
+ let left = dense.slice(s![.., 0..IN_DIM]);
+ let right = dense.slice(s![.., IN_DIM..]);
+ let expected_o0 = &left + 1f32;
+ let expected_o1 = &right - 1f32;
+
+ exec.load_params(params);
+ exec.set_input("data", (&x).into());
+
+ check_sum!(exec, data, x);
+ check_sum!(exec, dense0_weight, w);
+ check_sum!(exec, dense0_bias, b);
+
+ exec.run();
+
+ check_sum!(exec, 0, expected_o0);
+ check_sum!(exec, 1, expected_o1);
+ check_sum!(exec, 2, dense);
+}
--- /dev/null
+[package]
+name = "test-tvm-basic"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray = "0.11.2"
+tvm-runtime = { path = "../../" }
+
+[build-dependencies]
+ar = "0.6.0"
--- /dev/null
+extern crate ar;
+
+use std::{env, path::Path, process::Command};
+
+use ar::Builder;
+use std::fs::File;
+
+fn main() {
+ let out_dir = env::var("OUT_DIR").unwrap();
+
+ let output = Command::new(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/src/build_test_lib.py"
+ ))
+ .arg(&out_dir)
+ .output()
+ .expect("Failed to execute command");
+ assert!(
+ Path::new(&format!("{}/test.o", out_dir)).exists(),
+ "Could not build tvm lib: {}",
+ String::from_utf8(output.stderr)
+ .unwrap()
+ .trim()
+ .split("\n")
+ .last()
+ .unwrap_or("")
+ );
+
+ let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap());
+ builder.append_path(format!("{}/test.o", out_dir)).unwrap();
+
+ println!("cargo:rustc-link-lib=static=test");
+ println!("cargo:rustc-link-search=native={}", out_dir);
+}
--- /dev/null
+#!/usr/bin/env python3
+
+"""Prepares a simple TVM library for testing."""
+
+from os import path as osp
+import sys
+
+import tvm
+
+def main():
+ n = tvm.var('n')
+ A = tvm.placeholder((n,), name='A')
+ B = tvm.placeholder((n,), name='B')
+ C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+ s = tvm.create_schedule(C.op)
+ s[C].parallel(s[C].op.axis[0])
+ print(tvm.lower(s, [A, B, C], simple_mode=True))
+ tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
+
+if __name__ == '__main__':
+ main()
--- /dev/null
+extern crate ndarray;
+#[macro_use]
+extern crate tvm_runtime;
+
+use ndarray::Array;
+use tvm_runtime::{DLTensor, Module, SystemLibModule};
+
+fn main() {
+ let syslib = SystemLibModule::default();
+ let add = syslib
+ .get_function("default_function")
+ .expect("main function not found");
+ let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
+ let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
+ let mut c = Array::from_vec(vec![0f32; 4]);
+ let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
+ let mut a_dl: DLTensor = (&mut a).into();
+ let mut b_dl: DLTensor = (&mut b).into();
+ let mut c_dl: DLTensor = (&mut c).into();
+ call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
+ assert!(c.all_close(&e, 1e-8f32));
+}
+++ /dev/null
-#[cfg(target_env = "sgx")]
-use alloc::alloc;
-#[cfg(not(target_env = "sgx"))]
-use std::alloc;
-use std::num;
-
-use ndarray;
-use serde_json;
-
-error_chain! {
- errors {
- TryFromTVMRetValueError(expected: String, actual: i64) {
- description("mismatched types while downcasting TVMRetValue")
- display("invalid downcast: expected `{}` but was `{}`", expected, actual)
- }
-
- GraphFormatError(msg: String) {
- description("unable to load graph")
- display("could not load graph json: {}", msg)
- }
-
- LoadGraphParamsError(msg: String) {
- description("unable to load graph params")
- display("could not load graph params: {}", msg)
- }
- }
- foreign_links {
- Alloc(alloc::AllocErr);
- GraphDeserialize(serde_json::Error);
- ParseInt(num::ParseIntError);
- ShapeError(ndarray::ShapeError);
- }
-}
-
-impl From<alloc::LayoutErr> for Error {
- fn from(_err: alloc::LayoutErr) -> Error {
- Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
- }
-}
+++ /dev/null
-//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`.
-//! It's mainly useful for compiling to WebAssembly and SGX,
-//! but also native if you prefer Rust to C++.
-//!
-//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`.
-//! Single-function modules are used via the `packed_func!` macro after obtaining
-//! the function from `runtime::SystemLibModule`
-//!
-//! The main entrypoints to this crate are `GraphExecutor`
-//! For examples of use, please refer to the multi-file tests in the `tests` directory.
-
-#![feature(
- alloc,
- allocator_api,
- box_syntax,
- fn_traits,
- try_from,
- unboxed_closures,
- vec_remove_item
-)]
-
-#[cfg(target_env = "sgx")]
-extern crate alloc;
-extern crate bounded_spsc_queue;
-#[cfg(target_env = "sgx")]
-extern crate core;
-#[macro_use]
-extern crate error_chain;
-#[macro_use]
-extern crate itertools;
-#[macro_use]
-extern crate lazy_static;
-extern crate ndarray;
-#[macro_use]
-extern crate nom;
-#[cfg(not(target_env = "sgx"))]
-extern crate num_cpus;
-extern crate serde;
-#[macro_use]
-extern crate serde_derive;
-extern crate serde_json;
-
-pub mod ffi {
- #![allow(
- non_camel_case_types,
- non_snake_case,
- non_upper_case_globals,
- unused
- )]
-
- pub mod runtime {
- use std::os::raw::{c_char, c_int, c_void};
-
- include!(concat!(
- env!("CARGO_MANIFEST_DIR"),
- "/src/runtime/c_runtime_api.rs"
- ));
-
- pub type BackendPackedCFunc =
- extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
- }
-}
-
-pub mod errors;
-pub mod runtime;
-
-pub use errors::*;
+++ /dev/null
-#[cfg(target_env = "sgx")]
-use alloc::alloc::{self, Layout};
-#[cfg(not(target_env = "sgx"))]
-use std::alloc::{self, Layout};
-
-use errors::*;
-
-const DEFAULT_ALIGN_BYTES: usize = 4;
-
-#[derive(PartialEq, Eq)]
-pub struct Allocation {
- layout: Layout,
- ptr: *mut u8,
-}
-
-impl Allocation {
- /// Allocates a chunk of memory of `size` bytes with optional alignment.
- pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
- let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
- let layout = Layout::from_size_align(size, alignment)?;
- let ptr = unsafe { alloc::alloc(layout.clone()) };
- if ptr.is_null() {
- alloc::handle_alloc_error(layout);
- }
- Ok(Self {
- ptr: ptr,
- layout: layout,
- })
- }
-
- pub fn as_mut_ptr(&self) -> *mut u8 {
- self.ptr
- }
-
- /// Returns the size of the Allocation in bytes.
- pub fn size(&self) -> usize {
- self.layout.size()
- }
-
- /// Returns the byte alignment of the Allocation.
- pub fn align(&self) -> usize {
- self.layout.align()
- }
-}
-
-impl Drop for Allocation {
- fn drop(&mut self) {
- unsafe {
- alloc::dealloc(self.ptr, self.layout.clone());
- }
- }
-}
+++ /dev/null
-use std::{
- any::TypeId,
- convert::TryFrom,
- mem,
- os::raw::{c_int, c_void},
- ptr, slice,
-};
-
-use ndarray;
-
-use super::allocator::Allocation;
-use errors::*;
-use ffi::runtime::{
- DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
- DLDeviceType_kDLCPU, DLTensor,
-};
-
-/// A `Storage` is a container which holds `Tensor` data.
-#[derive(PartialEq)]
-pub enum Storage<'a> {
- /// A `Storage` which owns its contained bytes.
- Owned(Allocation),
-
- /// A view of an existing `Storage`.
- View(&'a mut [u8], usize), // ptr, align
-}
-
-impl<'a> Storage<'a> {
- pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
- Ok(Storage::Owned(Allocation::new(size, align)?))
- }
-
- pub fn as_mut_ptr(&self) -> *mut u8 {
- match self {
- Storage::Owned(alloc) => alloc.as_mut_ptr(),
- Storage::View(slice, _) => slice.as_ptr() as *mut u8,
- }
- }
-
- pub fn size(&self) -> usize {
- match self {
- Storage::Owned(alloc) => alloc.size(),
- Storage::View(slice, _) => slice.len(),
- }
- }
-
- pub fn align(&self) -> usize {
- match self {
- Storage::Owned(alloc) => alloc.align(),
- Storage::View(_, align) => *align,
- }
- }
-
- pub fn as_ptr(&self) -> *const u8 {
- self.as_mut_ptr() as *const _
- }
-
- /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
- pub fn view(&self) -> Storage<'a> {
- match self {
- Storage::Owned(alloc) => Storage::View(
- unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
- self.align(),
- ),
- Storage::View(slice, _) => Storage::View(
- unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
- self.align(),
- ),
- }
- }
-
- pub fn is_owned(&self) -> bool {
- match self {
- Storage::Owned(_) => true,
- _ => false,
- }
- }
-
- /// Returns an owned version of this storage via cloning.
- pub fn to_owned(&self) -> Storage<'static> {
- let s = Storage::new(self.size(), Some(self.align())).unwrap();
- unsafe {
- s.as_mut_ptr()
- .copy_from_nonoverlapping(self.as_ptr(), self.size())
- }
- s
- }
-}
-
-impl<'a, T> From<&'a [T]> for Storage<'a> {
- fn from(data: &'a [T]) -> Self {
- let data = unsafe {
- slice::from_raw_parts_mut(
- data.as_ptr() as *const u8 as *mut u8,
- data.len() * mem::size_of::<T>() as usize,
- )
- };
- Storage::View(data, mem::align_of::<T>())
- }
-}
-
-/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
-/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
-/// converted to `ndarray::Array` for non-TVM processing.
-///
-/// # Examples
-///
-/// ```
-/// extern crate ndarray;
-///
-/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
-/// let mut a: Tensor = a_nd.into();
-/// let mut a_dl: DLTensor = (&mut t).into();
-/// call_packed!(tvm_fn, &mut a_dl);
-///
-/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
-/// let mut a_nd = ndarray::Array::try_from(&a).unwrap();
-/// ```
-#[derive(PartialEq)]
-pub struct Tensor<'a> {
- /// The bytes which contain the data this `Tensor` represents.
- pub(super) data: Storage<'a>,
- pub(super) ctx: TVMContext,
- pub(super) dtype: DataType,
- pub(super) shape: Vec<i64>, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
- /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
- pub(super) strides: Option<Vec<usize>>,
- pub(super) byte_offset: isize,
- /// The number of elements in the `Tensor`.
- pub(super) size: usize,
-}
-
-unsafe impl<'a> Send for Tensor<'a> {}
-
-impl<'a> Tensor<'a> {
- pub fn shape(&self) -> Vec<i64> {
- self.shape.clone()
- }
-
- /// Returns the data of this `Tensor` as a `Vec`.
- ///
- /// # Panics
- ///
- /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
- pub fn to_vec<T: 'static>(&self) -> Vec<T> {
- assert!(self.is_contiguous());
- assert!(self.dtype.is_type::<T>());
- let mut vec: Vec<T> = Vec::with_capacity(self.size * self.dtype.itemsize());
- unsafe {
- vec.as_mut_ptr().copy_from_nonoverlapping(
- self.data.as_ptr().offset(self.byte_offset) as *const T,
- self.size,
- );
- vec.set_len(self.size);
- }
- vec
- }
-
- /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
- pub fn is_contiguous(&self) -> bool {
- match self.strides {
- None => true,
- Some(ref strides) => {
- // check that stride for each dimension is the product of all trailing dimensons' shapes
- self
- .shape
- .iter()
- .zip(strides)
- .rfold(
- (true, 1),
- |(is_contig, expected_stride), (shape, stride)| {
- (
- is_contig && *stride == expected_stride,
- expected_stride * (*shape as usize),
- )
- },
- )
- .0
- }
- }
- }
-
- /// Returns a clone of this `Tensor`.
- ///
- /// # Panics
- ///
- /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
- pub fn copy(&mut self, other: &Tensor) {
- assert!(
- self.dtype == other.dtype && self.size == other.size,
- "Tensor shape/dtype mismatch."
- );
- assert!(
- self.is_contiguous() && other.is_contiguous(),
- "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
- self.strides,
- other.strides
- );
- unsafe {
- self
- .data
- .as_mut_ptr()
- .offset(self.byte_offset as isize)
- .copy_from_nonoverlapping(
- other.data.as_mut_ptr().offset(other.byte_offset),
- other.size * other.dtype.itemsize(),
- );
- }
- }
-
- /// Returns an owned version of this `Tensor` via cloning.
- pub fn to_owned(&self) -> Tensor<'static> {
- let t = Tensor {
- data: self.data.to_owned(),
- ctx: self.ctx.clone(),
- dtype: self.dtype.clone(),
- size: self.size.clone(),
- shape: self.shape.clone(),
- strides: None,
- byte_offset: 0,
- };
- unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
- }
-
- fn from_array_storage<'s, T, D: ndarray::Dimension>(
- arr: &ndarray::Array<T, D>,
- storage: Storage<'s>,
- type_code: usize,
- ) -> Tensor<'s> {
- let type_width = mem::size_of::<T>() as usize;
- Tensor {
- data: storage,
- ctx: TVMContext::default(),
- dtype: DataType {
- code: type_code,
- bits: 8 * type_width,
- lanes: 1,
- },
- size: arr.len(),
- shape: arr.shape().iter().map(|&v| v as i64).collect(),
- strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
- byte_offset: 0,
- }
- }
-}
-
-/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
-macro_rules! impl_ndarray_try_from_tensor {
- ($type:ty, $dtype:expr) => {
- impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
- type Error = Error;
- fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
- ensure!(
- tensor.dtype == $dtype,
- "Cannot convert Tensor with dtype {:?} to ndarray",
- tensor.dtype
- );
- Ok(ndarray::Array::from_shape_vec(
- tensor
- .shape
- .iter()
- .map(|s| *s as usize)
- .collect::<Vec<usize>>(),
- tensor.to_vec::<$type>(),
- )?)
- }
- }
- };
-}
-
-impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
-impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
-impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
-impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
-
-impl DLTensor {
- pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
- assert!(!flatten || tensor.is_contiguous());
- Self {
- data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
- ctx: DLContext::from(&tensor.ctx),
- ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
- dtype: DLDataType::from(&tensor.dtype),
- shape: if flatten {
- &tensor.size as *const _ as *mut i64
- } else {
- tensor.shape.as_ptr()
- } as *mut i64,
- strides: if flatten || tensor.is_contiguous() {
- ptr::null_mut()
- } else {
- tensor.strides.as_ref().unwrap().as_ptr()
- } as *mut i64,
- byte_offset: 0,
- }
- }
-}
-
-impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
- fn from(tensor: &'a Tensor<'t>) -> Self {
- DLTensor::from_tensor(tensor, false /* flatten */)
- }
-}
-
-impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
- fn from(tensor: &'a mut Tensor<'t>) -> Self {
- DLTensor::from_tensor(tensor, false /* flatten */)
- }
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq)]
-pub struct DataType {
- pub(super) code: usize,
- pub(super) bits: usize,
- pub(super) lanes: usize,
-}
-
-impl DataType {
- /// Returns the number of bytes occupied by an element of this `DataType`.
- pub fn itemsize(&self) -> usize {
- (self.bits * self.lanes) >> 3
- }
-
- /// Returns whether this `DataType` represents primitive type `T`.
- pub fn is_type<T: 'static>(&self) -> bool {
- if self.lanes != 1 {
- return false;
- }
- let typ = TypeId::of::<T>();
- (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
- || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
- || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
- || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
- || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
- || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
- }
-}
-
-impl<'a> From<&'a DataType> for DLDataType {
- fn from(dtype: &'a DataType) -> Self {
- Self {
- code: dtype.code as u8,
- bits: dtype.bits as u8,
- lanes: dtype.lanes as u16,
- }
- }
-}
-
-impl From<DLDataType> for DataType {
- fn from(dtype: DLDataType) -> Self {
- Self {
- code: dtype.code as usize,
- bits: dtype.bits as usize,
- lanes: dtype.lanes as usize,
- }
- }
-}
-
-macro_rules! make_dtype_const {
- ($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
- const $name: DataType = DataType {
- code: $code as usize,
- bits: $bits,
- lanes: $lanes,
- };
- };
-}
-
-make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
-make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
-// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
-make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
-make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
-
-impl Default for DLContext {
- fn default() -> Self {
- DLContext {
- device_type: DLDeviceType_kDLCPU,
- device_id: 0,
- }
- }
-}
-
-#[derive(Debug, Clone, Copy, PartialEq)]
-pub struct TVMContext {
- pub(super) device_type: usize,
- pub(super) device_id: usize,
-}
-
-impl<'a> From<&'a TVMContext> for DLContext {
- fn from(ctx: &'a TVMContext) -> Self {
- Self {
- device_type: ctx.device_type as u32,
- device_id: ctx.device_id as i32,
- }
- }
-}
-
-impl Default for TVMContext {
- fn default() -> Self {
- Self {
- device_type: DLDeviceType_kDLCPU as usize,
- device_id: 0,
- }
- }
-}
-
-impl<'a> From<DLTensor> for Tensor<'a> {
- fn from(dlt: DLTensor) -> Self {
- unsafe {
- let dtype = DataType::from(dlt.dtype);
- let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
- let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
- let storage = Storage::from(slice::from_raw_parts(
- dlt.data as *const u8,
- dtype.itemsize() * size,
- ));
- Self {
- data: storage,
- ctx: TVMContext::default(),
- dtype: dtype,
- size: size,
- shape: shape,
- strides: if dlt.strides == ptr::null_mut() {
- None
- } else {
- Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
- },
- byte_offset: dlt.byte_offset as isize,
- }
- }
- }
-}
-
-/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
-///
-/// # Panics
-///
-/// Panics if the ndarray is not contiguous.
-macro_rules! impl_tensor_from_ndarray {
- ($type:ty, $typecode:expr) => {
- impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
- fn from(arr: ndarray::Array<$type, D>) -> Self {
- assert!(arr.is_standard_layout(), "Array must be contiguous.");
- let size = arr.len() * mem::size_of::<$type>() as usize;
- let storage =
- Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) });
- Tensor::from_array_storage(&arr, storage, $typecode as usize)
- }
- }
- impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
- fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
- assert!(arr.is_standard_layout(), "Array must be contiguous.");
- Tensor::from_array_storage(
- arr,
- Storage::from(arr.as_slice().unwrap()),
- $typecode as usize,
- )
- }
- }
- };
-}
-
-/// `From` conversions to `DLTensor` for `ndarray::Array`.
-/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
-macro_rules! impl_dltensor_from_ndarray {
- ($type:ty, $typecode:expr) => {
- impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
- fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
- DLTensor {
- data: arr.as_mut_ptr() as *mut c_void,
- ctx: DLContext::default(),
- ndim: arr.ndim() as c_int,
- dtype: DLDataType {
- code: $typecode as u8,
- bits: 8 * mem::size_of::<$type>() as u8,
- lanes: 1,
- },
- shape: arr.shape().as_ptr() as *const i64 as *mut i64,
- strides: arr.strides().as_ptr() as *const isize as *mut i64,
- byte_offset: 0,
- }
- }
- }
- };
-}
-
-impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
-impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
-impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
-impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
-impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
-impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
-
-impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
-impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
-impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
-impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
-impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
-impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
+++ /dev/null
-/* automatically generated by rust-bindgen for TVM revision 6292c78 */
-
-pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
-pub const DLPACK_VERSION: u32 = 8;
-pub const _STDINT_H: u32 = 1;
-pub const _FEATURES_H: u32 = 1;
-pub const _DEFAULT_SOURCE: u32 = 1;
-pub const __USE_ISOC11: u32 = 1;
-pub const __USE_ISOC99: u32 = 1;
-pub const __USE_ISOC95: u32 = 1;
-pub const __USE_POSIX_IMPLICITLY: u32 = 1;
-pub const _POSIX_SOURCE: u32 = 1;
-pub const _POSIX_C_SOURCE: u32 = 200809;
-pub const __USE_POSIX: u32 = 1;
-pub const __USE_POSIX2: u32 = 1;
-pub const __USE_POSIX199309: u32 = 1;
-pub const __USE_POSIX199506: u32 = 1;
-pub const __USE_XOPEN2K: u32 = 1;
-pub const __USE_XOPEN2K8: u32 = 1;
-pub const _ATFILE_SOURCE: u32 = 1;
-pub const __USE_MISC: u32 = 1;
-pub const __USE_ATFILE: u32 = 1;
-pub const __USE_FORTIFY_LEVEL: u32 = 0;
-pub const _STDC_PREDEF_H: u32 = 1;
-pub const __STDC_IEC_559__: u32 = 1;
-pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
-pub const __STDC_ISO_10646__: u32 = 201505;
-pub const __STDC_NO_THREADS__: u32 = 1;
-pub const __GNU_LIBRARY__: u32 = 6;
-pub const __GLIBC__: u32 = 2;
-pub const __GLIBC_MINOR__: u32 = 23;
-pub const _SYS_CDEFS_H: u32 = 1;
-pub const __WORDSIZE: u32 = 64;
-pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
-pub const __SYSCALL_WORDSIZE: u32 = 64;
-pub const _BITS_WCHAR_H: u32 = 1;
-pub const INT8_MIN: i32 = -128;
-pub const INT16_MIN: i32 = -32768;
-pub const INT32_MIN: i32 = -2147483648;
-pub const INT8_MAX: u32 = 127;
-pub const INT16_MAX: u32 = 32767;
-pub const INT32_MAX: u32 = 2147483647;
-pub const UINT8_MAX: u32 = 255;
-pub const UINT16_MAX: u32 = 65535;
-pub const UINT32_MAX: u32 = 4294967295;
-pub const INT_LEAST8_MIN: i32 = -128;
-pub const INT_LEAST16_MIN: i32 = -32768;
-pub const INT_LEAST32_MIN: i32 = -2147483648;
-pub const INT_LEAST8_MAX: u32 = 127;
-pub const INT_LEAST16_MAX: u32 = 32767;
-pub const INT_LEAST32_MAX: u32 = 2147483647;
-pub const UINT_LEAST8_MAX: u32 = 255;
-pub const UINT_LEAST16_MAX: u32 = 65535;
-pub const UINT_LEAST32_MAX: u32 = 4294967295;
-pub const INT_FAST8_MIN: i32 = -128;
-pub const INT_FAST16_MIN: i64 = -9223372036854775808;
-pub const INT_FAST32_MIN: i64 = -9223372036854775808;
-pub const INT_FAST8_MAX: u32 = 127;
-pub const INT_FAST16_MAX: u64 = 9223372036854775807;
-pub const INT_FAST32_MAX: u64 = 9223372036854775807;
-pub const UINT_FAST8_MAX: u32 = 255;
-pub const UINT_FAST16_MAX: i32 = -1;
-pub const UINT_FAST32_MAX: i32 = -1;
-pub const INTPTR_MIN: i64 = -9223372036854775808;
-pub const INTPTR_MAX: u64 = 9223372036854775807;
-pub const UINTPTR_MAX: i32 = -1;
-pub const PTRDIFF_MIN: i64 = -9223372036854775808;
-pub const PTRDIFF_MAX: u64 = 9223372036854775807;
-pub const SIG_ATOMIC_MIN: i32 = -2147483648;
-pub const SIG_ATOMIC_MAX: u32 = 2147483647;
-pub const SIZE_MAX: i32 = -1;
-pub const WINT_MIN: u32 = 0;
-pub const WINT_MAX: u32 = 4294967295;
-pub type int_least8_t = ::std::os::raw::c_schar;
-pub type int_least16_t = ::std::os::raw::c_short;
-pub type int_least32_t = ::std::os::raw::c_int;
-pub type int_least64_t = ::std::os::raw::c_long;
-pub type uint_least8_t = ::std::os::raw::c_uchar;
-pub type uint_least16_t = ::std::os::raw::c_ushort;
-pub type uint_least32_t = ::std::os::raw::c_uint;
-pub type uint_least64_t = ::std::os::raw::c_ulong;
-pub type int_fast8_t = ::std::os::raw::c_schar;
-pub type int_fast16_t = ::std::os::raw::c_long;
-pub type int_fast32_t = ::std::os::raw::c_long;
-pub type int_fast64_t = ::std::os::raw::c_long;
-pub type uint_fast8_t = ::std::os::raw::c_uchar;
-pub type uint_fast16_t = ::std::os::raw::c_ulong;
-pub type uint_fast32_t = ::std::os::raw::c_ulong;
-pub type uint_fast64_t = ::std::os::raw::c_ulong;
-pub type intmax_t = ::std::os::raw::c_long;
-pub type uintmax_t = ::std::os::raw::c_ulong;
-pub type wchar_t = ::std::os::raw::c_int;
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct max_align_t {
- pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
- pub __bindgen_padding_0: u64,
- pub __clang_max_align_nonce2: f64,
-}
-pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
-pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
-pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
-pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
-pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
-pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
-pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
-/// \brief The device type in DLContext.
-pub type DLDeviceType = u32;
-/// \brief A Device context for Tensor and operator.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLContext {
- /// \brief The device type used in the device.
- pub device_type: DLDeviceType,
- /// \brief The device index
- pub device_id: ::std::os::raw::c_int,
-}
-pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
-pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
-pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
-/// \brief The type code options DLDataType.
-pub type DLDataTypeCode = u32;
-/// \brief The data type the tensor can hold.
-///
-/// Examples
-/// - float: type_code = 2, bits = 32, lanes=1
-/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
-/// - int8: type_code = 0, bits = 8, lanes=1
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLDataType {
- /// \brief Type code of base types.
- /// We keep it uint8_t instead of DLDataTypeCode for minimal memory
- /// footprint, but the value should be one of DLDataTypeCode enum values.
- ///
- pub code: u8,
- /// \brief Number of bits, common choices are 8, 16, 32.
- pub bits: u8,
- /// \brief Number of lanes in the type, used for vector types.
- pub lanes: u16,
-}
-/// \brief Plain C Tensor object, does not manage memory.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLTensor {
- /// \brief The opaque data pointer points to the allocated data.
- /// This will be CUDA device pointer or cl_mem handle in OpenCL.
- /// This pointer is always aligns to 256 bytes as in CUDA.
- pub data: *mut ::std::os::raw::c_void,
- /// \brief The device context of the tensor
- pub ctx: DLContext,
- /// \brief Number of dimensions
- pub ndim: ::std::os::raw::c_int,
- /// \brief The data type of the pointer
- pub dtype: DLDataType,
- /// \brief The shape of the tensor
- pub shape: *mut i64,
- /// \brief strides of the tensor,
- /// can be NULL, indicating tensor is compact.
- pub strides: *mut i64,
- /// \brief The offset in bytes to the beginning pointer to data
- pub byte_offset: u64,
-}
-/// \brief C Tensor object, manage memory of DLTensor. This data structure is
-/// intended to faciliate the borrowing of DLTensor by another framework. It is
-/// not meant to transfer the tensor. When the borrowing framework doesn't need
-/// the tensor, it should call the deleter to notify the host that the resource
-/// is no longer needed.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLManagedTensor {
- /// \brief DLTensor which is being memory managed
- pub dl_tensor: DLTensor,
- /// \brief the context of the original host framework of DLManagedTensor in
- /// which DLManagedTensor is used in the framework. It can also be NULL.
- pub manager_ctx: *mut ::std::os::raw::c_void,
- /// \brief Destructor signature void (*)(void*) - this should be called
- /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
- /// if there is no way for the caller to provide a reasonable destructor.
- pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
-}
-/// \brief type of array index.
-pub type tvm_index_t = i64;
-pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
-pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
-pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
-pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
-pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
-/// \brief Extension device types in TVM
-pub type TVMDeviceExtType = u32;
-pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
-pub const TVMTypeCode_kNull: TVMTypeCode = 4;
-pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
-pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
-pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
-pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
-pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
-pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
-pub const TVMTypeCode_kStr: TVMTypeCode = 11;
-pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
-pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
-pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
-pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
-pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
-pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
-pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
-/// \brief The type code in TVMType
-/// \note TVMType is used in two places.
-pub type TVMTypeCode = u32;
-/// \brief The data type used in TVM Runtime.
-///
-/// Examples
-/// - float: type_code = 2, bits = 32, lanes=1
-/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
-/// - int8: type_code = 0, bits = 8, lanes=1
-///
-/// \note Arguments TVM API function always takes bits=64 and lanes=1
-pub type TVMType = DLDataType;
-/// \brief The Device information, abstract away common device types.
-pub type TVMContext = DLContext;
-/// \brief The tensor array stucture to TVM API.
-pub type TVMArray = DLTensor;
-/// \brief the array handle
-pub type TVMArrayHandle = *mut TVMArray;
-/// \brief Union type of values
-/// being passed through API and function calls.
-#[repr(C)]
-#[derive(Copy, Clone)]
-pub union TVMValue {
- pub v_int64: i64,
- pub v_float64: f64,
- pub v_handle: *mut ::std::os::raw::c_void,
- pub v_str: *const ::std::os::raw::c_char,
- pub v_type: TVMType,
- pub v_ctx: TVMContext,
- _bindgen_union_align: u64,
-}
-/// \brief Byte array type used to pass in byte array
-/// When kBytes is used as data type.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct TVMByteArray {
- pub data: *const ::std::os::raw::c_char,
- pub size: usize,
-}
-/// \brief Handle to TVM runtime modules.
-pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
-/// \brief Handle to packed function handle.
-pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
-/// \brief Handle to hold return value.
-pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
-/// \brief The stream that is specific to device
-/// can be NULL, which indicates the default one.
-pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
-extern "C" {
- /// \brief Used for implementing C API function.
- /// Set last error message before return.
- /// \param msg The error message to be set.
- pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
-}
-extern "C" {
- /// \brief return str message of the last error
- /// all function in this file will return 0 when success
- /// and -1 when an error occured,
- /// TVMGetLastError can be called to retrieve the error
- ///
- /// this function is threadsafe and can be called by different thread
- /// \return error info
- pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
-}
-extern "C" {
- /// \brief Load module from file.
- /// \param file_name The file name to load the module from.
- /// \param format The format of the module.
- /// \param out The result module
- ///
- /// \return 0 when success, -1 when failure happens
- /// \note The resulting module do not contain import relation.
- /// It can be reconstructed by TVMModImport.
- pub fn TVMModLoadFromFile(
- file_name: *const ::std::os::raw::c_char,
- format: *const ::std::os::raw::c_char,
- out: *mut TVMModuleHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Add dep to mod's dependency.
- /// This allows functions in this module to use modules.
- ///
- /// \param mod The module handle.
- /// \param dep The dependent module to be imported.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Get function from the module.
- /// \param mod The module handle.
- /// \param func_name The name of the function.
- /// \param query_imports Whether to query imported modules
- /// \param out The result function, can be NULL if it is not available.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMModGetFunction(
- mod_: TVMModuleHandle,
- func_name: *const ::std::os::raw::c_char,
- query_imports: ::std::os::raw::c_int,
- out: *mut TVMFunctionHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free front-end extension type resource.
- /// \param handle The extension handle.
- /// \param type_code The type of of the extension type.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMExtTypeFree(
- handle: *mut ::std::os::raw::c_void,
- type_code: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free the Module
- /// \param mod The module to be freed.
- ///
- /// \note This may not free up the module's resources.
- /// If there is active TVMFunctionHandle uses the module
- /// Or if this module is imported by another active module.
- ///
- /// The all functions remains valid until TVMFuncFree is called.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free the function when it is no longer needed.
- /// \param func The function handle
- /// \return 0 when success, -1 when failure happens
- pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Call a Packed TVM Function.
- ///
- /// \param func node handle of the function.
- /// \param arg_values The arguments
- /// \param type_codes The type codes of the arguments
- /// \param num_args Number of arguments.
- ///
- /// \param ret_val The return value.
- /// \param ret_type_code the type code of return value.
- ///
- /// \return 0 when success, -1 when failure happens
- /// \note TVM calls always exchanges with type bits=64, lanes=1
- ///
- /// \note API calls always exchanges with type bits=64, lanes=1
- /// If API call returns container handles (e.g. FunctionHandle)
- /// these handles should be managed by the front-end.
- /// The front-end need to call free function (e.g. TVMFuncFree)
- /// to free these handles.
- pub fn TVMFuncCall(
- func: TVMFunctionHandle,
- arg_values: *mut TVMValue,
- type_codes: *mut ::std::os::raw::c_int,
- num_args: ::std::os::raw::c_int,
- ret_val: *mut TVMValue,
- ret_type_code: *mut ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Set the return value of TVMPackedCFunc.
- ///
- /// This function is called by TVMPackedCFunc to set the return value.
- /// When this function is not called, the function returns null by default.
- ///
- /// \param ret The return value handle, pass by ret in TVMPackedCFunc
- /// \param value The value to be returned.
- /// \param type_code The type of the value to be returned.
- /// \param num_ret Number of return values, for now only 1 is supported.
- pub fn TVMCFuncSetReturn(
- ret: TVMRetValueHandle,
- value: *mut TVMValue,
- type_code: *mut ::std::os::raw::c_int,
- num_ret: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Inplace translate callback argument value to return value.
- /// This is only needed for non-POD arguments.
- ///
- /// \param value The value to be translated.
- /// \param code The type code to be translated.
- /// \note This function will do a shallow copy when necessary.
- ///
- /// \return 0 when success, -1 when failure happens.
- pub fn TVMCbArgToReturn(
- value: *mut TVMValue,
- code: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-/// \brief C type of packed function.
-///
-/// \param args The arguments
-/// \param type_codes The type codes of the arguments
-/// \param num_args Number of arguments.
-/// \param ret The return value handle.
-/// \param resource_handle The handle additional resouce handle from fron-end.
-/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
-/// \sa TVMCFuncSetReturn
-pub type TVMPackedCFunc = ::std::option::Option<
- unsafe extern "C" fn(
- args: *mut TVMValue,
- type_codes: *mut ::std::os::raw::c_int,
- num_args: ::std::os::raw::c_int,
- ret: TVMRetValueHandle,
- resource_handle: *mut ::std::os::raw::c_void,
- ) -> ::std::os::raw::c_int,
->;
-/// \brief C callback to free the resource handle in C packed function.
-/// \param resource_handle The handle additional resouce handle from fron-end.
-pub type TVMPackedCFuncFinalizer =
- ::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
-/// \brief Signature for extension function declarer.
-///
-/// TVM call this function to get the extension functions
-/// The declarer will call register_func to register function and their name.
-///
-/// \param register_func_handle The register function
-/// \return 0 if success, -1 if failure happens
-pub type TVMExtensionFuncDeclarer = ::std::option::Option<
- unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
->;
-extern "C" {
- /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
- ///
- /// The resource_handle will be managed by TVM API, until the function is no longer used.
- ///
- /// \param func The packed C function.
- /// \param resource_handle The resource handle from front-end, can be NULL.
- /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
- /// \param out the result function handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMFuncCreateFromCFunc(
- func: TVMPackedCFunc,
- resource_handle: *mut ::std::os::raw::c_void,
- fin: TVMPackedCFuncFinalizer,
- out: *mut TVMFunctionHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Register the function to runtime's global table.
- ///
- /// The registered function then can be pulled by the backend by the name.
- ///
- /// \param name The name of the function.
- /// \param f The function to be registered.
- /// \param override Whether allow override already registered function.
- pub fn TVMFuncRegisterGlobal(
- name: *const ::std::os::raw::c_char,
- f: TVMFunctionHandle,
- override_: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Get a global function.
- ///
- /// \param name The name of the function.
- /// \param out the result function pointer, NULL if it does not exist.
- ///
- /// \note The function handle of global function is managed by TVM runtime,
- /// So TVMFuncFree is should not be called when it get deleted.
- pub fn TVMFuncGetGlobal(
- name: *const ::std::os::raw::c_char,
- out: *mut TVMFunctionHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief List all the globally registered function name
- /// \param out_size The number of functions
- /// \param out_array The array of function names.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMFuncListGlobalNames(
- out_size: *mut ::std::os::raw::c_int,
- out_array: *mut *mut *const ::std::os::raw::c_char,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Allocate a nd-array's memory,
- /// including space of shape, of given spec.
- ///
- /// \param shape The shape of the array, the data content will be copied to out
- /// \param ndim The number of dimension of the array.
- /// \param dtype_code The type code of the dtype
- /// \param dtype_bits The number of bits of dtype
- /// \param dtype_lanes The number of lanes in the dtype.
- /// \param device_type The device type of context
- /// \param device_id The device id of context.
- /// \param out The output handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayAlloc(
- shape: *const tvm_index_t,
- ndim: ::std::os::raw::c_int,
- dtype_code: ::std::os::raw::c_int,
- dtype_bits: ::std::os::raw::c_int,
- dtype_lanes: ::std::os::raw::c_int,
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- out: *mut TVMArrayHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free the TVM Array.
- /// \param handle The array handle to be freed.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Copy array data from CPU byte array.
- /// \param handle The array handle.
- /// \param data the data pointer
- /// \param nbytes The number of bytes to copy.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayCopyFromBytes(
- handle: TVMArrayHandle,
- data: *mut ::std::os::raw::c_void,
- nbytes: usize,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Copy array data to CPU byte array.
- /// \param handle The array handle.
- /// \param data the data pointer
- /// \param nbytes The number of bytes to copy.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayCopyToBytes(
- handle: TVMArrayHandle,
- data: *mut ::std::os::raw::c_void,
- nbytes: usize,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Copy the array, both from and to must be valid during the copy.
- /// \param from The array to be copied from.
- /// \param to The target space.
- /// \param stream The stream where the copy happens, can be NULL.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayCopyFromTo(
- from: TVMArrayHandle,
- to: TVMArrayHandle,
- stream: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Produce an array from the DLManagedTensor that shares data memory
- /// with the DLManagedTensor.
- /// \param from The source DLManagedTensor.
- /// \param out The output array handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayFromDLPack(
- from: *mut DLManagedTensor,
- out: *mut TVMArrayHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Produce a DLMangedTensor from the array that shares data memory with
- /// the array.
- /// \param from The source array.
- /// \param out The DLManagedTensor handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayToDLPack(
- from: TVMArrayHandle,
- out: *mut *mut DLManagedTensor,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Delete (free) a DLManagedTensor's data.
- /// \param dltensor Pointer to the DLManagedTensor.
- pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
-}
-extern "C" {
- /// \brief Create a new runtime stream.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context
- /// \param out The new stream handle
- /// \return 0 when success, -1 when failure happens
- pub fn TVMStreamCreate(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- out: *mut TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free a created stream handle.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context
- /// \param stream The stream to be freed
- /// \return 0 when success, -1 when failure happens
- pub fn TVMStreamFree(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- stream: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Set the runtime stream of current thread to be stream.
- /// The subsequent calls to the same device_type
- /// will use the setted stream handle.
- /// The specific type of stream is runtime device dependent.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context.
- /// \param handle The stream handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMSetStream(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- handle: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Wait until all computations on stream completes.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context.
- /// \param stream The stream to be synchronized.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMSynchronize(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- stream: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Synchronize two streams of execution.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context
- /// \param src The source stream to synchronize.
- /// \param dst The destination stream to synchronize.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMStreamStreamSynchronize(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- src: TVMStreamHandle,
- dst: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Backend function for modules to get function
- /// from its environment mod_node (its imports and global function).
- /// The user do should not call TVMFuncFree on func.
- ///
- /// \param mod_node The module handle.
- /// \param func_name The name of the function.
- /// \param out The result function.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendGetFuncFromEnv(
- mod_node: *mut ::std::os::raw::c_void,
- func_name: *const ::std::os::raw::c_char,
- out: *mut TVMFunctionHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Backend function to register system-wide library symbol.
- ///
- /// \param name The name of the symbol
- /// \param ptr The symbol address.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendRegisterSystemLibSymbol(
- name: *const ::std::os::raw::c_char,
- ptr: *mut ::std::os::raw::c_void,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Backend function to allocate temporal workspace.
- ///
- /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
- ///
- /// \param nbytes The size of the space requested.
- /// \param device_type The device type which the space will be allocated.
- /// \param device_id The device id which the space will be allocated.
- /// \param dtype_code_hint The type code of the array elements. Only used in
- /// certain backends such as OpenGL.
- /// \param dtype_bits_hint The type bits of the array elements. Only used in
- /// certain backends such as OpenGL.
- /// \return nullptr when error is thrown, a valid ptr if success
- pub fn TVMBackendAllocWorkspace(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- nbytes: u64,
- dtype_code_hint: ::std::os::raw::c_int,
- dtype_bits_hint: ::std::os::raw::c_int,
- ) -> *mut ::std::os::raw::c_void;
-}
-extern "C" {
- /// \brief Backend function to free temporal workspace.
- ///
- /// \param ptr The result allocated space pointer.
- /// \param device_type The device type which the space will be allocated.
- /// \param device_id The device id which the space will be allocated.
- /// \return 0 when no error is thrown, -1 when failure happens
- ///
- /// \sa TVMBackendAllocWorkspace
- pub fn TVMBackendFreeWorkspace(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- ptr: *mut ::std::os::raw::c_void,
- ) -> ::std::os::raw::c_int;
-}
-/// \brief Environment for TVM parallel task.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct TVMParallelGroupEnv {
- /// \brief Auxiliary used for synchronization
- pub sync_handle: *mut ::std::os::raw::c_void,
- /// \brief total amount of task
- pub num_task: i32,
-}
-/// \brief The callback function to execute a parallel lambda
-/// \param task_id the task id of the function.
-/// \param penv The parallel environment backs the execution.
-/// \param cdata The supporting closure data.
-pub type FTVMParallelLambda = ::std::option::Option<
- unsafe extern "C" fn(
- task_id: ::std::os::raw::c_int,
- penv: *mut TVMParallelGroupEnv,
- cdata: *mut ::std::os::raw::c_void,
- ) -> ::std::os::raw::c_int,
->;
-extern "C" {
- /// \brief Backend function for running parallel jobs.
- ///
- /// \param flambda The parallel function to be launched.
- /// \param cdata The closure data.
- /// \param num_task Number of tasks to launch, can be 0, means launch
- /// with all available threads.
- ///
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendParallelLaunch(
- flambda: FTVMParallelLambda,
- cdata: *mut ::std::os::raw::c_void,
- num_task: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief BSP barrrier between parallel threads
- /// \param task_id the task id of the function.
- /// \param penv The parallel environment backs the execution.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendParallelBarrier(
- task_id: ::std::os::raw::c_int,
- penv: *mut TVMParallelGroupEnv,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Simple static initialization function.
- /// Run f once and set handle to be not null.
- /// This function is mainly used for test purpose.
- ///
- /// \param handle An global address to indicate f
- /// \param f The function to be ran
- /// \param cdata The closure data to pass to the function.
- /// \param nbytes Number of bytes in the closure data.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendRunOnce(
- handle: *mut *mut ::std::os::raw::c_void,
- f: ::std::option::Option<
- unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
- >,
- cdata: *mut ::std::os::raw::c_void,
- nbytes: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
+++ /dev/null
-use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
-
-use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
-use serde;
-use serde_json;
-
-use super::{DataType, Module, Storage, TVMArgValue, TVMContext, Tensor};
-use errors::{Error, ErrorKind, Result};
-use ffi::runtime::{
- DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor,
-};
-
-// Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h`
-const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
-// Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h`
-const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7;
-
-/// A TVM computation graph.
-///
-/// # Examples
-///
-/// ```
-/// let graph_json = fs::read_to_string("graph.json")).unwrap();
-/// let graph = Graph::try_from(&graph_json).unwrap();
-/// ```
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Graph {
- pub nodes: Vec<Node>,
- pub arg_nodes: Vec<usize>,
- pub heads: Vec<Entry>,
- pub node_row_ptr: Option<Vec<usize>>,
- pub attrs: Option<HashMap<String, serde_json::Value>>,
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Entry {
- pub id: usize,
- pub index: usize,
- pub version: usize,
-}
-
-impl Graph {
- fn entry_index(&self, entry: &Entry) -> Result<usize> {
- self
- .node_row_ptr
- .as_ref()
- .map(|nrp| nrp[entry.id] + entry.index)
- .ok_or("Missing node_row_ptr.".into())
- }
-
- /// Attempt to deserialize a JSON attribute to a type `T`.
- fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
- Ok(serde_json::from_value::<T>(
- self
- .attrs
- .as_ref()
- .ok_or(ErrorKind::GraphFormatError(
- "Missing graph attrs".to_string(),
- ))?
- .get(attr)
- .ok_or(ErrorKind::GraphFormatError(format!(
- "Missing {} attr",
- attr
- )))?
- .to_owned(),
- )?)
- }
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Node {
- pub op: String,
- pub name: String,
- pub inputs: Vec<Entry>,
- pub attrs: Option<HashMap<String, String>>,
- pub control_deps: Option<Vec<Entry>>,
-}
-
-struct NodeAttrs {
- func_name: String,
- num_outputs: usize,
- flatten_data: bool,
-}
-
-impl Node {
- fn parse_attrs(&self) -> Result<NodeAttrs> {
- let attrs = self
- .attrs
- .as_ref()
- .ok_or(format!("Missing node.attrs for `{}`", self.name))?;
- let func_name = attrs
- .get("func_name")
- .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
- .to_string();
- let num_outputs = attrs
- .get("num_outputs")
- .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
- .parse::<usize>()?;
- let flatten_data = attrs
- .get("flatten_data")
- .ok_or(format!(
- "Node `{}` is missing attrs.flatten_data",
- self.name
- ))?
- .parse::<u8>()?
- == 1;
- Ok(NodeAttrs {
- func_name,
- num_outputs,
- flatten_data,
- })
- }
-}
-
-impl<'a> TryFrom<&'a String> for Graph {
- type Error = Error;
- fn try_from(graph_json: &String) -> Result<Self> {
- let graph = serde_json::from_str(graph_json)?;
- Ok(graph)
- }
-}
-
-impl<'a> TryFrom<&'a str> for Graph {
- type Error = Error;
- fn try_from(graph_json: &'a str) -> Result<Self> {
- let graph = serde_json::from_str(graph_json)?;
- Ok(graph)
- }
-}
-
-/// A executor for a TVM computation graph.
-///
-/// # Examples
-///
-/// ```
-/// use ndarray::Array;
-///
-/// let syslib = SystemLibModule::default(); // a provider of TVM functions
-///
-/// let mut params_bytes = Vec::new();
-/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
-/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap();
-///
-/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
-///
-/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
-/// exec.load_params(params);
-///
-/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
-/// exec.set_input("data", x.into());
-/// exec.run();
-/// let output = exec.get_output(0).unwrap();
-///
-/// println!("{:#?}", Array::try_from(output).unwrap());
-/// ```
-pub struct GraphExecutor<'m, 't> {
- graph: Graph,
- op_execs: Vec<Box<Fn() + 'm>>,
- tensors: Vec<Tensor<'t>>,
-}
-
-unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
-
-impl<'m, 't> GraphExecutor<'m, 't> {
- pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
- let tensors = Self::setup_storages(&graph)?;
- Ok(GraphExecutor {
- op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
- tensors: tensors,
- graph: graph,
- })
- }
-
- /// Runs the computation graph.
- pub fn run(&self) {
- self.op_execs.iter().for_each(|op_exec| {
- op_exec();
- });
- }
-
- /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
- fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
- let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
- let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
- let dtypes = graph
- .get_attr::<(String, Vec<String>)>("dltype")?
- .1
- .iter()
- .map(|dltype| {
- if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
- Ok(dtype)
- } else {
- Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into())
- }
- })
- .collect::<Result<Vec<DataType>>>()?;
-
- let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
- let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
- for (i, &storage_id) in storage_ids.iter().enumerate() {
- let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
- let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
- storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
- }
-
- let mut storages: Vec<Storage> = storage_num_bytes
- .into_iter()
- .map(|nbytes| Storage::new(nbytes, align))
- .collect::<Result<Vec<Storage>>>()?;
-
- let tensors = izip!(storage_ids, shapes, dtypes)
- .map(|(storage_id, shape, dtype)| {
- let storage = storages[storage_id].view();
- Tensor {
- data: mem::replace(&mut storages[storage_id], storage),
- ctx: TVMContext::default(),
- dtype: dtype,
- size: shape.iter().product::<i64>() as usize,
- shape: shape,
- strides: None,
- byte_offset: 0,
- }
- })
- .collect();
-
- Ok(tensors)
- }
-
- /// Creates closures which represent the computation performed by this graph.
- fn setup_op_execs<M: 'm + Module>(
- graph: &Graph,
- lib: &'m M,
- tensors: &Vec<Tensor<'t>>,
- ) -> Result<Vec<Box<Fn() + 'm>>> {
- ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
- let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
-
- let mut op_execs = Vec::new();
- for (i, node) in graph.nodes.iter().enumerate() {
- if node.op == "null" {
- continue;
- }
- ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
- ensure!(node.attrs.is_some(), "Missing node attrs.");
-
- let attrs = node.parse_attrs()?;
-
- if attrs.func_name == "__nop" {
- continue;
- }
-
- let func = lib
- .get_function(&attrs.func_name)
- .ok_or(format!("Missing function {}", attrs.func_name))?;
- let arg_indices = node
- .inputs
- .iter()
- .map(|entry| graph.entry_index(entry))
- .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi)));
-
- let dl_tensors = arg_indices
- .map(|idx| {
- let tensor = &tensors[idx?];
- Ok(if attrs.flatten_data {
- DLTensor::from_tensor(tensor, true /* flatten */)
- } else {
- DLTensor::from(tensor)
- })
- })
- .collect::<Result<Vec<DLTensor>>>()
- .unwrap();
- let op: Box<Fn()> = box move || {
- let args = dl_tensors
- .iter()
- .map(|t| t.into())
- .collect::<Vec<TVMArgValue>>();
- func(args.as_slice());
- };
- op_execs.push(op);
- }
- Ok(op_execs)
- }
-
- pub fn load_params(&mut self, params: HashMap<String, Tensor<'t>>) {
- params.into_iter().for_each(|(name, param)| {
- self.set_input(name, param);
- })
- }
-
- pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor<'t>) {
- if let Some(idx) = self.get_input_index(name.as_ref()) {
- // TODO: consider `new_with_params` to avoid ever allocating
- let ptr = self.tensors[idx].data.as_ptr();
- let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
- let mut owner = to_replace.nth(0).unwrap();
- if value.data.is_owned() {
- // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
- // mem::replace(&mut (*owner), value);
- // to_replace.for_each(|t| {
- // panic!("replacing");
- // t.data = owner.data.view();
- // });
- owner.copy(&value);
- } else {
- owner.copy(&value);
- }
- } else {
- println!("Unexpected input `{}`", name.as_ref());
- }
- }
-
- /// Returns the graph input with name `name`, if it exists.
- pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
- self
- .get_input_index(name.as_ref())
- .and_then(move |idx| Some(&self.tensors[idx]))
- }
-
- /// Returns the graph output with index `index`, if it exists.
- pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
- let graph = &self.graph;
- graph.heads.get(idx).and_then(|entry| {
- graph
- .entry_index(entry)
- .map(|idx| self.tensors.get(idx))
- .unwrap_or(None)
- })
- }
-
- /// Returns the index for graph input with name `name`, if it exists.
- pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
- let graph = &self.graph;
- (0..graph.nodes.len())
- .skip_while(|&i| graph.nodes[i].name != name.as_ref())
- .nth(0)
- .and_then(|i| {
- if graph.arg_nodes.iter().any(|&id| id == i) {
- graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
- } else {
- None
- }
- })
- }
-}
-
-/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
-named!(
- tvm_str_to_type<CompleteStr, DataType>,
- do_parse!(
- type_name: alpha1 >>
- bits: digit1 >>
- lanes: opt!(tuple!(tag!("x"), digit1)) >>
- (DataType {
- code: match type_name {
- CompleteStr("int") => DLDataTypeCode_kDLInt,
- CompleteStr("uint") => DLDataTypeCode_kDLUInt,
- CompleteStr("float") => DLDataTypeCode_kDLFloat,
- _ => DLDataTypeCode_kDLFloat,
- } as usize,
- bits: bits.parse::<u8>().unwrap() as usize,
- lanes: match lanes {
- Some(lanes) => lanes.1.parse::<u16>().unwrap() as usize,
- None => 1,
- },
- })
- )
-);
-
-/// Converts a bytes to String.
-named!(
- name<String>,
- map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
- b.to_vec()
- ))
-);
-
-/// Parses a TVMContext
-named!(
- tvm_ctx<&[u8], TVMContext>,
- do_parse!(
- device_type: le_u32 >>
- device_id: le_i32 >>
- (TVMContext { device_type: device_type as usize, device_id: device_id as usize })
- )
-);
-
-/// Parses a DataType
-named!(
- data_type<&[u8], DataType>,
- do_parse!(
- code: le_u8 >>
- bits: le_u8 >>
- lanes: le_u16 >>
- (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize })
- )
-);
-
-/// Parses a Tensor from a TVM array file.
-named!(
- tensor<Tensor>,
- do_parse!(
- take!(8)
- >> bits!(tag_bits!(u64, 64, 0))
- >> ctx: tvm_ctx
- >> ndim: le_u32
- >> dtype: data_type
- >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize)
- >> length: le_i64
- >> data: take!(length)
- >> (Tensor {
- data: Storage::from(data),
- ctx: ctx,
- dtype: dtype,
- size: shape.iter().product::<i64>() as usize,
- shape: shape,
- strides: None,
- byte_offset: 0,
- })
- )
-);
-
-/// Parses a graph params dict from a params binary file.
-named!(
- parse_param_dict<HashMap<String, Tensor>>,
- do_parse!(
- take!(8)
- >> bits!(tag_bits!(u64, 64, 0))
- >> names: length_count!(le_u64, name)
- >> tensors: length_count!(le_u64, tensor)
- >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter())))
- )
-);
-
-/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
-pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
- if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
- if remaining_bytes.len() > 0 {
- bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
- } else {
- Ok(param_dict)
- }
- } else {
- bail!(ErrorKind::LoadGraphParamsError(
- "invalid parameters file".to_string()
- ))
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_str_to_type() {
- assert_eq!(
- tvm_str_to_type(CompleteStr("float24")).unwrap().1,
- DataType {
- code: DLDataTypeCode_kDLFloat as usize,
- bits: 24,
- lanes: 1
- }
- );
- assert_eq!(
- tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1,
- DataType {
- code: DLDataTypeCode_kDLUInt as usize,
- bits: 111,
- lanes: 44
- }
- );
- }
-}
+++ /dev/null
-mod allocator;
-mod array;
-mod module;
-#[macro_use]
-mod packed_func;
-mod graph;
-#[cfg(target_env = "sgx")]
-#[macro_use]
-pub mod sgx;
-mod threading;
-mod workspace;
-
-use std::os::raw::c_char;
-
-pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*};
-
-#[cfg(target_env = "sgx")]
-use self::sgx::ocall_packed_func;
-
-#[no_mangle]
-pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) {
- #[cfg(not(target_env = "sgx"))]
- unsafe {
- panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
- }
- #[cfg(target_env = "sgx")]
- ocall_packed!("__sgx_set_last_error__", cmsg);
-}
+++ /dev/null
-use std::{
- collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
-};
-
-use ffi::runtime::BackendPackedCFunc;
-use runtime::packed_func::{wrap_backend_packed_func, PackedFunc};
-
-pub trait Module {
- fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
-}
-
-pub struct SystemLibModule;
-
-lazy_static! {
- static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
- Mutex::new(HashMap::new());
-}
-
-impl Module for SystemLibModule {
- fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
- SYSTEM_LIB_FUNCTIONS
- .lock()
- .unwrap()
- .get(name.as_ref())
- .map(|func| wrap_backend_packed_func(func.to_owned()))
- }
-}
-
-impl Default for SystemLibModule {
- fn default() -> Self {
- SystemLibModule {}
- }
-}
-
-#[no_mangle]
-pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
- cname: *const c_char,
- func: BackendPackedCFunc,
-) -> i32 {
- let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
- SYSTEM_LIB_FUNCTIONS
- .lock()
- .unwrap()
- .insert(name.to_string(), func);
- return 0;
-}
+++ /dev/null
-use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
-
-use super::Tensor;
-use ffi::runtime::{
- BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor,
- TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMTypeCode_kNDArrayContainer, TVMValue,
-};
-
-use errors::*;
-
-pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
-
-/// Calls a packed function and returns a `TVMRetValue`.
-///
-/// # Example
-///
-/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
-#[macro_export]
-macro_rules! call_packed {
- ($fn:expr, $($args:expr),+) => {
- $fn(&[$($args.into(),)+])
- };
- ($fn:expr) => {
- $fn(&Vec::new())
- };
-}
-
-/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
-/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
-#[derive(Clone, Copy)]
-pub struct TVMArgValue<'a> {
- _lifetime: PhantomData<&'a ()>,
- pub(crate) value: TVMValue,
- pub(crate) type_code: i64,
-}
-
-impl<'a> TVMArgValue<'a> {
- pub fn new(value: TVMValue, type_code: i64) -> Self {
- TVMArgValue {
- _lifetime: PhantomData,
- value: value,
- type_code: type_code,
- }
- }
-}
-
-/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
-macro_rules! impl_prim_tvm_arg {
- ($type:ty, $field:ident, $code:expr, $as:ty) => {
- impl<'a> From<$type> for TVMArgValue<'a> {
- fn from(val: $type) -> Self {
- TVMArgValue {
- value: TVMValue { $field: val as $as },
- type_code: $code as i64,
- _lifetime: PhantomData,
- }
- }
- }
- impl<'a> TryFrom<TVMArgValue<'a>> for $type {
- type Error = Error;
- fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
- ensure!(
- val.type_code == $code as i64,
- "Could not downcast arg. Expected `{}`, got `{}`",
- $code,
- val.type_code
- );
- Ok(unsafe { val.value.$field as $type })
- }
- }
- };
- ($type:ty, $field:ident, $code:expr) => {
- impl_prim_tvm_arg!($type, $field, $code, $type);
- };
- ($type:ty,v_int64) => {
- impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64);
- };
- ($type:ty,v_float64) => {
- impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64);
- };
-}
-
-impl_prim_tvm_arg!(f32, v_float64);
-impl_prim_tvm_arg!(f64, v_float64);
-impl_prim_tvm_arg!(i8, v_int64);
-impl_prim_tvm_arg!(u8, v_int64);
-impl_prim_tvm_arg!(i32, v_int64);
-impl_prim_tvm_arg!(u32, v_int64);
-impl_prim_tvm_arg!(i64, v_int64);
-impl_prim_tvm_arg!(u64, v_int64);
-
-/// Creates a conversion to a `TVMArgValue` for an object handle.
-impl<'a, T> From<*const T> for TVMArgValue<'a> {
- fn from(ptr: *const T) -> Self {
- TVMArgValue {
- value: TVMValue {
- v_handle: ptr as *mut T as *mut c_void,
- },
- type_code: TVMTypeCode_kArrayHandle as i64,
- _lifetime: PhantomData,
- }
- }
-}
-
-/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
-impl<'a, T> From<*mut T> for TVMArgValue<'a> {
- fn from(ptr: *mut T) -> Self {
- TVMArgValue {
- value: TVMValue {
- v_handle: ptr as *mut c_void,
- },
- type_code: TVMTypeCode_kHandle as i64,
- _lifetime: PhantomData,
- }
- }
-}
-
-impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
- fn from(arr: &'a mut DLTensor) -> Self {
- TVMArgValue {
- value: TVMValue {
- v_handle: arr as *mut _ as *mut c_void,
- },
- type_code: TVMTypeCode_kArrayHandle as i64,
- _lifetime: PhantomData,
- }
- }
-}
-
-impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
- fn from(arr: &'a DLTensor) -> Self {
- TVMArgValue {
- value: TVMValue {
- v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
- },
- type_code: TVMTypeCode_kArrayHandle as i64,
- _lifetime: PhantomData,
- }
- }
-}
-
-impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
- type Error = Error;
- fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
- ensure!(
- val.type_code == TVMTypeCode_kArrayHandle as i64
- || val.type_code == TVMTypeCode_kNDArrayContainer as i64,
- "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
- TVMTypeCode_kArrayHandle,
- TVMTypeCode_kNDArrayContainer,
- val.type_code,
- );
-
- let dlt = unsafe { *(val.value.v_handle as *mut DLTensor as *const DLTensor) };
- Ok(dlt.into())
- }
-}
-
-/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
-/// Can be downcasted using `try_from` if it contains the desired type.
-///
-/// # Example
-///
-/// ```
-/// let a = 42u32;
-/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
-///
-/// let s = "hello, world!";
-/// let t: TVMRetValue = s.into();
-/// assert_eq!(String::try_from(t).unwrap(), s);
-/// ```
-pub struct TVMRetValue {
- /// A primitive return value, if any.
- prim_value: u64,
- /// An object return value, if any.
- box_value: Box<Any>,
- /// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use.
- type_code: i64,
-}
-
-#[cfg(target_env = "sgx")]
-impl TVMRetValue {
- pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
- unsafe {
- Self {
- prim_value: match type_code {
- 0 | 1 => value.v_int64 as u64,
- 2 => value.v_float64 as u64,
- 3 | 7 | 8 | 9 | 10 => value.v_handle as u64,
- 11 | 12 => value.v_str as u64,
- _ => 0,
- } as u64,
- box_value: box (),
- type_code: type_code,
- }
- }
- }
-
- pub fn into_tvm_value(self) -> (TVMValue, i64) {
- let val = match self.type_code {
- 0 | 1 => TVMValue {
- v_int64: self.prim_value.clone() as i64,
- },
- 2 => TVMValue {
- v_float64: self.prim_value.clone() as f64,
- },
- 3 | 7 | 8 | 9 | 10 | 13 => TVMValue {
- v_handle: Box::into_raw(self.box_value) as *mut c_void,
- },
- 11 | 12 => TVMValue {
- v_str: Box::into_raw(self.box_value) as *const _,
- },
- _ => unreachable!(),
- };
- (val, self.type_code)
- }
-}
-
-impl Default for TVMRetValue {
- fn default() -> Self {
- TVMRetValue {
- prim_value: 0,
- box_value: box (),
- type_code: 0,
- }
- }
-}
-
-macro_rules! impl_prim_ret_value {
- ($type:ty, $code:expr) => {
- impl From<$type> for TVMRetValue {
- fn from(val: $type) -> Self {
- TVMRetValue {
- prim_value: val as u64,
- box_value: box (),
- type_code: $code,
- }
- }
- }
- impl TryFrom<TVMRetValue> for $type {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<$type> {
- if ret.type_code == $code {
- Ok(ret.prim_value as $type)
- } else {
- bail!(ErrorKind::TryFromTVMRetValueError(
- stringify!($type).to_string(),
- ret.type_code
- ))
- }
- }
- }
- };
-}
-
-macro_rules! impl_boxed_ret_value {
- ($type:ty, $code:expr) => {
- impl From<$type> for TVMRetValue {
- fn from(val: $type) -> Self {
- TVMRetValue {
- prim_value: 0,
- box_value: box val,
- type_code: $code,
- }
- }
- }
- impl TryFrom<TVMRetValue> for $type {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<$type> {
- if let Ok(val) = ret.box_value.downcast::<$type>() {
- Ok(*val)
- } else {
- bail!(ErrorKind::TryFromTVMRetValueError(
- stringify!($type).to_string(),
- ret.type_code
- ))
- }
- }
- }
- };
-}
-
-impl_prim_ret_value!(i8, 0);
-impl_prim_ret_value!(u8, 1);
-impl_prim_ret_value!(i16, 0);
-impl_prim_ret_value!(u16, 1);
-impl_prim_ret_value!(i32, 0);
-impl_prim_ret_value!(u32, 1);
-impl_prim_ret_value!(f32, 2);
-impl_prim_ret_value!(i64, 0);
-impl_prim_ret_value!(u64, 1);
-impl_prim_ret_value!(f64, 2);
-impl_prim_ret_value!(isize, 0);
-impl_prim_ret_value!(usize, 1);
-impl_boxed_ret_value!(String, 11);
-
-impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
- fn from(val: &'t Tensor<'a>) -> Self {
- TVMRetValue {
- prim_value: 0,
- box_value: box DLTensor::from(val),
- type_code: TVMTypeCode_kNDArrayContainer as i64,
- }
- }
-}
-
-impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<Self> {
- ensure!(
- ret.type_code == TVMTypeCode_kArrayHandle as i64
- || ret.type_code == TVMTypeCode_kNDArrayContainer as i64,
- "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
- TVMTypeCode_kArrayHandle,
- TVMTypeCode_kNDArrayContainer,
- ret.type_code,
- );
-
- let dlt = unsafe { *(ret.prim_value as *mut DLTensor as *const DLTensor) };
- Ok(dlt.into())
- }
-}
-
-// @see `WrapPackedFunc` in `llvm_module.cc`.
-pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
- box move |args: &[TVMArgValue]| {
- func(
- args
- .iter()
- .map(|ref arg| arg.value)
- .collect::<Vec<TVMValue>>()
- .as_ptr(),
- args
- .iter()
- .map(|ref arg| arg.type_code as i32)
- .collect::<Vec<i32>>()
- .as_ptr() as *const i32,
- args.len() as i32,
- );
- TVMRetValue::default()
- }
-}
+++ /dev/null
-use std::{
- ffi::CString,
- os::raw::{c_char, c_int},
-};
-
-use errors::Result;
-use ffi::runtime::TVMValue;
-use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
-
-pub use runtime::threading::tvm_run_worker as run_worker;
-
-#[macro_export]
-macro_rules! tvm_ocall {
- ($func: expr) => {
- match $func {
- 0 => Ok(()),
- err => Err(format!("SGX error: {}", err)),
- }
- };
-}
-
-pub type SgxStatus = u32;
-
-#[cfg(target_env = "sgx")]
-extern "C" {
- fn tvm_ocall_packed_func(
- name: *const c_char,
- arg_values: *const TVMValue,
- type_codes: *const c_int,
- num_args: c_int,
- ret_val: *mut TVMValue,
- ret_type_code: *mut c_int,
- ) -> SgxStatus;
-}
-
-pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
- let mut ret_val = TVMValue { v_int64: 0 };
- let ret_type_code = 0i64;
- unsafe {
- tvm_ocall!(tvm_ocall_packed_func(
- CString::new(fn_name.as_ref()).unwrap().as_ptr(),
- args
- .iter()
- .map(|ref arg| arg.value)
- .collect::<Vec<TVMValue>>()
- .as_ptr(),
- args
- .iter()
- .map(|ref arg| arg.type_code as i32)
- .collect::<Vec<i32>>()
- .as_ptr() as *const i32,
- args.len() as i32,
- &mut ret_val as *mut TVMValue,
- &mut (ret_type_code as i32) as *mut c_int,
- ))?;
- }
- Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
-}
-
-#[macro_export]
-macro_rules! ocall_packed {
- ($fn_name:expr, $($args:expr),+) => {
- ocall_packed_func($fn_name, &[$($args.into(),)+])
- .expect(concat!("Error calling `", $fn_name, "`"))
- };
- ($fn_name:expr) => {
- ocall_packed_func($fn_name, &Vec::new())
- .expect(concat!("Error calling `", $fn_name, "`"))
- }
-}
-
-pub fn shutdown() {
- if env!("TVM_NUM_THREADS") != "0" {
- sgx_join_threads()
- }
-}
-
-impl Drop for SystemLibModule {
- fn drop(&mut self) {
- shutdown()
- }
-}
+++ /dev/null
-use std::{
- os::raw::{c_int, c_void},
- sync::{
- atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
- Arc, Barrier,
- },
-};
-
-#[cfg(not(target_env = "sgx"))]
-use num_cpus;
-#[cfg(not(target_env = "sgx"))]
-use std::{
- env,
- thread::{self, JoinHandle},
-};
-
-#[cfg(target_env = "sgx")]
-use std::{collections::VecDeque, ptr, sync::Mutex};
-
-use bounded_spsc_queue::{self, Producer};
-
-use super::super::errors::*;
-use ffi::runtime::TVMParallelGroupEnv;
-
-#[cfg(target_env = "sgx")]
-use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
-
-type FTVMParallelLambda =
- extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
-
-/// Holds a parallel job request made by a TVM library function.
-struct Job {
- cb: FTVMParallelLambda,
- cdata: *const c_void,
- req_num_tasks: usize,
- pending: Arc<AtomicUsize>,
-}
-
-impl Job {
- /// Splits this job into a number of `Task`s which can be scheduled.
- fn tasks(&self, num_workers: usize) -> Vec<Task> {
- let num_tasks = if self.req_num_tasks == 0 {
- num_workers
- } else {
- self.req_num_tasks.min(num_workers)
- };
- self.pending.store(num_tasks, Ordering::SeqCst);
-
- let barrier = Arc::new(Barrier::new(num_tasks));
-
- (0..num_tasks)
- .map(move |i| Task {
- id: i,
- flambda: self.cb,
- penv: TVMParallelGroupEnv {
- sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
- num_task: num_tasks as i32,
- },
- cdata: self.cdata,
- pending: Arc::clone(&self.pending),
- })
- .collect()
- }
-
- /// Waits for all tasks in this `Job` to be completed.
- fn wait(&self) -> Result<()> {
- while self.pending.load(Ordering::Acquire) > 0 {
- #[cfg(not(target_env = "sgx"))]
- thread::yield_now();
- }
- Ok(())
- }
-}
-
-/// A chunk of work requested by a TVM function.
-struct Task {
- id: usize,
- flambda: FTVMParallelLambda,
- penv: TVMParallelGroupEnv,
- cdata: *const c_void,
- pending: Arc<AtomicUsize>,
-}
-unsafe impl Send for Task {}
-unsafe impl Sync for Task {}
-
-impl FnOnce<()> for Task {
- type Output = i32;
- extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
- let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
- self.pending.fetch_sub(1, Ordering::AcqRel);
- status
- }
-}
-
-#[derive(Default)]
-struct Threads {
- #[allow(unused)]
- #[cfg(not(target_env = "sgx"))]
- handles: Vec<JoinHandle<()>>,
- queues: Vec<Producer<Task>>,
-}
-
-impl<'a> Threads {
- #[cfg(not(target_env = "sgx"))]
- fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
- num_threads: usize,
- cb: F,
- ) -> Self {
- let (handles, queues) = (0..num_threads)
- .map(|_| {
- let (p, c) = bounded_spsc_queue::make(2);
- let handle = thread::spawn(move || cb(c.into()));
- (handle, p)
- })
- .unzip();
- Threads {
- handles: handles,
- queues: queues,
- }
- }
-
- #[cfg(target_env = "sgx")]
- fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
- num_threads: usize,
- _cb: F,
- ) -> Self {
- let mut consumer_queues = SGX_QUEUES.lock().unwrap();
- let queues = (0..num_threads)
- .map(|_| {
- let (p, c) = bounded_spsc_queue::make(2);
- consumer_queues.push_back(c.into());
- p
- })
- .collect();
- ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
- Threads { queues: queues }
- }
-}
-
-struct ThreadPool {
- num_workers: usize,
- #[allow(unused)]
- threads: Threads,
-}
-
-thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
-
-impl ThreadPool {
- fn new() -> Self {
- let num_workers = max_concurrency();
- ThreadPool {
- num_workers: num_workers,
- threads: Threads::launch(num_workers, ThreadPool::run_worker),
- }
- }
-
- fn launch(&self, job: Job) {
- let mut tasks = job.tasks(self.num_workers + 1);
-
- for (i, task) in tasks.split_off(1).into_iter().enumerate() {
- self.threads.queues[i].push(task);
- }
-
- tasks.pop().unwrap()();
- job.wait().unwrap();
- }
-
- fn run_worker(queue: Consumer<Task>) {
- loop {
- let task = queue.pop();
- let result = task();
- if result == <i32>::min_value() {
- break;
- } else if result != 0 {
- panic!("Error running task.");
- }
- }
- }
-}
-
-// Send + Sync wrapper for bounded_spsc_queue::Consumer
-struct Consumer<T> {
- consumer: bounded_spsc_queue::Consumer<T>,
-}
-impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
- fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
- Consumer { consumer: c }
- }
-}
-impl<T> Consumer<T> {
- fn pop(&self) -> T {
- self.consumer.pop()
- }
-}
-unsafe impl<T> Send for Consumer<T> {}
-unsafe impl<T> Sync for Consumer<T> {}
-
-#[cfg(target_env = "sgx")]
-lazy_static! {
- /// Holds tasks for untrusted threads which re-enter the enclave to execute.
- static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
-}
-
-#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
-fn max_concurrency() -> usize {
- if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
- if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
- return threads;
- }
- }
- num_cpus::get_physical()
-}
-
-#[cfg(target_env = "sgx")]
-fn max_concurrency() -> usize {
- usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
-}
-
-#[cfg(target_arch = "wasm32")]
-fn max_concurrency() -> usize {
- 0 // wasm doesn't support threads yet
-}
-
-#[cfg(target_env = "sgx")]
-pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
- let q = {
- let mut qs = SGX_QUEUES.lock().unwrap();
- qs.pop_front()
- // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
- };
- if let Some(q) = q {
- ThreadPool::run_worker(q);
- }
- TVMRetValue::default()
-}
-
-#[no_mangle]
-pub extern "C" fn TVMBackendParallelLaunch(
- cb: FTVMParallelLambda,
- cdata: *const c_void,
- num_task: usize,
-) -> c_int {
- if max_concurrency() == 0 {
- let penv = TVMParallelGroupEnv {
- sync_handle: 0 as *mut c_void,
- num_task: 1,
- };
- cb(0, &penv as *const _, cdata);
- } else {
- THREAD_POOL.with(|pool| {
- pool.launch(Job {
- cb: cb,
- cdata: cdata,
- req_num_tasks: num_task,
- pending: Arc::new(ATOMIC_USIZE_INIT),
- });
- });
- }
- return 0;
-}
-
-#[cfg(target_env = "sgx")]
-pub(crate) fn sgx_join_threads() {
- extern "C" fn poison_pill(
- _task_id: usize,
- _penv: *const TVMParallelGroupEnv,
- _cdata: *const c_void,
- ) -> i32 {
- <i32>::min_value()
- }
-
- THREAD_POOL.with(|pool| {
- pool.launch(Job {
- cb: poison_pill,
- cdata: ptr::null(),
- req_num_tasks: 0,
- pending: Arc::new(ATOMIC_USIZE_INIT),
- });
- });
- ocall_packed!("__sgx_thread_group_join__", 0);
-}
-
-// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used.
-#[no_mangle]
-pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
- let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
- barrier.wait();
-}
-
-#[cfg(test)]
-mod tests {
- use std::{ptr, thread, time::Duration};
-
- use super::*;
-
- #[test]
- fn test_max_concurrency() {
- env::set_var("TVM_NUM_THREADS", "42");
- env::set_var("OMP_NUM_THREADS", "24");
- assert_eq!(max_concurrency(), 42);
- env::remove_var("TVM_NUM_THREADS");
- assert_eq!(max_concurrency(), 24);
- }
-
- extern "C" fn flambda(
- task_id: usize,
- penv: *const TVMParallelGroupEnv,
- cdata: *const c_void,
- ) -> i32 {
- if cdata == ptr::null() {
- return 0;
- }
- unsafe {
- let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
- thread::sleep(Duration::from_millis(50 * task_id as u64));
- counter.fetch_add(1, Ordering::SeqCst);
- task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
- assert_eq!((*penv).num_task, 3);
- }
- 0
- }
-
- #[test]
- fn test_parallel_launch() {
- TVMBackendParallelLaunch(flambda, ptr::null(), 6);
- let counter = ATOMIC_USIZE_INIT;
- let task_ids_sum = ATOMIC_USIZE_INIT;
- let cdata = (counter, task_ids_sum);
- let num_tasks = 3;
- TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
- assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
- assert_eq!(
- cdata.1.load(Ordering::SeqCst),
- (0..num_tasks).sum::<usize>()
- );
- }
-}
+++ /dev/null
-use std::{
- cell::RefCell,
- os::raw::{c_int, c_void},
- ptr,
-};
-
-use super::allocator::Allocation;
-use errors::*;
-
-const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
-
-struct WorkspacePool {
- workspaces: Vec<Allocation>,
- free: Vec<usize>,
- in_use: Vec<usize>,
-}
-
-impl WorkspacePool {
- fn new() -> Self {
- WorkspacePool {
- workspaces: Vec::new(),
- free: Vec::new(),
- in_use: Vec::new(),
- }
- }
-
- fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
- self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
- self.in_use.push(self.workspaces.len() - 1);
- Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
- }
-
- fn alloc(&mut self, size: usize) -> Result<*mut u8> {
- if self.free.len() == 0 {
- return self.alloc_new(size);
- }
- let idx = self
- .free
- .iter()
- .fold(None, |cur_ws_idx: Option<usize>, &idx| {
- let ws_size = self.workspaces[idx].size();
- if !ws_size >= size {
- return cur_ws_idx;
- }
- cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
- let cur_size = self.workspaces[cur_idx].size();
- Some(match ws_size <= cur_size {
- true => idx,
- false => cur_idx,
- })
- })
- });
- match idx {
- Some(idx) => {
- self.free.remove_item(&idx).unwrap();
- self.in_use.push(idx);
- Ok(self.workspaces[idx].as_mut_ptr())
- }
- None => self.alloc_new(size),
- }
- }
-
- fn free(&mut self, ptr: *mut u8) -> Result<()> {
- let mut ws_idx = None;
- for i in 0..self.in_use.len() {
- let idx = self.in_use[i];
- if self.workspaces[idx].as_mut_ptr() == ptr {
- self.in_use.remove(i);
- ws_idx = Some(idx);
- break;
- }
- }
- Ok(
- self
- .free
- .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?),
- )
- }
-}
-
-thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
-
-const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
-
-#[no_mangle]
-pub extern "C" fn TVMBackendAllocWorkspace(
- _device_type: c_int,
- _device_id: c_int,
- size: u64,
- _dtype_code_hint: c_int,
- _dtype_bits_hint: c_int,
-) -> *mut c_void {
- let nbytes = if size == 0 {
- WORKSPACE_PAGE_SIZE
- } else {
- size as usize
- };
- WORKSPACE_POOL.with(|pool_cell| {
- pool_cell
- .borrow_mut()
- .alloc(nbytes as usize)
- .unwrap_or(ptr::null_mut()) as *mut c_void
- })
-}
-
-#[no_mangle]
-pub extern "C" fn TVMBackendFreeWorkspace(
- _device_type: c_int,
- _device_id: c_int,
- ptr: *mut c_void,
-) -> c_int {
- WORKSPACE_POOL.with(|pool_cell| {
- (match pool_cell.borrow_mut().free(ptr as *mut u8) {
- Ok(()) => 0,
- Err(_) => -1,
- }) as c_int
- });
- return 0;
-}
+++ /dev/null
-*.json
-*.params
-*.o
+++ /dev/null
-"""Builds a simple NNVM graph for testing."""
-
-from os import path as osp
-
-import nnvm
-from nnvm import sym
-from nnvm.compiler import graph_util
-from nnvm.testing import init
-import numpy as np
-import tvm
-
-CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
-
-
-def _get_model(dshape):
- data = sym.Variable('data', shape=dshape)
- fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True)
- left, right = sym.split(fc1, indices_or_sections=2, axis=1)
- return sym.Group(((left + 1), (right - 1)))
-
-
-def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
- if isinstance(graph, sym.Symbol):
- graph = nnvm.graph.create(graph)
- ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
- param_shapes = dict(zip(graph.index.input_names, ishapes))
- np.random.seed(seed)
- params = {}
- for param, shape in param_shapes.items():
- if param in {'data', 'label'} or not shape:
- continue
- init_value = np.empty(shape).astype('float32')
- initializer(param, init_value)
- params[param] = tvm.nd.array(init_value)
- return params
-
-def main():
- dshape = (32, 16)
- net = _get_model(dshape)
- ishape_dict = {'data': dshape}
- params = _init_params(net, ishape_dict)
- graph, lib, params = nnvm.compiler.build(net, 'llvm',
- shape=ishape_dict,
- params=params,
- dtype='float32')
-
- with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet:
- f_resnet.write(graph.json())
- with open(osp.join(CWD, 'graph.params'), 'wb') as f_params:
- f_params.write(nnvm.compiler.save_param_dict(params))
-
-if __name__ == '__main__':
- main()
+++ /dev/null
-#![feature(try_from)]
-
-extern crate serde;
-extern crate serde_json;
-
-extern crate tvm;
-
-use std::{convert::TryFrom, fs, io::Read};
-
-use tvm::runtime::Graph;
-
-#[test]
-fn test_load_graph() {
- let mut params_bytes = Vec::new();
- fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
- .expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
- .read_to_end(&mut params_bytes)
- .unwrap();
- let _params = tvm::runtime::load_param_dict(¶ms_bytes);
-
- let graph = Graph::try_from(
- &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
- )
- .unwrap();
-
- assert_eq!(graph.nodes[3].op, "tvm_op");
- assert_eq!(
- graph.nodes[3]
- .attrs
- .as_ref()
- .unwrap()
- .get("func_name")
- .unwrap(),
- "fuse_dense"
- );
- assert_eq!(graph.nodes[5].inputs[0].index, 0);
- assert_eq!(graph.nodes[6].inputs[0].index, 1);
- assert_eq!(graph.heads.len(), 2);
-}
+++ /dev/null
-[package]
-name = "test-nnvm"
-version = "0.0.0"
-license = "Apache-2.0"
-authors = ["Nick Hynes <nhynes@berkeley.edu>"]
-
-[dependencies]
-ndarray = "0.11.2"
-tvm = { path = "../../" }
-serde = "1.0.59"
-serde_json = "1.0.17"
-
-[build-dependencies]
-ar = "0.6.0"
+++ /dev/null
-extern crate ar;
-
-use std::{
- env,
- fs::File,
- path::{Path, PathBuf},
- process::Command,
-};
-
-use ar::Builder;
-
-fn main() {
- let out_dir = env::var("OUT_DIR").unwrap();
-
- let output = Command::new(concat!(
- env!("CARGO_MANIFEST_DIR"),
- "/src/build_test_graph.py"
- ))
- .arg(&out_dir)
- .output()
- .expect("Failed to execute command");
- assert!(
- Path::new(&format!("{}/graph.o", out_dir)).exists(),
- "Could not build graph lib: {}",
- String::from_utf8(output.stderr)
- .unwrap()
- .trim()
- .split("\n")
- .last()
- .unwrap_or("")
- );
-
- let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect();
- let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect();
- let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
- builder.append_path(in_path.to_str().unwrap()).unwrap();
-
- println!("cargo:rustc-link-lib=static=graph");
- println!("cargo:rustc-link-search=native={}", out_dir);
-}
+++ /dev/null
-#!/usr/bin/env python3
-
-"""Builds a simple NNVM graph for testing."""
-
-from os import path as osp
-import sys
-
-import nnvm
-from nnvm import sym
-from nnvm.compiler import graph_util
-from nnvm.testing import init
-import numpy as np
-import tvm
-
-
-def _get_model(dshape):
- data = sym.Variable('data', shape=dshape)
- fc = sym.dense(data, units=dshape[-1]*2, use_bias=True)
- left, right = sym.split(fc, indices_or_sections=2, axis=1)
- return sym.Group(((left + 1), (right - 1), fc))
-
-
-def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
- if isinstance(graph, sym.Symbol):
- graph = nnvm.graph.create(graph)
- ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
- param_shapes = dict(zip(graph.index.input_names, ishapes))
- np.random.seed(seed)
- params = {}
- for param, shape in param_shapes.items():
- if param in {'data', 'label'} or not shape:
- continue
-
- init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32')
- if param.endswith('_bias'):
- params[param] = tvm.nd.array(init_value)
- continue
-
- init_value = np.empty(shape).astype('float32')
- initializer(param, init_value)
- # init_value /= init_value.sum() + 1e-10
- params[param] = tvm.nd.array(init_value)
- return params
-
-def main():
- dshape = (4, 8)
- net = _get_model(dshape)
- ishape_dict = {'data': dshape}
- params = _init_params(net, ishape_dict)
- graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib',
- shape=ishape_dict,
- params=params,
- dtype='float32')
-
- out_dir = sys.argv[1]
- lib.save(osp.join(sys.argv[1], 'graph.o'))
- with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
- f_resnet.write(graph.json())
- with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
- f_params.write(nnvm.compiler.save_param_dict(params))
-
-if __name__ == '__main__':
- main()
+++ /dev/null
-#![feature(try_from)]
-
-#[macro_use]
-extern crate ndarray;
-extern crate serde;
-extern crate serde_json;
-
-extern crate tvm;
-use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
-
-use ndarray::Array;
-use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
-
-const BATCH_SIZE: usize = 4;
-const IN_DIM: usize = 8;
-
-macro_rules! check_sum {
- ($e:expr, $a:ident, $b:ident) => {
- let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
- check_sum!(a, $b);
- };
- ($e:expr, $a:expr, $b:ident) => {
- let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
- check_sum!(a, $b);
- };
- ($a:ident, $b:ident) => {
- let a_sum: f32 = $a.scalar_sum();
- let b_sum: f32 = $b.scalar_sum();
- assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
- };
-}
-
-fn main() {
- let syslib = SystemLibModule::default();
-
- let mut params_bytes = Vec::new();
- fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
- .unwrap()
- .read_to_end(&mut params_bytes)
- .unwrap();
- let params = tvm::runtime::load_param_dict(¶ms_bytes)
- .unwrap()
- .into_iter()
- .map(|(k, v)| (k, v.to_owned()))
- .collect::<HashMap<String, Tensor<'static>>>();
-
- let graph =
- Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap();
- let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
-
- let x = Array::from_shape_vec(
- (BATCH_SIZE, IN_DIM),
- (0..BATCH_SIZE * IN_DIM)
- .map(|x| x as f32)
- .collect::<Vec<f32>>(),
- ).unwrap();
- let w = Array::try_from(params.get("dense0_weight").unwrap())
- .unwrap()
- .into_shape((IN_DIM * 2, IN_DIM))
- .unwrap();
- let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
- let dense = x.dot(&w.t()) + &b;
- let left = dense.slice(s![.., 0..IN_DIM]);
- let right = dense.slice(s![.., IN_DIM..]);
- let expected_o0 = &left + 1f32;
- let expected_o1 = &right - 1f32;
-
- exec.load_params(params);
- exec.set_input("data", x.clone().into());
-
- check_sum!(exec, data, x);
- check_sum!(exec, dense0_weight, w);
- check_sum!(exec, dense0_bias, b);
-
- exec.run();
-
- check_sum!(exec, 0, expected_o0);
- check_sum!(exec, 1, expected_o1);
- check_sum!(exec, 2, dense);
-}
+++ /dev/null
-[package]
-name = "test-tvm-basic"
-version = "0.0.0"
-license = "Apache-2.0"
-authors = ["Nick Hynes <nhynes@berkeley.edu>"]
-
-[dependencies]
-ndarray = "0.11.2"
-tvm = { path = "../../" }
-
-[build-dependencies]
-ar = "0.6.0"
+++ /dev/null
-extern crate ar;
-
-use std::{env, path::PathBuf, process::Command};
-
-use ar::Builder;
-use std::fs::File;
-
-fn main() {
- let out_dir = env::var("OUT_DIR").unwrap();
-
- let output = Command::new(concat!(
- env!("CARGO_MANIFEST_DIR"),
- "/src/build_test_lib.py"
- )).arg(&out_dir)
- .output()
- .expect("Failed to execute command");
- if output.stderr.len() > 0 {
- panic!(String::from_utf8(output.stderr).unwrap());
- }
-
- let in_path: PathBuf = [&out_dir, "test.o"].iter().collect();
- let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect();
- let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
- builder.append_path(in_path.to_str().unwrap()).unwrap();
-
- println!("cargo:rustc-link-lib=static=test");
- println!("cargo:rustc-link-search=native={}", out_dir);
-}
+++ /dev/null
-#!/usr/bin/env python3
-
-"""Prepares a simple TVM library for testing."""
-
-from os import path as osp
-import sys
-
-import tvm
-
-def main():
- n = tvm.var('n')
- A = tvm.placeholder((n,), name='A')
- B = tvm.placeholder((n,), name='B')
- C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
- s = tvm.create_schedule(C.op)
- s[C].parallel(s[C].op.axis[0])
- print(tvm.lower(s, [A, B, C], simple_mode=True))
- tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
-
-if __name__ == '__main__':
- main()
+++ /dev/null
-extern crate ndarray;
-#[macro_use]
-extern crate tvm;
-
-use ndarray::Array;
-use tvm::{
- ffi::runtime::DLTensor,
- runtime::{Module, SystemLibModule},
-};
-
-fn main() {
- let syslib = SystemLibModule::default();
- let add = syslib
- .get_function("default_function")
- .expect("main function not found");
- let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
- let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
- let mut c = Array::from_vec(vec![0f32; 4]);
- let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
- let mut a_dl: DLTensor = (&mut a).into();
- let mut b_dl: DLTensor = (&mut b).into();
- let mut c_dl: DLTensor = (&mut c).into();
- call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
- assert!(c.all_close(&e, 1e-8f32));
-}
set -e
-export LD_LIBRARY_PATH=lib:$LD_LIBRARY_PATH
+export TVM_HOME="$(git rev-parse --show-toplevel)"
-tvm_root="$(git rev-parse --show-toplevel)"
-export PYTHONPATH="$tvm_root/python":"$tvm_root/nnvm/python":"$tvm_root/topi/python"
+export LD_LIBRARY_PATH="$TVM_HOME/lib":"$TVM_HOME/build":"$TVM_HOME/nnvm":$LD_LIBRARY_PATH
+export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/nnvm/python":"$TVM_HOME/topi/python"
+export RUST_DIR="$TVM_HOME/rust"
-#cd rust
-#cargo fmt -- --check
+cd $RUST_DIR
+cargo fmt -- --check
+
+# test common
+cd $RUST_DIR/common
+cargo build --features runtime
+cargo test --features runtime --tests
+
+cargo build --features frontend
+cargo test --features frontend --tests
+
+# test runtime
+cd $RUST_DIR/runtime
# run basic tests
-#python3 tests/build_model.py
-#cargo test --tests
+python3 tests/build_model.py
+cargo test --tests
# run TVM module test
-#cd tests/test_tvm_basic
-#cargo run
-#cd -
+cd tests/test_tvm_basic
+cargo run
+cd -
# run NNVM graph test
-#cd tests/test_nnvm
-#cargo run
-#cd -
+cd tests/test_nnvm
+cargo run
+cd -
+
+# test frontend
+cd $RUST_DIR/frontend
+
+cargo test --tests -- --test-threads=1
+
+# run basic tests on cpu
+cd tests/basics
+cargo build --features cpu
+cargo run --features cpu
+# uncomment when have more CI resources
+# cargo build --features gpu
+# cargo run --features gpu
+# fi
+cd -
+
+# run callback tests separately: https://discuss.tvm.ai/t/are-global-functions-need-to-be-accessed-in-separate-processes/1075
+cd tests/callback
+cargo build
+cargo run --bin int
+cargo run --bin float
+cargo run --bin array
+cargo run --bin string
+cd -