From: Ehsan M. Kermani Date: Sun, 3 Feb 2019 03:56:11 +0000 (-0800) Subject: [RUST][FRONTEND] Add rust frontend v0.1 (#2292) X-Git-Tag: upstream/0.7.0~2811 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e2970b226e39ba196b5ced5fe98c604c0620f939;p=platform%2Fupstream%2Ftvm.git [RUST][FRONTEND] Add rust frontend v0.1 (#2292) --- diff --git a/rust/.gitignore b/rust/.gitignore deleted file mode 100644 index 230ab6610..000000000 --- a/rust/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -Cargo.lock -target/ -**/*.rs.bk diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml index 51e3cbfa7..9e52f9efa 100644 --- a/rust/.rustfmt.toml +++ b/rust/.rustfmt.toml @@ -1,6 +1,6 @@ max_width = 100 hard_tabs = false -tab_spaces = 2 +tab_spaces = 4 newline_style = "Auto" use_small_heuristics = "Default" indent_style = "Block" @@ -38,7 +38,7 @@ trailing_comma = "Vertical" match_block_trailing_comma = false blank_lines_upper_bound = 1 blank_lines_lower_bound = 0 -edition = "2015" +edition = "2018" merge_derives = true use_try_shorthand = true use_field_init_shorthand = false @@ -50,8 +50,8 @@ unstable_features = false disable_all_formatting = false skip_children = false hide_parse_errors = false -error_on_line_overflow = false -error_on_unformatted = false +error_on_line_overflow = true +error_on_unformatted = true report_todo = "Never" report_fixme = "Never" ignore = [] diff --git a/rust/.travis.yml b/rust/.travis.yml deleted file mode 100644 index 63a3d0277..000000000 --- a/rust/.travis.yml +++ /dev/null @@ -1,5 +0,0 @@ -language: rust -rust: - - nightly -matrix: - fast_finish: true diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 4dd793e41..448cbfe30 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,28 +1,11 @@ -[package] -name = "tvm" -version = "0.1.0" -license = "Apache-2.0" -description = "TVM Rust runtime" -repository = "https://github.com/dmlc/tvm" -readme = "README.md" -keywords = ["tvm", "nnvm"] -categories = ["api-bindings", "science"] -authors = ["TVM Contributors"] - -[features] -default = ["nom/std"] -sgx = ["nom/alloc"] - -[dependencies] -bounded-spsc-queue = "0.4.0" -error-chain = { version = "0.12.0", default-features = false } -itertools = "0.7.8" -lazy_static = "1.1.0" -ndarray = "0.11.2" -nom = {version = "4.0.0", default-features = false } -serde = "1.0.59" -serde_derive = "1.0.79" -serde_json = "1.0.17" - -[target.'cfg(not(target_env = "sgx"))'.dependencies] -num_cpus = "1.8.0" +[workspace] +members = [ + "common", + "runtime", + "runtime/tests/test_tvm_basic", + "runtime/tests/test_nnvm", + "frontend", + "frontend/tests/basics", + "frontend/tests/callback", + "frontend/examples/resnet" +] diff --git a/rust/common/.gitignore b/rust/common/.gitignore new file mode 100644 index 000000000..84c2ae990 --- /dev/null +++ b/rust/common/.gitignore @@ -0,0 +1,4 @@ +target +**/*.rs.bk +Cargo.lock +/tvm-sys/src/bindgen.rs diff --git a/rust/common/Cargo.toml b/rust/common/Cargo.toml new file mode 100644 index 000000000..bcba5ad62 --- /dev/null +++ b/rust/common/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "tvm-common" +version = "0.1.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" + +[features] +runtime = [] +frontend = ["tvm-sys"] + +[dependencies] +error-chain = { version = "0.12.0", default-features = false } +tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true } diff --git a/rust/common/src/c_runtime_api.rs b/rust/common/src/c_runtime_api.rs new file mode 100644 index 000000000..6facf9ca2 --- /dev/null +++ b/rust/common/src/c_runtime_api.rs @@ -0,0 +1,770 @@ +/* automatically generated by rust-bindgen for TVM revision 6292c78 */ + +pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0"; +pub const DLPACK_VERSION: u32 = 8; +pub const _STDINT_H: u32 = 1; +pub const _FEATURES_H: u32 = 1; +pub const _DEFAULT_SOURCE: u32 = 1; +pub const __USE_ISOC11: u32 = 1; +pub const __USE_ISOC99: u32 = 1; +pub const __USE_ISOC95: u32 = 1; +pub const __USE_POSIX_IMPLICITLY: u32 = 1; +pub const _POSIX_SOURCE: u32 = 1; +pub const _POSIX_C_SOURCE: u32 = 200809; +pub const __USE_POSIX: u32 = 1; +pub const __USE_POSIX2: u32 = 1; +pub const __USE_POSIX199309: u32 = 1; +pub const __USE_POSIX199506: u32 = 1; +pub const __USE_XOPEN2K: u32 = 1; +pub const __USE_XOPEN2K8: u32 = 1; +pub const _ATFILE_SOURCE: u32 = 1; +pub const __USE_MISC: u32 = 1; +pub const __USE_ATFILE: u32 = 1; +pub const __USE_FORTIFY_LEVEL: u32 = 0; +pub const _STDC_PREDEF_H: u32 = 1; +pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_ISO_10646__: u32 = 201505; +pub const __STDC_NO_THREADS__: u32 = 1; +pub const __GNU_LIBRARY__: u32 = 6; +pub const __GLIBC__: u32 = 2; +pub const __GLIBC_MINOR__: u32 = 23; +pub const _SYS_CDEFS_H: u32 = 1; +pub const __WORDSIZE: u32 = 64; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const _BITS_WCHAR_H: u32 = 1; +pub const INT8_MIN: i32 = -128; +pub const INT16_MIN: i32 = -32768; +pub const INT32_MIN: i32 = -2147483648; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; +pub const UINT8_MAX: u32 = 255; +pub const UINT16_MAX: u32 = 65535; +pub const UINT32_MAX: u32 = 4294967295; +pub const INT_LEAST8_MIN: i32 = -128; +pub const INT_LEAST16_MIN: i32 = -32768; +pub const INT_LEAST32_MIN: i32 = -2147483648; +pub const INT_LEAST8_MAX: u32 = 127; +pub const INT_LEAST16_MAX: u32 = 32767; +pub const INT_LEAST32_MAX: u32 = 2147483647; +pub const UINT_LEAST8_MAX: u32 = 255; +pub const UINT_LEAST16_MAX: u32 = 65535; +pub const UINT_LEAST32_MAX: u32 = 4294967295; +pub const INT_FAST8_MIN: i32 = -128; +pub const INT_FAST16_MIN: i64 = -9223372036854775808; +pub const INT_FAST32_MIN: i64 = -9223372036854775808; +pub const INT_FAST8_MAX: u32 = 127; +pub const INT_FAST16_MAX: u64 = 9223372036854775807; +pub const INT_FAST32_MAX: u64 = 9223372036854775807; +pub const UINT_FAST8_MAX: u32 = 255; +pub const UINT_FAST16_MAX: i32 = -1; +pub const UINT_FAST32_MAX: i32 = -1; +pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const UINTPTR_MAX: i32 = -1; +pub const PTRDIFF_MIN: i64 = -9223372036854775808; +pub const PTRDIFF_MAX: u64 = 9223372036854775807; +pub const SIG_ATOMIC_MIN: i32 = -2147483648; +pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const SIZE_MAX: i32 = -1; +pub const WINT_MIN: u32 = 0; +pub const WINT_MAX: u32 = 4294967295; +pub type int_least8_t = ::std::os::raw::c_schar; +pub type int_least16_t = ::std::os::raw::c_short; +pub type int_least32_t = ::std::os::raw::c_int; +pub type int_least64_t = ::std::os::raw::c_long; +pub type uint_least8_t = ::std::os::raw::c_uchar; +pub type uint_least16_t = ::std::os::raw::c_ushort; +pub type uint_least32_t = ::std::os::raw::c_uint; +pub type uint_least64_t = ::std::os::raw::c_ulong; +pub type int_fast8_t = ::std::os::raw::c_schar; +pub type int_fast16_t = ::std::os::raw::c_long; +pub type int_fast32_t = ::std::os::raw::c_long; +pub type int_fast64_t = ::std::os::raw::c_long; +pub type uint_fast8_t = ::std::os::raw::c_uchar; +pub type uint_fast16_t = ::std::os::raw::c_ulong; +pub type uint_fast32_t = ::std::os::raw::c_ulong; +pub type uint_fast64_t = ::std::os::raw::c_ulong; +pub type intmax_t = ::std::os::raw::c_long; +pub type uintmax_t = ::std::os::raw::c_ulong; +pub type wchar_t = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct max_align_t { + pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, + pub __bindgen_padding_0: u64, + pub __clang_max_align_nonce2: f64, +} +pub const DLDeviceType_kDLCPU: DLDeviceType = 1; +pub const DLDeviceType_kDLGPU: DLDeviceType = 2; +pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3; +pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4; +pub const DLDeviceType_kDLMetal: DLDeviceType = 8; +pub const DLDeviceType_kDLVPI: DLDeviceType = 9; +pub const DLDeviceType_kDLROCM: DLDeviceType = 10; +/// \brief The device type in DLContext. +pub type DLDeviceType = u32; +/// \brief A Device context for Tensor and operator. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLContext { + /// \brief The device type used in the device. + pub device_type: DLDeviceType, + /// \brief The device index + pub device_id: ::std::os::raw::c_int, +} +pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0; +pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1; +pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2; +/// \brief The type code options DLDataType. +pub type DLDataTypeCode = u32; +/// \brief The data type the tensor can hold. +/// +/// Examples +/// - float: type_code = 2, bits = 32, lanes=1 +/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 +/// - int8: type_code = 0, bits = 8, lanes=1 +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLDataType { + /// \brief Type code of base types. + /// We keep it uint8_t instead of DLDataTypeCode for minimal memory + /// footprint, but the value should be one of DLDataTypeCode enum values. + /// + pub code: u8, + /// \brief Number of bits, common choices are 8, 16, 32. + pub bits: u8, + /// \brief Number of lanes in the type, used for vector types. + pub lanes: u16, +} +/// \brief Plain C Tensor object, does not manage memory. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLTensor { + /// \brief The opaque data pointer points to the allocated data. + /// This will be CUDA device pointer or cl_mem handle in OpenCL. + /// This pointer is always aligns to 256 bytes as in CUDA. + pub data: *mut ::std::os::raw::c_void, + /// \brief The device context of the tensor + pub ctx: DLContext, + /// \brief Number of dimensions + pub ndim: ::std::os::raw::c_int, + /// \brief The data type of the pointer + pub dtype: DLDataType, + /// \brief The shape of the tensor + pub shape: *mut i64, + /// \brief strides of the tensor, + /// can be NULL, indicating tensor is compact. + pub strides: *mut i64, + /// \brief The offset in bytes to the beginning pointer to data + pub byte_offset: u64, +} +/// \brief C Tensor object, manage memory of DLTensor. This data structure is +/// intended to faciliate the borrowing of DLTensor by another framework. It is +/// not meant to transfer the tensor. When the borrowing framework doesn't need +/// the tensor, it should call the deleter to notify the host that the resource +/// is no longer needed. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLManagedTensor { + /// \brief DLTensor which is being memory managed + pub dl_tensor: DLTensor, + /// \brief the context of the original host framework of DLManagedTensor in + /// which DLManagedTensor is used in the framework. It can also be NULL. + pub manager_ctx: *mut ::std::os::raw::c_void, + /// \brief Destructor signature void (*)(void*) - this should be called + /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + /// if there is no way for the caller to provide a reasonable destructor. + pub deleter: ::std::option::Option, +} +/// \brief type of array index. +pub type tvm_index_t = i64; +pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5; +pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6; +pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7; +pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11; +pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12; +/// \brief Extension device types in TVM +pub type TVMDeviceExtType = u32; +pub const TVMTypeCode_kHandle: TVMTypeCode = 3; +pub const TVMTypeCode_kNull: TVMTypeCode = 4; +pub const TVMTypeCode_kTVMType: TVMTypeCode = 5; +pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6; +pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7; +pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8; +pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9; +pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10; +pub const TVMTypeCode_kStr: TVMTypeCode = 11; +pub const TVMTypeCode_kBytes: TVMTypeCode = 12; +pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13; +pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15; +pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16; +pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20; +pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64; +pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128; +/// \brief The type code in TVMType +/// \note TVMType is used in two places. +pub type TVMTypeCode = u32; +/// \brief The data type used in TVM Runtime. +/// +/// Examples +/// - float: type_code = 2, bits = 32, lanes=1 +/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 +/// - int8: type_code = 0, bits = 8, lanes=1 +/// +/// \note Arguments TVM API function always takes bits=64 and lanes=1 +pub type TVMType = DLDataType; +/// \brief The Device information, abstract away common device types. +pub type TVMContext = DLContext; +/// \brief The tensor array stucture to TVM API. +pub type TVMArray = DLTensor; +/// \brief the array handle +pub type TVMArrayHandle = *mut TVMArray; +/// \brief Union type of values +/// being passed through API and function calls. +#[repr(C)] +#[derive(Copy, Clone)] +pub union TVMValue { + pub v_int64: i64, + pub v_float64: f64, + pub v_handle: *mut ::std::os::raw::c_void, + pub v_str: *const ::std::os::raw::c_char, + pub v_type: TVMType, + pub v_ctx: TVMContext, + _bindgen_union_align: u64, +} +/// \brief Byte array type used to pass in byte array +/// When kBytes is used as data type. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TVMByteArray { + pub data: *const ::std::os::raw::c_char, + pub size: usize, +} +/// \brief Handle to TVM runtime modules. +pub type TVMModuleHandle = *mut ::std::os::raw::c_void; +/// \brief Handle to packed function handle. +pub type TVMFunctionHandle = *mut ::std::os::raw::c_void; +/// \brief Handle to hold return value. +pub type TVMRetValueHandle = *mut ::std::os::raw::c_void; +/// \brief The stream that is specific to device +/// can be NULL, which indicates the default one. +pub type TVMStreamHandle = *mut ::std::os::raw::c_void; +extern "C" { + /// \brief Used for implementing C API function. + /// Set last error message before return. + /// \param msg The error message to be set. + pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char); +} +extern "C" { + /// \brief return str message of the last error + /// all function in this file will return 0 when success + /// and -1 when an error occured, + /// TVMGetLastError can be called to retrieve the error + /// + /// this function is threadsafe and can be called by different thread + /// \return error info + pub fn TVMGetLastError() -> *const ::std::os::raw::c_char; +} +extern "C" { + /// \brief Load module from file. + /// \param file_name The file name to load the module from. + /// \param format The format of the module. + /// \param out The result module + /// + /// \return 0 when success, -1 when failure happens + /// \note The resulting module do not contain import relation. + /// It can be reconstructed by TVMModImport. + pub fn TVMModLoadFromFile( + file_name: *const ::std::os::raw::c_char, + format: *const ::std::os::raw::c_char, + out: *mut TVMModuleHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Add dep to mod's dependency. + /// This allows functions in this module to use modules. + /// + /// \param mod The module handle. + /// \param dep The dependent module to be imported. + /// \return 0 when success, -1 when failure happens + pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Get function from the module. + /// \param mod The module handle. + /// \param func_name The name of the function. + /// \param query_imports Whether to query imported modules + /// \param out The result function, can be NULL if it is not available. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMModGetFunction( + mod_: TVMModuleHandle, + func_name: *const ::std::os::raw::c_char, + query_imports: ::std::os::raw::c_int, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free front-end extension type resource. + /// \param handle The extension handle. + /// \param type_code The type of of the extension type. + /// \return 0 when success, -1 when failure happens + pub fn TVMExtTypeFree( + handle: *mut ::std::os::raw::c_void, + type_code: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the Module + /// \param mod The module to be freed. + /// + /// \note This may not free up the module's resources. + /// If there is active TVMFunctionHandle uses the module + /// Or if this module is imported by another active module. + /// + /// The all functions remains valid until TVMFuncFree is called. + /// \return 0 when success, -1 when failure happens + pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the function when it is no longer needed. + /// \param func The function handle + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Call a Packed TVM Function. + /// + /// \param func node handle of the function. + /// \param arg_values The arguments + /// \param type_codes The type codes of the arguments + /// \param num_args Number of arguments. + /// + /// \param ret_val The return value. + /// \param ret_type_code the type code of return value. + /// + /// \return 0 when success, -1 when failure happens + /// \note TVM calls always exchanges with type bits=64, lanes=1 + /// + /// \note API calls always exchanges with type bits=64, lanes=1 + /// If API call returns container handles (e.g. FunctionHandle) + /// these handles should be managed by the front-end. + /// The front-end need to call free function (e.g. TVMFuncFree) + /// to free these handles. + pub fn TVMFuncCall( + func: TVMFunctionHandle, + arg_values: *mut TVMValue, + type_codes: *mut ::std::os::raw::c_int, + num_args: ::std::os::raw::c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Set the return value of TVMPackedCFunc. + /// + /// This function is called by TVMPackedCFunc to set the return value. + /// When this function is not called, the function returns null by default. + /// + /// \param ret The return value handle, pass by ret in TVMPackedCFunc + /// \param value The value to be returned. + /// \param type_code The type of the value to be returned. + /// \param num_ret Number of return values, for now only 1 is supported. + pub fn TVMCFuncSetReturn( + ret: TVMRetValueHandle, + value: *mut TVMValue, + type_code: *mut ::std::os::raw::c_int, + num_ret: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Inplace translate callback argument value to return value. + /// This is only needed for non-POD arguments. + /// + /// \param value The value to be translated. + /// \param code The type code to be translated. + /// \note This function will do a shallow copy when necessary. + /// + /// \return 0 when success, -1 when failure happens. + pub fn TVMCbArgToReturn( + value: *mut TVMValue, + code: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +/// \brief C type of packed function. +/// +/// \param args The arguments +/// \param type_codes The type codes of the arguments +/// \param num_args Number of arguments. +/// \param ret The return value handle. +/// \param resource_handle The handle additional resouce handle from fron-end. +/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. +/// \sa TVMCFuncSetReturn +pub type TVMPackedCFunc = ::std::option::Option< + unsafe extern "C" fn( + args: *mut TVMValue, + type_codes: *mut ::std::os::raw::c_int, + num_args: ::std::os::raw::c_int, + ret: TVMRetValueHandle, + resource_handle: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int, +>; +/// \brief C callback to free the resource handle in C packed function. +/// \param resource_handle The handle additional resouce handle from fron-end. +pub type TVMPackedCFuncFinalizer = + ::std::option::Option; +/// \brief Signature for extension function declarer. +/// +/// TVM call this function to get the extension functions +/// The declarer will call register_func to register function and their name. +/// +/// \param register_func_handle The register function +/// \return 0 if success, -1 if failure happens +pub type TVMExtensionFuncDeclarer = ::std::option::Option< + unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int, +>; +extern "C" { + /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. + /// + /// The resource_handle will be managed by TVM API, until the function is no longer used. + /// + /// \param func The packed C function. + /// \param resource_handle The resource handle from front-end, can be NULL. + /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL + /// \param out the result function handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncCreateFromCFunc( + func: TVMPackedCFunc, + resource_handle: *mut ::std::os::raw::c_void, + fin: TVMPackedCFuncFinalizer, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Register the function to runtime's global table. + /// + /// The registered function then can be pulled by the backend by the name. + /// + /// \param name The name of the function. + /// \param f The function to be registered. + /// \param override Whether allow override already registered function. + pub fn TVMFuncRegisterGlobal( + name: *const ::std::os::raw::c_char, + f: TVMFunctionHandle, + override_: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Get a global function. + /// + /// \param name The name of the function. + /// \param out the result function pointer, NULL if it does not exist. + /// + /// \note The function handle of global function is managed by TVM runtime, + /// So TVMFuncFree is should not be called when it get deleted. + pub fn TVMFuncGetGlobal( + name: *const ::std::os::raw::c_char, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief List all the globally registered function name + /// \param out_size The number of functions + /// \param out_array The array of function names. + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncListGlobalNames( + out_size: *mut ::std::os::raw::c_int, + out_array: *mut *mut *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Allocate a nd-array's memory, + /// including space of shape, of given spec. + /// + /// \param shape The shape of the array, the data content will be copied to out + /// \param ndim The number of dimension of the array. + /// \param dtype_code The type code of the dtype + /// \param dtype_bits The number of bits of dtype + /// \param dtype_lanes The number of lanes in the dtype. + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param out The output handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayAlloc( + shape: *const tvm_index_t, + ndim: ::std::os::raw::c_int, + dtype_code: ::std::os::raw::c_int, + dtype_bits: ::std::os::raw::c_int, + dtype_lanes: ::std::os::raw::c_int, + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + out: *mut TVMArrayHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the TVM Array. + /// \param handle The array handle to be freed. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy array data from CPU byte array. + /// \param handle The array handle. + /// \param data the data pointer + /// \param nbytes The number of bytes to copy. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromBytes( + handle: TVMArrayHandle, + data: *mut ::std::os::raw::c_void, + nbytes: usize, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy array data to CPU byte array. + /// \param handle The array handle. + /// \param data the data pointer + /// \param nbytes The number of bytes to copy. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyToBytes( + handle: TVMArrayHandle, + data: *mut ::std::os::raw::c_void, + nbytes: usize, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy the array, both from and to must be valid during the copy. + /// \param from The array to be copied from. + /// \param to The target space. + /// \param stream The stream where the copy happens, can be NULL. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromTo( + from: TVMArrayHandle, + to: TVMArrayHandle, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Produce an array from the DLManagedTensor that shares data memory + /// with the DLManagedTensor. + /// \param from The source DLManagedTensor. + /// \param out The output array handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayFromDLPack( + from: *mut DLManagedTensor, + out: *mut TVMArrayHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Produce a DLMangedTensor from the array that shares data memory with + /// the array. + /// \param from The source array. + /// \param out The DLManagedTensor handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayToDLPack( + from: TVMArrayHandle, + out: *mut *mut DLManagedTensor, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Delete (free) a DLManagedTensor's data. + /// \param dltensor Pointer to the DLManagedTensor. + pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor); +} +extern "C" { + /// \brief Create a new runtime stream. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param out The new stream handle + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamCreate( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + out: *mut TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free a created stream handle. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param stream The stream to be freed + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamFree( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Set the runtime stream of current thread to be stream. + /// The subsequent calls to the same device_type + /// will use the setted stream handle. + /// The specific type of stream is runtime device dependent. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param handle The stream handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMSetStream( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + handle: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Wait until all computations on stream completes. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param stream The stream to be synchronized. + /// \return 0 when success, -1 when failure happens + pub fn TVMSynchronize( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Synchronize two streams of execution. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param src The source stream to synchronize. + /// \param dst The destination stream to synchronize. + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamStreamSynchronize( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + src: TVMStreamHandle, + dst: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function for modules to get function + /// from its environment mod_node (its imports and global function). + /// The user do should not call TVMFuncFree on func. + /// + /// \param mod_node The module handle. + /// \param func_name The name of the function. + /// \param out The result function. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendGetFuncFromEnv( + mod_node: *mut ::std::os::raw::c_void, + func_name: *const ::std::os::raw::c_char, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function to register system-wide library symbol. + /// + /// \param name The name of the symbol + /// \param ptr The symbol address. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendRegisterSystemLibSymbol( + name: *const ::std::os::raw::c_char, + ptr: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function to allocate temporal workspace. + /// + /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. + /// + /// \param nbytes The size of the space requested. + /// \param device_type The device type which the space will be allocated. + /// \param device_id The device id which the space will be allocated. + /// \param dtype_code_hint The type code of the array elements. Only used in + /// certain backends such as OpenGL. + /// \param dtype_bits_hint The type bits of the array elements. Only used in + /// certain backends such as OpenGL. + /// \return nullptr when error is thrown, a valid ptr if success + pub fn TVMBackendAllocWorkspace( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + nbytes: u64, + dtype_code_hint: ::std::os::raw::c_int, + dtype_bits_hint: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_void; +} +extern "C" { + /// \brief Backend function to free temporal workspace. + /// + /// \param ptr The result allocated space pointer. + /// \param device_type The device type which the space will be allocated. + /// \param device_id The device id which the space will be allocated. + /// \return 0 when no error is thrown, -1 when failure happens + /// + /// \sa TVMBackendAllocWorkspace + pub fn TVMBackendFreeWorkspace( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + ptr: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int; +} +/// \brief Environment for TVM parallel task. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TVMParallelGroupEnv { + /// \brief Auxiliary used for synchronization + pub sync_handle: *mut ::std::os::raw::c_void, + /// \brief total amount of task + pub num_task: i32, +} +/// \brief The callback function to execute a parallel lambda +/// \param task_id the task id of the function. +/// \param penv The parallel environment backs the execution. +/// \param cdata The supporting closure data. +pub type FTVMParallelLambda = ::std::option::Option< + unsafe extern "C" fn( + task_id: ::std::os::raw::c_int, + penv: *mut TVMParallelGroupEnv, + cdata: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int, +>; +extern "C" { + /// \brief Backend function for running parallel jobs. + /// + /// \param flambda The parallel function to be launched. + /// \param cdata The closure data. + /// \param num_task Number of tasks to launch, can be 0, means launch + /// with all available threads. + /// + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendParallelLaunch( + flambda: FTVMParallelLambda, + cdata: *mut ::std::os::raw::c_void, + num_task: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief BSP barrrier between parallel threads + /// \param task_id the task id of the function. + /// \param penv The parallel environment backs the execution. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendParallelBarrier( + task_id: ::std::os::raw::c_int, + penv: *mut TVMParallelGroupEnv, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Simple static initialization function. + /// Run f once and set handle to be not null. + /// This function is mainly used for test purpose. + /// + /// \param handle An global address to indicate f + /// \param f The function to be ran + /// \param cdata The closure data to pass to the function. + /// \param nbytes Number of bytes in the closure data. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendRunOnce( + handle: *mut *mut ::std::os::raw::c_void, + f: ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int, + >, + cdata: *mut ::std::os::raw::c_void, + nbytes: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} diff --git a/rust/common/src/errors.rs b/rust/common/src/errors.rs new file mode 100644 index 000000000..a81fab9f8 --- /dev/null +++ b/rust/common/src/errors.rs @@ -0,0 +1,15 @@ +//! Error types for `TVMArgValue` and `TVMRetValue` conversions. + +error_chain! { + errors { + TryFromTVMArgValueError(expected: String, actual: String) { + description("mismatched types while converting from TVMArgValue") + display("expected `{}` but given `{}`", expected, actual) + } + + TryFromTVMRetValueError(expected: String, actual: String) { + description("mismatched types while downcasting TVMRetValue") + display("invalid downcast: expected `{}` but given `{}`", expected, actual) + } + } +} diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs new file mode 100644 index 000000000..ad4c4f235 --- /dev/null +++ b/rust/common/src/lib.rs @@ -0,0 +1,39 @@ +//! This crate contains the refactored basic components required +//! for `runtime` and `frontend` TVM crates. + +#![crate_name = "tvm_common"] +#![recursion_limit = "1024"] +#![allow(non_camel_case_types, unused_imports)] +#![feature(box_syntax, try_from)] + +#[macro_use] +extern crate error_chain; + +/// Unified ffi module for both runtime and frontend crates. +pub mod ffi { + #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] + + #[cfg(feature = "frontend")] + pub extern crate tvm_sys as ts; + + #[cfg(feature = "runtime")] + pub mod runtime { + use std::os::raw::{c_char, c_int, c_void}; + + include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); + + pub type BackendPackedCFunc = extern "C" fn( + args: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + ) -> c_int; + } +} + +pub mod errors; +pub mod ty; +pub mod value; + +pub use errors::*; +pub use ty::TVMTypeCode; +pub use value::{TVMArgValue, TVMRetValue, TVMValue}; diff --git a/rust/common/src/ty.rs b/rust/common/src/ty.rs new file mode 100644 index 000000000..126bcd445 --- /dev/null +++ b/rust/common/src/ty.rs @@ -0,0 +1,144 @@ +//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods. +//! +//! # Example +//! +//! ``` +//! let dtype = TVMType::from("float"); +//! println!("dtype is: {}", dtype); +//! ``` + +use std::{ + ffi::{CStr, CString}, + fmt::{self, Display, Formatter}, +}; + +/// TVM type codes. +#[repr(u32)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum TVMTypeCode { + kDLInt = 0, + kDLUInt = 1, + kDLFloat = 2, + kHandle = 3, + kNull = 4, + kTVMType = 5, + kTVMContext = 6, + kArrayHandle = 7, + kNodeHandle = 8, + kModuleHandle = 9, + kFuncHandle = 10, + kStr = 11, + kBytes = 12, + kNDArrayContainer = 13, +} + +impl Default for TVMTypeCode { + fn default() -> Self { + TVMTypeCode::kDLInt + } +} + +impl From for i64 { + fn from(arg: TVMTypeCode) -> i64 { + match arg { + TVMTypeCode::kDLInt => 0, + TVMTypeCode::kDLUInt => 1, + TVMTypeCode::kDLFloat => 2, + TVMTypeCode::kHandle => 3, + TVMTypeCode::kNull => 4, + TVMTypeCode::kTVMType => 5, + TVMTypeCode::kTVMContext => 6, + TVMTypeCode::kArrayHandle => 7, + TVMTypeCode::kNodeHandle => 8, + TVMTypeCode::kModuleHandle => 9, + TVMTypeCode::kFuncHandle => 10, + TVMTypeCode::kStr => 11, + TVMTypeCode::kBytes => 12, + TVMTypeCode::kNDArrayContainer => 13, + } + } +} + +impl Into for i64 { + fn into(self) -> TVMTypeCode { + match self { + 0 => TVMTypeCode::kDLInt, + 1 => TVMTypeCode::kDLUInt, + 2 => TVMTypeCode::kDLFloat, + 3 => TVMTypeCode::kHandle, + 4 => TVMTypeCode::kNull, + 5 => TVMTypeCode::kTVMType, + 6 => TVMTypeCode::kTVMContext, + 7 => TVMTypeCode::kArrayHandle, + 8 => TVMTypeCode::kNodeHandle, + 9 => TVMTypeCode::kModuleHandle, + 10 => TVMTypeCode::kFuncHandle, + 11 => TVMTypeCode::kStr, + 12 => TVMTypeCode::kBytes, + 13 => TVMTypeCode::kNDArrayContainer, + _ => unreachable!(), + } + } +} + +impl Display for TVMTypeCode { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "{}", + match self { + TVMTypeCode::kDLInt => "int", + TVMTypeCode::kDLUInt => "uint", + TVMTypeCode::kDLFloat => "float", + TVMTypeCode::kHandle => "handle", + TVMTypeCode::kNull => "null", + TVMTypeCode::kTVMType => "TVM type", + TVMTypeCode::kTVMContext => "TVM context", + TVMTypeCode::kArrayHandle => "Array handle", + TVMTypeCode::kNodeHandle => "Node handle", + TVMTypeCode::kModuleHandle => "Module handle", + TVMTypeCode::kFuncHandle => "Function handle", + TVMTypeCode::kStr => "string", + TVMTypeCode::kBytes => "bytes", + TVMTypeCode::kNDArrayContainer => "ndarray container", + } + ) + } +} + +macro_rules! impl_prim_type { + ($type:ty, $variant:ident) => { + impl<'a> From<&'a $type> for TVMTypeCode { + fn from(_arg: &$type) -> Self { + TVMTypeCode::$variant + } + } + + impl<'a> From<&'a mut $type> for TVMTypeCode { + fn from(_arg: &mut $type) -> Self { + TVMTypeCode::$variant + } + } + }; +} + +impl_prim_type!(usize, kDLInt); +impl_prim_type!(i64, kDLInt); +impl_prim_type!(i32, kDLInt); +impl_prim_type!(i16, kDLInt); +impl_prim_type!(i8, kDLInt); + +impl_prim_type!(u64, kDLUInt); +impl_prim_type!(u32, kDLUInt); +impl_prim_type!(u16, kDLUInt); +impl_prim_type!(u8, kDLUInt); + +impl_prim_type!(f64, kDLFloat); +impl_prim_type!(f32, kDLFloat); + +impl_prim_type!(str, kStr); +impl_prim_type!(CStr, kStr); +impl_prim_type!(String, kStr); +impl_prim_type!(CString, kStr); + +impl_prim_type!([u8], kBytes); diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs new file mode 100644 index 000000000..6da8b27e8 --- /dev/null +++ b/rust/common/src/value.rs @@ -0,0 +1,559 @@ +//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue` +//! required for using TVM functions. + +use std::{ + any::Any, + convert::TryFrom, + ffi::{CStr, CString}, + fmt::{self, Debug, Formatter}, + marker::PhantomData, + mem, + ops::Deref, + os::raw::{c_char, c_void}, +}; + +#[cfg(feature = "runtime")] +use ffi::runtime::TVMValue as _TVMValue; + +#[cfg(feature = "frontend")] +use ffi::ts::TVMValue as _TVMValue; + +use errors::*; + +use ty::TVMTypeCode; + +/// Wrapped TVMValue type. +#[derive(Clone, Copy)] +pub struct TVMValue { + pub inner: _TVMValue, +} + +impl TVMValue { + /// Creates TVMValue from the raw part. + pub fn new(inner: _TVMValue) -> Self { + TVMValue { inner } + } + + pub(crate) fn into_raw(self) -> _TVMValue { + self.inner + } +} + +impl Debug for TVMValue { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + unsafe { + write!( + f, + "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\ + [v_str: {:?}]", + self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str + ) + } + } +} + +impl Deref for TVMValue { + type Target = _TVMValue; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +macro_rules! impl_prim_val { + ($type:ty, $field:ident, $cast:ty) => { + impl From<$type> for TVMValue { + fn from(arg: $type) -> Self { + let inner = _TVMValue { + $field: arg as $cast, + }; + Self::new(inner) + } + } + + impl<'a> From<&'a $type> for TVMValue { + fn from(arg: &$type) -> Self { + let inner = _TVMValue { + $field: *arg as $cast, + }; + Self::new(inner) + } + } + + impl<'a> From<&'a mut $type> for TVMValue { + fn from(arg: &mut $type) -> Self { + let inner = _TVMValue { + $field: *arg as $cast, + }; + Self::new(inner) + } + } + + impl TryFrom for $type { + type Error = Error; + fn try_from(val: TVMValue) -> Result { + Ok(unsafe { val.inner.$field as $type }) + } + } + + impl<'a> TryFrom<&'a TVMValue> for $type { + type Error = Error; + fn try_from(val: &TVMValue) -> Result { + Ok(unsafe { val.into_raw().$field as $type }) + } + } + + impl<'a> TryFrom<&'a mut TVMValue> for $type { + type Error = Error; + fn try_from(val: &mut TVMValue) -> Result { + Ok(unsafe { val.into_raw().$field as $type }) + } + } + }; +} + +impl_prim_val!(isize, v_int64, i64); +impl_prim_val!(i64, v_int64, i64); +impl_prim_val!(i32, v_int64, i64); +impl_prim_val!(i16, v_int64, i64); +impl_prim_val!(i8, v_int64, i64); +impl_prim_val!(usize, v_int64, i64); +impl_prim_val!(u64, v_int64, i64); +impl_prim_val!(u32, v_int64, i64); +impl_prim_val!(u16, v_int64, i64); +impl_prim_val!(u8, v_int64, i64); + +impl_prim_val!(f64, v_float64, f64); +impl_prim_val!(f32, v_float64, f64); + +impl<'a> From<&'a str> for TVMValue { + fn from(arg: &str) -> TVMValue { + let arg = CString::new(arg).unwrap(); + let inner = _TVMValue { + v_str: arg.as_ptr() as *const c_char, + }; + mem::forget(arg); + Self::new(inner) + } +} + +impl<'a> From<&'a String> for TVMValue { + fn from(arg: &String) -> TVMValue { + let arg = CString::new(arg.as_bytes()).unwrap(); + let inner = _TVMValue { + v_str: arg.as_ptr() as *const c_char, + }; + mem::forget(arg); + Self::new(inner) + } +} + +impl<'a> From<&'a CString> for TVMValue { + fn from(arg: &CString) -> TVMValue { + let arg = arg.to_owned(); + let inner = _TVMValue { + v_str: arg.as_ptr() as *const c_char, + }; + mem::forget(arg); + Self::new(inner) + } +} + +impl<'a> From<&'a [u8]> for TVMValue { + fn from(arg: &[u8]) -> TVMValue { + let arg = arg.to_owned(); + let inner = _TVMValue { + v_handle: &arg as *const _ as *mut c_void, + }; + mem::forget(arg); + Self::new(inner) + } +} + +/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function. +/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`. +/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions. +/// +/// ## Example +/// +/// ``` +/// let s = "hello".to_string(); +/// let arg = TVMArgValue::from(&s); +/// let tvm: String = arg.try_into().unwrap(); +/// assert_eq!(arg, s); +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct TVMArgValue<'a> { + /// The wrapped TVMValue + pub value: TVMValue, + /// The matching type code. + pub type_code: TVMTypeCode, + /// This is only exposed to runtime and frontend crates and is not meant to be used directly. + pub lifetime: PhantomData<&'a ()>, +} + +impl<'a> TVMArgValue<'a> { + pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self { + TVMArgValue { + value: value, + type_code: type_code, + lifetime: PhantomData, + } + } +} + +impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + if (arg.type_code == TVMTypeCode::kDLInt) + | (arg.type_code == TVMTypeCode::kDLUInt) + | (arg.type_code == TVMTypeCode::kNull) + { + Ok(unsafe { arg.value.inner.v_int64 }) + } else { + bail!(ErrorKind::TryFromTVMArgValueError( + stringify!(i64).to_string(), + arg.type_code.to_string() + )) + } + } +} + +impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + if arg.type_code == TVMTypeCode::kDLFloat { + Ok(unsafe { arg.value.inner.v_float64 }) + } else { + bail!(ErrorKind::TryFromTVMArgValueError( + stringify!(f64).to_string(), + arg.type_code.to_string() + )) + } + } +} + +impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + if arg.type_code == TVMTypeCode::kStr { + let ret_str = unsafe { + match CStr::from_ptr(arg.value.inner.v_str).to_str() { + Ok(s) => s, + Err(_) => "Invalid UTF-8 message", + } + }; + Ok(ret_str.to_string()) + } else { + bail!(ErrorKind::TryFromTVMArgValueError( + stringify!(String).to_string(), + arg.type_code.to_string() + )) + } + } +} + +/// Main way to create a TVMArgValue from suported Rust values. +impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a> +where + TVMValue: From<&'b T>, + TVMTypeCode: From<&'b T>, +{ + fn from(arg: &'b T) -> Self { + TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg)) + } +} + +/// Creates a conversion to a `TVMArgValue` for an object handle. +impl<'a, T> From<*const T> for TVMArgValue<'a> { + fn from(ptr: *const T) -> Self { + let value = TVMValue::new(_TVMValue { + v_handle: ptr as *mut T as *mut c_void, + }); + + TVMArgValue::new(value, TVMTypeCode::kArrayHandle) + } +} + +/// Creates a conversion to a `TVMArgValue` for a mutable object handle. +impl<'a, T> From<*mut T> for TVMArgValue<'a> { + fn from(ptr: *mut T) -> Self { + let value = TVMValue::new(_TVMValue { + v_handle: ptr as *mut c_void, + }); + + TVMArgValue::new(value, TVMTypeCode::kHandle) + } +} + +/// An owned version of TVMPODValue. It can be converted from varieties of +/// primitive and object types. +/// It can be downcasted using `try_from` if it contains the desired type. +/// +/// # Example +/// +/// ``` +/// let a = 42u32; +/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); +/// +/// let s = "hello, world!"; +/// let t: TVMRetValue = s.into(); +/// assert_eq!(String::try_from(t).unwrap(), s); +/// ``` +pub struct TVMRetValue { + /// A primitive return value, if any. + pub prim_value: usize, + /// An object return value, if any. + pub box_value: Box, + pub type_code: TVMTypeCode, +} + +impl TVMRetValue { + fn new(prim_value: usize, box_value: Box, type_code: TVMTypeCode) -> Self { + Self { + prim_value, + box_value, + type_code, + } + } + + /// unsafe function to create `TVMRetValue` from `TVMValue` and + /// its matching `TVMTypeCode`. + pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self { + let value = value.into_raw(); + match type_code { + TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => { + Self::new(value.v_int64 as usize, box (), type_code) + } + TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code), + TVMTypeCode::kHandle + | TVMTypeCode::kArrayHandle + | TVMTypeCode::kNodeHandle + | TVMTypeCode::kModuleHandle + | TVMTypeCode::kFuncHandle => { + Self::new(value.v_handle as usize, box value.v_handle, type_code) + } + TVMTypeCode::kStr | TVMTypeCode::kBytes => { + Self::new(value.v_str as usize, box (value.v_str), type_code) + } + _ => Self::new(0usize, box (), type_code), + } + } + + /// Returns the underlying `TVMValue` and `TVMTypeCode`. + pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) { + let val = match self.type_code { + TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue { + v_int64: self.prim_value as i64, + }), + TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue { + v_float64: self.prim_value as f64, + }), + TVMTypeCode::kHandle + | TVMTypeCode::kArrayHandle + | TVMTypeCode::kNodeHandle + | TVMTypeCode::kModuleHandle + | TVMTypeCode::kFuncHandle + | TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue { + v_handle: self.prim_value as *const c_void as *mut c_void, + }), + TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue { + v_str: self.prim_value as *const c_char, + }), + _ => unreachable!(), + }; + (val, self.type_code) + } +} + +impl Default for TVMRetValue { + fn default() -> Self { + TVMRetValue { + prim_value: 0usize, + box_value: box (), + type_code: TVMTypeCode::default(), + } + } +} + +impl Clone for TVMRetValue { + fn clone(&self) -> Self { + match self.type_code { + TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => { + Self::new(self.prim_value.clone(), box (), self.type_code.clone()) + } + TVMTypeCode::kHandle + | TVMTypeCode::kArrayHandle + | TVMTypeCode::kNodeHandle + | TVMTypeCode::kModuleHandle + | TVMTypeCode::kFuncHandle + | TVMTypeCode::kNDArrayContainer => Self::new( + self.prim_value.clone(), + box (self.prim_value.clone() as *const c_void as *mut c_void), + self.type_code.clone(), + ), + TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new( + self.prim_value.clone(), + box (self.prim_value.clone() as *const c_char), + self.type_code.clone(), + ), + _ => unreachable!(), + } + } +} + +impl Debug for TVMRetValue { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "prim_value: {:?}, box_value: {:?}, type_code: {:?}", + self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code + ) + } +} + +macro_rules! impl_prim_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: val as usize, + box_value: box (), + type_code: $code, + } + } + } + + impl<'a> From<&'a $type> for TVMRetValue { + fn from(val: &$type) -> Self { + TVMRetValue { + prim_value: *val as usize, + box_value: box (), + type_code: $code, + } + } + } + + impl<'a> From<&'a mut $type> for TVMRetValue { + fn from(val: &mut $type) -> Self { + TVMRetValue { + prim_value: *val as usize, + box_value: box (), + type_code: $code, + } + } + } + + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if ret.type_code == $code { + Ok(ret.prim_value as $type) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code.to_string(), + )) + } + } + } + }; +} + +impl_prim_ret_value!(i8, TVMTypeCode::kDLInt); +impl_prim_ret_value!(i16, TVMTypeCode::kDLInt); +impl_prim_ret_value!(i32, TVMTypeCode::kDLInt); +impl_prim_ret_value!(i64, TVMTypeCode::kDLInt); +impl_prim_ret_value!(isize, TVMTypeCode::kDLInt); + +impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt); +impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt); +impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt); +impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt); +impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt); + +impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat); +impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat); + +macro_rules! impl_ptr_ret_value { + ($type:ty) => { + impl From<$type> for TVMRetValue { + fn from(ptr: $type) -> Self { + TVMRetValue { + prim_value: ptr as usize, + box_value: box (), + type_code: TVMTypeCode::kHandle, + } + } + } + + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if ret.type_code == TVMTypeCode::kHandle { + Ok(ret.prim_value as $type) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code.to_string(), + )) + } + } + } + }; +} + +impl_ptr_ret_value!(*const c_void); +impl_ptr_ret_value!(*mut c_void); + +impl From for TVMRetValue { + fn from(val: String) -> Self { + let pval = val.as_ptr() as *const c_char as usize; + let bval = box (val.as_ptr() as *const c_char); + mem::forget(val); + TVMRetValue::new(pval, bval, TVMTypeCode::kStr) + } +} + +impl TryFrom for String { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result { + // Note: simple downcast doesn't work for function call return values + let ret_str = unsafe { + match CStr::from_ptr(ret.prim_value as *const c_char).to_str() { + Ok(s) => s, + Err(_) => "Invalid UTF-8 message", + } + }; + + Ok(ret_str.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::convert::TryInto; + + #[test] + fn numeric() { + macro_rules! arg_ret_tests { + ($v:expr; $($ty:ty),+) => {{ + $( + let v = $v as $ty; + let b = TVMRetValue::from(&v); + let b: $ty = b.try_into().unwrap(); + assert_eq!(b, v); + )+ + }}; + } + + arg_ret_tests!(42; i8, i16, i32, i64, f32, f64); + } + + #[test] + fn string() { + let s = "hello".to_string(); + let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap(); + assert_eq!(tvm_arg, s); + } +} diff --git a/rust/common/tvm-sys/Cargo.toml b/rust/common/tvm-sys/Cargo.toml new file mode 100644 index 000000000..117d174b4 --- /dev/null +++ b/rust/common/tvm-sys/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tvm-sys" +version = "0.1.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +description = "Raw C API" + +[build-dependencies] +bindgen = "0.37.4" diff --git a/rust/common/tvm-sys/build.rs b/rust/common/tvm-sys/build.rs new file mode 100644 index 000000000..f842043a1 --- /dev/null +++ b/rust/common/tvm-sys/build.rs @@ -0,0 +1,25 @@ +extern crate bindgen; + +use std::path::PathBuf; + +fn main() { + println!("cargo:rerun-if-env-changed=TVM_HOME"); + println!("cargo:rustc-link-lib=dylib=tvm_runtime"); + println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); + let bindings = bindgen::Builder::default() + .header(format!( + "{}/include/tvm/runtime/c_runtime_api.h", + env!("TVM_HOME") + )) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) + .blacklist_type("max_align_t") // @see rust-bindgen#550 + .layout_tests(false) + .derive_partialeq(true) + .derive_eq(true) + .generate() + .expect("unable to generate bindings"); + + bindings + .write_to_file(PathBuf::from("src/bindgen.rs")) + .expect("can not write the bindings!"); +} diff --git a/rust/common/tvm-sys/src/lib.rs b/rust/common/tvm-sys/src/lib.rs new file mode 100644 index 000000000..15f1ea3a6 --- /dev/null +++ b/rust/common/tvm-sys/src/lib.rs @@ -0,0 +1,9 @@ +#![allow( + non_camel_case_types, + non_snake_case, + non_upper_case_globals, + dead_code, + improper_ctypes +)] + +include!("bindgen.rs"); diff --git a/rust/frontend/.gitignore b/rust/frontend/.gitignore new file mode 100644 index 000000000..2430329c7 --- /dev/null +++ b/rust/frontend/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/frontend/.travis.yml b/rust/frontend/.travis.yml new file mode 100644 index 000000000..63a3d0277 --- /dev/null +++ b/rust/frontend/.travis.yml @@ -0,0 +1,5 @@ +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml new file mode 100644 index 000000000..db261551e --- /dev/null +++ b/rust/frontend/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "tvm-frontend" +version = "0.1.0" +license = "Apache-2.0" +description = "Rust frontend support for TVM" +repository = "https://github.com/dmlc/tvm" +homepage = "https://github.com/dmlc/tvm" +readme = "README.md" +keywords = ["rust", "tvm", "nnvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] + +[lib] +name = "tvm_frontend" +crate-type = ["dylib"] + +[dependencies] +error-chain = "0.12.0" +lazy_static = "1.1.0" +ndarray = "0.12.1" +num-traits = "0.2" +tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] } + +[features] +blas = ["ndarray/blas"] diff --git a/rust/frontend/README.md b/rust/frontend/README.md new file mode 100644 index 000000000..5bd4362ae --- /dev/null +++ b/rust/frontend/README.md @@ -0,0 +1,219 @@ +# TVM Runtime Frontend Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/dmlc/tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` + +## What Does This Crate Offer? + +Here is a major workflow + +1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/) +2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. +3. Deploy your models using **Rust** :heart: + +### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k + +Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. + +Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM + +```python +block = get_model('resnet18_v1', pretrained=True) + +sym, params = nnvm.frontend.from_mxnet(block) +# add the softmax layer for prediction +net = nnvm.sym.softmax(sym) +# compile the model +with nnvm.compiler.build_config(opt_level=opt_level): + graph, lib, params = nnvm.compiler.build( + net, target, shape={"data": data_shape}, params=params) +# same the model artifacts +lib.save(os.path.join(target_dir, "deploy_lib.o")) +cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), + [os.path.join(target_dir, "deploy_lib.o")]) + +with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph.json()) +with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(nnvm.compiler.save_param_dict(params)) +``` + +Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image + +![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true) + +as demostrated in the following Rust snippet + +```rust + let graph = fs::read_to_string("deploy_graph.json")?; + // load the built module + let lib = Module::load(&Path::new("deploy_lib.so"))?; + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + &graph, + &lib, + &ctx.device_type, + &ctx.device_id + )?; + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec = fs::read("deploy_param.params")?; + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr)?; + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data", &input)?; + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,)?; + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output)?; + // flatten the output as Vec + let output = output.to_vec::()?; +``` + +and the model correctly predicts the input image as **tiger cat**. + +## Installations + +Please follow TVM [installations](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +*Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. + +## Supported TVM Functionalities + +### Use TVM to Generate Shared Library + +One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. + +```python +import os +import tvm +from tvm.contrib import cc + +def test_add(target_dir): + if not tvm.module.enabled("cuda"): + print(f"skip {__file__} because cuda is not enabled...") + return + n = tvm.var("n") + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") + s = tvm.create_schedule(C.op) + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") + + fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) + fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) + cc.create_shared(os.path.join(target_dir, "add_gpu.so"), + [os.path.join(target_dir, "add_gpu.o")]) + + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + sys.exit(-1) + test_add(sys.argv[1]) +``` + +### Run the Generated Shared Library + +The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. + +```rust +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); + let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); + assert!(fadd.enabled("gpu")); + fadd.import_module(fadd_dep); + fadd.entry(); + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .set_output(&mut ret)? + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); +} +``` + +**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by +`cargo:rustc-link-search=native=add_gpu`. + +See the tests and examples custom `build.rs` for more details. + +### Convert and Register a Rust Function as a TVM Packed Function + +One can use `register_global_func!` macro to convert and register a Rust +function of type `fn(&[TVMArgValue]) -> Result` to a global TVM **packed function** as follows + +```rust +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e).unwrap(); + let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); + ret += rnd.scalar_sum(); + } + let ret_val = TVMRetValue::from(&ret); + Ok(ret_val) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut registered = function::Builder::default(); + let ret: f64 = registered + .get_function("sum", true) + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + + assert_eq!(ret, 14f64); + } +``` diff --git a/rust/frontend/examples/resnet/Cargo.toml b/rust/frontend/examples/resnet/Cargo.toml new file mode 100644 index 000000000..e8a3eb7f5 --- /dev/null +++ b/rust/frontend/examples/resnet/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "resnet" +version = "0.0.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +build = "build.rs" + +[dependencies] +ndarray = "0.12.1" +tvm-frontend = { path = "../../" } +image = "0.20.1" +csv = "1" diff --git a/rust/frontend/examples/resnet/README.md b/rust/frontend/examples/resnet/README.md new file mode 100644 index 000000000..3d20d55a8 --- /dev/null +++ b/rust/frontend/examples/resnet/README.md @@ -0,0 +1,15 @@ +## Resnet example + +This end-to-end example shows how to: +* build `Resnet 18` with `tvm` and `nnvm` from Python +* use the provided Rust frontend API to test for an input image + +To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html). + +* **Build the example**: `cargo build` + +To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with +`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details. + +* **Run the example**: `cargo run` diff --git a/rust/frontend/examples/resnet/build.rs b/rust/frontend/examples/resnet/build.rs new file mode 100644 index 000000000..f913bf8b0 --- /dev/null +++ b/rust/frontend/examples/resnet/build.rs @@ -0,0 +1,16 @@ +use std::process::Command; + +fn main() { + let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .output() + .expect("Failed to execute command"); + assert!( + std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(), + "Could not prepare demo: {}", + String::from_utf8(output.stderr).unwrap().trim() + ); + println!( + "cargo:rustc-link-search=native={}", + env!("CARGO_MANIFEST_DIR") + ); +} diff --git a/rust/frontend/examples/resnet/src/build_resnet.py b/rust/frontend/examples/resnet/src/build_resnet.py new file mode 100755 index 000000000..e5b76aa82 --- /dev/null +++ b/rust/frontend/examples/resnet/src/build_resnet.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import argparse +import csv +import logging +from os import path as osp +import sys + +import numpy as np + +import mxnet as mx +from mxnet.gluon.model_zoo.vision import get_model +from mxnet.gluon.utils import download + +import tvm +from tvm.contrib import graph_runtime, cc +import nnvm + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser(description='Resnet build example') +aa = parser.add_argument +aa('--batch-size', type=int, default=1, help='input image batch size') +aa('--opt-level', type=int, default=3, + help='level of optimization. 0 is unoptimized and 3 is the highest level') +aa('--target', type=str, default='llvm', help='target context for compilation') +aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') +aa('--image-name', type=str, default='cat.png', help='name of input image to download') +args = parser.parse_args() + +target_dir = osp.dirname(osp.dirname(osp.realpath(__file__))) +batch_size = args.batch_size +opt_level = args.opt_level +target = tvm.target.create(args.target) +image_shape = tuple(map(int, args.image_shape.split(","))) +data_shape = (batch_size,) + image_shape + +def build(target_dir): + """ Compiles resnet18 with TVM""" + deploy_lib = osp.join(target_dir, 'deploy_lib.o') + if osp.exists(deploy_lib): + return + # download the pretrained resnet18 trained on imagenet1k dataset for + # image classification task + block = get_model('resnet18_v1', pretrained=True) + + sym, params = nnvm.frontend.from_mxnet(block) + # add the softmax layer for prediction + net = nnvm.sym.softmax(sym) + # compile the model + with nnvm.compiler.build_config(opt_level=opt_level): + graph, lib, params = nnvm.compiler.build( + net, target, shape={"data": data_shape}, params=params) + # save the model artifacts + lib.save(deploy_lib) + cc.create_shared(osp.join(target_dir, "deploy_lib.so"), + [osp.join(target_dir, "deploy_lib.o")]) + + with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph.json()) + + with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(nnvm.compiler.save_param_dict(params)) + +def download_img_labels(): + """ Download an image and imagenet1k class labels for test""" + img_name = 'cat.png' + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) + synset_name = 'synset.txt' + download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name) + download(synset_url, synset_name) + + with open(synset_name) as fin: + synset = eval(fin.read()) + + with open("synset.csv", "w") as fout: + w = csv.writer(fout) + w.writerows(synset.items()) + +def test_build(target_dir): + """ Sanity check with random input""" + graph = open(osp.join(target_dir, "deploy_graph.json")).read() + lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so")) + params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read()) + input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.load_params(params) + module.run(data=input_data) + out = module.get_output(0).asnumpy() + + +if __name__ == '__main__': + logger.info("building the model") + build(target_dir) + logger.info("build was successful") + logger.info("test the build artifacts") + test_build(target_dir) + logger.info("test was successful") + download_img_labels() + logger.info("image and synset downloads are successful") diff --git a/rust/frontend/examples/resnet/src/main.rs b/rust/frontend/examples/resnet/src/main.rs new file mode 100644 index 000000000..869a35b3a --- /dev/null +++ b/rust/frontend/examples/resnet/src/main.rs @@ -0,0 +1,134 @@ +#![feature(try_from)] + +extern crate csv; +extern crate image; +extern crate ndarray; +extern crate tvm_frontend as tvm; + +use std::{ + collections::HashMap, + convert::TryInto, + fs::{self, File}, + path::Path, +}; + +use image::{FilterType, GenericImageView}; +use ndarray::{Array, ArrayD, Axis}; + +use tvm::*; + +fn main() { + let ctx = TVMContext::cpu(0); + let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); + println!("original image dimensions: {:?}", img.dimensions()); + // for bigger size images, one needs to first resize to 256x256 + // with `img.resize_exact` method and then `image.crop` to 224x224 + let img = img.resize(224, 224, FilterType::Nearest).to_rgb(); + println!("resized image dimensions: {:?}", img.dimensions()); + let mut pixels: Vec = vec![]; + for pixel in img.pixels() { + let tmp = pixel.data; + // normalize the RGB channels using mean, std of imagenet1k + let tmp = [ + (tmp[0] as f32 - 123.0) / 58.395, // R + (tmp[1] as f32 - 117.0) / 57.12, // G + (tmp[2] as f32 - 104.0) / 57.375, // B + ]; + for e in &tmp { + pixels.push(*e); + } + } + + let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap(); + let arr: ArrayD = arr.permuted_axes([2, 0, 1]).into_dyn(); + // make arr shape as [1, 3, 224, 224] acceptable to resnet + let arr = arr.insert_axis(Axis(0)); + // create input tensor from rust's ndarray + let input = + NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); + println!( + "input size is {:?}", + input.shape().expect("cannot get the input shape") + ); + let graph = + fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); + // load the built module + let lib = Module::load(&Path::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/deploy_lib.so" + ))) + .unwrap(); + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + &graph, + &lib, + &ctx.device_type, + &ctx.device_id + ) + .unwrap(); + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap(); + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec = + fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr).unwrap(); + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data", &input).unwrap(); + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,).unwrap(); + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output).unwrap(); + // flatten the output as Vec + let output = output.to_vec::().unwrap(); + // find the maximum entry in the output and its index + let mut argmax = -1; + let mut max_prob = 0.; + for i in 0..output.len() { + if output[i] > max_prob { + max_prob = output[i]; + argmax = i as i32; + } + } + // create a hash map of (class id, class name) + let mut synset: HashMap = HashMap::new(); + let file = File::open("synset.csv").unwrap(); + let mut rdr = csv::ReaderBuilder::new() + .has_headers(true) + .from_reader(file); + + for result in rdr.records() { + let record = result.unwrap(); + let id: i32 = record[0].parse().unwrap(); + let cls = record[1].to_string(); + synset.insert(id, cls); + } + + println!( + "input image belongs to the class `{}` with probability {}", + synset + .get(&argmax) + .expect("cannot find the class id for argmax"), + max_prob + ); +} diff --git a/rust/frontend/src/bytearray.rs b/rust/frontend/src/bytearray.rs new file mode 100644 index 000000000..395f34c24 --- /dev/null +++ b/rust/frontend/src/bytearray.rs @@ -0,0 +1,72 @@ +//! Provides [`TVMByteArray`] used for passing the model parameters +//! (stored as byte-array) to a runtime module. +//! +//! For more detail, please see the example `resnet` in `examples` repository. + +use std::os::raw::c_char; + +use crate::ts; + +/// A struct holding TVM byte-array. +/// +/// ## Example +/// +/// ``` +/// let v = b"hello".to_vec(); +/// let barr = TVMByteArray::from(&v); +/// assert_eq!(barr.len(), v.len()); +/// assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]); +/// ``` +#[derive(Debug, Clone)] +pub struct TVMByteArray { + pub(crate) inner: ts::TVMByteArray, +} + +impl TVMByteArray { + pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray { + TVMByteArray { inner: barr } + } + + /// Gets the length of the underlying byte-array + pub fn len(&self) -> usize { + self.inner.size + } + + /// Gets the underlying byte-array as `Vec` + pub fn data(&self) -> Vec { + unsafe { + let sz = self.len(); + let mut ret_buf = Vec::with_capacity(sz); + ret_buf.set_len(sz); + self.inner.data.copy_to(ret_buf.as_mut_ptr(), sz); + ret_buf + } + } +} + +impl<'a> From<&'a Vec> for TVMByteArray { + fn from(arg: &Vec) -> Self { + let barr = ts::TVMByteArray { + data: arg.as_ptr() as *const c_char, + size: arg.len(), + }; + TVMByteArray::new(barr) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn convert() { + let v = vec![1u8, 2, 3]; + let barr = TVMByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.data(), vec![1i8, 2, 3]); + let v = b"hello".to_vec(); + let barr = TVMByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]); + } +} diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs new file mode 100644 index 000000000..65e11d82e --- /dev/null +++ b/rust/frontend/src/context.rs @@ -0,0 +1,286 @@ +//! Provides [`TVMContext`] and related device specific queries. +//! +//! Create a new context by device type (cpu is 1) and device id. +//! +//! # Example +//! +//! ``` +//! let ctx = TVMContext::new(1, 0); +//! let cpu0 = TVMContext::cpu(0); +//! assert_eq!(ctx, cpu0); +//! ``` +//! +//! Or from a supported device name. +//! +//! ``` +//! let cpu0 = TVMContext::from("cpu"); +//! println!("{}", cpu0); +//! ``` + +use std::{ + fmt::{self, Display, Formatter}, + os::raw::c_void, + ptr, +}; + +use crate::{function, ts, Result}; + +/// Device type can be from a supported device name. See the supported devices +/// in [TVM](https://github.com/dmlc/tvm). +/// +/// ## Example +/// +/// ``` +/// let cpu = TVMDeviceType::from("cpu"); +/// println!("device is: {}", cpu); +///``` + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TVMDeviceType(pub usize); + +impl Default for TVMDeviceType { + /// default device is cpu. + fn default() -> Self { + TVMDeviceType(1) + } +} + +impl From for ts::DLDeviceType { + fn from(device_type: TVMDeviceType) -> Self { + match device_type.0 { + 1 => ts::DLDeviceType_kDLCPU, + 2 => ts::DLDeviceType_kDLGPU, + 3 => ts::DLDeviceType_kDLCPUPinned, + 4 => ts::DLDeviceType_kDLOpenCL, + 7 => ts::DLDeviceType_kDLVulkan, + 8 => ts::DLDeviceType_kDLMetal, + 9 => ts::DLDeviceType_kDLVPI, + 10 => ts::DLDeviceType_kDLROCM, + 12 => ts::DLDeviceType_kDLExtDev, + _ => panic!("device type not found!"), + } + } +} + +impl From for TVMDeviceType { + fn from(device_type: ts::DLDeviceType) -> Self { + match device_type { + ts::DLDeviceType_kDLCPU => TVMDeviceType(1), + ts::DLDeviceType_kDLGPU => TVMDeviceType(2), + ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3), + ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4), + ts::DLDeviceType_kDLVulkan => TVMDeviceType(7), + ts::DLDeviceType_kDLMetal => TVMDeviceType(8), + ts::DLDeviceType_kDLVPI => TVMDeviceType(9), + ts::DLDeviceType_kDLROCM => TVMDeviceType(10), + ts::DLDeviceType_kDLExtDev => TVMDeviceType(12), + _ => panic!("device type not found!"), + } + } +} + +impl Display for TVMDeviceType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "{}", + match self { + TVMDeviceType(1) => "cpu", + TVMDeviceType(2) => "gpu", + TVMDeviceType(3) => "cpu_pinned", + TVMDeviceType(4) => "opencl", + TVMDeviceType(8) => "meta", + TVMDeviceType(9) => "vpi", + TVMDeviceType(10) => "rocm", + TVMDeviceType(_) => "rpc", + } + ) + } +} + +impl<'a> From<&'a str> for TVMDeviceType { + fn from(type_str: &'a str) -> Self { + match type_str { + "cpu" => TVMDeviceType(1), + "llvm" => TVMDeviceType(1), + "stackvm" => TVMDeviceType(1), + "gpu" => TVMDeviceType(2), + "cuda" => TVMDeviceType(2), + "nvptx" => TVMDeviceType(2), + "cl" => TVMDeviceType(4), + "opencl" => TVMDeviceType(4), + "metal" => TVMDeviceType(8), + "vpi" => TVMDeviceType(9), + "rocm" => TVMDeviceType(10), + _ => panic!("{:?} not supported!", type_str), + } + } +} + +/// Represents the underlying device context. Default is cpu. +/// +/// ## Examples +/// +/// ``` +/// let ctx = TVMContext::from("gpu"); +/// assert!(ctx.exist()); +/// +/// ``` +/// +/// It is possible to query the underlying context as follows +/// +/// ``` +/// println!("maximun threads per block: {}", ctx.max_threads_per_block()); +/// println!("compute version: {}", ctx.compute_version()); +/// ``` +#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)] +pub struct TVMContext { + /// Supported device types + pub device_type: TVMDeviceType, + /// Device id + pub device_id: usize, +} + +impl TVMContext { + /// Creates context from device type and id. + pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self { + TVMContext { + device_type: device_type, + device_id: device_id, + } + } +} + +macro_rules! impl_ctxs { + ($(($ctx:ident, $dldevt:expr));+) => { + $( + impl TVMContext { + pub fn $ctx(device_id: usize) -> Self { + Self::new(TVMDeviceType($dldevt), device_id) + } + } + )+ + }; +} + +impl_ctxs!((cpu, 1); + (gpu, 2); + (nvptx, 2); + (cuda, 2); + (cpu_pinned, 3); + (cl, 4); + (opencl, 4); + (metal, 8); + (vpi, 9); + (rocm, 10); + (opengl, 11); + (ext_dev, 12)); + +impl<'a> From<&'a str> for TVMContext { + fn from(target: &str) -> Self { + TVMContext::new(TVMDeviceType::from(target), 0) + } +} + +impl TVMContext { + /// Checks whether the context exists or not. + pub fn exist(&self) -> bool { + let func = function::Function::get("_GetDeviceAttr", true /* is_global */) + .expect("API function always exists"); + let dt = self.device_type.0 as usize; + // `unwrap` is ok here because if there is any error, + // if would occure inside `call_packed!` + let ret = call_packed!(func, &dt, &self.device_id, &0) + .unwrap() + .prim_value; + ret != 0 + } + + /// Synchronize the context stream. + pub fn sync(&self) -> Result<()> { + check_call!(ts::TVMSynchronize( + self.device_type.0 as i32, + self.device_id as i32, + ptr::null_mut() as *mut c_void + )); + Ok(()) + } +} + +macro_rules! impl_device_attrs { + ($(($attr_name:ident, $attr_kind:expr));+) => { + $( + impl TVMContext { + pub fn $attr_name(&self) -> usize { + let func = function::Function::get("_GetDeviceAttr", true /* is_global */) + .expect("API function always exists"); + let dt = self.device_type.0 as usize; + // `unwrap` is ok here because if there is any error, + // if would occur in function call. + let ret = function::Builder::from(func) + .args(&[dt, self.device_id, $attr_kind]) + .invoke() + .unwrap(); + ret.prim_value as usize + } + } + )+ + }; +} + +impl_device_attrs!((max_threads_per_block, 1); + (warp_size, 2); + (max_shared_memory_per_block, 3); + (compute_version, 4); + (device_name, 5); + (max_clock_rate, 6); + (multi_processor_count, 7); + (max_thread_dimensions, 8)); + +impl From for TVMContext { + fn from(ctx: ts::DLContext) -> Self { + TVMContext { + device_type: TVMDeviceType::from(ctx.device_type), + device_id: ctx.device_id as usize, + } + } +} + +impl From for ts::DLContext { + fn from(ctx: TVMContext) -> Self { + ts::DLContext { + device_type: ctx.device_type.into(), + device_id: ctx.device_id as i32, + } + } +} + +impl Display for TVMContext { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.device_type, self.device_id) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn context() { + let ctx = TVMContext::cpu(0); + println!("ctx: {}", ctx); + let default_ctx = TVMContext::new(TVMDeviceType(1), 0); + assert_eq!(ctx.clone(), default_ctx); + assert_ne!(ctx, TVMContext::gpu(0)); + + let str_ctx = TVMContext::new(TVMDeviceType::from("gpu"), 0); + assert_eq!(str_ctx.clone(), str_ctx); + assert_ne!(str_ctx, TVMContext::new(TVMDeviceType::from("cpu"), 0)); + } + + #[test] + fn sync() { + let ctx = TVMContext::cpu(0); + assert!(ctx.sync().is_ok()) + } +} diff --git a/rust/frontend/src/errors.rs b/rust/frontend/src/errors.rs new file mode 100644 index 000000000..a10f83c41 --- /dev/null +++ b/rust/frontend/src/errors.rs @@ -0,0 +1,51 @@ +//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types. + +use std::{ffi, option}; + +use crate::{common_errors, rust_ndarray}; + +error_chain! { + errors { + EmptyArray { + description("cannot convert from an empty array") + } + + NullHandle(name: String) { + description("null handle") + display("requested `{}` handle is null", name) + } + + FunctionNotFound { + description("function not found") + display("function was not set in `function::Builder`") + } + + TypeMismatch(expected: String, found: String) { + description("type mismatch!") + display("expected type `{}`, but found `{}`", expected, found) + } + + MissingShapeError { + description("ndarray `shape()` returns `None`") + display("called `Option::unwrap()` on a `None` value") + } + + AtMostOneReturn { + description("TVM functions accept at most one return value") + } + + } + + foreign_links { + ShapeError(rust_ndarray::ShapeError); + NulError(ffi::NulError); + IntoStringError(ffi::IntoStringError); + CommonError(common_errors::Error); + } +} + +impl From for Error { + fn from(_err: option::NoneError) -> Self { + ErrorKind::MissingShapeError.into() + } +} diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs new file mode 100644 index 000000000..fa6bed141 --- /dev/null +++ b/rust/frontend/src/function.rs @@ -0,0 +1,512 @@ +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::{ + collections::BTreeMap, + ffi::{CStr, CString}, + mem, + os::raw::{c_char, c_int, c_void}, + ptr, slice, str, + sync::Mutex, +}; + +use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue}; + +lazy_static! { + static ref GLOBAL_FUNCTIONS: Mutex>> = { + let mut out_size = 0 as c_int; + let name = ptr::null_mut() as *mut c_char; + let mut out_array = name as *mut _; + check_call!(ts::TVMFuncListGlobalNames( + &mut out_size as *mut _, + &mut out_array + )); + let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) }; + Mutex::new( + names_list + .into_iter() + .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None)) + .collect(), + ) + }; +} + +/// Wrapper around TVM function handle which includes `is_global` +/// indicating whether the function is global or not, `is_released` +/// to hint dropping the function handle and `is_cloned` showing +/// not to drop a cloned function from Rust side. +/// The value of these fields can be accessed through their respective methods. +#[derive(Debug, Hash)] +pub struct Function { + pub(crate) handle: ts::TVMFunctionHandle, + // whether the registered function is global or not. + is_global: bool, + // whether the function has been dropped from frontend or not. + is_released: bool, + // whether the function has been cloned from frontend or not. + is_cloned: bool, +} + +unsafe impl Send for Function {} +unsafe impl Sync for Function {} + +impl Function { + pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self { + Function { + handle: handle, + is_global: is_global, + is_released: is_released, + is_cloned: false, + } + } + + /// For a given function, it returns a function by name. + pub fn get>(name: S, is_global: bool) -> Option<&'static Function> { + let mut globals = GLOBAL_FUNCTIONS.lock().unwrap(); + globals.get_mut(name.as_ref()).and_then(|maybe_func| { + if maybe_func.is_none() { + let name = CString::new(name.as_ref()).unwrap(); + let mut handle = ptr::null_mut() as ts::TVMFunctionHandle; + check_call!(ts::TVMFuncGetGlobal( + name.as_ptr() as *const c_char, + &mut handle as *mut _ + )); + maybe_func.replace(Function::new( + handle, is_global, false, /* is_released */ + )); + } + unsafe { + std::mem::transmute::, Option<&'static Function>>( + maybe_func.as_ref(), + ) + } + }) + } + + /// Returns the underlying TVM function handle. + pub fn handle(&self) -> ts::TVMFunctionHandle { + self.handle + } + + /// Returns `true` if the underlying TVM function is global and `false` otherwise. + pub fn is_global(&self) -> bool { + self.is_global + } + + /// Returns `true` if the underlying TVM function has been released + /// from the frontend and `false` otherwise. + pub fn is_released(&self) -> bool { + self.is_released + } + + /// Returns `true` if the underlying TVM function has been cloned + /// from the frontend and `false` otherwise. + pub fn is_cloned(&self) -> bool { + self.is_cloned + } +} + +impl Clone for Function { + fn clone(&self) -> Function { + if !self.is_released && !self.is_cloned { + Self { + handle: self.handle, + is_global: self.is_global, + is_released: self.is_released, + is_cloned: true, + } + } else { + Function::new(self.handle, self.is_global, self.is_released) + } + } +} + +impl Drop for Function { + fn drop(&mut self) { + if !self.is_released && !self.is_global && !self.is_cloned { + check_call!(ts::TVMFuncFree(self.handle)); + self.is_released = true; + } + } +} + +/// Function builder in order to create and call functions. +/// +/// *Note:* Currently TVM functions accept *at most* one return value. +#[derive(Debug, Clone, Default)] +pub struct Builder<'a, 'm> { + pub func: Option<&'m Function>, + pub arg_buf: Option]>>, + pub ret_buf: Option, +} + +impl<'a, 'm> Builder<'a, 'm> { + pub fn new( + func: Option<&'m Function>, + arg_buf: Option]>>, + ret_buf: Option, + ) -> Self { + Self { + func, + arg_buf, + ret_buf, + } + } + + pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self { + self.func = Function::get(name, is_global); + self + } + + /// Pushes a [`TVMArgValue`] into the function argument buffer. + pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self + where + TVMValue: From<&'b T>, + TVMTypeCode: From<&'b T>, + { + let tvm_arg = TVMArgValue::from(arg); + if self.arg_buf.is_none() { + self.arg_buf = Some(Box::new([tvm_arg])); + } else { + let new_arg_buf = self.arg_buf.take().map(|bbuf| { + let mut new_arg_buf = Vec::from(bbuf); + new_arg_buf.push(tvm_arg); + let new_len = new_arg_buf.len(); + new_arg_buf.truncate(new_len); + new_arg_buf.into_boxed_slice() + }); + self.arg_buf = new_arg_buf; + } + self + } + + /// Pushes multiple [`TVMArgValue`]s into the function argument buffer. + pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self + where + I: IntoIterator, + TVMValue: From<&'b T>, + TVMTypeCode: From<&'b T>, + { + for arg in args { + self.arg(&arg); + } + self + } + + /// Sets an output for a function that requirs a mutable output to be provided. + /// See the `basics` in tests for an example. + pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self> + where + TVMValue: From<&'b T>, + TVMTypeCode: From<&'b T>, + { + if self.ret_buf.is_none() { + let tvm_ret = + unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) }; + self.ret_buf = Some(tvm_ret); + } else { + bail!(ErrorKind::AtMostOneReturn) + } + Ok(self) + } + + /// Calls the function that created from `Builder`. + pub fn invoke(&mut self) -> Result { + self.clone()(()) + } +} + +impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> { + type Output = Result; + extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output { + if self.func.is_none() { + bail!("{}", ErrorKind::FunctionNotFound); + } + + let mut ret_val = unsafe { mem::uninitialized::() }; + let mut ret_type_code = 0 as c_int; + if self.arg_buf.is_some() { + let arg_buf = self.arg_buf?; + let mut num_args = arg_buf.len(); + let mut values = arg_buf + .iter() + .map(|tav| tav.value.inner) + .collect::>(); + let mut tcodes = arg_buf + .iter() + .map(|tav| tav.type_code as c_int) + .collect::>(); + + if self.ret_buf.is_some() { + num_args = num_args + 1; + let ret_buf = self.ret_buf?; + let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf); + values.append(&mut vec![ret_val.inner]); + tcodes.append(&mut vec![ret_type_code as c_int]); + } + + values.truncate(num_args); + tcodes.truncate(num_args); + check_call!(ts::TVMFuncCall( + self.func?.handle, + values.as_mut_ptr(), + tcodes.as_mut_ptr(), + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + )); + } else { + check_call!(ts::TVMFuncCall( + self.func?.handle, + ptr::null_mut(), + ptr::null_mut(), + 0 as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + )); + } + + let ret = unsafe { + TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into()) + }; + Ok(ret) + } +} + +/// Converts a [`Function`] to builder. Currently, this is the best way to work with +/// TVM functions. +impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> { + fn from(func: &'m Function) -> Self { + Builder::new(Some(func), None, None) + } +} + +/// Converts a mutable reference of a [`Module`] to [`Builder`]. +impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> { + fn from(module: &'m mut Module) -> Self { + Builder::new(module.entry(), None, None) + } +} + +unsafe extern "C" fn tvm_callback( + args: *mut ts::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: ts::TVMRetValueHandle, + fhandle: *mut c_void, +) -> c_int { + // turning off the incorrect linter complaints + #![allow(unused_assignments)] + let len = num_args as usize; + let args_list = slice::from_raw_parts_mut(args, len); + let type_codes_list = slice::from_raw_parts_mut(type_codes, len); + let mut local_args: Vec = Vec::new(); + let mut value = mem::uninitialized::(); + let mut tcode = mem::uninitialized::(); + let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); + for i in 0..len { + value = args_list[i]; + tcode = type_codes_list[i]; + if tcode == ts::TVMTypeCode_kNodeHandle as c_int + || tcode == ts::TVMTypeCode_kFuncHandle as c_int + || tcode == ts::TVMTypeCode_kModuleHandle as c_int + { + check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode)); + } + local_args.push(TVMArgValue::new( + TVMValue::new(value), + (tcode as i64).into(), + )); + } + + let rv = match rust_fn(local_args.as_slice()) { + Ok(v) => v, + Err(msg) => { + crate::set_last_error(&msg); + return -1; + } + }; + + let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv); + let mut ret_val = ret_val.inner; + let mut ret_type_code = ret_tcode as c_int; + check_call!(ts::TVMCFuncSetReturn( + ret, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _, + 1 as c_int + )); + 0 +} + +unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) { + let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); + mem::drop(rust_fn); +} + +fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { + let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; + let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; + check_call!(ts::TVMFuncCreateFromCFunc( + Some(tvm_callback), + resource_handle as *mut c_void, + Some(tvm_callback_finalizer), + &mut fhandle as *mut _ + )); + Function::new(fhandle, false, false) +} + +/// Registers a Rust function with signature +/// `fn(&[TVMArgValue]) -> Result` +/// as a **global TVM packed function** from frontend to TVM backend. +/// +/// Use [`register_global_func`] if overriding an existing global TVM function +/// is not required. +/// +/// ## Example +/// +/// ``` +/// use std::convert::TryInto; +/// +/// fn sum(args: &[TVMArgValue]) -> Result { +/// let mut ret = 0i64; +/// for arg in args.iter() { +/// let arg: i64 = arg.try_into()?; +/// ret += arg; +/// } +/// let ret_val = TVMRetValue::from(&ret); +/// Ok(ret_val) +/// } +/// +/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); +/// let mut registered = function::Builder::default(); +/// registered.get_function("mysum", true); +/// assert!(registered.func.is_some()); +/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap(); +/// assert_eq!(ret, 60); +/// ``` +pub fn register>( + f: fn(&[TVMArgValue]) -> Result, + name: S, + override_: bool, +) -> Result<()> { + let func = convert_to_tvm_func(f); + let name = CString::new(name.as_ref())?; + check_call!(ts::TVMFuncRegisterGlobal( + name.as_ref().as_ptr() as *const c_char, + func.handle(), + override_ as c_int + )); + mem::forget(name); + Ok(()) +} + +/// Convenient macro for registering functions from frontend to backend as global +/// TVM packed functions without overriding. If overriding an existing function is needed +/// use the [`function::register`] function instead. +/// +/// ## Example +/// +/// ``` +/// use std::convert::TryInto; +/// +/// register_global_func! { +/// fn sum(args: &[TVMArgValue]) -> Result { +/// let mut ret = 0f64; +/// for arg in args.iter() { +/// let arg: f64 = arg.try_into()?; +/// ret += arg; +/// } +/// let ret_val = TVMRetValue::from(&ret); +/// Ok(ret_val) +/// } +/// } +/// +/// let mut registered = function::Builder::default(); +/// registered.get_function("sum", true); +/// assert!(registered.func.is_some()); +/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap(); +/// assert_eq!(ret, 60f64); +/// ``` +#[macro_export] +macro_rules! register_global_func { + { + $(#[$m:meta])* + fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result { + $($code:tt)* + } + } => {{ + $(#[$m])* + fn $fn_name($args: &[TVMArgValue]) -> Result { + $($code)* + } + + $crate::function::register($fn_name, stringify!($fn_name).to_owned(), false).unwrap(); + }} +} + +/// Convenient macro for calling TVM packed functions by providing a +/// function identifier and some arguments. This macro outputs a `Result` type +/// and let user to perform proper error handling. +/// +/// **Note**: this macro does *not* expect an outside mutable output. To +/// set mutable output use [`set_output`] directly in the builder pattern. +/// +/// [`set_output`]:function/struct.Builder.html#method.set_output +/// +/// ## Example +/// +/// Instead of +/// +/// ``` +/// function::Builder::from(func).arg(&a).arg(&b).invoke(); +/// ``` +/// +/// one can use +/// +/// ``` +/// call_packed!(func, &a, &b); +/// ``` +#[macro_export] +macro_rules! call_packed { + ($fn_name:expr, $($arg:expr),*) => {{ + let mut builder = $crate::function::Builder::from($fn_name); + $( + builder.arg($arg); + )* + builder.invoke() + }} +} + +#[cfg(test)] +mod tests { + use super::*; + + static CANARY: &str = "module._LoadFromFile"; + + #[test] + fn list_global_func() { + assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY)); + } + + #[test] + fn get_fn() { + assert!(Function::get(CANARY, true).is_some()); + assert!(Function::get("does not exists!", false).is_none()); + } + + #[test] + fn provide_args() { + let mut func = Builder::default(); + func.get_function("tvm.graph_runtime.remote_create", true) + .args(&[10, 20]) + .arg(&"test".to_owned()); + assert!(func.arg_buf.is_some()); + assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3)); + } +} diff --git a/rust/frontend/src/lib.rs b/rust/frontend/src/lib.rs new file mode 100644 index 000000000..6e15e4f8d --- /dev/null +++ b/rust/frontend/src/lib.rs @@ -0,0 +1,115 @@ +//! [TVM](https://github.com/dmlc/tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime frontend. +//! +//! One particular use case is that given optimized deep learning model artifacts, +//! (compiled with TVM) which include a shared library +//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them +//! in Rust idomatically to create a TVM Graph Runtime and +//! run the model for some inputs and get the +//! desired predictions *all in Rust*. +//! +//! Checkout the `examples` repository for more details. + +#![crate_name = "tvm_frontend"] +#![recursion_limit = "1024"] +#![allow(non_camel_case_types, unused_unsafe)] +#![feature( + try_from, + try_trait, + fn_traits, + unboxed_closures, + box_syntax, + option_replace +)] + +#[macro_use] +extern crate error_chain; +extern crate tvm_common as common; +#[macro_use] +extern crate lazy_static; +extern crate ndarray as rust_ndarray; +extern crate num_traits; + +use std::{ + ffi::{CStr, CString}, + str, +}; + +use crate::common::ffi::ts; + +// Macro to check the return call to TVM runtime shared library. +macro_rules! check_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + panic!("{}", $crate::get_last_error()); + } + }}; +} + +/// Gets the last error message. +pub fn get_last_error() -> &'static str { + unsafe { + match CStr::from_ptr(ts::TVMGetLastError()).to_str() { + Ok(s) => s, + Err(_) => "Invalid UTF-8 message", + } + } +} + +pub(crate) fn set_last_error(err: &Error) { + let c_string = CString::new(err.to_string()).unwrap(); + unsafe { + ts::TVMAPISetLastError(c_string.as_ptr()); + } +} + +#[macro_use] +pub mod function; +pub mod bytearray; +pub mod context; +pub mod errors; +pub mod module; +pub mod ndarray; +pub mod ty; +pub mod value; + +pub use crate::{ + bytearray::TVMByteArray, + common::{ + errors as common_errors, + ty::TVMTypeCode, + value::{TVMArgValue, TVMRetValue, TVMValue}, + }, + context::{TVMContext, TVMDeviceType}, + errors::*, + function::Function, + module::Module, + ndarray::NDArray, + ty::TVMType, +}; + +/// Outputs the current TVM version. +pub fn version() -> &'static str { + match str::from_utf8(ts::TVM_VERSION) { + Ok(s) => s, + Err(_) => "Invalid UTF-8 string", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn print_version() { + println!("TVM version: {}", version()); + } + + #[test] + fn set_error() { + let err = ErrorKind::EmptyArray; + set_last_error(&err.into()); + assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string()); + } +} diff --git a/rust/frontend/src/module.rs b/rust/frontend/src/module.rs new file mode 100644 index 000000000..c12d9d48c --- /dev/null +++ b/rust/frontend/src/module.rs @@ -0,0 +1,105 @@ +//! Provides the [`Module`] type and methods for working with runtime TVM modules. + +use std::{ + convert::TryInto, + ffi::CString, + os::raw::{c_char, c_int}, + path::Path, + ptr, +}; + +use crate::ts; + +use crate::{function::Function, ErrorKind, Result}; + +const ENTRY_FUNC: &'static str = "__tvm_main__"; + +/// Wrapper around TVM module handle which contains an entry function. +/// The entry function can be applied to an imported module through [`entry_func`]. +/// Also [`is_released`] shows whether the module is dropped or not. +/// +/// [`entry_func`]:struct.Module.html#method.entry_func +/// [`is_released`]:struct.Module.html#method.is_released +#[derive(Debug, Clone)] +pub struct Module { + pub(crate) handle: ts::TVMModuleHandle, + is_released: bool, + entry_func: Option, +} + +impl Module { + pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self { + Self { + handle, + is_released, + entry_func: None, + } + } + + pub fn entry(&mut self) -> Option<&Function> { + if self.entry_func.is_none() { + self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); + } + self.entry_func.as_ref() + } + + /// Gets a function by name from a registered module. + pub fn get_function(&self, name: &str, query_import: bool) -> Result { + let name = CString::new(name)?; + let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; + check_call!(ts::TVMModGetFunction( + self.handle, + name.as_ptr() as *const c_char, + query_import as c_int, + &mut fhandle as *mut _ + )); + if fhandle.is_null() { + bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?))) + } else { + Ok(Function::new(fhandle, false, false)) + } + } + + /// Imports a dependent module such as `.ptx` for gpu. + pub fn import_module(&self, dependent_module: Module) { + check_call!(ts::TVMModImport(self.handle, dependent_module.handle)) + } + + /// Loads a module shared library from path. + pub fn load>(path: &P) -> Result { + let ext = path.as_ref().extension()?.to_str()?; + let func = Function::get("module._LoadFromFile", true /* is_global */) + .expect("API function always exists"); + let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?; + Ok(ret) + } + + /// Checks if a target device is enabled for a module. + pub fn enabled(&self, target: &str) -> bool { + let func = Function::get("module._Enabled", true /* is_global */) + .expect("API function always exists"); + // `unwrap` is safe here because if there is any error during the + // function call, it would occur in `call_packed!`. + let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap(); + ret != 0 + } + + /// Returns the underlying module handle. + pub fn handle(&self) -> ts::TVMModuleHandle { + self.handle + } + + /// Returns true if the underlying module has been dropped and false otherwise. + pub fn is_released(&self) -> bool { + self.is_released + } +} + +impl Drop for Module { + fn drop(&mut self) { + if !self.is_released { + check_call!(ts::TVMModFree(self.handle)); + self.is_released = true; + } + } +} diff --git a/rust/frontend/src/ndarray.rs b/rust/frontend/src/ndarray.rs new file mode 100644 index 000000000..44dfcca3b --- /dev/null +++ b/rust/frontend/src/ndarray.rs @@ -0,0 +1,363 @@ +//! This module implements the [`NDArray`] type for working with *TVM tensors* or +//! coverting from a Rust's ndarray to TVM `NDArray`. +//! +//! One can create an empty NDArray given the shape, context and dtype using [`empty`]. +//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. +//! To copy an NDArray to different context use [`copy_to_ctx`]. +//! +//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: +//! +//! # Example +//! +//! ``` +//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +//! .unwrap() +//! .into_dyn(); // Rust's ndarray +//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); +//! assert_eq!(nd.shape(), Some(&mut [2, 2])); +//! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); +//! assert!(rnd.all_close(&a, 1e-8f32)); +//! ``` +//! +//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ +//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer +//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx + +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice}; + +use crate::rust_ndarray::{Array, ArrayD}; +use num_traits::Num; + +use crate::ts; + +use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType}; + +/// See the [`module-level documentation`](../ndarray/index.html) for more details. +/// +/// Wrapper around TVM array handle. +#[derive(Debug)] +pub struct NDArray { + pub(crate) handle: ts::TVMArrayHandle, + is_view: bool, +} + +impl NDArray { + pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self { + NDArray { + handle: handle, + is_view: is_view, + } + } + + /// Returns the underlying array handle. + pub fn handle(&self) -> ts::TVMArrayHandle { + self.handle + } + + pub fn is_view(&self) -> bool { + self.is_view + } + + /// Returns the shape of the NDArray. + pub fn shape(&self) -> Option<&mut [usize]> { + let arr = unsafe { *(self.handle) }; + if arr.shape.is_null() || arr.data.is_null() { + return None; + }; + let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; + Some(slc) + } + + /// Returns the total number of entries of the NDArray. + pub fn size(&self) -> Option { + self.shape() + .map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e)) + } + + /// Returns the context which the NDArray was defined. + pub fn ctx(&self) -> TVMContext { + unsafe { (*self.handle).ctx.into() } + } + + /// Returns the type of the entries of the NDArray. + pub fn dtype(&self) -> TVMType { + unsafe { (*self.handle).dtype.into() } + } + + /// Returns the number of dimensions of the NDArray. + pub fn ndim(&self) -> usize { + unsafe { (*self.handle).ndim as usize } + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[usize]> { + unsafe { + let sz = self.ndim() * mem::size_of::(); + let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz); + Some(slc) + } + } + + /// Shows whether the underlying ndarray is contiguous in memory or not. + pub fn is_contiguous(&self) -> Result { + Ok(match self.strides() { + None => true, + Some(strides) => { + // MissingShapeError in case shape is not determined + self.shape()? + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + }) + } + + pub fn byte_offset(&self) -> isize { + unsafe { (*self.handle).byte_offset as isize } + } + + /// Flattens the NDArray to a `Vec` of the same type in cpu. + /// + /// ## Example + /// + /// ``` + /// let shape = &mut [4]; + /// let mut data = vec![1i32, 2, 3, 4]; + /// let ctx = TVMContext::cpu(0); + /// let mut ndarray = empty(shape, ctx, TVMType::from("int32")); + /// ndarray.copy_from_buffer(&mut data); + /// assert_eq!(ndarray.shape(), Some(shape)); + /// assert_eq!(ndarray.to_vec::().unwrap(), data); + /// ``` + pub fn to_vec(&self) -> Result> { + if self.shape().is_none() { + bail!("{}", ErrorKind::EmptyArray); + } + let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype()); + let target = self.copy_to_ndarray(earr)?; + let arr = unsafe { *(target.handle) }; + let sz = self.size()? as usize; + let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); + unsafe { + v.as_mut_ptr() + .copy_from_nonoverlapping(arr.data as *const T, sz); + v.set_len(sz); + } + Ok(v) + } + + /// Converts the NDArray to [`TVMByteArray`]. + pub fn to_bytearray(&self) -> Result { + let v = self.to_vec::()?; + Ok(TVMByteArray::from(&v)) + } + + /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. + /// + /// ## Example + /// + /// ``` + /// let shape = &mut [2]; + /// let mut data = vec![1f32, 2]; + /// let ctx = TVMContext::gpu(0); + /// let mut ndarray = empty(shape, ctx, TVMType::from("int32")); + /// ndarray.copy_from_buffer(&mut data); + /// ``` + /// + /// *Note*: if something goes wrong during the copy, it will panic + /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. + pub fn copy_from_buffer(&mut self, data: &mut [T]) { + check_call!(ts::TVMArrayCopyFromBytes( + self.handle, + data.as_ptr() as *mut _, + data.len() * mem::size_of::() + )); + } + + /// Copies the NDArray to another target NDArray. + pub fn copy_to_ndarray(&self, target: NDArray) -> Result { + if self.dtype() != target.dtype() { + bail!( + "{}", + ErrorKind::TypeMismatch( + format!("{}", self.dtype().to_string()), + format!("{}", target.dtype().to_string()), + ) + ); + } + check_call!(ts::TVMArrayCopyFromTo( + self.handle, + target.handle, + ptr::null_mut() as ts::TVMStreamHandle + )); + Ok(target) + } + + /// Copies the NDArray to a target context. + pub fn copy_to_ctx(&self, target: &TVMContext) -> Result { + let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype()); + let copy = self.copy_to_ndarray(tmp)?; + Ok(copy) + } + + /// Converts a Rust's ndarray to TVM NDArray. + pub fn from_rust_ndarray( + rnd: &ArrayD, + ctx: TVMContext, + dtype: TVMType, + ) -> Result { + let mut shape = rnd.shape().to_vec(); + let mut nd = NDArray::empty(&mut shape, ctx, dtype); + let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); + nd.copy_from_buffer(buf.as_slice_mut()?); + Ok(nd) + } + + /// Allocates and creates an empty NDArray given the shape, context and dtype. + pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray { + let mut handle = ptr::null_mut() as ts::TVMArrayHandle; + check_call!(ts::TVMArrayAlloc( + shape.as_ptr() as *const i64, + shape.len() as c_int, + dtype.inner.code as c_int, + dtype.inner.bits as c_int, + dtype.inner.lanes as c_int, + ctx.device_type.0 as c_int, + ctx.device_id as c_int, + &mut handle as *mut _, + )); + NDArray::new(handle, false) + } +} + +macro_rules! impl_from_ndarray_rustndarray { + ($type:ty, $type_name:tt) => { + impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { + type Error = Error; + fn try_from(nd: &NDArray) -> Result> { + if nd.shape().is_none() { + bail!("{}", ErrorKind::EmptyArray); + } + assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); + Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) + } + } + + impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { + type Error = Error; + fn try_from(nd: &mut NDArray) -> Result> { + if nd.shape().is_none() { + bail!("{}", ErrorKind::EmptyArray); + } + assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); + Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) + } + } + }; +} + +impl_from_ndarray_rustndarray!(i32, "int"); +impl_from_ndarray_rustndarray!(u32, "uint"); +impl_from_ndarray_rustndarray!(f32, "float"); + +impl Drop for NDArray { + fn drop(&mut self) { + if !self.is_view { + check_call!(ts::TVMArrayFree(self.handle)); + } + } +} + +mod sealed { + /// Private trait to prevent other traits from being implemeneted in downstream crates. + pub trait Sealed {} +} + +/// A trait for the supported 32-bits numerical types in frontend. +pub trait Num32: Num + sealed::Sealed { + const BITS: u8 = 32; +} + +macro_rules! impl_num32 { + ($($type:ty),+) => { + $( + impl sealed::Sealed for $type {} + impl Num32 for $type {} + )+ + }; +} + +impl_num32!(i32, u32, f32); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basics() { + let shape = &mut [1, 2, 3]; + let ctx = TVMContext::cpu(0); + let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32")); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!( + ndarray.size().unwrap(), + shape.to_vec().into_iter().product() + ); + assert_eq!(ndarray.ndim(), 3); + assert!(ndarray.strides().is_none()); + assert_eq!(ndarray.byte_offset(), 0); + } + + #[test] + fn copy() { + let shape = &mut [4]; + let mut data = vec![1i32, 2, 3, 4]; + let ctx = TVMContext::cpu(0); + let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32")); + assert!(ndarray.to_vec::().is_ok()); + ndarray.copy_from_buffer(&mut data); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!(ndarray.to_vec::().unwrap(), data); + assert_eq!(ndarray.ndim(), 1); + assert!(ndarray.is_contiguous().is_ok()); + assert_eq!(ndarray.byte_offset(), 0); + let mut shape = vec![4]; + let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32")); + let nd = ndarray.copy_to_ndarray(e); + assert!(nd.is_ok()); + assert_eq!(nd.unwrap().to_vec::().unwrap(), data); + } + + #[test] + #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + fn copy_wrong_dtype() { + let mut shape = vec![4]; + let mut data = vec![1f32, 2., 3., 4.]; + let ctx = TVMContext::cpu(0); + let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32")); + nd_float.copy_from_buffer(&mut data); + let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32")); + nd_float.copy_to_ndarray(empty_int).unwrap(); + } + + #[test] + fn rust_ndarray() { + let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) + .unwrap() + .into_dyn(); + let nd = + NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); + assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); + assert!(rnd.all_close(&a, 1e-8f32)); + } +} diff --git a/rust/frontend/src/ty.rs b/rust/frontend/src/ty.rs new file mode 100644 index 000000000..7e912a517 --- /dev/null +++ b/rust/frontend/src/ty.rs @@ -0,0 +1,150 @@ +//! This module implements the required conversions from Rust types to TVM types. +//! +//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32) +//! and 64-bits pointers are supported. + +use std::{ + fmt::{self, Display, Formatter}, + ops::{Deref, DerefMut}, +}; + +use crate::ts; + +use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode}; + +macro_rules! impl_prim_type { + ($type:ty, $variant:ident) => { + impl From<$type> for TVMTypeCode { + fn from(_arg: $type) -> Self { + TVMTypeCode::$variant + } + } + + impl<'a> From<&'a $type> for TVMTypeCode { + fn from(_arg: &$type) -> Self { + TVMTypeCode::$variant + } + } + + impl<'a> From<&'a mut $type> for TVMTypeCode { + fn from(_arg: &mut $type) -> Self { + TVMTypeCode::$variant + } + } + }; +} + +impl_prim_type!(TVMDeviceType, kDLInt); +impl_prim_type!(TVMContext, kTVMContext); +impl_prim_type!(TVMType, kTVMType); +impl_prim_type!(Function, kFuncHandle); +impl_prim_type!(Module, kModuleHandle); +impl_prim_type!(NDArray, kArrayHandle); +impl_prim_type!(TVMByteArray, kBytes); + +/// See the [module-level documentation](../ty/index.html) for more details. +/// +/// Wrapper around underlying TVMType +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct TVMType { + // inner fields are (code: u8, bits: u8, lanes: u16) + pub inner: ts::TVMType, +} + +impl TVMType { + pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self { + TVMType { + inner: ts::TVMType { + code: type_code, + bits: bits, + lanes: lanes, + }, + } + } +} + +/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` +/// such as "int32", "float32" or with lane "float32x1". +impl<'a> From<&'a str> for TVMType { + fn from(type_str: &'a str) -> Self { + if type_str == "bool" { + return TVMType::new(1, 1, 1); + } + + let mut type_lanes = type_str.split("x"); + let typ = type_lanes.next().expect("Missing dtype"); + let lanes = type_lanes + .next() + .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l))) + .unwrap_or(1); + let (type_name, bits) = match typ.find(char::is_numeric) { + Some(idx) => { + let (name, bits_str) = typ.split_at(idx); + ( + name, + u8::from_str_radix(bits_str, 10) + .expect(&format!("Bad dtype bits: {}", bits_str)), + ) + } + None => (typ, 32), + }; + + let type_code = match type_name { + "int" => 0, + "uint" => 1, + "float" => 2, + "handle" => 3, + _ => unimplemented!(), + }; + + TVMType::new(type_code, bits, lanes) + } +} + +impl Display for TVMType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let ts::TVMType { code, bits, lanes } = self.inner; + if bits == 1 && lanes == 1 { + return write!(f, "bool"); + } + let mut tcode_str = match code { + 0 => "int", + 1 => "uint", + 2 => "float", + 4 => "handle", + _ => "Unknown", + } + .to_string(); + + tcode_str += &bits.to_string(); + if lanes > 1 { + tcode_str += &format!("x{}", lanes.to_string()); + } + f.write_str(&tcode_str) + } +} + +impl From for ts::DLDataType { + fn from(dtype: TVMType) -> Self { + dtype.inner + } +} + +impl From for TVMType { + fn from(dtype: ts::DLDataType) -> Self { + Self::new(dtype.code, dtype.bits, dtype.lanes) + } +} + +impl Deref for TVMType { + type Target = ts::TVMType; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for TVMType { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} diff --git a/rust/frontend/src/value.rs b/rust/frontend/src/value.rs new file mode 100644 index 000000000..9fad7de49 --- /dev/null +++ b/rust/frontend/src/value.rs @@ -0,0 +1,241 @@ +//! This module implements [`TVMArgValue`] and [`TVMRetValue`] types +//! and their conversions needed for the types used in frontend crate. +//! `TVMRetValue` is the owned version of `TVMPODValue`. + +use std::{convert::TryFrom, mem, os::raw::c_void}; + +use crate::{ + common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext, + TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue, +}; + +macro_rules! impl_tvm_val_from_handle { + ($($ty:ty),+) => { + $( + impl<'a> From<&'a $ty> for TVMValue { + fn from(arg: &$ty) -> Self { + let inner = ts::TVMValue { + v_handle: arg.handle as *mut _ as *mut c_void, + }; + Self::new(inner) + } + } + )+ + } +} + +impl_tvm_val_from_handle!(Module, Function, NDArray); + +impl<'a> From<&'a TVMType> for TVMValue { + fn from(ty: &TVMType) -> Self { + let inner = ts::TVMValue { v_type: ty.inner }; + Self::new(inner) + } +} + +impl<'a> From<&'a TVMContext> for TVMValue { + fn from(ctx: &TVMContext) -> Self { + let inner = ts::TVMValue { + v_ctx: ctx.clone().into(), + }; + Self::new(inner) + } +} + +impl<'a> From<&'a TVMDeviceType> for TVMValue { + fn from(dev: &TVMDeviceType) -> Self { + let inner = ts::TVMValue { + v_int64: dev.0 as i64, + }; + Self::new(inner) + } +} + +impl<'a> From<&'a TVMByteArray> for TVMValue { + fn from(barr: &TVMByteArray) -> Self { + let inner = ts::TVMValue { + v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void, + }; + Self::new(inner) + } +} + +impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + if arg.type_code == TVMTypeCode::kArrayHandle { + let handle = unsafe { arg.value.inner.v_handle }; + let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) }; + Ok(Self::new(arr_handle, true)) + } else { + bail!(ErrorKind::TryFromTVMArgValueError( + stringify!(NDArray).to_string(), + arg.type_code.to_string() + )) + } + } +} + +impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + if arg.type_code == TVMTypeCode::kModuleHandle { + let handle = unsafe { arg.value.inner.v_handle }; + Ok(Self::new(handle, false)) + } else { + bail!(ErrorKind::TryFromTVMArgValueError( + stringify!(Module).to_string(), + arg.type_code.to_string() + )) + } + } +} + +impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + if arg.type_code == TVMTypeCode::kBytes { + unsafe { + let barr_ptr = + mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle); + Ok(Self::new(*barr_ptr)) + } + } else { + bail!(ErrorKind::TryFromTVMArgValueError( + stringify!(TVMByteArray).to_string(), + arg.type_code.to_string() + )) + } + } +} + +impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + if arg.type_code == TVMTypeCode::kTVMType { + let ty = unsafe { arg.value.inner.v_type }; + Ok(TVMType::from(ty)) + } else { + bail!(ErrorKind::TryFromTVMArgValueError( + stringify!(TVMType).to_string(), + arg.type_code.to_string() + )) + } + } +} + +impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + if arg.type_code == TVMTypeCode::kTVMContext { + let ty = unsafe { arg.value.inner.v_ctx }; + Ok(TVMContext::from(ty)) + } else { + bail!(ErrorKind::TryFromTVMArgValueError( + stringify!(TVMContext).to_string(), + arg.type_code.to_string() + )) + } + } +} + +macro_rules! impl_boxed_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: 0, + box_value: box val, + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if let Ok(val) = ret.box_value.downcast::<$type>() { + Ok(*val) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code.to_string() + )) + } + } + } + }; +} + +impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType); +impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext); +impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes); + +impl TryFrom for Module { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result { + if let Ok(handle) = ret.box_value.downcast::() { + Ok(Module::new(*handle, false)) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!(TVMTypeCode::kModuleHandle).to_string(), + ret.type_code.to_string() + )) + } + } +} + +impl TryFrom for Function { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result { + if let Ok(handle) = ret.box_value.downcast::() { + Ok(Function::new(*handle, false, false)) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!(TVMTypeCode::kFuncHandle).to_string(), + ret.type_code.to_string() + )) + } + } +} + +impl TryFrom for NDArray { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result { + if let Ok(handle) = ret.box_value.downcast::() { + Ok(NDArray::new(*handle, false)) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!(TVMTypeCode::kArrayHandle).to_string(), + ret.type_code.to_string() + )) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::convert::TryInto; + + #[test] + fn bytearray() { + let w = vec![1u8, 2, 3, 4, 5]; + let v = TVMByteArray::from(&w); + let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap(); + assert_eq!(tvm.data(), w.iter().map(|e| *e as i8).collect::>()); + } + + #[test] + fn ty() { + let t = TVMType::from("int32"); + let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap(); + assert_eq!(tvm, t); + } + + #[test] + fn ctx() { + let c = TVMContext::from("gpu"); + let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap(); + assert_eq!(tvm, c); + } +} diff --git a/rust/frontend/tests/basics/.gitignore b/rust/frontend/tests/basics/.gitignore new file mode 100644 index 000000000..10a4b225a --- /dev/null +++ b/rust/frontend/tests/basics/.gitignore @@ -0,0 +1,7 @@ +/target +**/*.rs.bk +Cargo.lock +*.o +*.so +*.ptx +*.json diff --git a/rust/frontend/tests/basics/Cargo.toml b/rust/frontend/tests/basics/Cargo.toml new file mode 100644 index 000000000..496c0dd02 --- /dev/null +++ b/rust/frontend/tests/basics/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "basics" +version = "0.0.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +build = "build.rs" + +[dependencies] +ndarray = "0.12.1" +tvm-frontend = { path = "../../" } + +[features] +default = ["cpu"] +cpu = [] +gpu = [] diff --git a/rust/frontend/tests/basics/build.rs b/rust/frontend/tests/basics/build.rs new file mode 100644 index 000000000..67c21e004 --- /dev/null +++ b/rust/frontend/tests/basics/build.rs @@ -0,0 +1,27 @@ +fn main() { + let out_dir = std::env::var("OUT_DIR").unwrap(); + + let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py")) + .args(&[ + if cfg!(feature = "cpu") { + "llvm" + } else { + "cuda" + }, + &std::env::var("OUT_DIR").unwrap(), + ]) + .output() + .expect("Failed to execute command"); + assert!( + std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/frontend/tests/basics/src/main.rs b/rust/frontend/tests/basics/src/main.rs new file mode 100644 index 000000000..69b948e91 --- /dev/null +++ b/rust/frontend/tests/basics/src/main.rs @@ -0,0 +1,35 @@ +extern crate ndarray as rust_ndarray; +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + + let (ctx, ctx_name) = if cfg!(feature = "cpu") { + (TVMContext::cpu(0), "cpu") + } else { + (TVMContext::gpu(0), "gpu") + }; + let dtype = TVMType::from("float32"); + let mut arr = NDArray::empty(shape, ctx, dtype); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = NDArray::empty(shape, ctx, dtype); + let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); + if !fadd.enabled(ctx_name) { + return; + } + if cfg!(feature = "gpu") { + fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); + } + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .set_output(&mut ret) + .unwrap() + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); +} diff --git a/rust/frontend/tests/basics/src/tvm_add.py b/rust/frontend/tests/basics/src/tvm_add.py new file mode 100755 index 000000000..2f3b7c8a8 --- /dev/null +++ b/rust/frontend/tests/basics/src/tvm_add.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +import os.path as osp +import sys + +import tvm +from tvm.contrib import cc + + +def main(target, out_dir): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C') + s = tvm.create_schedule(C.op) + + if target == 'cuda': + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, tvm.thread_axis('blockIdx.x')) + s[C].bind(tx, tvm.thread_axis('threadIdx.x')) + + fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd') + + fadd.save(osp.join(out_dir, 'test_add.o')) + if target == 'cuda': + fadd.imported_modules[0].save(os.path.join(out_dir, 'test_add.ptx')) + cc.create_shared( + osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')]) + + +if __name__ == '__main__': + main(sys.argv[1], sys.argv[2]) + diff --git a/rust/frontend/tests/callback/Cargo.toml b/rust/frontend/tests/callback/Cargo.toml new file mode 100644 index 000000000..1795c5745 --- /dev/null +++ b/rust/frontend/tests/callback/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "callback" +version = "0.0.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray = "0.12.1" +tvm-frontend = { path = "../../" } diff --git a/rust/frontend/tests/callback/src/bin/array.rs b/rust/frontend/tests/callback/src/bin/array.rs new file mode 100644 index 000000000..81dcadc30 --- /dev/null +++ b/rust/frontend/tests/callback/src/bin/array.rs @@ -0,0 +1,44 @@ +#![feature(extern_crate_item_prelude, try_from)] +#![allow(unused_imports)] + +extern crate ndarray as rust_ndarray; +#[macro_use] +extern crate tvm_frontend as tvm; + +use rust_ndarray::ArrayD; +use std::convert::{TryFrom, TryInto}; + +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e)?; + let rnd: ArrayD = ArrayD::try_from(&arr)?; + ret += rnd.scalar_sum(); + } + Ok(TVMRetValue::from(ret)) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + + let mut registered = function::Builder::default(); + let ret: f32 = registered + .get_function("sum", true) + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 14f32); +} diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/frontend/tests/callback/src/bin/error.rs new file mode 100644 index 000000000..f40f0f157 --- /dev/null +++ b/rust/frontend/tests/callback/src/bin/error.rs @@ -0,0 +1,43 @@ +#![feature(extern_crate_item_prelude, panic_info_message)] +#![allow(unused_imports)] + +use std::panic; + +#[macro_use] +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + register_global_func! { + fn error(_args: &[TVMArgValue]) -> Result { + Err(ErrorKind::TypeMismatch( + format!("{}", "i64".to_string()), + format!("{}", "f64".to_string()), + ).into()) + } + } + + let mut registered = function::Builder::default(); + registered.get_function("error", true); + assert!(registered.func.is_some()); + registered.args(&[10, 20]); + + println!("expected error message is:"); + panic::set_hook(Box::new(|panic_info| { + if let Some(msg) = panic_info.message() { + println!("{:?}", msg); + } + if let Some(location) = panic_info.location() { + println!( + "panic occurred in file '{}' at line {}", + location.file(), + location.line() + ); + } else { + println!("panic occurred but can't get location information"); + } + })); + + let _result = registered.invoke(); +} diff --git a/rust/frontend/tests/callback/src/bin/float.rs b/rust/frontend/tests/callback/src/bin/float.rs new file mode 100644 index 000000000..307055284 --- /dev/null +++ b/rust/frontend/tests/callback/src/bin/float.rs @@ -0,0 +1,32 @@ +#![feature(extern_crate_item_prelude, try_from)] +#![allow(unused_imports)] + +#[macro_use] +extern crate tvm_frontend as tvm; + +use std::convert::TryInto; +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0.0; + for arg in args.iter() { + let val: f64 = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(&ret)) + } + } + + let mut registered = function::Builder::default(); + registered.get_function("sum", true); + assert!(registered.func.is_some()); + let ret: f64 = registered + .args(&[10.0f64, 20.0, 30.0]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 60f64); +} diff --git a/rust/frontend/tests/callback/src/bin/int.rs b/rust/frontend/tests/callback/src/bin/int.rs new file mode 100644 index 000000000..301882220 --- /dev/null +++ b/rust/frontend/tests/callback/src/bin/int.rs @@ -0,0 +1,31 @@ +#![feature(extern_crate_item_prelude, try_from)] +#![allow(unused_imports)] + +extern crate tvm_frontend as tvm; + +use std::convert::TryInto; +use tvm::*; + +fn main() { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0i64; + for arg in args.iter() { + let val: i64 = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(&ret)) + } + + tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); + + let mut registered = function::Builder::default(); + registered.get_function("mysum", true); + assert!(registered.func.is_some()); + let ret: i64 = registered + .args(&[10, 20, 30]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 60); +} diff --git a/rust/frontend/tests/callback/src/bin/string.rs b/rust/frontend/tests/callback/src/bin/string.rs new file mode 100644 index 000000000..eafee3179 --- /dev/null +++ b/rust/frontend/tests/callback/src/bin/string.rs @@ -0,0 +1,34 @@ +#![feature(extern_crate_item_prelude, try_from)] +#![allow(unused_imports)] + +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::*; + +// FIXME +fn main() { + register_global_func! { + fn concate_str(args: &[TVMArgValue]) -> Result { + let mut ret = "".to_string(); + for arg in args.iter() { + let val: String = arg.try_into()?; + ret += val.as_str(); + } + Ok(TVMRetValue::from(ret)) + } + } + let mut registered = function::Builder::default(); + registered.get_function("concate_str", true); + assert!(registered.func.is_some()); + let a = "a".to_string(); + let b = "b".to_string(); + let c = "c".to_string(); + let ret: String = registered + .args(&[a, b, c]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, "abc".to_owned()); +} diff --git a/rust/runtime/.gitignore b/rust/runtime/.gitignore new file mode 100644 index 000000000..230ab6610 --- /dev/null +++ b/rust/runtime/.gitignore @@ -0,0 +1,3 @@ +Cargo.lock +target/ +**/*.rs.bk diff --git a/rust/runtime/.travis.yml b/rust/runtime/.travis.yml new file mode 100644 index 000000000..63a3d0277 --- /dev/null +++ b/rust/runtime/.travis.yml @@ -0,0 +1,5 @@ +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml new file mode 100644 index 000000000..d48c0d98c --- /dev/null +++ b/rust/runtime/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "tvm-runtime" +version = "0.1.0" +license = "Apache-2.0" +description = "A static TVM runtime" +repository = "https://github.com/dmlc/tvm" +readme = "README.md" +keywords = ["tvm", "nnvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] + +[features] +default = ["nom/std"] +sgx = ["nom/alloc"] + +[dependencies] +bounded-spsc-queue = "0.4.0" +error-chain = { version = "0.12.0", default-features = false } +itertools = "0.7.8" +lazy_static = "1.1.0" +ndarray = "0.11.2" +nom = {version = "4.0.0", default-features = false } +serde = "1.0.59" +serde_derive = "1.0.79" +serde_json = "1.0.17" +tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] } + +[target.'cfg(not(target_env = "sgx"))'.dependencies] +num_cpus = "1.8.0" diff --git a/rust/runtime/src/allocator.rs b/rust/runtime/src/allocator.rs new file mode 100644 index 000000000..5f77037e2 --- /dev/null +++ b/rust/runtime/src/allocator.rs @@ -0,0 +1,52 @@ +#[cfg(target_env = "sgx")] +use alloc::alloc::{self, Layout}; +#[cfg(not(target_env = "sgx"))] +use std::alloc::{self, Layout}; + +use crate::errors::*; + +const DEFAULT_ALIGN_BYTES: usize = 4; + +#[derive(PartialEq, Eq)] +pub struct Allocation { + layout: Layout, + ptr: *mut u8, +} + +impl Allocation { + /// Allocates a chunk of memory of `size` bytes with optional alignment. + pub fn new(size: usize, align: Option) -> Result { + let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); + let layout = Layout::from_size_align(size, alignment)?; + let ptr = unsafe { alloc::alloc(layout.clone()) }; + if ptr.is_null() { + alloc::handle_alloc_error(layout); + } + Ok(Self { + ptr: ptr, + layout: layout, + }) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } + + /// Returns the size of the Allocation in bytes. + pub fn size(&self) -> usize { + self.layout.size() + } + + /// Returns the byte alignment of the Allocation. + pub fn align(&self) -> usize { + self.layout.align() + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + unsafe { + alloc::dealloc(self.ptr, self.layout.clone()); + } + } +} diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs new file mode 100644 index 000000000..5c49515a0 --- /dev/null +++ b/rust/runtime/src/array.rs @@ -0,0 +1,507 @@ +use std::{ + any::TypeId, + convert::TryFrom, + mem, + ops::{Deref, DerefMut}, + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use ndarray; + +use crate::{ + allocator::Allocation, + errors::*, + ffi::runtime::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, + DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor, + }, +}; + +/// A `Storage` is a container which holds `Tensor` data. +#[derive(PartialEq)] +pub enum Storage<'a> { + /// A `Storage` which owns its contained bytes. + Owned(Allocation), + + /// A view of an existing `Storage`. + View(&'a mut [u8], usize), // ptr, align +} + +impl<'a> Storage<'a> { + pub fn new(size: usize, align: Option) -> Result> { + Ok(Storage::Owned(Allocation::new(size, align)?)) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + match self { + Storage::Owned(alloc) => alloc.as_mut_ptr(), + Storage::View(slice, _) => slice.as_ptr() as *mut u8, + } + } + + pub fn size(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.size(), + Storage::View(slice, _) => slice.len(), + } + } + + pub fn align(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.align(), + Storage::View(_, align) => *align, + } + } + + pub fn as_ptr(&self) -> *const u8 { + self.as_mut_ptr() as *const _ + } + + /// Returns a `Storage::View` which points to an owned `Storage::Owned`. + pub fn view(&self) -> Storage<'a> { + match self { + Storage::Owned(alloc) => Storage::View( + unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) }, + self.align(), + ), + Storage::View(slice, _) => Storage::View( + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) }, + self.align(), + ), + } + } + + pub fn is_owned(&self) -> bool { + match self { + Storage::Owned(_) => true, + _ => false, + } + } + + /// Returns an owned version of this storage via cloning. + pub fn to_owned(&self) -> Storage<'static> { + let s = Storage::new(self.size(), Some(self.align())).unwrap(); + unsafe { + s.as_mut_ptr() + .copy_from_nonoverlapping(self.as_ptr(), self.size()); + } + s + } +} + +impl<'d, 's, T> From<&'d [T]> for Storage<'s> { + fn from(data: &'d [T]) -> Self { + let data = unsafe { + slice::from_raw_parts_mut( + data.as_ptr() as *const u8 as *mut u8, + data.len() * mem::size_of::() as usize, + ) + }; + Storage::View(data, mem::align_of::()) + } +} + +/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`. +/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or +/// converted to `ndarray::Array` for non-TVM processing. +/// +/// # Examples +/// +/// ``` +/// extern crate ndarray; +/// +/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); +/// let mut a: Tensor = a_nd.into(); +/// let mut a_dl: DLTensor = (&mut t).into(); +/// call_packed!(tvm_fn, &mut a_dl); +/// +/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. +/// let mut a_nd = ndarray::Array::try_from(&a).unwrap(); +/// ``` +#[derive(PartialEq)] +pub struct Tensor<'a> { + /// The bytes which contain the data this `Tensor` represents. + pub(crate) data: Storage<'a>, + pub(crate) ctx: TVMContext, + pub(crate) dtype: DataType, + pub(crate) shape: Vec, + // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h + /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. + pub(crate) strides: Option>, + pub(crate) byte_offset: isize, + /// The number of elements in the `Tensor`. + pub(crate) size: usize, +} + +unsafe impl<'a> Send for Tensor<'a> {} + +impl<'a> Tensor<'a> { + pub fn shape(&self) -> Vec { + self.shape.clone() + } + + /// Returns the data of this `Tensor` as a `Vec`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. + pub fn to_vec(&self) -> Vec { + assert!(self.is_contiguous()); + assert!(self.dtype.is_type::()); + unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() } + } + + /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory. + pub fn is_contiguous(&self) -> bool { + match self.strides { + None => true, + Some(ref strides) => { + // check that stride for each dimension is the + // product of all trailing dimensons' shapes + self.shape + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + } + } + + /// Returns a clone of this `Tensor`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. + pub fn copy(&mut self, other: &Tensor) { + assert!( + self.dtype == other.dtype && self.size == other.size, + "Tensor shape/dtype mismatch." + ); + assert!( + self.is_contiguous() && other.is_contiguous(), + "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`", + self.strides, + other.strides + ); + unsafe { + self.data + .as_mut_ptr() + .offset(self.byte_offset as isize) + .copy_from_nonoverlapping( + other.data.as_mut_ptr().offset(other.byte_offset), + other.size * other.dtype.itemsize(), + ); + } + } + + /// Returns an owned version of this `Tensor` via cloning. + pub fn to_owned(&self) -> Tensor<'static> { + let t = Tensor { + data: self.data.to_owned(), + ctx: self.ctx.clone(), + dtype: self.dtype.clone(), + size: self.size.clone(), + shape: self.shape.clone(), + strides: None, + byte_offset: 0, + }; + unsafe { mem::transmute::, Tensor<'static>>(t) } + } + + fn from_array_storage<'s, T, D: ndarray::Dimension>( + arr: &ndarray::Array, + storage: Storage<'s>, + type_code: usize, + ) -> Tensor<'s> { + let type_width = mem::size_of::() as usize; + Tensor { + data: storage, + ctx: TVMContext::default(), + dtype: DataType { + code: type_code, + bits: 8 * type_width, + lanes: 1, + }, + size: arr.len(), + shape: arr.shape().iter().map(|&v| v as i64).collect(), + strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), + byte_offset: 0, + } + } +} + +/// Conversions to `ndarray::Array` from `Tensor`, if the types match. +macro_rules! impl_ndarray_try_from_tensor { + ($type:ty, $dtype:expr) => { + impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { + type Error = Error; + fn try_from(tensor: &'a Tensor) -> Result> { + ensure!( + tensor.dtype == $dtype, + "Cannot convert Tensor with dtype {:?} to ndarray", + tensor.dtype + ); + Ok(ndarray::Array::from_shape_vec( + tensor + .shape + .iter() + .map(|s| *s as usize) + .collect::>(), + tensor.to_vec::<$type>(), + )?) + } + } + }; +} + +impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); +impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); +impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); +impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); + +pub struct DLTensor { + pub(crate) inner: _DLTensor, +} + +impl Deref for DLTensor { + type Target = _DLTensor; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for DLTensor { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl DLTensor { + pub(crate) fn new(raw: _DLTensor) -> Self { + Self { inner: raw } + } + + pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { + assert!(!flatten || tensor.is_contiguous()); + Self { + inner: _DLTensor { + data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void, + ctx: DLContext::from(&tensor.ctx), + ndim: if flatten { 1 } else { tensor.shape.len() } as i32, + dtype: DLDataType::from(&tensor.dtype), + shape: if flatten { + &tensor.size as *const _ as *mut i64 + } else { + tensor.shape.as_ptr() + } as *mut i64, + strides: if flatten || tensor.is_contiguous() { + ptr::null_mut() + } else { + tensor.strides.as_ref().unwrap().as_ptr() + } as *mut i64, + byte_offset: 0, + }, + } + } +} + +impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { + fn from(tensor: &'a Tensor<'t>) -> Self { + DLTensor::from_tensor(tensor, false /* flatten */) + } +} + +impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { + fn from(tensor: &'a mut Tensor<'t>) -> Self { + DLTensor::from_tensor(tensor, false /* flatten */) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + pub(crate) code: usize, + pub(crate) bits: usize, + pub(crate) lanes: usize, +} + +impl DataType { + /// Returns the number of bytes occupied by an element of this `DataType`. + pub fn itemsize(&self) -> usize { + (self.bits * self.lanes) >> 3 + } + + /// Returns whether this `DataType` represents primitive type `T`. + pub fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == 0 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) + } +} + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +impl From for DataType { + fn from(dtype: DLDataType) -> Self { + Self { + code: dtype.code as usize, + bits: dtype.bits as usize, + lanes: dtype.lanes as usize, + } + } +} + +macro_rules! make_dtype_const { + ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { + const $name: DataType = DataType { + code: $code as usize, + bits: $bits, + lanes: $lanes, + }; + }; +} + +make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); +make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); +// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); +make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); +make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TVMContext { + pub(crate) device_type: usize, + pub(crate) device_id: usize, +} + +impl<'a> From<&'a TVMContext> for DLContext { + fn from(ctx: &'a TVMContext) -> Self { + Self { + device_type: ctx.device_type as u32, + device_id: ctx.device_id as i32, + } + } +} + +impl Default for TVMContext { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU as usize, + device_id: 0, + } + } +} + +impl<'a> From for Tensor<'a> { + fn from(dlt: DLTensor) -> Self { + unsafe { + let dtype = DataType::from(dlt.dtype); + let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec(); + let size = shape.iter().map(|v| *v as usize).product::() as usize; + let storage = Storage::from(slice::from_raw_parts( + dlt.data as *const u8, + dtype.itemsize() * size, + )); + Self { + data: storage, + ctx: TVMContext::default(), + dtype: dtype, + size: size, + shape: shape, + strides: if dlt.strides == ptr::null_mut() { + None + } else { + Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec()) + }, + byte_offset: dlt.byte_offset as isize, + } + } + } +} + +/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. +/// +/// # Panics +/// +/// Panics if the ndarray is not contiguous. +macro_rules! impl_tensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl From> for Tensor<'static> { + fn from(arr: ndarray::Array<$type, D>) -> Self { + let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous")); + Tensor::from_array_storage(&arr, storage.to_owned(), $typecode as usize) + } + } + impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { + fn from(arr: &'a ndarray::Array<$type, D>) -> Self { + let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous")); + Tensor::from_array_storage(arr, storage, $typecode as usize) + } + } + }; +} + +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + inner: _DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + }, + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const isize as *mut i64, + byte_offset: 0, + }, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); + +impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/runtime/src/errors.rs b/rust/runtime/src/errors.rs new file mode 100644 index 000000000..cf7723034 --- /dev/null +++ b/rust/runtime/src/errors.rs @@ -0,0 +1,36 @@ +#[cfg(target_env = "sgx")] +use alloc::alloc; +#[cfg(not(target_env = "sgx"))] +use std::alloc; +use std::num; + +use crate::common::errors as common_errors; +use ndarray; +use serde_json; + +error_chain! { + errors { + GraphFormatError(msg: String) { + description("unable to load graph") + display("could not load graph json: {}", msg) + } + + LoadGraphParamsError(msg: String) { + description("unable to load graph params") + display("could not load graph params: {}", msg) + } + } + foreign_links { + Alloc(alloc::AllocErr); + GraphDeserialize(serde_json::Error); + ParseInt(num::ParseIntError); + ShapeError(ndarray::ShapeError); + CommonError(common_errors::Error); + } +} + +impl From for Error { + fn from(_err: alloc::LayoutErr) -> Error { + Error::from_kind(ErrorKind::Msg("Layout error".to_string())) + } +} diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs new file mode 100644 index 000000000..0d5e281f3 --- /dev/null +++ b/rust/runtime/src/graph.rs @@ -0,0 +1,473 @@ +use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; + +use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr}; +use serde; +use serde_json; + +use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor}; +use crate::{ + common::value::TVMArgValue, + errors::{Error, ErrorKind, Result}, + ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt}, +}; + +// @see `kTVMNDArrayMagic` in `ndarray.h` +const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; +// @see `kTVMNDArrayListMagic` in `graph_runtime.h` +const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; + +/// A TVM computation graph. +/// +/// # Examples +/// +/// ``` +/// let graph_json = fs::read_to_string("graph.json")).unwrap(); +/// let graph = Graph::try_from(&graph_json).unwrap(); +/// ``` +#[derive(Serialize, Deserialize, Debug)] +pub struct Graph { + pub nodes: Vec, + pub arg_nodes: Vec, + pub heads: Vec, + pub node_row_ptr: Option>, + pub attrs: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Entry { + pub id: usize, + pub index: usize, + pub version: usize, +} + +impl Graph { + fn entry_index(&self, entry: &Entry) -> Result { + self.node_row_ptr + .as_ref() + .map(|nrp| nrp[entry.id] + entry.index) + .ok_or("Missing node_row_ptr.".into()) + } + + /// Attempt to deserialize a JSON attribute to a type `T`. + fn get_attr(&self, attr: &str) -> Result { + Ok(serde_json::from_value::( + self.attrs + .as_ref() + .ok_or(ErrorKind::GraphFormatError( + "Missing graph attrs".to_string(), + ))? + .get(attr) + .ok_or(ErrorKind::GraphFormatError(format!( + "Missing {} attr", + attr + )))? + .to_owned(), + )?) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Node { + pub op: String, + pub name: String, + pub inputs: Vec, + pub attrs: Option>, + pub control_deps: Option>, +} + +struct NodeAttrs { + func_name: String, + num_outputs: usize, + flatten_data: bool, +} + +impl Node { + fn parse_attrs(&self) -> Result { + let attrs = self + .attrs + .as_ref() + .ok_or(format!("Missing node.attrs for `{}`", self.name))?; + let func_name = attrs + .get("func_name") + .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))? + .to_string(); + let num_outputs = attrs + .get("num_outputs") + .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))? + .parse::()?; + let flatten_data = attrs + .get("flatten_data") + .ok_or(format!( + "Node `{}` is missing attrs.flatten_data", + self.name + ))? + .parse::()? + == 1; + Ok(NodeAttrs { + func_name, + num_outputs, + flatten_data, + }) + } +} + +impl<'a> TryFrom<&'a String> for Graph { + type Error = Error; + fn try_from(graph_json: &String) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +impl<'a> TryFrom<&'a str> for Graph { + type Error = Error; + fn try_from(graph_json: &'a str) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +/// A executor for a TVM computation graph. +/// +/// # Examples +/// +/// ``` +/// use ndarray::Array; +/// +/// let syslib = SystemLibModule::default(); // a provider of TVM functions +/// +/// let mut params_bytes = Vec::new(); +/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap(); +/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap(); +/// +/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap(); +/// +/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); +/// exec.load_params(params); +/// +/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]); +/// exec.set_input("data", x.into()); +/// exec.run(); +/// let output = exec.get_output(0).unwrap(); +/// +/// println!("{:#?}", Array::try_from(output).unwrap()); +/// ``` +pub struct GraphExecutor<'m, 't> { + graph: Graph, + op_execs: Vec>, + tensors: Vec>, +} + +unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} + +impl<'m, 't> GraphExecutor<'m, 't> { + pub fn new(graph: Graph, lib: &'m M) -> Result { + let tensors = Self::setup_storages(&graph)?; + Ok(GraphExecutor { + op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, + tensors: tensors, + graph: graph, + }) + } + + /// Runs the computation graph. + pub fn run(&self) { + self.op_execs.iter().for_each(|op_exec| { + op_exec(); + }); + } + + /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. + fn setup_storages<'a>(graph: &'a Graph) -> Result>> { + let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; + let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; + let dtypes = graph + .get_attr::<(String, Vec)>("dltype")? + .1 + .iter() + .map(|dltype| { + if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) { + Ok(dtype) + } else { + Err(ErrorKind::GraphFormatError( + format!("Invalid dltype: {}", dltype).to_string(), + ) + .into()) + } + }) + .collect::>>()?; + + let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); + let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; + for (i, &storage_id) in storage_ids.iter().enumerate() { + let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; + let nbytes = dtype_size * shapes[i].iter().product::() as usize; + storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); + } + + let mut storages: Vec = storage_num_bytes + .into_iter() + .map(|nbytes| Storage::new(nbytes, align)) + .collect::>>()?; + + let tensors = izip!(storage_ids, shapes, dtypes) + .map(|(storage_id, shape, dtype)| { + let storage = storages[storage_id].view(); + Tensor { + data: mem::replace(&mut storages[storage_id], storage), + ctx: TVMContext::default(), + dtype: dtype, + size: shape.iter().product::() as usize, + shape: shape, + strides: None, + byte_offset: 0, + } + }) + .collect(); + + Ok(tensors) + } + + /// Creates closures which represent the computation performed by this graph. + fn setup_op_execs( + graph: &Graph, + lib: &'m M, + tensors: &Vec>, + ) -> Result>> { + ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); + let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); + + let mut op_execs = Vec::new(); + for (i, node) in graph.nodes.iter().enumerate() { + if node.op == "null" { + continue; + } + ensure!(node.op == "tvm_op", "Only TVM ops are supported."); + ensure!(node.attrs.is_some(), "Missing node attrs."); + + let attrs = node.parse_attrs()?; + + if attrs.func_name == "__nop" { + continue; + } + + let func = lib + .get_function(&attrs.func_name) + .ok_or(format!("Missing function {}", attrs.func_name))?; + let arg_indices = node + .inputs + .iter() + .map(|entry| graph.entry_index(entry)) + .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi))); + + let dl_tensors = arg_indices + .map(|idx| { + let tensor = &tensors[idx?]; + Ok(if attrs.flatten_data { + DLTensor::from_tensor(tensor, true /* flatten */) + } else { + DLTensor::from(tensor) + }) + }) + .collect::>>() + .unwrap(); + let op: Box = box move || { + let args = dl_tensors + .iter() + .map(|t| t.into()) + .collect::>(); + func(args.as_slice()); + }; + op_execs.push(op); + } + Ok(op_execs) + } + + pub fn load_params(&mut self, params: HashMap) { + params.into_iter().for_each(|(name, param)| { + self.set_input(name, param); + }) + } + + pub fn set_input>(&mut self, name: S, value: Tensor) { + if let Some(idx) = self.get_input_index(name.as_ref()) { + // TODO: consider `new_with_params` to avoid ever allocating + let ptr = self.tensors[idx].data.as_ptr(); + let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr); + let mut owner = to_replace.nth(0).unwrap(); + if value.data.is_owned() { + // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr + // mem::replace(&mut (*owner), value); + // to_replace.for_each(|t| { + // panic!("replacing"); + // t.data = owner.data.view(); + // }); + owner.copy(&value); + } else { + owner.copy(&value); + } + } else { + println!("Unexpected input `{}`", name.as_ref()); + } + } + + /// Returns the graph input with name `name`, if it exists. + pub fn get_input>(&mut self, name: S) -> Option<&Tensor> { + self.get_input_index(name.as_ref()) + .and_then(move |idx| Some(&self.tensors[idx])) + } + + /// Returns the graph output with index `index`, if it exists. + pub fn get_output(&self, idx: usize) -> Option<&Tensor> { + let graph = &self.graph; + graph.heads.get(idx).and_then(|entry| { + graph + .entry_index(entry) + .map(|idx| self.tensors.get(idx)) + .unwrap_or(None) + }) + } + + /// Returns the index for graph input with name `name`, if it exists. + pub fn get_input_index>(&self, name: S) -> Option { + let graph = &self.graph; + (0..graph.nodes.len()) + .skip_while(|&i| graph.nodes[i].name != name.as_ref()) + .nth(0) + .and_then(|i| { + if graph.arg_nodes.iter().any(|&id| id == i) { + graph.node_row_ptr.as_ref().map(|nrp| nrp[i]) + } else { + None + } + }) + } +} + +/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h +named!( + tvm_str_to_type, + do_parse!( + type_name: alpha1 >> + bits: digit1 >> + lanes: opt!(tuple!(tag!("x"), digit1)) >> + (DataType { + code: match type_name { + CompleteStr("int") => DLDataTypeCode_kDLInt, + CompleteStr("uint") => DLDataTypeCode_kDLUInt, + CompleteStr("float") => DLDataTypeCode_kDLFloat, + _ => DLDataTypeCode_kDLFloat, + } as usize, + bits: bits.parse::().unwrap() as usize, + lanes: match lanes { + Some(lanes) => lanes.1.parse::().unwrap() as usize, + None => 1, + }, + }) + ) +); + +/// Converts a bytes to String. +named!( + name, + map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8( + b.to_vec() + )) +); + +/// Parses a TVMContext +named!( + tvm_ctx<&[u8], TVMContext>, + do_parse!( + device_type: le_u32 >> + device_id: le_i32 >> + (TVMContext { device_type: device_type as usize, device_id: device_id as usize }) + ) +); + +/// Parses a DataType +named!( + data_type<&[u8], DataType>, + do_parse!( + code: le_u8 >> + bits: le_u8 >> + lanes: le_u16 >> + (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize }) + ) +); + +/// Parses a Tensor from a TVM array file. +named!( + tensor, + do_parse!( + take!(8) + >> bits!(tag_bits!(u64, 64, 0)) + >> ctx: tvm_ctx + >> ndim: le_u32 + >> dtype: data_type + >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize) + >> length: le_i64 + >> data: take!(length) + >> (Tensor { + data: Storage::from(data), + ctx: ctx, + dtype: dtype, + size: shape.iter().product::() as usize, + shape: shape, + strides: None, + byte_offset: 0, + }) + ) +); + +/// Parses a graph params dict from a params binary file. +named!( + parse_param_dict>, + do_parse!( + take!(8) + >> bits!(tag_bits!(u64, 64, 0)) + >> names: length_count!(le_u64, name) + >> tensors: length_count!(le_u64, tensor) + >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter()))) + ) +); + +/// Loads a param dict saved using `nnvm.compiler.save_param_dict`. +pub fn load_param_dict(bytes: &[u8]) -> Result> { + if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { + if remaining_bytes.len() > 0 { + bail!(ErrorKind::LoadGraphParamsError("extra input".to_string())) + } else { + Ok(param_dict) + } + } else { + bail!(ErrorKind::LoadGraphParamsError( + "invalid parameters file".to_string() + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_str_to_type() { + assert_eq!( + tvm_str_to_type(CompleteStr("float24")).unwrap().1, + DataType { + code: DLDataTypeCode_kDLFloat as usize, + bits: 24, + lanes: 1 + } + ); + assert_eq!( + tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1, + DataType { + code: DLDataTypeCode_kDLUInt as usize, + bits: 111, + lanes: 44 + } + ); + } +} diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs new file mode 100644 index 000000000..da030bc4b --- /dev/null +++ b/rust/runtime/src/lib.rs @@ -0,0 +1,74 @@ +//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`. +//! It's mainly useful for compiling to WebAssembly and SGX, +//! but also native if you prefer Rust to C++. +//! +//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`. +//! Single-function modules are used via the `packed_func!` macro after obtaining +//! the function from `runtime::SystemLibModule` +//! +//! The main entrypoints to this crate are `GraphExecutor` +//! For examples of use, please refer to the multi-file tests in the `tests` directory. + +#![feature( + alloc, + allocator_api, + box_syntax, + fn_traits, + try_from, + unboxed_closures, + vec_remove_item +)] + +#[cfg(target_env = "sgx")] +extern crate alloc; +extern crate bounded_spsc_queue; +#[cfg(target_env = "sgx")] +extern crate core; +#[macro_use] +extern crate error_chain; +#[macro_use] +extern crate itertools; +#[macro_use] +extern crate lazy_static; +extern crate ndarray; +#[macro_use] +extern crate nom; +#[cfg(not(target_env = "sgx"))] +extern crate num_cpus; +extern crate serde; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; +extern crate tvm_common as common; + +mod allocator; +mod array; +pub mod errors; +mod module; +#[macro_use] +mod packed_func; +mod graph; +#[cfg(target_env = "sgx")] +#[macro_use] +pub mod sgx; +mod threading; +mod workspace; + +pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue}; + +pub use self::{ + array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*, +}; + +#[cfg(target_env = "sgx")] +use self::sgx::ocall_packed_func; + +#[no_mangle] +pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) { + #[cfg(not(target_env = "sgx"))] + unsafe { + panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap()); + } + #[cfg(target_env = "sgx")] + ocall_packed!("__sgx_set_last_error__", cmsg); +} diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs new file mode 100644 index 000000000..8e6f7d665 --- /dev/null +++ b/rust/runtime/src/module.rs @@ -0,0 +1,48 @@ +use std::{ + collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, +}; + +use crate::{ + ffi::runtime::BackendPackedCFunc, + packed_func::{wrap_backend_packed_func, PackedFunc}, +}; + +pub trait Module { + fn get_function>(&self, name: S) -> Option; +} + +pub struct SystemLibModule; + +lazy_static! { + static ref SYSTEM_LIB_FUNCTIONS: Mutex> = + Mutex::new(HashMap::new()); +} + +impl Module for SystemLibModule { + fn get_function>(&self, name: S) -> Option { + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .get(name.as_ref()) + .map(|func| wrap_backend_packed_func(func.to_owned())) + } +} + +impl Default for SystemLibModule { + fn default() -> Self { + SystemLibModule {} + } +} + +#[no_mangle] +pub extern "C" fn TVMBackendRegisterSystemLibSymbol( + cname: *const c_char, + func: BackendPackedCFunc, +) -> i32 { + let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .insert(name.to_string(), func); + return 0; +} diff --git a/rust/runtime/src/packed_func.rs b/rust/runtime/src/packed_func.rs new file mode 100644 index 000000000..2fe0086e9 --- /dev/null +++ b/rust/runtime/src/packed_func.rs @@ -0,0 +1,118 @@ +use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void}; + +use super::Tensor; +use crate::ffi::runtime::{ + BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle, + TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue, +}; + +use super::DLTensor; +use crate::{ + common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue}, + errors::*, +}; + +pub type PackedFunc = Box TVMRetValue + Send + Sync>; + +/// Calls a packed function and returns a `TVMRetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; +} + +impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + let raw = _TVMValue { + v_handle: arr as *const _ as *mut DLTensor as *mut c_void, + }; + TVMArgValue { + value: TVMValue::new(raw), + type_code: TVMTypeCode::kArrayHandle, + lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + let raw = _TVMValue { + v_handle: arr as *mut _ as *mut c_void, + }; + TVMArgValue { + value: TVMValue::new(raw), + type_code: TVMTypeCode::kArrayHandle, + lifetime: PhantomData, + } + } +} + +impl<'a> TryFrom> for Tensor<'a> { + type Error = Error; + fn try_from(val: TVMArgValue<'a>) -> Result { + ensure!( + val.type_code == TVMTypeCode::kArrayHandle + || val.type_code == TVMTypeCode::kNDArrayContainer, + "Could not downcast arg. Expected `{}` or `{}`, but got `{}`", + TVMTypeCode::kArrayHandle, + TVMTypeCode::kNDArrayContainer, + val.type_code, + ); + + let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) }; + Ok(DLTensor::new(dlt).into()) + } +} + +impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue { + fn from(val: &'t Tensor<'a>) -> Self { + TVMRetValue { + prim_value: 0, + box_value: box DLTensor::from(val), + type_code: TVMTypeCode::kNDArrayContainer, + } + } +} + +impl<'a> TryFrom for Tensor<'a> { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result { + ensure!( + ret.type_code == TVMTypeCode::kArrayHandle + || ret.type_code == TVMTypeCode::kNDArrayContainer, + "Could not downcast arg. Expected `{}` or `{}`, but got `{}`", + TVMTypeCode_kArrayHandle, + TVMTypeCode_kNDArrayContainer, + ret.type_code, + ); + + let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) }; + Ok(DLTensor::new(dlt).into()) + } +} + +// @see `WrapPackedFunc` in `llvm_module.cc`. +pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { + box move |args: &[TVMArgValue]| { + func( + args.iter() + .map(|ref arg| arg.value.inner) + .collect::>() + .as_ptr(), + args.iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + ); + TVMRetValue::default() + } +} diff --git a/rust/runtime/src/sgx.rs b/rust/runtime/src/sgx.rs new file mode 100644 index 000000000..1edf3ef49 --- /dev/null +++ b/rust/runtime/src/sgx.rs @@ -0,0 +1,80 @@ +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, +}; + +use errors::Result; +use ffi::runtime::TVMValue; +use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; + +pub use runtime::threading::tvm_run_worker as run_worker; + +#[macro_export] +macro_rules! tvm_ocall { + ($func: expr) => { + match $func { + 0 => Ok(()), + err => Err(format!("SGX error: {}", err)), + } + }; +} + +pub type SgxStatus = u32; + +#[cfg(target_env = "sgx")] +extern "C" { + fn tvm_ocall_packed_func( + name: *const c_char, + arg_values: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut c_int, + ) -> SgxStatus; +} + +pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Result { + let mut ret_val = TVMValue { v_int64: 0 }; + let ret_type_code = 0i64; + unsafe { + tvm_ocall!(tvm_ocall_packed_func( + CString::new(fn_name.as_ref()).unwrap().as_ptr(), + args.iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args.iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + &mut ret_val as *mut TVMValue, + &mut (ret_type_code as i32) as *mut c_int, + ))?; + } + Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64)) +} + +#[macro_export] +macro_rules! ocall_packed { + ($fn_name:expr, $($args:expr),+) => { + ocall_packed_func($fn_name, &[$($args.into(),)+]) + .expect(concat!("Error calling `", $fn_name, "`")) + }; + ($fn_name:expr) => { + ocall_packed_func($fn_name, &Vec::new()) + .expect(concat!("Error calling `", $fn_name, "`")) + } +} + +pub fn shutdown() { + if env!("TVM_NUM_THREADS") != "0" { + sgx_join_threads() + } +} + +impl Drop for SystemLibModule { + fn drop(&mut self) { + shutdown() + } +} diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs new file mode 100644 index 000000000..38f4b7d23 --- /dev/null +++ b/rust/runtime/src/threading.rs @@ -0,0 +1,336 @@ +use std::{ + os::raw::{c_int, c_void}, + sync::{ + atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, + Arc, Barrier, + }, +}; + +#[cfg(not(target_env = "sgx"))] +use num_cpus; +#[cfg(not(target_env = "sgx"))] +use std::{ + env, + thread::{self, JoinHandle}, +}; + +#[cfg(target_env = "sgx")] +use std::{collections::VecDeque, ptr, sync::Mutex}; + +use bounded_spsc_queue::{self, Producer}; + +use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv}; + +#[cfg(target_env = "sgx")] +use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue}; + +type FTVMParallelLambda = + extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; + +/// Holds a parallel job request made by a TVM library function. +struct Job { + cb: FTVMParallelLambda, + cdata: *const c_void, + req_num_tasks: usize, + pending: Arc, +} + +impl Job { + /// Splits this job into a number of `Task`s which can be scheduled. + fn tasks(&self, num_workers: usize) -> Vec { + let num_tasks = if self.req_num_tasks == 0 { + num_workers + } else { + self.req_num_tasks.min(num_workers) + }; + self.pending.store(num_tasks, Ordering::SeqCst); + + let barrier = Arc::new(Barrier::new(num_tasks)); + + (0..num_tasks) + .map(move |i| Task { + id: i, + flambda: self.cb, + penv: TVMParallelGroupEnv { + sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void, + num_task: num_tasks as i32, + }, + cdata: self.cdata, + pending: Arc::clone(&self.pending), + }) + .collect() + } + + /// Waits for all tasks in this `Job` to be completed. + fn wait(&self) -> Result<()> { + while self.pending.load(Ordering::Acquire) > 0 { + #[cfg(not(target_env = "sgx"))] + thread::yield_now(); + } + Ok(()) + } +} + +/// A chunk of work requested by a TVM function. +struct Task { + id: usize, + flambda: FTVMParallelLambda, + penv: TVMParallelGroupEnv, + cdata: *const c_void, + pending: Arc, +} +unsafe impl Send for Task {} +unsafe impl Sync for Task {} + +impl FnOnce<()> for Task { + type Output = i32; + extern "rust-call" fn call_once(self, _args: ()) -> Self::Output { + let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata); + self.pending.fetch_sub(1, Ordering::AcqRel); + status + } +} + +#[derive(Default)] +struct Threads { + #[allow(unused)] + #[cfg(not(target_env = "sgx"))] + handles: Vec>, + queues: Vec>, +} + +impl<'a> Threads { + #[cfg(not(target_env = "sgx"))] + fn launch) + 'static + Copy>( + num_threads: usize, + cb: F, + ) -> Self { + let (handles, queues) = (0..num_threads) + .map(|_| { + let (p, c) = bounded_spsc_queue::make(2); + let handle = thread::spawn(move || cb(c.into())); + (handle, p) + }) + .unzip(); + Threads { + handles: handles, + queues: queues, + } + } + + #[cfg(target_env = "sgx")] + fn launch) + 'static + Copy>( + num_threads: usize, + _cb: F, + ) -> Self { + let mut consumer_queues = SGX_QUEUES.lock().unwrap(); + let queues = (0..num_threads) + .map(|_| { + let (p, c) = bounded_spsc_queue::make(2); + consumer_queues.push_back(c.into()); + p + }) + .collect(); + ocall_packed!("__sgx_thread_group_launch__", num_threads as u64); + Threads { queues: queues } + } +} + +struct ThreadPool { + num_workers: usize, + #[allow(unused)] + threads: Threads, +} + +thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new()); + +impl ThreadPool { + fn new() -> Self { + let num_workers = max_concurrency(); + ThreadPool { + num_workers: num_workers, + threads: Threads::launch(num_workers, ThreadPool::run_worker), + } + } + + fn launch(&self, job: Job) { + let mut tasks = job.tasks(self.num_workers + 1); + + for (i, task) in tasks.split_off(1).into_iter().enumerate() { + self.threads.queues[i].push(task); + } + + tasks.pop().unwrap()(); + job.wait().unwrap(); + } + + fn run_worker(queue: Consumer) { + loop { + let task = queue.pop(); + let result = task(); + if result == ::min_value() { + break; + } else if result != 0 { + panic!("Error running task."); + } + } + } +} + +// Send + Sync wrapper for bounded_spsc_queue::Consumer +struct Consumer { + consumer: bounded_spsc_queue::Consumer, +} +impl From> for Consumer { + fn from(c: bounded_spsc_queue::Consumer) -> Self { + Consumer { consumer: c } + } +} +impl Consumer { + fn pop(&self) -> T { + self.consumer.pop() + } +} +unsafe impl Send for Consumer {} +unsafe impl Sync for Consumer {} + +#[cfg(target_env = "sgx")] +lazy_static! { + /// Holds tasks for untrusted threads which re-enter the enclave to execute. + static ref SGX_QUEUES: Mutex>> = Mutex::new(VecDeque::new()); +} + +#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))] +fn max_concurrency() -> usize { + if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) { + if let Ok(threads) = usize::from_str_radix(&threads_str, 10) { + return threads; + } + } + num_cpus::get_physical() +} + +#[cfg(target_env = "sgx")] +fn max_concurrency() -> usize { + usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1) +} + +#[cfg(target_arch = "wasm32")] +fn max_concurrency() -> usize { + 0 // wasm doesn't support threads yet +} + +#[cfg(target_env = "sgx")] +pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue { + let q = { + let mut qs = SGX_QUEUES.lock().unwrap(); + qs.pop_front() + // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return + }; + if let Some(q) = q { + ThreadPool::run_worker(q); + } + TVMRetValue::default() +} + +#[no_mangle] +pub extern "C" fn TVMBackendParallelLaunch( + cb: FTVMParallelLambda, + cdata: *const c_void, + num_task: usize, +) -> c_int { + if max_concurrency() == 0 { + let penv = TVMParallelGroupEnv { + sync_handle: 0 as *mut c_void, + num_task: 1, + }; + cb(0, &penv as *const _, cdata); + } else { + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb: cb, + cdata: cdata, + req_num_tasks: num_task, + pending: Arc::new(ATOMIC_USIZE_INIT), + }); + }); + } + return 0; +} + +#[cfg(target_env = "sgx")] +pub(crate) fn sgx_join_threads() { + extern "C" fn poison_pill( + _task_id: usize, + _penv: *const TVMParallelGroupEnv, + _cdata: *const c_void, + ) -> i32 { + ::min_value() + } + + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb: poison_pill, + cdata: ptr::null(), + req_num_tasks: 0, + pending: Arc::new(ATOMIC_USIZE_INIT), + }); + }); + ocall_packed!("__sgx_thread_group_join__", 0); +} + +// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used. +#[no_mangle] +pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { + let barrier: &Arc = unsafe { &*((*penv).sync_handle as *const Arc) }; + barrier.wait(); +} + +#[cfg(test)] +mod tests { + use std::{ptr, thread, time::Duration}; + + use super::*; + + #[test] + fn test_max_concurrency() { + env::set_var("TVM_NUM_THREADS", "42"); + env::set_var("OMP_NUM_THREADS", "24"); + assert_eq!(max_concurrency(), 42); + env::remove_var("TVM_NUM_THREADS"); + assert_eq!(max_concurrency(), 24); + } + + extern "C" fn flambda( + task_id: usize, + penv: *const TVMParallelGroupEnv, + cdata: *const c_void, + ) -> i32 { + if cdata == ptr::null() { + return 0; + } + unsafe { + let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize)); + thread::sleep(Duration::from_millis(50 * task_id as u64)); + counter.fetch_add(1, Ordering::SeqCst); + task_ids_sum.fetch_add(task_id, Ordering::SeqCst); + assert_eq!((*penv).num_task, 3); + } + 0 + } + + #[test] + fn test_parallel_launch() { + TVMBackendParallelLaunch(flambda, ptr::null(), 6); + let counter = ATOMIC_USIZE_INIT; + let task_ids_sum = ATOMIC_USIZE_INIT; + let cdata = (counter, task_ids_sum); + let num_tasks = 3; + TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); + assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks); + assert_eq!( + cdata.1.load(Ordering::SeqCst), + (0..num_tasks).sum::() + ); + } +} diff --git a/rust/runtime/src/workspace.rs b/rust/runtime/src/workspace.rs new file mode 100644 index 000000000..a12a27e4c --- /dev/null +++ b/rust/runtime/src/workspace.rs @@ -0,0 +1,117 @@ +use std::{ + cell::RefCell, + os::raw::{c_int, c_void}, + ptr, +}; + +use super::allocator::Allocation; +use crate::errors::*; + +const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` + +struct WorkspacePool { + workspaces: Vec, + free: Vec, + in_use: Vec, +} + +impl WorkspacePool { + fn new() -> Self { + WorkspacePool { + workspaces: Vec::new(), + free: Vec::new(), + in_use: Vec::new(), + } + } + + fn alloc_new(&mut self, size: usize) -> Result<*mut u8> { + self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); + self.in_use.push(self.workspaces.len() - 1); + Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) + } + + fn alloc(&mut self, size: usize) -> Result<*mut u8> { + if self.free.len() == 0 { + return self.alloc_new(size); + } + let idx = self + .free + .iter() + .fold(None, |cur_ws_idx: Option, &idx| { + let ws_size = self.workspaces[idx].size(); + if !ws_size >= size { + return cur_ws_idx; + } + cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { + let cur_size = self.workspaces[cur_idx].size(); + Some(match ws_size <= cur_size { + true => idx, + false => cur_idx, + }) + }) + }); + match idx { + Some(idx) => { + self.free.remove_item(&idx).unwrap(); + self.in_use.push(idx); + Ok(self.workspaces[idx].as_mut_ptr()) + } + None => self.alloc_new(size), + } + } + + fn free(&mut self, ptr: *mut u8) -> Result<()> { + let mut ws_idx = None; + for i in 0..self.in_use.len() { + let idx = self.in_use[i]; + if self.workspaces[idx].as_mut_ptr() == ptr { + self.in_use.remove(i); + ws_idx = Some(idx); + break; + } + } + Ok(self + .free + .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?)) + } +} + +thread_local!(static WORKSPACE_POOL: RefCell = RefCell::new(WorkspacePool::new())); + +const WORKSPACE_PAGE_SIZE: usize = 4 << 10; + +#[no_mangle] +pub extern "C" fn TVMBackendAllocWorkspace( + _device_type: c_int, + _device_id: c_int, + size: u64, + _dtype_code_hint: c_int, + _dtype_bits_hint: c_int, +) -> *mut c_void { + let nbytes = if size == 0 { + WORKSPACE_PAGE_SIZE + } else { + size as usize + }; + WORKSPACE_POOL.with(|pool_cell| { + pool_cell + .borrow_mut() + .alloc(nbytes as usize) + .unwrap_or(ptr::null_mut()) as *mut c_void + }) +} + +#[no_mangle] +pub extern "C" fn TVMBackendFreeWorkspace( + _device_type: c_int, + _device_id: c_int, + ptr: *mut c_void, +) -> c_int { + WORKSPACE_POOL.with(|pool_cell| { + (match pool_cell.borrow_mut().free(ptr as *mut u8) { + Ok(()) => 0, + Err(_) => -1, + }) as c_int + }); + return 0; +} diff --git a/rust/runtime/tests/.gitignore b/rust/runtime/tests/.gitignore new file mode 100644 index 000000000..811076739 --- /dev/null +++ b/rust/runtime/tests/.gitignore @@ -0,0 +1,3 @@ +*.json +*.params +*.o diff --git a/rust/runtime/tests/build_model.py b/rust/runtime/tests/build_model.py new file mode 100755 index 000000000..ea55ce434 --- /dev/null +++ b/rust/runtime/tests/build_model.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 + +"""Builds a simple NNVM graph for testing.""" + +from os import path as osp + +import nnvm +from nnvm import sym +from nnvm.compiler import graph_util +from nnvm.testing import init +import numpy as np +import tvm + +CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) + + +def _get_model(dshape): + data = sym.Variable('data', shape=dshape) + fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True) + left, right = sym.split(fc1, indices_or_sections=2, axis=1) + return sym.Group(((left + 1), (right - 1))) + + +def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): + if isinstance(graph, sym.Symbol): + graph = nnvm.graph.create(graph) + ishapes, _ = graph_util.infer_shape(graph, **input_shapes) + param_shapes = dict(zip(graph.index.input_names, ishapes)) + np.random.seed(seed) + params = {} + for param, shape in param_shapes.items(): + if param in {'data', 'label'} or not shape: + continue + init_value = np.empty(shape).astype('float32') + initializer(param, init_value) + params[param] = tvm.nd.array(init_value) + return params + +def main(): + dshape = (32, 16) + net = _get_model(dshape) + ishape_dict = {'data': dshape} + params = _init_params(net, ishape_dict) + graph, lib, params = nnvm.compiler.build(net, 'llvm', + shape=ishape_dict, + params=params, + dtype='float32') + + with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph.json()) + with open(osp.join(CWD, 'graph.params'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_graph_serde.rs b/rust/runtime/tests/test_graph_serde.rs new file mode 100644 index 000000000..18ac19a79 --- /dev/null +++ b/rust/runtime/tests/test_graph_serde.rs @@ -0,0 +1,39 @@ +#![feature(try_from)] + +extern crate serde; +extern crate serde_json; + +extern crate tvm_runtime; + +use std::{convert::TryFrom, fs, io::Read}; + +use tvm_runtime::Graph; + +#[test] +fn test_load_graph() { + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params")) + .expect("Could not find TVM graph. Did you run `tests/build_model.py`?") + .read_to_end(&mut params_bytes) + .unwrap(); + let _params = tvm_runtime::load_param_dict(¶ms_bytes); + + let graph = Graph::try_from( + &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), + ) + .unwrap(); + + assert_eq!(graph.nodes[3].op, "tvm_op"); + assert_eq!( + graph.nodes[3] + .attrs + .as_ref() + .unwrap() + .get("func_name") + .unwrap(), + "fuse_dense" + ); + assert_eq!(graph.nodes[5].inputs[0].index, 0); + assert_eq!(graph.nodes[6].inputs[0].index, 1); + assert_eq!(graph.heads.len(), 2); +} diff --git a/rust/runtime/tests/test_nnvm/Cargo.toml b/rust/runtime/tests/test_nnvm/Cargo.toml new file mode 100644 index 000000000..14d0b3961 --- /dev/null +++ b/rust/runtime/tests/test_nnvm/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "test-nnvm" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray = "0.11.2" +serde = "1.0.59" +serde_json = "1.0.17" +tvm-runtime = { path = "../../" } + +[build-dependencies] +ar = "0.6.0" diff --git a/rust/runtime/tests/test_nnvm/build.rs b/rust/runtime/tests/test_nnvm/build.rs new file mode 100644 index 000000000..3a4fc0a09 --- /dev/null +++ b/rust/runtime/tests/test_nnvm/build.rs @@ -0,0 +1,33 @@ +extern crate ar; + +use std::{env, fs::File, path::Path, process::Command}; + +use ar::Builder; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_graph.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/graph.o", out_dir)).exists(), + "Could not build graph lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap()); + builder.append_path(format!("{}/graph.o", out_dir)).unwrap(); + + println!("cargo:rustc-link-lib=static=graph"); + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/runtime/tests/test_nnvm/src/build_test_graph.py b/rust/runtime/tests/test_nnvm/src/build_test_graph.py new file mode 100755 index 000000000..e9f74ecca --- /dev/null +++ b/rust/runtime/tests/test_nnvm/src/build_test_graph.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +"""Builds a simple NNVM graph for testing.""" + +from os import path as osp +import sys + +import nnvm +from nnvm import sym +from nnvm.compiler import graph_util +from nnvm.testing import init +import numpy as np +import tvm + + +def _get_model(dshape): + data = sym.Variable('data', shape=dshape) + fc = sym.dense(data, units=dshape[-1]*2, use_bias=True) + left, right = sym.split(fc, indices_or_sections=2, axis=1) + return sym.Group(((left + 1), (right - 1), fc)) + + +def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): + if isinstance(graph, sym.Symbol): + graph = nnvm.graph.create(graph) + + ishapes, _ = graph_util.infer_shape(graph, **input_shapes) + param_shapes = dict(zip(graph.index.input_names, ishapes)) + np.random.seed(seed) + params = {} + for param, shape in param_shapes.items(): + if param in {'data', 'label'} or not shape: + continue + + init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32') + if param.endswith('_bias'): + params[param] = tvm.nd.array(init_value) + continue + + init_value = np.empty(shape).astype('float32') + initializer(param, init_value) + # init_value /= init_value.sum() + 1e-10 + params[param] = tvm.nd.array(init_value) + + return params + +def main(): + dshape = (4, 8) + net = _get_model(dshape) + ishape_dict = {'data': dshape} + params = _init_params(net, ishape_dict) + graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib', + shape=ishape_dict, + params=params, + dtype='float32') + + out_dir = sys.argv[1] + lib.save(osp.join(sys.argv[1], 'graph.o')) + with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph.json()) + + with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_nnvm/src/main.rs b/rust/runtime/tests/test_nnvm/src/main.rs new file mode 100644 index 000000000..50179798c --- /dev/null +++ b/rust/runtime/tests/test_nnvm/src/main.rs @@ -0,0 +1,82 @@ +#![feature(try_from)] + +#[macro_use] +extern crate ndarray; +extern crate serde; +extern crate serde_json; + +extern crate tvm_runtime; +use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; + +use ndarray::Array; +use tvm_runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; + +const BATCH_SIZE: usize = 4; +const IN_DIM: usize = 8; + +macro_rules! check_sum { + ($e:expr, $a:ident, $b:ident) => { + let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap(); + check_sum!(a, $b); + }; + ($e:expr, $a:expr, $b:ident) => { + let a = Array::try_from($e.get_output($a).unwrap()).unwrap(); + check_sum!(a, $b); + }; + ($a:ident, $b:ident) => { + let a_sum: f32 = $a.scalar_sum(); + let b_sum: f32 = $b.scalar_sum(); + assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); + }; +} + +fn main() { + let syslib = SystemLibModule::default(); + + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("OUT_DIR"), "/graph.params")) + .unwrap() + .read_to_end(&mut params_bytes) + .unwrap(); + let params = tvm_runtime::load_param_dict(¶ms_bytes) + .unwrap() + .into_iter() + .map(|(k, v)| (k, v.to_owned())) + .collect::>>(); + + let graph = + Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()) + .unwrap(); + let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); + + let x = Array::from_shape_vec( + (BATCH_SIZE, IN_DIM), + (0..BATCH_SIZE * IN_DIM) + .map(|x| x as f32) + .collect::>(), + ) + .unwrap(); + let w = Array::try_from(params.get("dense0_weight").unwrap()) + .unwrap() + .into_shape((IN_DIM * 2, IN_DIM)) + .unwrap(); + let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap(); + let dense = x.dot(&w.t()) + &b; + let left = dense.slice(s![.., 0..IN_DIM]); + let right = dense.slice(s![.., IN_DIM..]); + let expected_o0 = &left + 1f32; + let expected_o1 = &right - 1f32; + + exec.load_params(params); + exec.set_input("data", (&x).into()); + + check_sum!(exec, data, x); + check_sum!(exec, dense0_weight, w); + check_sum!(exec, dense0_bias, b); + + exec.run(); + + check_sum!(exec, 0, expected_o0); + check_sum!(exec, 1, expected_o1); + check_sum!(exec, 2, dense); +} diff --git a/rust/runtime/tests/test_tvm_basic/Cargo.toml b/rust/runtime/tests/test_tvm_basic/Cargo.toml new file mode 100644 index 000000000..2a753b430 --- /dev/null +++ b/rust/runtime/tests/test_tvm_basic/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "test-tvm-basic" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray = "0.11.2" +tvm-runtime = { path = "../../" } + +[build-dependencies] +ar = "0.6.0" diff --git a/rust/runtime/tests/test_tvm_basic/build.rs b/rust/runtime/tests/test_tvm_basic/build.rs new file mode 100644 index 000000000..d8775857d --- /dev/null +++ b/rust/runtime/tests/test_tvm_basic/build.rs @@ -0,0 +1,34 @@ +extern crate ar; + +use std::{env, path::Path, process::Command}; + +use ar::Builder; +use std::fs::File; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/test.o", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap()); + builder.append_path(format!("{}/test.o", out_dir)).unwrap(); + + println!("cargo:rustc-link-lib=static=test"); + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/runtime/tests/test_tvm_basic/src/build_test_lib.py b/rust/runtime/tests/test_tvm_basic/src/build_test_lib.py new file mode 100755 index 000000000..7289a778f --- /dev/null +++ b/rust/runtime/tests/test_tvm_basic/src/build_test_lib.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm + +def main(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o')) + +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs new file mode 100644 index 000000000..f14fbec8c --- /dev/null +++ b/rust/runtime/tests/test_tvm_basic/src/main.rs @@ -0,0 +1,22 @@ +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, Module, SystemLibModule}; + +fn main() { + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/rust/src/errors.rs b/rust/src/errors.rs deleted file mode 100644 index f9da7180b..000000000 --- a/rust/src/errors.rs +++ /dev/null @@ -1,39 +0,0 @@ -#[cfg(target_env = "sgx")] -use alloc::alloc; -#[cfg(not(target_env = "sgx"))] -use std::alloc; -use std::num; - -use ndarray; -use serde_json; - -error_chain! { - errors { - TryFromTVMRetValueError(expected: String, actual: i64) { - description("mismatched types while downcasting TVMRetValue") - display("invalid downcast: expected `{}` but was `{}`", expected, actual) - } - - GraphFormatError(msg: String) { - description("unable to load graph") - display("could not load graph json: {}", msg) - } - - LoadGraphParamsError(msg: String) { - description("unable to load graph params") - display("could not load graph params: {}", msg) - } - } - foreign_links { - Alloc(alloc::AllocErr); - GraphDeserialize(serde_json::Error); - ParseInt(num::ParseIntError); - ShapeError(ndarray::ShapeError); - } -} - -impl From for Error { - fn from(_err: alloc::LayoutErr) -> Error { - Error::from_kind(ErrorKind::Msg("Layout error".to_string())) - } -} diff --git a/rust/src/lib.rs b/rust/src/lib.rs deleted file mode 100644 index e17c66911..000000000 --- a/rust/src/lib.rs +++ /dev/null @@ -1,67 +0,0 @@ -//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`. -//! It's mainly useful for compiling to WebAssembly and SGX, -//! but also native if you prefer Rust to C++. -//! -//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`. -//! Single-function modules are used via the `packed_func!` macro after obtaining -//! the function from `runtime::SystemLibModule` -//! -//! The main entrypoints to this crate are `GraphExecutor` -//! For examples of use, please refer to the multi-file tests in the `tests` directory. - -#![feature( - alloc, - allocator_api, - box_syntax, - fn_traits, - try_from, - unboxed_closures, - vec_remove_item -)] - -#[cfg(target_env = "sgx")] -extern crate alloc; -extern crate bounded_spsc_queue; -#[cfg(target_env = "sgx")] -extern crate core; -#[macro_use] -extern crate error_chain; -#[macro_use] -extern crate itertools; -#[macro_use] -extern crate lazy_static; -extern crate ndarray; -#[macro_use] -extern crate nom; -#[cfg(not(target_env = "sgx"))] -extern crate num_cpus; -extern crate serde; -#[macro_use] -extern crate serde_derive; -extern crate serde_json; - -pub mod ffi { - #![allow( - non_camel_case_types, - non_snake_case, - non_upper_case_globals, - unused - )] - - pub mod runtime { - use std::os::raw::{c_char, c_int, c_void}; - - include!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/runtime/c_runtime_api.rs" - )); - - pub type BackendPackedCFunc = - extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; - } -} - -pub mod errors; -pub mod runtime; - -pub use errors::*; diff --git a/rust/src/runtime/allocator.rs b/rust/src/runtime/allocator.rs deleted file mode 100644 index d704336bf..000000000 --- a/rust/src/runtime/allocator.rs +++ /dev/null @@ -1,52 +0,0 @@ -#[cfg(target_env = "sgx")] -use alloc::alloc::{self, Layout}; -#[cfg(not(target_env = "sgx"))] -use std::alloc::{self, Layout}; - -use errors::*; - -const DEFAULT_ALIGN_BYTES: usize = 4; - -#[derive(PartialEq, Eq)] -pub struct Allocation { - layout: Layout, - ptr: *mut u8, -} - -impl Allocation { - /// Allocates a chunk of memory of `size` bytes with optional alignment. - pub fn new(size: usize, align: Option) -> Result { - let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); - let layout = Layout::from_size_align(size, alignment)?; - let ptr = unsafe { alloc::alloc(layout.clone()) }; - if ptr.is_null() { - alloc::handle_alloc_error(layout); - } - Ok(Self { - ptr: ptr, - layout: layout, - }) - } - - pub fn as_mut_ptr(&self) -> *mut u8 { - self.ptr - } - - /// Returns the size of the Allocation in bytes. - pub fn size(&self) -> usize { - self.layout.size() - } - - /// Returns the byte alignment of the Allocation. - pub fn align(&self) -> usize { - self.layout.align() - } -} - -impl Drop for Allocation { - fn drop(&mut self) { - unsafe { - alloc::dealloc(self.ptr, self.layout.clone()); - } - } -} diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs deleted file mode 100644 index 100258d9a..000000000 --- a/rust/src/runtime/array.rs +++ /dev/null @@ -1,500 +0,0 @@ -use std::{ - any::TypeId, - convert::TryFrom, - mem, - os::raw::{c_int, c_void}, - ptr, slice, -}; - -use ndarray; - -use super::allocator::Allocation; -use errors::*; -use ffi::runtime::{ - DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, - DLDeviceType_kDLCPU, DLTensor, -}; - -/// A `Storage` is a container which holds `Tensor` data. -#[derive(PartialEq)] -pub enum Storage<'a> { - /// A `Storage` which owns its contained bytes. - Owned(Allocation), - - /// A view of an existing `Storage`. - View(&'a mut [u8], usize), // ptr, align -} - -impl<'a> Storage<'a> { - pub fn new(size: usize, align: Option) -> Result> { - Ok(Storage::Owned(Allocation::new(size, align)?)) - } - - pub fn as_mut_ptr(&self) -> *mut u8 { - match self { - Storage::Owned(alloc) => alloc.as_mut_ptr(), - Storage::View(slice, _) => slice.as_ptr() as *mut u8, - } - } - - pub fn size(&self) -> usize { - match self { - Storage::Owned(alloc) => alloc.size(), - Storage::View(slice, _) => slice.len(), - } - } - - pub fn align(&self) -> usize { - match self { - Storage::Owned(alloc) => alloc.align(), - Storage::View(_, align) => *align, - } - } - - pub fn as_ptr(&self) -> *const u8 { - self.as_mut_ptr() as *const _ - } - - /// Returns a `Storage::View` which points to an owned `Storage::Owned`. - pub fn view(&self) -> Storage<'a> { - match self { - Storage::Owned(alloc) => Storage::View( - unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) }, - self.align(), - ), - Storage::View(slice, _) => Storage::View( - unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) }, - self.align(), - ), - } - } - - pub fn is_owned(&self) -> bool { - match self { - Storage::Owned(_) => true, - _ => false, - } - } - - /// Returns an owned version of this storage via cloning. - pub fn to_owned(&self) -> Storage<'static> { - let s = Storage::new(self.size(), Some(self.align())).unwrap(); - unsafe { - s.as_mut_ptr() - .copy_from_nonoverlapping(self.as_ptr(), self.size()) - } - s - } -} - -impl<'a, T> From<&'a [T]> for Storage<'a> { - fn from(data: &'a [T]) -> Self { - let data = unsafe { - slice::from_raw_parts_mut( - data.as_ptr() as *const u8 as *mut u8, - data.len() * mem::size_of::() as usize, - ) - }; - Storage::View(data, mem::align_of::()) - } -} - -/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`. -/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or -/// converted to `ndarray::Array` for non-TVM processing. -/// -/// # Examples -/// -/// ``` -/// extern crate ndarray; -/// -/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); -/// let mut a: Tensor = a_nd.into(); -/// let mut a_dl: DLTensor = (&mut t).into(); -/// call_packed!(tvm_fn, &mut a_dl); -/// -/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. -/// let mut a_nd = ndarray::Array::try_from(&a).unwrap(); -/// ``` -#[derive(PartialEq)] -pub struct Tensor<'a> { - /// The bytes which contain the data this `Tensor` represents. - pub(super) data: Storage<'a>, - pub(super) ctx: TVMContext, - pub(super) dtype: DataType, - pub(super) shape: Vec, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h - /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. - pub(super) strides: Option>, - pub(super) byte_offset: isize, - /// The number of elements in the `Tensor`. - pub(super) size: usize, -} - -unsafe impl<'a> Send for Tensor<'a> {} - -impl<'a> Tensor<'a> { - pub fn shape(&self) -> Vec { - self.shape.clone() - } - - /// Returns the data of this `Tensor` as a `Vec`. - /// - /// # Panics - /// - /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. - pub fn to_vec(&self) -> Vec { - assert!(self.is_contiguous()); - assert!(self.dtype.is_type::()); - let mut vec: Vec = Vec::with_capacity(self.size * self.dtype.itemsize()); - unsafe { - vec.as_mut_ptr().copy_from_nonoverlapping( - self.data.as_ptr().offset(self.byte_offset) as *const T, - self.size, - ); - vec.set_len(self.size); - } - vec - } - - /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory. - pub fn is_contiguous(&self) -> bool { - match self.strides { - None => true, - Some(ref strides) => { - // check that stride for each dimension is the product of all trailing dimensons' shapes - self - .shape - .iter() - .zip(strides) - .rfold( - (true, 1), - |(is_contig, expected_stride), (shape, stride)| { - ( - is_contig && *stride == expected_stride, - expected_stride * (*shape as usize), - ) - }, - ) - .0 - } - } - } - - /// Returns a clone of this `Tensor`. - /// - /// # Panics - /// - /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. - pub fn copy(&mut self, other: &Tensor) { - assert!( - self.dtype == other.dtype && self.size == other.size, - "Tensor shape/dtype mismatch." - ); - assert!( - self.is_contiguous() && other.is_contiguous(), - "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`", - self.strides, - other.strides - ); - unsafe { - self - .data - .as_mut_ptr() - .offset(self.byte_offset as isize) - .copy_from_nonoverlapping( - other.data.as_mut_ptr().offset(other.byte_offset), - other.size * other.dtype.itemsize(), - ); - } - } - - /// Returns an owned version of this `Tensor` via cloning. - pub fn to_owned(&self) -> Tensor<'static> { - let t = Tensor { - data: self.data.to_owned(), - ctx: self.ctx.clone(), - dtype: self.dtype.clone(), - size: self.size.clone(), - shape: self.shape.clone(), - strides: None, - byte_offset: 0, - }; - unsafe { mem::transmute::, Tensor<'static>>(t) } - } - - fn from_array_storage<'s, T, D: ndarray::Dimension>( - arr: &ndarray::Array, - storage: Storage<'s>, - type_code: usize, - ) -> Tensor<'s> { - let type_width = mem::size_of::() as usize; - Tensor { - data: storage, - ctx: TVMContext::default(), - dtype: DataType { - code: type_code, - bits: 8 * type_width, - lanes: 1, - }, - size: arr.len(), - shape: arr.shape().iter().map(|&v| v as i64).collect(), - strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), - byte_offset: 0, - } - } -} - -/// Conversions to `ndarray::Array` from `Tensor`, if the types match. -macro_rules! impl_ndarray_try_from_tensor { - ($type:ty, $dtype:expr) => { - impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { - type Error = Error; - fn try_from(tensor: &'a Tensor) -> Result> { - ensure!( - tensor.dtype == $dtype, - "Cannot convert Tensor with dtype {:?} to ndarray", - tensor.dtype - ); - Ok(ndarray::Array::from_shape_vec( - tensor - .shape - .iter() - .map(|s| *s as usize) - .collect::>(), - tensor.to_vec::<$type>(), - )?) - } - } - }; -} - -impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); -impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); -impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); -impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); - -impl DLTensor { - pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { - assert!(!flatten || tensor.is_contiguous()); - Self { - data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void, - ctx: DLContext::from(&tensor.ctx), - ndim: if flatten { 1 } else { tensor.shape.len() } as i32, - dtype: DLDataType::from(&tensor.dtype), - shape: if flatten { - &tensor.size as *const _ as *mut i64 - } else { - tensor.shape.as_ptr() - } as *mut i64, - strides: if flatten || tensor.is_contiguous() { - ptr::null_mut() - } else { - tensor.strides.as_ref().unwrap().as_ptr() - } as *mut i64, - byte_offset: 0, - } - } -} - -impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { - fn from(tensor: &'a Tensor<'t>) -> Self { - DLTensor::from_tensor(tensor, false /* flatten */) - } -} - -impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { - fn from(tensor: &'a mut Tensor<'t>) -> Self { - DLTensor::from_tensor(tensor, false /* flatten */) - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct DataType { - pub(super) code: usize, - pub(super) bits: usize, - pub(super) lanes: usize, -} - -impl DataType { - /// Returns the number of bytes occupied by an element of this `DataType`. - pub fn itemsize(&self) -> usize { - (self.bits * self.lanes) >> 3 - } - - /// Returns whether this `DataType` represents primitive type `T`. - pub fn is_type(&self) -> bool { - if self.lanes != 1 { - return false; - } - let typ = TypeId::of::(); - (typ == TypeId::of::() && self.code == 0 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) - || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) - || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) - } -} - -impl<'a> From<&'a DataType> for DLDataType { - fn from(dtype: &'a DataType) -> Self { - Self { - code: dtype.code as u8, - bits: dtype.bits as u8, - lanes: dtype.lanes as u16, - } - } -} - -impl From for DataType { - fn from(dtype: DLDataType) -> Self { - Self { - code: dtype.code as usize, - bits: dtype.bits as usize, - lanes: dtype.lanes as usize, - } - } -} - -macro_rules! make_dtype_const { - ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { - const $name: DataType = DataType { - code: $code as usize, - bits: $bits, - lanes: $lanes, - }; - }; -} - -make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); -make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); -// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); -make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); -make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); - -impl Default for DLContext { - fn default() -> Self { - DLContext { - device_type: DLDeviceType_kDLCPU, - device_id: 0, - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TVMContext { - pub(super) device_type: usize, - pub(super) device_id: usize, -} - -impl<'a> From<&'a TVMContext> for DLContext { - fn from(ctx: &'a TVMContext) -> Self { - Self { - device_type: ctx.device_type as u32, - device_id: ctx.device_id as i32, - } - } -} - -impl Default for TVMContext { - fn default() -> Self { - Self { - device_type: DLDeviceType_kDLCPU as usize, - device_id: 0, - } - } -} - -impl<'a> From for Tensor<'a> { - fn from(dlt: DLTensor) -> Self { - unsafe { - let dtype = DataType::from(dlt.dtype); - let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec(); - let size = shape.iter().map(|v| *v as usize).product::() as usize; - let storage = Storage::from(slice::from_raw_parts( - dlt.data as *const u8, - dtype.itemsize() * size, - )); - Self { - data: storage, - ctx: TVMContext::default(), - dtype: dtype, - size: size, - shape: shape, - strides: if dlt.strides == ptr::null_mut() { - None - } else { - Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec()) - }, - byte_offset: dlt.byte_offset as isize, - } - } - } -} - -/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. -/// -/// # Panics -/// -/// Panics if the ndarray is not contiguous. -macro_rules! impl_tensor_from_ndarray { - ($type:ty, $typecode:expr) => { - impl From> for Tensor<'static> { - fn from(arr: ndarray::Array<$type, D>) -> Self { - assert!(arr.is_standard_layout(), "Array must be contiguous."); - let size = arr.len() * mem::size_of::<$type>() as usize; - let storage = - Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) }); - Tensor::from_array_storage(&arr, storage, $typecode as usize) - } - } - impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { - fn from(arr: &'a ndarray::Array<$type, D>) -> Self { - assert!(arr.is_standard_layout(), "Array must be contiguous."); - Tensor::from_array_storage( - arr, - Storage::from(arr.as_slice().unwrap()), - $typecode as usize, - ) - } - } - }; -} - -/// `From` conversions to `DLTensor` for `ndarray::Array`. -/// Takes a reference to the `ndarray` since `DLTensor` is not owned. -macro_rules! impl_dltensor_from_ndarray { - ($type:ty, $typecode:expr) => { - impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { - fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { - DLTensor { - data: arr.as_mut_ptr() as *mut c_void, - ctx: DLContext::default(), - ndim: arr.ndim() as c_int, - dtype: DLDataType { - code: $typecode as u8, - bits: 8 * mem::size_of::<$type>() as u8, - lanes: 1, - }, - shape: arr.shape().as_ptr() as *const i64 as *mut i64, - strides: arr.strides().as_ptr() as *const isize as *mut i64, - byte_offset: 0, - } - } - } - }; -} - -impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); -impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); - -impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); -impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); -impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); -impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); -impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); -impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/src/runtime/c_runtime_api.rs b/rust/src/runtime/c_runtime_api.rs deleted file mode 100644 index 6facf9ca2..000000000 --- a/rust/src/runtime/c_runtime_api.rs +++ /dev/null @@ -1,770 +0,0 @@ -/* automatically generated by rust-bindgen for TVM revision 6292c78 */ - -pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0"; -pub const DLPACK_VERSION: u32 = 8; -pub const _STDINT_H: u32 = 1; -pub const _FEATURES_H: u32 = 1; -pub const _DEFAULT_SOURCE: u32 = 1; -pub const __USE_ISOC11: u32 = 1; -pub const __USE_ISOC99: u32 = 1; -pub const __USE_ISOC95: u32 = 1; -pub const __USE_POSIX_IMPLICITLY: u32 = 1; -pub const _POSIX_SOURCE: u32 = 1; -pub const _POSIX_C_SOURCE: u32 = 200809; -pub const __USE_POSIX: u32 = 1; -pub const __USE_POSIX2: u32 = 1; -pub const __USE_POSIX199309: u32 = 1; -pub const __USE_POSIX199506: u32 = 1; -pub const __USE_XOPEN2K: u32 = 1; -pub const __USE_XOPEN2K8: u32 = 1; -pub const _ATFILE_SOURCE: u32 = 1; -pub const __USE_MISC: u32 = 1; -pub const __USE_ATFILE: u32 = 1; -pub const __USE_FORTIFY_LEVEL: u32 = 0; -pub const _STDC_PREDEF_H: u32 = 1; -pub const __STDC_IEC_559__: u32 = 1; -pub const __STDC_IEC_559_COMPLEX__: u32 = 1; -pub const __STDC_ISO_10646__: u32 = 201505; -pub const __STDC_NO_THREADS__: u32 = 1; -pub const __GNU_LIBRARY__: u32 = 6; -pub const __GLIBC__: u32 = 2; -pub const __GLIBC_MINOR__: u32 = 23; -pub const _SYS_CDEFS_H: u32 = 1; -pub const __WORDSIZE: u32 = 64; -pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; -pub const __SYSCALL_WORDSIZE: u32 = 64; -pub const _BITS_WCHAR_H: u32 = 1; -pub const INT8_MIN: i32 = -128; -pub const INT16_MIN: i32 = -32768; -pub const INT32_MIN: i32 = -2147483648; -pub const INT8_MAX: u32 = 127; -pub const INT16_MAX: u32 = 32767; -pub const INT32_MAX: u32 = 2147483647; -pub const UINT8_MAX: u32 = 255; -pub const UINT16_MAX: u32 = 65535; -pub const UINT32_MAX: u32 = 4294967295; -pub const INT_LEAST8_MIN: i32 = -128; -pub const INT_LEAST16_MIN: i32 = -32768; -pub const INT_LEAST32_MIN: i32 = -2147483648; -pub const INT_LEAST8_MAX: u32 = 127; -pub const INT_LEAST16_MAX: u32 = 32767; -pub const INT_LEAST32_MAX: u32 = 2147483647; -pub const UINT_LEAST8_MAX: u32 = 255; -pub const UINT_LEAST16_MAX: u32 = 65535; -pub const UINT_LEAST32_MAX: u32 = 4294967295; -pub const INT_FAST8_MIN: i32 = -128; -pub const INT_FAST16_MIN: i64 = -9223372036854775808; -pub const INT_FAST32_MIN: i64 = -9223372036854775808; -pub const INT_FAST8_MAX: u32 = 127; -pub const INT_FAST16_MAX: u64 = 9223372036854775807; -pub const INT_FAST32_MAX: u64 = 9223372036854775807; -pub const UINT_FAST8_MAX: u32 = 255; -pub const UINT_FAST16_MAX: i32 = -1; -pub const UINT_FAST32_MAX: i32 = -1; -pub const INTPTR_MIN: i64 = -9223372036854775808; -pub const INTPTR_MAX: u64 = 9223372036854775807; -pub const UINTPTR_MAX: i32 = -1; -pub const PTRDIFF_MIN: i64 = -9223372036854775808; -pub const PTRDIFF_MAX: u64 = 9223372036854775807; -pub const SIG_ATOMIC_MIN: i32 = -2147483648; -pub const SIG_ATOMIC_MAX: u32 = 2147483647; -pub const SIZE_MAX: i32 = -1; -pub const WINT_MIN: u32 = 0; -pub const WINT_MAX: u32 = 4294967295; -pub type int_least8_t = ::std::os::raw::c_schar; -pub type int_least16_t = ::std::os::raw::c_short; -pub type int_least32_t = ::std::os::raw::c_int; -pub type int_least64_t = ::std::os::raw::c_long; -pub type uint_least8_t = ::std::os::raw::c_uchar; -pub type uint_least16_t = ::std::os::raw::c_ushort; -pub type uint_least32_t = ::std::os::raw::c_uint; -pub type uint_least64_t = ::std::os::raw::c_ulong; -pub type int_fast8_t = ::std::os::raw::c_schar; -pub type int_fast16_t = ::std::os::raw::c_long; -pub type int_fast32_t = ::std::os::raw::c_long; -pub type int_fast64_t = ::std::os::raw::c_long; -pub type uint_fast8_t = ::std::os::raw::c_uchar; -pub type uint_fast16_t = ::std::os::raw::c_ulong; -pub type uint_fast32_t = ::std::os::raw::c_ulong; -pub type uint_fast64_t = ::std::os::raw::c_ulong; -pub type intmax_t = ::std::os::raw::c_long; -pub type uintmax_t = ::std::os::raw::c_ulong; -pub type wchar_t = ::std::os::raw::c_int; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct max_align_t { - pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, - pub __bindgen_padding_0: u64, - pub __clang_max_align_nonce2: f64, -} -pub const DLDeviceType_kDLCPU: DLDeviceType = 1; -pub const DLDeviceType_kDLGPU: DLDeviceType = 2; -pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3; -pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4; -pub const DLDeviceType_kDLMetal: DLDeviceType = 8; -pub const DLDeviceType_kDLVPI: DLDeviceType = 9; -pub const DLDeviceType_kDLROCM: DLDeviceType = 10; -/// \brief The device type in DLContext. -pub type DLDeviceType = u32; -/// \brief A Device context for Tensor and operator. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLContext { - /// \brief The device type used in the device. - pub device_type: DLDeviceType, - /// \brief The device index - pub device_id: ::std::os::raw::c_int, -} -pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0; -pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1; -pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2; -/// \brief The type code options DLDataType. -pub type DLDataTypeCode = u32; -/// \brief The data type the tensor can hold. -/// -/// Examples -/// - float: type_code = 2, bits = 32, lanes=1 -/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 -/// - int8: type_code = 0, bits = 8, lanes=1 -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLDataType { - /// \brief Type code of base types. - /// We keep it uint8_t instead of DLDataTypeCode for minimal memory - /// footprint, but the value should be one of DLDataTypeCode enum values. - /// - pub code: u8, - /// \brief Number of bits, common choices are 8, 16, 32. - pub bits: u8, - /// \brief Number of lanes in the type, used for vector types. - pub lanes: u16, -} -/// \brief Plain C Tensor object, does not manage memory. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLTensor { - /// \brief The opaque data pointer points to the allocated data. - /// This will be CUDA device pointer or cl_mem handle in OpenCL. - /// This pointer is always aligns to 256 bytes as in CUDA. - pub data: *mut ::std::os::raw::c_void, - /// \brief The device context of the tensor - pub ctx: DLContext, - /// \brief Number of dimensions - pub ndim: ::std::os::raw::c_int, - /// \brief The data type of the pointer - pub dtype: DLDataType, - /// \brief The shape of the tensor - pub shape: *mut i64, - /// \brief strides of the tensor, - /// can be NULL, indicating tensor is compact. - pub strides: *mut i64, - /// \brief The offset in bytes to the beginning pointer to data - pub byte_offset: u64, -} -/// \brief C Tensor object, manage memory of DLTensor. This data structure is -/// intended to faciliate the borrowing of DLTensor by another framework. It is -/// not meant to transfer the tensor. When the borrowing framework doesn't need -/// the tensor, it should call the deleter to notify the host that the resource -/// is no longer needed. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLManagedTensor { - /// \brief DLTensor which is being memory managed - pub dl_tensor: DLTensor, - /// \brief the context of the original host framework of DLManagedTensor in - /// which DLManagedTensor is used in the framework. It can also be NULL. - pub manager_ctx: *mut ::std::os::raw::c_void, - /// \brief Destructor signature void (*)(void*) - this should be called - /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL - /// if there is no way for the caller to provide a reasonable destructor. - pub deleter: ::std::option::Option, -} -/// \brief type of array index. -pub type tvm_index_t = i64; -pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5; -pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6; -pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7; -pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11; -pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12; -/// \brief Extension device types in TVM -pub type TVMDeviceExtType = u32; -pub const TVMTypeCode_kHandle: TVMTypeCode = 3; -pub const TVMTypeCode_kNull: TVMTypeCode = 4; -pub const TVMTypeCode_kTVMType: TVMTypeCode = 5; -pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6; -pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7; -pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8; -pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9; -pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10; -pub const TVMTypeCode_kStr: TVMTypeCode = 11; -pub const TVMTypeCode_kBytes: TVMTypeCode = 12; -pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13; -pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15; -pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16; -pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20; -pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64; -pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128; -/// \brief The type code in TVMType -/// \note TVMType is used in two places. -pub type TVMTypeCode = u32; -/// \brief The data type used in TVM Runtime. -/// -/// Examples -/// - float: type_code = 2, bits = 32, lanes=1 -/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 -/// - int8: type_code = 0, bits = 8, lanes=1 -/// -/// \note Arguments TVM API function always takes bits=64 and lanes=1 -pub type TVMType = DLDataType; -/// \brief The Device information, abstract away common device types. -pub type TVMContext = DLContext; -/// \brief The tensor array stucture to TVM API. -pub type TVMArray = DLTensor; -/// \brief the array handle -pub type TVMArrayHandle = *mut TVMArray; -/// \brief Union type of values -/// being passed through API and function calls. -#[repr(C)] -#[derive(Copy, Clone)] -pub union TVMValue { - pub v_int64: i64, - pub v_float64: f64, - pub v_handle: *mut ::std::os::raw::c_void, - pub v_str: *const ::std::os::raw::c_char, - pub v_type: TVMType, - pub v_ctx: TVMContext, - _bindgen_union_align: u64, -} -/// \brief Byte array type used to pass in byte array -/// When kBytes is used as data type. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct TVMByteArray { - pub data: *const ::std::os::raw::c_char, - pub size: usize, -} -/// \brief Handle to TVM runtime modules. -pub type TVMModuleHandle = *mut ::std::os::raw::c_void; -/// \brief Handle to packed function handle. -pub type TVMFunctionHandle = *mut ::std::os::raw::c_void; -/// \brief Handle to hold return value. -pub type TVMRetValueHandle = *mut ::std::os::raw::c_void; -/// \brief The stream that is specific to device -/// can be NULL, which indicates the default one. -pub type TVMStreamHandle = *mut ::std::os::raw::c_void; -extern "C" { - /// \brief Used for implementing C API function. - /// Set last error message before return. - /// \param msg The error message to be set. - pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char); -} -extern "C" { - /// \brief return str message of the last error - /// all function in this file will return 0 when success - /// and -1 when an error occured, - /// TVMGetLastError can be called to retrieve the error - /// - /// this function is threadsafe and can be called by different thread - /// \return error info - pub fn TVMGetLastError() -> *const ::std::os::raw::c_char; -} -extern "C" { - /// \brief Load module from file. - /// \param file_name The file name to load the module from. - /// \param format The format of the module. - /// \param out The result module - /// - /// \return 0 when success, -1 when failure happens - /// \note The resulting module do not contain import relation. - /// It can be reconstructed by TVMModImport. - pub fn TVMModLoadFromFile( - file_name: *const ::std::os::raw::c_char, - format: *const ::std::os::raw::c_char, - out: *mut TVMModuleHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Add dep to mod's dependency. - /// This allows functions in this module to use modules. - /// - /// \param mod The module handle. - /// \param dep The dependent module to be imported. - /// \return 0 when success, -1 when failure happens - pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Get function from the module. - /// \param mod The module handle. - /// \param func_name The name of the function. - /// \param query_imports Whether to query imported modules - /// \param out The result function, can be NULL if it is not available. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMModGetFunction( - mod_: TVMModuleHandle, - func_name: *const ::std::os::raw::c_char, - query_imports: ::std::os::raw::c_int, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free front-end extension type resource. - /// \param handle The extension handle. - /// \param type_code The type of of the extension type. - /// \return 0 when success, -1 when failure happens - pub fn TVMExtTypeFree( - handle: *mut ::std::os::raw::c_void, - type_code: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the Module - /// \param mod The module to be freed. - /// - /// \note This may not free up the module's resources. - /// If there is active TVMFunctionHandle uses the module - /// Or if this module is imported by another active module. - /// - /// The all functions remains valid until TVMFuncFree is called. - /// \return 0 when success, -1 when failure happens - pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the function when it is no longer needed. - /// \param func The function handle - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Call a Packed TVM Function. - /// - /// \param func node handle of the function. - /// \param arg_values The arguments - /// \param type_codes The type codes of the arguments - /// \param num_args Number of arguments. - /// - /// \param ret_val The return value. - /// \param ret_type_code the type code of return value. - /// - /// \return 0 when success, -1 when failure happens - /// \note TVM calls always exchanges with type bits=64, lanes=1 - /// - /// \note API calls always exchanges with type bits=64, lanes=1 - /// If API call returns container handles (e.g. FunctionHandle) - /// these handles should be managed by the front-end. - /// The front-end need to call free function (e.g. TVMFuncFree) - /// to free these handles. - pub fn TVMFuncCall( - func: TVMFunctionHandle, - arg_values: *mut TVMValue, - type_codes: *mut ::std::os::raw::c_int, - num_args: ::std::os::raw::c_int, - ret_val: *mut TVMValue, - ret_type_code: *mut ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Set the return value of TVMPackedCFunc. - /// - /// This function is called by TVMPackedCFunc to set the return value. - /// When this function is not called, the function returns null by default. - /// - /// \param ret The return value handle, pass by ret in TVMPackedCFunc - /// \param value The value to be returned. - /// \param type_code The type of the value to be returned. - /// \param num_ret Number of return values, for now only 1 is supported. - pub fn TVMCFuncSetReturn( - ret: TVMRetValueHandle, - value: *mut TVMValue, - type_code: *mut ::std::os::raw::c_int, - num_ret: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Inplace translate callback argument value to return value. - /// This is only needed for non-POD arguments. - /// - /// \param value The value to be translated. - /// \param code The type code to be translated. - /// \note This function will do a shallow copy when necessary. - /// - /// \return 0 when success, -1 when failure happens. - pub fn TVMCbArgToReturn( - value: *mut TVMValue, - code: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -/// \brief C type of packed function. -/// -/// \param args The arguments -/// \param type_codes The type codes of the arguments -/// \param num_args Number of arguments. -/// \param ret The return value handle. -/// \param resource_handle The handle additional resouce handle from fron-end. -/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. -/// \sa TVMCFuncSetReturn -pub type TVMPackedCFunc = ::std::option::Option< - unsafe extern "C" fn( - args: *mut TVMValue, - type_codes: *mut ::std::os::raw::c_int, - num_args: ::std::os::raw::c_int, - ret: TVMRetValueHandle, - resource_handle: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int, ->; -/// \brief C callback to free the resource handle in C packed function. -/// \param resource_handle The handle additional resouce handle from fron-end. -pub type TVMPackedCFuncFinalizer = - ::std::option::Option; -/// \brief Signature for extension function declarer. -/// -/// TVM call this function to get the extension functions -/// The declarer will call register_func to register function and their name. -/// -/// \param register_func_handle The register function -/// \return 0 if success, -1 if failure happens -pub type TVMExtensionFuncDeclarer = ::std::option::Option< - unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int, ->; -extern "C" { - /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. - /// - /// The resource_handle will be managed by TVM API, until the function is no longer used. - /// - /// \param func The packed C function. - /// \param resource_handle The resource handle from front-end, can be NULL. - /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL - /// \param out the result function handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncCreateFromCFunc( - func: TVMPackedCFunc, - resource_handle: *mut ::std::os::raw::c_void, - fin: TVMPackedCFuncFinalizer, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Register the function to runtime's global table. - /// - /// The registered function then can be pulled by the backend by the name. - /// - /// \param name The name of the function. - /// \param f The function to be registered. - /// \param override Whether allow override already registered function. - pub fn TVMFuncRegisterGlobal( - name: *const ::std::os::raw::c_char, - f: TVMFunctionHandle, - override_: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Get a global function. - /// - /// \param name The name of the function. - /// \param out the result function pointer, NULL if it does not exist. - /// - /// \note The function handle of global function is managed by TVM runtime, - /// So TVMFuncFree is should not be called when it get deleted. - pub fn TVMFuncGetGlobal( - name: *const ::std::os::raw::c_char, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief List all the globally registered function name - /// \param out_size The number of functions - /// \param out_array The array of function names. - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncListGlobalNames( - out_size: *mut ::std::os::raw::c_int, - out_array: *mut *mut *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Allocate a nd-array's memory, - /// including space of shape, of given spec. - /// - /// \param shape The shape of the array, the data content will be copied to out - /// \param ndim The number of dimension of the array. - /// \param dtype_code The type code of the dtype - /// \param dtype_bits The number of bits of dtype - /// \param dtype_lanes The number of lanes in the dtype. - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param out The output handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayAlloc( - shape: *const tvm_index_t, - ndim: ::std::os::raw::c_int, - dtype_code: ::std::os::raw::c_int, - dtype_bits: ::std::os::raw::c_int, - dtype_lanes: ::std::os::raw::c_int, - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - out: *mut TVMArrayHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the TVM Array. - /// \param handle The array handle to be freed. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy array data from CPU byte array. - /// \param handle The array handle. - /// \param data the data pointer - /// \param nbytes The number of bytes to copy. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyFromBytes( - handle: TVMArrayHandle, - data: *mut ::std::os::raw::c_void, - nbytes: usize, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy array data to CPU byte array. - /// \param handle The array handle. - /// \param data the data pointer - /// \param nbytes The number of bytes to copy. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyToBytes( - handle: TVMArrayHandle, - data: *mut ::std::os::raw::c_void, - nbytes: usize, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy the array, both from and to must be valid during the copy. - /// \param from The array to be copied from. - /// \param to The target space. - /// \param stream The stream where the copy happens, can be NULL. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyFromTo( - from: TVMArrayHandle, - to: TVMArrayHandle, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Produce an array from the DLManagedTensor that shares data memory - /// with the DLManagedTensor. - /// \param from The source DLManagedTensor. - /// \param out The output array handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayFromDLPack( - from: *mut DLManagedTensor, - out: *mut TVMArrayHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Produce a DLMangedTensor from the array that shares data memory with - /// the array. - /// \param from The source array. - /// \param out The DLManagedTensor handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayToDLPack( - from: TVMArrayHandle, - out: *mut *mut DLManagedTensor, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Delete (free) a DLManagedTensor's data. - /// \param dltensor Pointer to the DLManagedTensor. - pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor); -} -extern "C" { - /// \brief Create a new runtime stream. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param out The new stream handle - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamCreate( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - out: *mut TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free a created stream handle. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param stream The stream to be freed - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamFree( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Set the runtime stream of current thread to be stream. - /// The subsequent calls to the same device_type - /// will use the setted stream handle. - /// The specific type of stream is runtime device dependent. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param handle The stream handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMSetStream( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - handle: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Wait until all computations on stream completes. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param stream The stream to be synchronized. - /// \return 0 when success, -1 when failure happens - pub fn TVMSynchronize( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Synchronize two streams of execution. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param src The source stream to synchronize. - /// \param dst The destination stream to synchronize. - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamStreamSynchronize( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - src: TVMStreamHandle, - dst: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function for modules to get function - /// from its environment mod_node (its imports and global function). - /// The user do should not call TVMFuncFree on func. - /// - /// \param mod_node The module handle. - /// \param func_name The name of the function. - /// \param out The result function. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendGetFuncFromEnv( - mod_node: *mut ::std::os::raw::c_void, - func_name: *const ::std::os::raw::c_char, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function to register system-wide library symbol. - /// - /// \param name The name of the symbol - /// \param ptr The symbol address. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendRegisterSystemLibSymbol( - name: *const ::std::os::raw::c_char, - ptr: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function to allocate temporal workspace. - /// - /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. - /// - /// \param nbytes The size of the space requested. - /// \param device_type The device type which the space will be allocated. - /// \param device_id The device id which the space will be allocated. - /// \param dtype_code_hint The type code of the array elements. Only used in - /// certain backends such as OpenGL. - /// \param dtype_bits_hint The type bits of the array elements. Only used in - /// certain backends such as OpenGL. - /// \return nullptr when error is thrown, a valid ptr if success - pub fn TVMBackendAllocWorkspace( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - nbytes: u64, - dtype_code_hint: ::std::os::raw::c_int, - dtype_bits_hint: ::std::os::raw::c_int, - ) -> *mut ::std::os::raw::c_void; -} -extern "C" { - /// \brief Backend function to free temporal workspace. - /// - /// \param ptr The result allocated space pointer. - /// \param device_type The device type which the space will be allocated. - /// \param device_id The device id which the space will be allocated. - /// \return 0 when no error is thrown, -1 when failure happens - /// - /// \sa TVMBackendAllocWorkspace - pub fn TVMBackendFreeWorkspace( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - ptr: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int; -} -/// \brief Environment for TVM parallel task. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct TVMParallelGroupEnv { - /// \brief Auxiliary used for synchronization - pub sync_handle: *mut ::std::os::raw::c_void, - /// \brief total amount of task - pub num_task: i32, -} -/// \brief The callback function to execute a parallel lambda -/// \param task_id the task id of the function. -/// \param penv The parallel environment backs the execution. -/// \param cdata The supporting closure data. -pub type FTVMParallelLambda = ::std::option::Option< - unsafe extern "C" fn( - task_id: ::std::os::raw::c_int, - penv: *mut TVMParallelGroupEnv, - cdata: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int, ->; -extern "C" { - /// \brief Backend function for running parallel jobs. - /// - /// \param flambda The parallel function to be launched. - /// \param cdata The closure data. - /// \param num_task Number of tasks to launch, can be 0, means launch - /// with all available threads. - /// - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendParallelLaunch( - flambda: FTVMParallelLambda, - cdata: *mut ::std::os::raw::c_void, - num_task: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief BSP barrrier between parallel threads - /// \param task_id the task id of the function. - /// \param penv The parallel environment backs the execution. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendParallelBarrier( - task_id: ::std::os::raw::c_int, - penv: *mut TVMParallelGroupEnv, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Simple static initialization function. - /// Run f once and set handle to be not null. - /// This function is mainly used for test purpose. - /// - /// \param handle An global address to indicate f - /// \param f The function to be ran - /// \param cdata The closure data to pass to the function. - /// \param nbytes Number of bytes in the closure data. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendRunOnce( - handle: *mut *mut ::std::os::raw::c_void, - f: ::std::option::Option< - unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int, - >, - cdata: *mut ::std::os::raw::c_void, - nbytes: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs deleted file mode 100644 index 08fbd5938..000000000 --- a/rust/src/runtime/graph.rs +++ /dev/null @@ -1,472 +0,0 @@ -use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; - -use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr}; -use serde; -use serde_json; - -use super::{DataType, Module, Storage, TVMArgValue, TVMContext, Tensor}; -use errors::{Error, ErrorKind, Result}; -use ffi::runtime::{ - DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor, -}; - -// Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h` -const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; -// Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h` -const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; - -/// A TVM computation graph. -/// -/// # Examples -/// -/// ``` -/// let graph_json = fs::read_to_string("graph.json")).unwrap(); -/// let graph = Graph::try_from(&graph_json).unwrap(); -/// ``` -#[derive(Serialize, Deserialize, Debug)] -pub struct Graph { - pub nodes: Vec, - pub arg_nodes: Vec, - pub heads: Vec, - pub node_row_ptr: Option>, - pub attrs: Option>, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Entry { - pub id: usize, - pub index: usize, - pub version: usize, -} - -impl Graph { - fn entry_index(&self, entry: &Entry) -> Result { - self - .node_row_ptr - .as_ref() - .map(|nrp| nrp[entry.id] + entry.index) - .ok_or("Missing node_row_ptr.".into()) - } - - /// Attempt to deserialize a JSON attribute to a type `T`. - fn get_attr(&self, attr: &str) -> Result { - Ok(serde_json::from_value::( - self - .attrs - .as_ref() - .ok_or(ErrorKind::GraphFormatError( - "Missing graph attrs".to_string(), - ))? - .get(attr) - .ok_or(ErrorKind::GraphFormatError(format!( - "Missing {} attr", - attr - )))? - .to_owned(), - )?) - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Node { - pub op: String, - pub name: String, - pub inputs: Vec, - pub attrs: Option>, - pub control_deps: Option>, -} - -struct NodeAttrs { - func_name: String, - num_outputs: usize, - flatten_data: bool, -} - -impl Node { - fn parse_attrs(&self) -> Result { - let attrs = self - .attrs - .as_ref() - .ok_or(format!("Missing node.attrs for `{}`", self.name))?; - let func_name = attrs - .get("func_name") - .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))? - .to_string(); - let num_outputs = attrs - .get("num_outputs") - .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))? - .parse::()?; - let flatten_data = attrs - .get("flatten_data") - .ok_or(format!( - "Node `{}` is missing attrs.flatten_data", - self.name - ))? - .parse::()? - == 1; - Ok(NodeAttrs { - func_name, - num_outputs, - flatten_data, - }) - } -} - -impl<'a> TryFrom<&'a String> for Graph { - type Error = Error; - fn try_from(graph_json: &String) -> Result { - let graph = serde_json::from_str(graph_json)?; - Ok(graph) - } -} - -impl<'a> TryFrom<&'a str> for Graph { - type Error = Error; - fn try_from(graph_json: &'a str) -> Result { - let graph = serde_json::from_str(graph_json)?; - Ok(graph) - } -} - -/// A executor for a TVM computation graph. -/// -/// # Examples -/// -/// ``` -/// use ndarray::Array; -/// -/// let syslib = SystemLibModule::default(); // a provider of TVM functions -/// -/// let mut params_bytes = Vec::new(); -/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap(); -/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap(); -/// -/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap(); -/// -/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); -/// exec.load_params(params); -/// -/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]); -/// exec.set_input("data", x.into()); -/// exec.run(); -/// let output = exec.get_output(0).unwrap(); -/// -/// println!("{:#?}", Array::try_from(output).unwrap()); -/// ``` -pub struct GraphExecutor<'m, 't> { - graph: Graph, - op_execs: Vec>, - tensors: Vec>, -} - -unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} - -impl<'m, 't> GraphExecutor<'m, 't> { - pub fn new(graph: Graph, lib: &'m M) -> Result { - let tensors = Self::setup_storages(&graph)?; - Ok(GraphExecutor { - op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, - tensors: tensors, - graph: graph, - }) - } - - /// Runs the computation graph. - pub fn run(&self) { - self.op_execs.iter().for_each(|op_exec| { - op_exec(); - }); - } - - /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. - fn setup_storages<'a>(graph: &'a Graph) -> Result>> { - let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; - let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; - let dtypes = graph - .get_attr::<(String, Vec)>("dltype")? - .1 - .iter() - .map(|dltype| { - if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) { - Ok(dtype) - } else { - Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into()) - } - }) - .collect::>>()?; - - let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); - let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; - for (i, &storage_id) in storage_ids.iter().enumerate() { - let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; - let nbytes = dtype_size * shapes[i].iter().product::() as usize; - storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); - } - - let mut storages: Vec = storage_num_bytes - .into_iter() - .map(|nbytes| Storage::new(nbytes, align)) - .collect::>>()?; - - let tensors = izip!(storage_ids, shapes, dtypes) - .map(|(storage_id, shape, dtype)| { - let storage = storages[storage_id].view(); - Tensor { - data: mem::replace(&mut storages[storage_id], storage), - ctx: TVMContext::default(), - dtype: dtype, - size: shape.iter().product::() as usize, - shape: shape, - strides: None, - byte_offset: 0, - } - }) - .collect(); - - Ok(tensors) - } - - /// Creates closures which represent the computation performed by this graph. - fn setup_op_execs( - graph: &Graph, - lib: &'m M, - tensors: &Vec>, - ) -> Result>> { - ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); - let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); - - let mut op_execs = Vec::new(); - for (i, node) in graph.nodes.iter().enumerate() { - if node.op == "null" { - continue; - } - ensure!(node.op == "tvm_op", "Only TVM ops are supported."); - ensure!(node.attrs.is_some(), "Missing node attrs."); - - let attrs = node.parse_attrs()?; - - if attrs.func_name == "__nop" { - continue; - } - - let func = lib - .get_function(&attrs.func_name) - .ok_or(format!("Missing function {}", attrs.func_name))?; - let arg_indices = node - .inputs - .iter() - .map(|entry| graph.entry_index(entry)) - .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi))); - - let dl_tensors = arg_indices - .map(|idx| { - let tensor = &tensors[idx?]; - Ok(if attrs.flatten_data { - DLTensor::from_tensor(tensor, true /* flatten */) - } else { - DLTensor::from(tensor) - }) - }) - .collect::>>() - .unwrap(); - let op: Box = box move || { - let args = dl_tensors - .iter() - .map(|t| t.into()) - .collect::>(); - func(args.as_slice()); - }; - op_execs.push(op); - } - Ok(op_execs) - } - - pub fn load_params(&mut self, params: HashMap>) { - params.into_iter().for_each(|(name, param)| { - self.set_input(name, param); - }) - } - - pub fn set_input>(&mut self, name: S, value: Tensor<'t>) { - if let Some(idx) = self.get_input_index(name.as_ref()) { - // TODO: consider `new_with_params` to avoid ever allocating - let ptr = self.tensors[idx].data.as_ptr(); - let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr); - let mut owner = to_replace.nth(0).unwrap(); - if value.data.is_owned() { - // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr - // mem::replace(&mut (*owner), value); - // to_replace.for_each(|t| { - // panic!("replacing"); - // t.data = owner.data.view(); - // }); - owner.copy(&value); - } else { - owner.copy(&value); - } - } else { - println!("Unexpected input `{}`", name.as_ref()); - } - } - - /// Returns the graph input with name `name`, if it exists. - pub fn get_input>(&mut self, name: S) -> Option<&Tensor> { - self - .get_input_index(name.as_ref()) - .and_then(move |idx| Some(&self.tensors[idx])) - } - - /// Returns the graph output with index `index`, if it exists. - pub fn get_output(&self, idx: usize) -> Option<&Tensor> { - let graph = &self.graph; - graph.heads.get(idx).and_then(|entry| { - graph - .entry_index(entry) - .map(|idx| self.tensors.get(idx)) - .unwrap_or(None) - }) - } - - /// Returns the index for graph input with name `name`, if it exists. - pub fn get_input_index>(&self, name: S) -> Option { - let graph = &self.graph; - (0..graph.nodes.len()) - .skip_while(|&i| graph.nodes[i].name != name.as_ref()) - .nth(0) - .and_then(|i| { - if graph.arg_nodes.iter().any(|&id| id == i) { - graph.node_row_ptr.as_ref().map(|nrp| nrp[i]) - } else { - None - } - }) - } -} - -/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h -named!( - tvm_str_to_type, - do_parse!( - type_name: alpha1 >> - bits: digit1 >> - lanes: opt!(tuple!(tag!("x"), digit1)) >> - (DataType { - code: match type_name { - CompleteStr("int") => DLDataTypeCode_kDLInt, - CompleteStr("uint") => DLDataTypeCode_kDLUInt, - CompleteStr("float") => DLDataTypeCode_kDLFloat, - _ => DLDataTypeCode_kDLFloat, - } as usize, - bits: bits.parse::().unwrap() as usize, - lanes: match lanes { - Some(lanes) => lanes.1.parse::().unwrap() as usize, - None => 1, - }, - }) - ) -); - -/// Converts a bytes to String. -named!( - name, - map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8( - b.to_vec() - )) -); - -/// Parses a TVMContext -named!( - tvm_ctx<&[u8], TVMContext>, - do_parse!( - device_type: le_u32 >> - device_id: le_i32 >> - (TVMContext { device_type: device_type as usize, device_id: device_id as usize }) - ) -); - -/// Parses a DataType -named!( - data_type<&[u8], DataType>, - do_parse!( - code: le_u8 >> - bits: le_u8 >> - lanes: le_u16 >> - (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize }) - ) -); - -/// Parses a Tensor from a TVM array file. -named!( - tensor, - do_parse!( - take!(8) - >> bits!(tag_bits!(u64, 64, 0)) - >> ctx: tvm_ctx - >> ndim: le_u32 - >> dtype: data_type - >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize) - >> length: le_i64 - >> data: take!(length) - >> (Tensor { - data: Storage::from(data), - ctx: ctx, - dtype: dtype, - size: shape.iter().product::() as usize, - shape: shape, - strides: None, - byte_offset: 0, - }) - ) -); - -/// Parses a graph params dict from a params binary file. -named!( - parse_param_dict>, - do_parse!( - take!(8) - >> bits!(tag_bits!(u64, 64, 0)) - >> names: length_count!(le_u64, name) - >> tensors: length_count!(le_u64, tensor) - >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter()))) - ) -); - -/// Loads a param dict saved using `nnvm.compiler.save_param_dict`. -pub fn load_param_dict(bytes: &[u8]) -> Result> { - if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { - if remaining_bytes.len() > 0 { - bail!(ErrorKind::LoadGraphParamsError("extra input".to_string())) - } else { - Ok(param_dict) - } - } else { - bail!(ErrorKind::LoadGraphParamsError( - "invalid parameters file".to_string() - )) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_str_to_type() { - assert_eq!( - tvm_str_to_type(CompleteStr("float24")).unwrap().1, - DataType { - code: DLDataTypeCode_kDLFloat as usize, - bits: 24, - lanes: 1 - } - ); - assert_eq!( - tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1, - DataType { - code: DLDataTypeCode_kDLUInt as usize, - bits: 111, - lanes: 44 - } - ); - } -} diff --git a/rust/src/runtime/mod.rs b/rust/src/runtime/mod.rs deleted file mode 100644 index 1a9c5ba7c..000000000 --- a/rust/src/runtime/mod.rs +++ /dev/null @@ -1,28 +0,0 @@ -mod allocator; -mod array; -mod module; -#[macro_use] -mod packed_func; -mod graph; -#[cfg(target_env = "sgx")] -#[macro_use] -pub mod sgx; -mod threading; -mod workspace; - -use std::os::raw::c_char; - -pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*}; - -#[cfg(target_env = "sgx")] -use self::sgx::ocall_packed_func; - -#[no_mangle] -pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) { - #[cfg(not(target_env = "sgx"))] - unsafe { - panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap()); - } - #[cfg(target_env = "sgx")] - ocall_packed!("__sgx_set_last_error__", cmsg); -} diff --git a/rust/src/runtime/module.rs b/rust/src/runtime/module.rs deleted file mode 100644 index 2594756d9..000000000 --- a/rust/src/runtime/module.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::{ - collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, -}; - -use ffi::runtime::BackendPackedCFunc; -use runtime::packed_func::{wrap_backend_packed_func, PackedFunc}; - -pub trait Module { - fn get_function>(&self, name: S) -> Option; -} - -pub struct SystemLibModule; - -lazy_static! { - static ref SYSTEM_LIB_FUNCTIONS: Mutex> = - Mutex::new(HashMap::new()); -} - -impl Module for SystemLibModule { - fn get_function>(&self, name: S) -> Option { - SYSTEM_LIB_FUNCTIONS - .lock() - .unwrap() - .get(name.as_ref()) - .map(|func| wrap_backend_packed_func(func.to_owned())) - } -} - -impl Default for SystemLibModule { - fn default() -> Self { - SystemLibModule {} - } -} - -#[no_mangle] -pub extern "C" fn TVMBackendRegisterSystemLibSymbol( - cname: *const c_char, - func: BackendPackedCFunc, -) -> i32 { - let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; - SYSTEM_LIB_FUNCTIONS - .lock() - .unwrap() - .insert(name.to_string(), func); - return 0; -} diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs deleted file mode 100644 index a6ad7fc35..000000000 --- a/rust/src/runtime/packed_func.rs +++ /dev/null @@ -1,342 +0,0 @@ -use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; - -use super::Tensor; -use ffi::runtime::{ - BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor, - TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMTypeCode_kNDArrayContainer, TVMValue, -}; - -use errors::*; - -pub type PackedFunc = Box TVMRetValue + Send + Sync>; - -/// Calls a packed function and returns a `TVMRetValue`. -/// -/// # Example -/// -/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` -#[macro_export] -macro_rules! call_packed { - ($fn:expr, $($args:expr),+) => { - $fn(&[$($args.into(),)+]) - }; - ($fn:expr) => { - $fn(&Vec::new()) - }; -} - -/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way -/// to obtain a `TVMArgValue` is automatically via `call_packed!`. -#[derive(Clone, Copy)] -pub struct TVMArgValue<'a> { - _lifetime: PhantomData<&'a ()>, - pub(crate) value: TVMValue, - pub(crate) type_code: i64, -} - -impl<'a> TVMArgValue<'a> { - pub fn new(value: TVMValue, type_code: i64) -> Self { - TVMArgValue { - _lifetime: PhantomData, - value: value, - type_code: type_code, - } - } -} - -/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. -macro_rules! impl_prim_tvm_arg { - ($type:ty, $field:ident, $code:expr, $as:ty) => { - impl<'a> From<$type> for TVMArgValue<'a> { - fn from(val: $type) -> Self { - TVMArgValue { - value: TVMValue { $field: val as $as }, - type_code: $code as i64, - _lifetime: PhantomData, - } - } - } - impl<'a> TryFrom> for $type { - type Error = Error; - fn try_from(val: TVMArgValue<'a>) -> Result { - ensure!( - val.type_code == $code as i64, - "Could not downcast arg. Expected `{}`, got `{}`", - $code, - val.type_code - ); - Ok(unsafe { val.value.$field as $type }) - } - } - }; - ($type:ty, $field:ident, $code:expr) => { - impl_prim_tvm_arg!($type, $field, $code, $type); - }; - ($type:ty,v_int64) => { - impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64); - }; - ($type:ty,v_float64) => { - impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64); - }; -} - -impl_prim_tvm_arg!(f32, v_float64); -impl_prim_tvm_arg!(f64, v_float64); -impl_prim_tvm_arg!(i8, v_int64); -impl_prim_tvm_arg!(u8, v_int64); -impl_prim_tvm_arg!(i32, v_int64); -impl_prim_tvm_arg!(u32, v_int64); -impl_prim_tvm_arg!(i64, v_int64); -impl_prim_tvm_arg!(u64, v_int64); - -/// Creates a conversion to a `TVMArgValue` for an object handle. -impl<'a, T> From<*const T> for TVMArgValue<'a> { - fn from(ptr: *const T) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: ptr as *mut T as *mut c_void, - }, - type_code: TVMTypeCode_kArrayHandle as i64, - _lifetime: PhantomData, - } - } -} - -/// Creates a conversion to a `TVMArgValue` for a mutable object handle. -impl<'a, T> From<*mut T> for TVMArgValue<'a> { - fn from(ptr: *mut T) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: ptr as *mut c_void, - }, - type_code: TVMTypeCode_kHandle as i64, - _lifetime: PhantomData, - } - } -} - -impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { - fn from(arr: &'a mut DLTensor) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: arr as *mut _ as *mut c_void, - }, - type_code: TVMTypeCode_kArrayHandle as i64, - _lifetime: PhantomData, - } - } -} - -impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { - fn from(arr: &'a DLTensor) -> Self { - TVMArgValue { - value: TVMValue { - v_handle: arr as *const _ as *mut DLTensor as *mut c_void, - }, - type_code: TVMTypeCode_kArrayHandle as i64, - _lifetime: PhantomData, - } - } -} - -impl<'a> TryFrom> for Tensor<'a> { - type Error = Error; - fn try_from(val: TVMArgValue<'a>) -> Result { - ensure!( - val.type_code == TVMTypeCode_kArrayHandle as i64 - || val.type_code == TVMTypeCode_kNDArrayContainer as i64, - "Could not downcast arg. Expected `{}` or `{}`, but got `{}`", - TVMTypeCode_kArrayHandle, - TVMTypeCode_kNDArrayContainer, - val.type_code, - ); - - let dlt = unsafe { *(val.value.v_handle as *mut DLTensor as *const DLTensor) }; - Ok(dlt.into()) - } -} - -/// An owned TVMPODValue. Can be converted from a variety of primitive and object types. -/// Can be downcasted using `try_from` if it contains the desired type. -/// -/// # Example -/// -/// ``` -/// let a = 42u32; -/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); -/// -/// let s = "hello, world!"; -/// let t: TVMRetValue = s.into(); -/// assert_eq!(String::try_from(t).unwrap(), s); -/// ``` -pub struct TVMRetValue { - /// A primitive return value, if any. - prim_value: u64, - /// An object return value, if any. - box_value: Box, - /// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use. - type_code: i64, -} - -#[cfg(target_env = "sgx")] -impl TVMRetValue { - pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { - unsafe { - Self { - prim_value: match type_code { - 0 | 1 => value.v_int64 as u64, - 2 => value.v_float64 as u64, - 3 | 7 | 8 | 9 | 10 => value.v_handle as u64, - 11 | 12 => value.v_str as u64, - _ => 0, - } as u64, - box_value: box (), - type_code: type_code, - } - } - } - - pub fn into_tvm_value(self) -> (TVMValue, i64) { - let val = match self.type_code { - 0 | 1 => TVMValue { - v_int64: self.prim_value.clone() as i64, - }, - 2 => TVMValue { - v_float64: self.prim_value.clone() as f64, - }, - 3 | 7 | 8 | 9 | 10 | 13 => TVMValue { - v_handle: Box::into_raw(self.box_value) as *mut c_void, - }, - 11 | 12 => TVMValue { - v_str: Box::into_raw(self.box_value) as *const _, - }, - _ => unreachable!(), - }; - (val, self.type_code) - } -} - -impl Default for TVMRetValue { - fn default() -> Self { - TVMRetValue { - prim_value: 0, - box_value: box (), - type_code: 0, - } - } -} - -macro_rules! impl_prim_ret_value { - ($type:ty, $code:expr) => { - impl From<$type> for TVMRetValue { - fn from(val: $type) -> Self { - TVMRetValue { - prim_value: val as u64, - box_value: box (), - type_code: $code, - } - } - } - impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { - if ret.type_code == $code { - Ok(ret.prim_value as $type) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code - )) - } - } - } - }; -} - -macro_rules! impl_boxed_ret_value { - ($type:ty, $code:expr) => { - impl From<$type> for TVMRetValue { - fn from(val: $type) -> Self { - TVMRetValue { - prim_value: 0, - box_value: box val, - type_code: $code, - } - } - } - impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { - if let Ok(val) = ret.box_value.downcast::<$type>() { - Ok(*val) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code - )) - } - } - } - }; -} - -impl_prim_ret_value!(i8, 0); -impl_prim_ret_value!(u8, 1); -impl_prim_ret_value!(i16, 0); -impl_prim_ret_value!(u16, 1); -impl_prim_ret_value!(i32, 0); -impl_prim_ret_value!(u32, 1); -impl_prim_ret_value!(f32, 2); -impl_prim_ret_value!(i64, 0); -impl_prim_ret_value!(u64, 1); -impl_prim_ret_value!(f64, 2); -impl_prim_ret_value!(isize, 0); -impl_prim_ret_value!(usize, 1); -impl_boxed_ret_value!(String, 11); - -impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue { - fn from(val: &'t Tensor<'a>) -> Self { - TVMRetValue { - prim_value: 0, - box_value: box DLTensor::from(val), - type_code: TVMTypeCode_kNDArrayContainer as i64, - } - } -} - -impl<'a> TryFrom for Tensor<'a> { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - ensure!( - ret.type_code == TVMTypeCode_kArrayHandle as i64 - || ret.type_code == TVMTypeCode_kNDArrayContainer as i64, - "Could not downcast arg. Expected `{}` or `{}`, but got `{}`", - TVMTypeCode_kArrayHandle, - TVMTypeCode_kNDArrayContainer, - ret.type_code, - ); - - let dlt = unsafe { *(ret.prim_value as *mut DLTensor as *const DLTensor) }; - Ok(dlt.into()) - } -} - -// @see `WrapPackedFunc` in `llvm_module.cc`. -pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { - box move |args: &[TVMArgValue]| { - func( - args - .iter() - .map(|ref arg| arg.value) - .collect::>() - .as_ptr(), - args - .iter() - .map(|ref arg| arg.type_code as i32) - .collect::>() - .as_ptr() as *const i32, - args.len() as i32, - ); - TVMRetValue::default() - } -} diff --git a/rust/src/runtime/sgx.rs b/rust/src/runtime/sgx.rs deleted file mode 100644 index 00be3ee3b..000000000 --- a/rust/src/runtime/sgx.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::{ - ffi::CString, - os::raw::{c_char, c_int}, -}; - -use errors::Result; -use ffi::runtime::TVMValue; -use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; - -pub use runtime::threading::tvm_run_worker as run_worker; - -#[macro_export] -macro_rules! tvm_ocall { - ($func: expr) => { - match $func { - 0 => Ok(()), - err => Err(format!("SGX error: {}", err)), - } - }; -} - -pub type SgxStatus = u32; - -#[cfg(target_env = "sgx")] -extern "C" { - fn tvm_ocall_packed_func( - name: *const c_char, - arg_values: *const TVMValue, - type_codes: *const c_int, - num_args: c_int, - ret_val: *mut TVMValue, - ret_type_code: *mut c_int, - ) -> SgxStatus; -} - -pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Result { - let mut ret_val = TVMValue { v_int64: 0 }; - let ret_type_code = 0i64; - unsafe { - tvm_ocall!(tvm_ocall_packed_func( - CString::new(fn_name.as_ref()).unwrap().as_ptr(), - args - .iter() - .map(|ref arg| arg.value) - .collect::>() - .as_ptr(), - args - .iter() - .map(|ref arg| arg.type_code as i32) - .collect::>() - .as_ptr() as *const i32, - args.len() as i32, - &mut ret_val as *mut TVMValue, - &mut (ret_type_code as i32) as *mut c_int, - ))?; - } - Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64)) -} - -#[macro_export] -macro_rules! ocall_packed { - ($fn_name:expr, $($args:expr),+) => { - ocall_packed_func($fn_name, &[$($args.into(),)+]) - .expect(concat!("Error calling `", $fn_name, "`")) - }; - ($fn_name:expr) => { - ocall_packed_func($fn_name, &Vec::new()) - .expect(concat!("Error calling `", $fn_name, "`")) - } -} - -pub fn shutdown() { - if env!("TVM_NUM_THREADS") != "0" { - sgx_join_threads() - } -} - -impl Drop for SystemLibModule { - fn drop(&mut self) { - shutdown() - } -} diff --git a/rust/src/runtime/threading.rs b/rust/src/runtime/threading.rs deleted file mode 100644 index 1d6d7fc78..000000000 --- a/rust/src/runtime/threading.rs +++ /dev/null @@ -1,337 +0,0 @@ -use std::{ - os::raw::{c_int, c_void}, - sync::{ - atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, - Arc, Barrier, - }, -}; - -#[cfg(not(target_env = "sgx"))] -use num_cpus; -#[cfg(not(target_env = "sgx"))] -use std::{ - env, - thread::{self, JoinHandle}, -}; - -#[cfg(target_env = "sgx")] -use std::{collections::VecDeque, ptr, sync::Mutex}; - -use bounded_spsc_queue::{self, Producer}; - -use super::super::errors::*; -use ffi::runtime::TVMParallelGroupEnv; - -#[cfg(target_env = "sgx")] -use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue}; - -type FTVMParallelLambda = - extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; - -/// Holds a parallel job request made by a TVM library function. -struct Job { - cb: FTVMParallelLambda, - cdata: *const c_void, - req_num_tasks: usize, - pending: Arc, -} - -impl Job { - /// Splits this job into a number of `Task`s which can be scheduled. - fn tasks(&self, num_workers: usize) -> Vec { - let num_tasks = if self.req_num_tasks == 0 { - num_workers - } else { - self.req_num_tasks.min(num_workers) - }; - self.pending.store(num_tasks, Ordering::SeqCst); - - let barrier = Arc::new(Barrier::new(num_tasks)); - - (0..num_tasks) - .map(move |i| Task { - id: i, - flambda: self.cb, - penv: TVMParallelGroupEnv { - sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void, - num_task: num_tasks as i32, - }, - cdata: self.cdata, - pending: Arc::clone(&self.pending), - }) - .collect() - } - - /// Waits for all tasks in this `Job` to be completed. - fn wait(&self) -> Result<()> { - while self.pending.load(Ordering::Acquire) > 0 { - #[cfg(not(target_env = "sgx"))] - thread::yield_now(); - } - Ok(()) - } -} - -/// A chunk of work requested by a TVM function. -struct Task { - id: usize, - flambda: FTVMParallelLambda, - penv: TVMParallelGroupEnv, - cdata: *const c_void, - pending: Arc, -} -unsafe impl Send for Task {} -unsafe impl Sync for Task {} - -impl FnOnce<()> for Task { - type Output = i32; - extern "rust-call" fn call_once(self, _args: ()) -> Self::Output { - let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata); - self.pending.fetch_sub(1, Ordering::AcqRel); - status - } -} - -#[derive(Default)] -struct Threads { - #[allow(unused)] - #[cfg(not(target_env = "sgx"))] - handles: Vec>, - queues: Vec>, -} - -impl<'a> Threads { - #[cfg(not(target_env = "sgx"))] - fn launch) + 'static + Copy>( - num_threads: usize, - cb: F, - ) -> Self { - let (handles, queues) = (0..num_threads) - .map(|_| { - let (p, c) = bounded_spsc_queue::make(2); - let handle = thread::spawn(move || cb(c.into())); - (handle, p) - }) - .unzip(); - Threads { - handles: handles, - queues: queues, - } - } - - #[cfg(target_env = "sgx")] - fn launch) + 'static + Copy>( - num_threads: usize, - _cb: F, - ) -> Self { - let mut consumer_queues = SGX_QUEUES.lock().unwrap(); - let queues = (0..num_threads) - .map(|_| { - let (p, c) = bounded_spsc_queue::make(2); - consumer_queues.push_back(c.into()); - p - }) - .collect(); - ocall_packed!("__sgx_thread_group_launch__", num_threads as u64); - Threads { queues: queues } - } -} - -struct ThreadPool { - num_workers: usize, - #[allow(unused)] - threads: Threads, -} - -thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new()); - -impl ThreadPool { - fn new() -> Self { - let num_workers = max_concurrency(); - ThreadPool { - num_workers: num_workers, - threads: Threads::launch(num_workers, ThreadPool::run_worker), - } - } - - fn launch(&self, job: Job) { - let mut tasks = job.tasks(self.num_workers + 1); - - for (i, task) in tasks.split_off(1).into_iter().enumerate() { - self.threads.queues[i].push(task); - } - - tasks.pop().unwrap()(); - job.wait().unwrap(); - } - - fn run_worker(queue: Consumer) { - loop { - let task = queue.pop(); - let result = task(); - if result == ::min_value() { - break; - } else if result != 0 { - panic!("Error running task."); - } - } - } -} - -// Send + Sync wrapper for bounded_spsc_queue::Consumer -struct Consumer { - consumer: bounded_spsc_queue::Consumer, -} -impl From> for Consumer { - fn from(c: bounded_spsc_queue::Consumer) -> Self { - Consumer { consumer: c } - } -} -impl Consumer { - fn pop(&self) -> T { - self.consumer.pop() - } -} -unsafe impl Send for Consumer {} -unsafe impl Sync for Consumer {} - -#[cfg(target_env = "sgx")] -lazy_static! { - /// Holds tasks for untrusted threads which re-enter the enclave to execute. - static ref SGX_QUEUES: Mutex>> = Mutex::new(VecDeque::new()); -} - -#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))] -fn max_concurrency() -> usize { - if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) { - if let Ok(threads) = usize::from_str_radix(&threads_str, 10) { - return threads; - } - } - num_cpus::get_physical() -} - -#[cfg(target_env = "sgx")] -fn max_concurrency() -> usize { - usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1) -} - -#[cfg(target_arch = "wasm32")] -fn max_concurrency() -> usize { - 0 // wasm doesn't support threads yet -} - -#[cfg(target_env = "sgx")] -pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue { - let q = { - let mut qs = SGX_QUEUES.lock().unwrap(); - qs.pop_front() - // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return - }; - if let Some(q) = q { - ThreadPool::run_worker(q); - } - TVMRetValue::default() -} - -#[no_mangle] -pub extern "C" fn TVMBackendParallelLaunch( - cb: FTVMParallelLambda, - cdata: *const c_void, - num_task: usize, -) -> c_int { - if max_concurrency() == 0 { - let penv = TVMParallelGroupEnv { - sync_handle: 0 as *mut c_void, - num_task: 1, - }; - cb(0, &penv as *const _, cdata); - } else { - THREAD_POOL.with(|pool| { - pool.launch(Job { - cb: cb, - cdata: cdata, - req_num_tasks: num_task, - pending: Arc::new(ATOMIC_USIZE_INIT), - }); - }); - } - return 0; -} - -#[cfg(target_env = "sgx")] -pub(crate) fn sgx_join_threads() { - extern "C" fn poison_pill( - _task_id: usize, - _penv: *const TVMParallelGroupEnv, - _cdata: *const c_void, - ) -> i32 { - ::min_value() - } - - THREAD_POOL.with(|pool| { - pool.launch(Job { - cb: poison_pill, - cdata: ptr::null(), - req_num_tasks: 0, - pending: Arc::new(ATOMIC_USIZE_INIT), - }); - }); - ocall_packed!("__sgx_thread_group_join__", 0); -} - -// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used. -#[no_mangle] -pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { - let barrier: &Arc = unsafe { &*((*penv).sync_handle as *const Arc) }; - barrier.wait(); -} - -#[cfg(test)] -mod tests { - use std::{ptr, thread, time::Duration}; - - use super::*; - - #[test] - fn test_max_concurrency() { - env::set_var("TVM_NUM_THREADS", "42"); - env::set_var("OMP_NUM_THREADS", "24"); - assert_eq!(max_concurrency(), 42); - env::remove_var("TVM_NUM_THREADS"); - assert_eq!(max_concurrency(), 24); - } - - extern "C" fn flambda( - task_id: usize, - penv: *const TVMParallelGroupEnv, - cdata: *const c_void, - ) -> i32 { - if cdata == ptr::null() { - return 0; - } - unsafe { - let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize)); - thread::sleep(Duration::from_millis(50 * task_id as u64)); - counter.fetch_add(1, Ordering::SeqCst); - task_ids_sum.fetch_add(task_id, Ordering::SeqCst); - assert_eq!((*penv).num_task, 3); - } - 0 - } - - #[test] - fn test_parallel_launch() { - TVMBackendParallelLaunch(flambda, ptr::null(), 6); - let counter = ATOMIC_USIZE_INIT; - let task_ids_sum = ATOMIC_USIZE_INIT; - let cdata = (counter, task_ids_sum); - let num_tasks = 3; - TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); - assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks); - assert_eq!( - cdata.1.load(Ordering::SeqCst), - (0..num_tasks).sum::() - ); - } -} diff --git a/rust/src/runtime/workspace.rs b/rust/src/runtime/workspace.rs deleted file mode 100644 index d0e6d8c89..000000000 --- a/rust/src/runtime/workspace.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::{ - cell::RefCell, - os::raw::{c_int, c_void}, - ptr, -}; - -use super::allocator::Allocation; -use errors::*; - -const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` - -struct WorkspacePool { - workspaces: Vec, - free: Vec, - in_use: Vec, -} - -impl WorkspacePool { - fn new() -> Self { - WorkspacePool { - workspaces: Vec::new(), - free: Vec::new(), - in_use: Vec::new(), - } - } - - fn alloc_new(&mut self, size: usize) -> Result<*mut u8> { - self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); - self.in_use.push(self.workspaces.len() - 1); - Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) - } - - fn alloc(&mut self, size: usize) -> Result<*mut u8> { - if self.free.len() == 0 { - return self.alloc_new(size); - } - let idx = self - .free - .iter() - .fold(None, |cur_ws_idx: Option, &idx| { - let ws_size = self.workspaces[idx].size(); - if !ws_size >= size { - return cur_ws_idx; - } - cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { - let cur_size = self.workspaces[cur_idx].size(); - Some(match ws_size <= cur_size { - true => idx, - false => cur_idx, - }) - }) - }); - match idx { - Some(idx) => { - self.free.remove_item(&idx).unwrap(); - self.in_use.push(idx); - Ok(self.workspaces[idx].as_mut_ptr()) - } - None => self.alloc_new(size), - } - } - - fn free(&mut self, ptr: *mut u8) -> Result<()> { - let mut ws_idx = None; - for i in 0..self.in_use.len() { - let idx = self.in_use[i]; - if self.workspaces[idx].as_mut_ptr() == ptr { - self.in_use.remove(i); - ws_idx = Some(idx); - break; - } - } - Ok( - self - .free - .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?), - ) - } -} - -thread_local!(static WORKSPACE_POOL: RefCell = RefCell::new(WorkspacePool::new())); - -const WORKSPACE_PAGE_SIZE: usize = 4 << 10; - -#[no_mangle] -pub extern "C" fn TVMBackendAllocWorkspace( - _device_type: c_int, - _device_id: c_int, - size: u64, - _dtype_code_hint: c_int, - _dtype_bits_hint: c_int, -) -> *mut c_void { - let nbytes = if size == 0 { - WORKSPACE_PAGE_SIZE - } else { - size as usize - }; - WORKSPACE_POOL.with(|pool_cell| { - pool_cell - .borrow_mut() - .alloc(nbytes as usize) - .unwrap_or(ptr::null_mut()) as *mut c_void - }) -} - -#[no_mangle] -pub extern "C" fn TVMBackendFreeWorkspace( - _device_type: c_int, - _device_id: c_int, - ptr: *mut c_void, -) -> c_int { - WORKSPACE_POOL.with(|pool_cell| { - (match pool_cell.borrow_mut().free(ptr as *mut u8) { - Ok(()) => 0, - Err(_) => -1, - }) as c_int - }); - return 0; -} diff --git a/rust/tests/.gitignore b/rust/tests/.gitignore deleted file mode 100644 index 811076739..000000000 --- a/rust/tests/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -*.json -*.params -*.o diff --git a/rust/tests/build_model.py b/rust/tests/build_model.py deleted file mode 100644 index e0b904951..000000000 --- a/rust/tests/build_model.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Builds a simple NNVM graph for testing.""" - -from os import path as osp - -import nnvm -from nnvm import sym -from nnvm.compiler import graph_util -from nnvm.testing import init -import numpy as np -import tvm - -CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) - - -def _get_model(dshape): - data = sym.Variable('data', shape=dshape) - fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True) - left, right = sym.split(fc1, indices_or_sections=2, axis=1) - return sym.Group(((left + 1), (right - 1))) - - -def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): - if isinstance(graph, sym.Symbol): - graph = nnvm.graph.create(graph) - ishapes, _ = graph_util.infer_shape(graph, **input_shapes) - param_shapes = dict(zip(graph.index.input_names, ishapes)) - np.random.seed(seed) - params = {} - for param, shape in param_shapes.items(): - if param in {'data', 'label'} or not shape: - continue - init_value = np.empty(shape).astype('float32') - initializer(param, init_value) - params[param] = tvm.nd.array(init_value) - return params - -def main(): - dshape = (32, 16) - net = _get_model(dshape) - ishape_dict = {'data': dshape} - params = _init_params(net, ishape_dict) - graph, lib, params = nnvm.compiler.build(net, 'llvm', - shape=ishape_dict, - params=params, - dtype='float32') - - with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet: - f_resnet.write(graph.json()) - with open(osp.join(CWD, 'graph.params'), 'wb') as f_params: - f_params.write(nnvm.compiler.save_param_dict(params)) - -if __name__ == '__main__': - main() diff --git a/rust/tests/test_graph_serde.rs b/rust/tests/test_graph_serde.rs deleted file mode 100644 index b02c12889..000000000 --- a/rust/tests/test_graph_serde.rs +++ /dev/null @@ -1,39 +0,0 @@ -#![feature(try_from)] - -extern crate serde; -extern crate serde_json; - -extern crate tvm; - -use std::{convert::TryFrom, fs, io::Read}; - -use tvm::runtime::Graph; - -#[test] -fn test_load_graph() { - let mut params_bytes = Vec::new(); - fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params")) - .expect("Could not find TVM graph. Did you run `tests/build_model.py`?") - .read_to_end(&mut params_bytes) - .unwrap(); - let _params = tvm::runtime::load_param_dict(¶ms_bytes); - - let graph = Graph::try_from( - &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), - ) - .unwrap(); - - assert_eq!(graph.nodes[3].op, "tvm_op"); - assert_eq!( - graph.nodes[3] - .attrs - .as_ref() - .unwrap() - .get("func_name") - .unwrap(), - "fuse_dense" - ); - assert_eq!(graph.nodes[5].inputs[0].index, 0); - assert_eq!(graph.nodes[6].inputs[0].index, 1); - assert_eq!(graph.heads.len(), 2); -} diff --git a/rust/tests/test_nnvm/Cargo.toml b/rust/tests/test_nnvm/Cargo.toml deleted file mode 100644 index 7e6ce5fb7..000000000 --- a/rust/tests/test_nnvm/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "test-nnvm" -version = "0.0.0" -license = "Apache-2.0" -authors = ["Nick Hynes "] - -[dependencies] -ndarray = "0.11.2" -tvm = { path = "../../" } -serde = "1.0.59" -serde_json = "1.0.17" - -[build-dependencies] -ar = "0.6.0" diff --git a/rust/tests/test_nnvm/build.rs b/rust/tests/test_nnvm/build.rs deleted file mode 100644 index 4d9cd302b..000000000 --- a/rust/tests/test_nnvm/build.rs +++ /dev/null @@ -1,40 +0,0 @@ -extern crate ar; - -use std::{ - env, - fs::File, - path::{Path, PathBuf}, - process::Command, -}; - -use ar::Builder; - -fn main() { - let out_dir = env::var("OUT_DIR").unwrap(); - - let output = Command::new(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/build_test_graph.py" - )) - .arg(&out_dir) - .output() - .expect("Failed to execute command"); - assert!( - Path::new(&format!("{}/graph.o", out_dir)).exists(), - "Could not build graph lib: {}", - String::from_utf8(output.stderr) - .unwrap() - .trim() - .split("\n") - .last() - .unwrap_or("") - ); - - let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect(); - let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect(); - let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap()); - builder.append_path(in_path.to_str().unwrap()).unwrap(); - - println!("cargo:rustc-link-lib=static=graph"); - println!("cargo:rustc-link-search=native={}", out_dir); -} diff --git a/rust/tests/test_nnvm/src/build_test_graph.py b/rust/tests/test_nnvm/src/build_test_graph.py deleted file mode 100755 index 429cc2128..000000000 --- a/rust/tests/test_nnvm/src/build_test_graph.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python3 - -"""Builds a simple NNVM graph for testing.""" - -from os import path as osp -import sys - -import nnvm -from nnvm import sym -from nnvm.compiler import graph_util -from nnvm.testing import init -import numpy as np -import tvm - - -def _get_model(dshape): - data = sym.Variable('data', shape=dshape) - fc = sym.dense(data, units=dshape[-1]*2, use_bias=True) - left, right = sym.split(fc, indices_or_sections=2, axis=1) - return sym.Group(((left + 1), (right - 1), fc)) - - -def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): - if isinstance(graph, sym.Symbol): - graph = nnvm.graph.create(graph) - ishapes, _ = graph_util.infer_shape(graph, **input_shapes) - param_shapes = dict(zip(graph.index.input_names, ishapes)) - np.random.seed(seed) - params = {} - for param, shape in param_shapes.items(): - if param in {'data', 'label'} or not shape: - continue - - init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32') - if param.endswith('_bias'): - params[param] = tvm.nd.array(init_value) - continue - - init_value = np.empty(shape).astype('float32') - initializer(param, init_value) - # init_value /= init_value.sum() + 1e-10 - params[param] = tvm.nd.array(init_value) - return params - -def main(): - dshape = (4, 8) - net = _get_model(dshape) - ishape_dict = {'data': dshape} - params = _init_params(net, ishape_dict) - graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib', - shape=ishape_dict, - params=params, - dtype='float32') - - out_dir = sys.argv[1] - lib.save(osp.join(sys.argv[1], 'graph.o')) - with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: - f_resnet.write(graph.json()) - with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: - f_params.write(nnvm.compiler.save_param_dict(params)) - -if __name__ == '__main__': - main() diff --git a/rust/tests/test_nnvm/src/main.rs b/rust/tests/test_nnvm/src/main.rs deleted file mode 100644 index 0953ce2a2..000000000 --- a/rust/tests/test_nnvm/src/main.rs +++ /dev/null @@ -1,80 +0,0 @@ -#![feature(try_from)] - -#[macro_use] -extern crate ndarray; -extern crate serde; -extern crate serde_json; - -extern crate tvm; -use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; - -use ndarray::Array; -use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; - -const BATCH_SIZE: usize = 4; -const IN_DIM: usize = 8; - -macro_rules! check_sum { - ($e:expr, $a:ident, $b:ident) => { - let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap(); - check_sum!(a, $b); - }; - ($e:expr, $a:expr, $b:ident) => { - let a = Array::try_from($e.get_output($a).unwrap()).unwrap(); - check_sum!(a, $b); - }; - ($a:ident, $b:ident) => { - let a_sum: f32 = $a.scalar_sum(); - let b_sum: f32 = $b.scalar_sum(); - assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); - }; -} - -fn main() { - let syslib = SystemLibModule::default(); - - let mut params_bytes = Vec::new(); - fs::File::open(concat!(env!("OUT_DIR"), "/graph.params")) - .unwrap() - .read_to_end(&mut params_bytes) - .unwrap(); - let params = tvm::runtime::load_param_dict(¶ms_bytes) - .unwrap() - .into_iter() - .map(|(k, v)| (k, v.to_owned())) - .collect::>>(); - - let graph = - Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap(); - let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); - - let x = Array::from_shape_vec( - (BATCH_SIZE, IN_DIM), - (0..BATCH_SIZE * IN_DIM) - .map(|x| x as f32) - .collect::>(), - ).unwrap(); - let w = Array::try_from(params.get("dense0_weight").unwrap()) - .unwrap() - .into_shape((IN_DIM * 2, IN_DIM)) - .unwrap(); - let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap(); - let dense = x.dot(&w.t()) + &b; - let left = dense.slice(s![.., 0..IN_DIM]); - let right = dense.slice(s![.., IN_DIM..]); - let expected_o0 = &left + 1f32; - let expected_o1 = &right - 1f32; - - exec.load_params(params); - exec.set_input("data", x.clone().into()); - - check_sum!(exec, data, x); - check_sum!(exec, dense0_weight, w); - check_sum!(exec, dense0_bias, b); - - exec.run(); - - check_sum!(exec, 0, expected_o0); - check_sum!(exec, 1, expected_o1); - check_sum!(exec, 2, dense); -} diff --git a/rust/tests/test_tvm_basic/Cargo.toml b/rust/tests/test_tvm_basic/Cargo.toml deleted file mode 100644 index bd4193bcb..000000000 --- a/rust/tests/test_tvm_basic/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "test-tvm-basic" -version = "0.0.0" -license = "Apache-2.0" -authors = ["Nick Hynes "] - -[dependencies] -ndarray = "0.11.2" -tvm = { path = "../../" } - -[build-dependencies] -ar = "0.6.0" diff --git a/rust/tests/test_tvm_basic/build.rs b/rust/tests/test_tvm_basic/build.rs deleted file mode 100644 index 778dd1cab..000000000 --- a/rust/tests/test_tvm_basic/build.rs +++ /dev/null @@ -1,28 +0,0 @@ -extern crate ar; - -use std::{env, path::PathBuf, process::Command}; - -use ar::Builder; -use std::fs::File; - -fn main() { - let out_dir = env::var("OUT_DIR").unwrap(); - - let output = Command::new(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/build_test_lib.py" - )).arg(&out_dir) - .output() - .expect("Failed to execute command"); - if output.stderr.len() > 0 { - panic!(String::from_utf8(output.stderr).unwrap()); - } - - let in_path: PathBuf = [&out_dir, "test.o"].iter().collect(); - let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect(); - let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap()); - builder.append_path(in_path.to_str().unwrap()).unwrap(); - - println!("cargo:rustc-link-lib=static=test"); - println!("cargo:rustc-link-search=native={}", out_dir); -} diff --git a/rust/tests/test_tvm_basic/src/build_test_lib.py b/rust/tests/test_tvm_basic/src/build_test_lib.py deleted file mode 100755 index 7289a778f..000000000 --- a/rust/tests/test_tvm_basic/src/build_test_lib.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python3 - -"""Prepares a simple TVM library for testing.""" - -from os import path as osp -import sys - -import tvm - -def main(): - n = tvm.var('n') - A = tvm.placeholder((n,), name='A') - B = tvm.placeholder((n,), name='B') - C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') - s = tvm.create_schedule(C.op) - s[C].parallel(s[C].op.axis[0]) - print(tvm.lower(s, [A, B, C], simple_mode=True)) - tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o')) - -if __name__ == '__main__': - main() diff --git a/rust/tests/test_tvm_basic/src/main.rs b/rust/tests/test_tvm_basic/src/main.rs deleted file mode 100644 index b6c11451d..000000000 --- a/rust/tests/test_tvm_basic/src/main.rs +++ /dev/null @@ -1,25 +0,0 @@ -extern crate ndarray; -#[macro_use] -extern crate tvm; - -use ndarray::Array; -use tvm::{ - ffi::runtime::DLTensor, - runtime::{Module, SystemLibModule}, -}; - -fn main() { - let syslib = SystemLibModule::default(); - let add = syslib - .get_function("default_function") - .expect("main function not found"); - let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); - let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); - let mut c = Array::from_vec(vec![0f32; 4]); - let e = Array::from_vec(vec![2f32, 2., 4., 4.]); - let mut a_dl: DLTensor = (&mut a).into(); - let mut b_dl: DLTensor = (&mut b).into(); - let mut c_dl: DLTensor = (&mut c).into(); - call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); - assert!(c.all_close(&e, 1e-8f32)); -} diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index 8e66d1098..5d8c242f4 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -2,24 +2,60 @@ set -e -export LD_LIBRARY_PATH=lib:$LD_LIBRARY_PATH +export TVM_HOME="$(git rev-parse --show-toplevel)" -tvm_root="$(git rev-parse --show-toplevel)" -export PYTHONPATH="$tvm_root/python":"$tvm_root/nnvm/python":"$tvm_root/topi/python" +export LD_LIBRARY_PATH="$TVM_HOME/lib":"$TVM_HOME/build":"$TVM_HOME/nnvm":$LD_LIBRARY_PATH +export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/nnvm/python":"$TVM_HOME/topi/python" +export RUST_DIR="$TVM_HOME/rust" -#cd rust -#cargo fmt -- --check +cd $RUST_DIR +cargo fmt -- --check + +# test common +cd $RUST_DIR/common +cargo build --features runtime +cargo test --features runtime --tests + +cargo build --features frontend +cargo test --features frontend --tests + +# test runtime +cd $RUST_DIR/runtime # run basic tests -#python3 tests/build_model.py -#cargo test --tests +python3 tests/build_model.py +cargo test --tests # run TVM module test -#cd tests/test_tvm_basic -#cargo run -#cd - +cd tests/test_tvm_basic +cargo run +cd - # run NNVM graph test -#cd tests/test_nnvm -#cargo run -#cd - +cd tests/test_nnvm +cargo run +cd - + +# test frontend +cd $RUST_DIR/frontend + +cargo test --tests -- --test-threads=1 + +# run basic tests on cpu +cd tests/basics +cargo build --features cpu +cargo run --features cpu +# uncomment when have more CI resources +# cargo build --features gpu +# cargo run --features gpu +# fi +cd - + +# run callback tests separately: https://discuss.tvm.ai/t/are-global-functions-need-to-be-accessed-in-separate-processes/1075 +cd tests/callback +cargo build +cargo run --bin int +cargo run --bin float +cargo run --bin array +cargo run --bin string +cd -