--- /dev/null
+target/
+*.rs.bk
+Cargo.lock
+c_runtime_api.rs
+++ /dev/null
-target
-**/*.rs.bk
-Cargo.lock
-/tvm-sys/src/bindgen.rs
license = "Apache-2.0"
[features]
-runtime = []
-frontend = ["tvm-sys"]
+bindings = []
[dependencies]
-error-chain = { version = "0.12.0", default-features = false }
-tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
+failure = "0.1.5"
+ndarray = "0.12.1"
+
+[build-dependencies]
+bindgen = "0.37.4"
--- /dev/null
+extern crate bindgen;
+
+use std::path::PathBuf;
+
+fn main() {
+ if cfg!(feature = "bindings") {
+ println!("cargo:rerun-if-env-changed=TVM_HOME");
+ println!("cargo:rustc-link-lib=dylib=tvm_runtime");
+ println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
+ }
+
+ // @see rust-bindgen#550 for `blacklist_type`
+ bindgen::Builder::default()
+ .header(format!(
+ "{}/include/tvm/runtime/c_runtime_api.h",
+ env!("TVM_HOME")
+ ))
+ .header(format!(
+ "{}/include/tvm/runtime/c_backend_api.h",
+ env!("TVM_HOME")
+ ))
+ .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
+ .blacklist_type("max_align_t")
+ .layout_tests(false)
+ .derive_partialeq(true)
+ .derive_eq(true)
+ .generate()
+ .expect("unable to generate bindings")
+ .write_to_file(PathBuf::from("src/c_runtime_api.rs"))
+ .expect("can not write the bindings!");
+}
--- /dev/null
+use std::{
+ any::TypeId,
+ mem,
+ os::raw::{c_int, c_void},
+};
+
+use crate::ffi::{
+ DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
+ DLDeviceType_kDLCPU, DLTensor,
+};
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct DataType {
+ pub code: usize,
+ pub bits: usize,
+ pub lanes: usize,
+}
+
+impl DataType {
+ /// Returns the number of bytes occupied by an element of this `DataType`.
+ pub fn itemsize(&self) -> usize {
+ (self.bits * self.lanes) >> 3
+ }
+
+ /// Returns whether this `DataType` represents primitive type `T`.
+ pub fn is_type<T: 'static>(&self) -> bool {
+ if self.lanes != 1 {
+ return false;
+ }
+ let typ = TypeId::of::<T>();
+ (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
+ || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
+ || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
+ || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
+ || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
+ || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
+ }
+
+ pub fn code(&self) -> usize {
+ self.code
+ }
+
+ pub fn bits(&self) -> usize {
+ self.bits
+ }
+
+ pub fn lanes(&self) -> usize {
+ self.lanes
+ }
+}
+
+impl<'a> From<&'a DataType> for DLDataType {
+ fn from(dtype: &'a DataType) -> Self {
+ Self {
+ code: dtype.code as u8,
+ bits: dtype.bits as u8,
+ lanes: dtype.lanes as u16,
+ }
+ }
+}
+
+impl From<DLDataType> for DataType {
+ fn from(dtype: DLDataType) -> Self {
+ Self {
+ code: dtype.code as usize,
+ bits: dtype.bits as usize,
+ lanes: dtype.lanes as usize,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub struct TVMContext {
+ pub device_type: usize,
+ pub device_id: usize,
+}
+
+impl<'a> From<&'a TVMContext> for DLContext {
+ fn from(ctx: &'a TVMContext) -> Self {
+ Self {
+ device_type: ctx.device_type as u32,
+ device_id: ctx.device_id as i32,
+ }
+ }
+}
+
+impl Default for TVMContext {
+ fn default() -> Self {
+ Self {
+ device_type: DLDeviceType_kDLCPU as usize,
+ device_id: 0,
+ }
+ }
+}
+
+/// `From` conversions to `DLTensor` for `ndarray::Array`.
+/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
+macro_rules! impl_dltensor_from_ndarray {
+ ($type:ty, $typecode:expr) => {
+ impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
+ fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
+ DLTensor {
+ data: arr.as_mut_ptr() as *mut c_void,
+ ctx: DLContext {
+ device_type: DLDeviceType_kDLCPU,
+ device_id: 0,
+ },
+ ndim: arr.ndim() as c_int,
+ dtype: DLDataType {
+ code: $typecode as u8,
+ bits: 8 * mem::size_of::<$type>() as u8,
+ lanes: 1,
+ },
+ shape: arr.shape().as_ptr() as *const i64 as *mut i64,
+ strides: arr.strides().as_ptr() as *const isize as *mut i64,
+ byte_offset: 0,
+ }
+ }
+ }
+ };
+}
+
+impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
+impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
+++ /dev/null
-/* automatically generated by rust-bindgen for TVM revision 6292c78 */
-
-pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0";
-pub const DLPACK_VERSION: u32 = 8;
-pub const _STDINT_H: u32 = 1;
-pub const _FEATURES_H: u32 = 1;
-pub const _DEFAULT_SOURCE: u32 = 1;
-pub const __USE_ISOC11: u32 = 1;
-pub const __USE_ISOC99: u32 = 1;
-pub const __USE_ISOC95: u32 = 1;
-pub const __USE_POSIX_IMPLICITLY: u32 = 1;
-pub const _POSIX_SOURCE: u32 = 1;
-pub const _POSIX_C_SOURCE: u32 = 200809;
-pub const __USE_POSIX: u32 = 1;
-pub const __USE_POSIX2: u32 = 1;
-pub const __USE_POSIX199309: u32 = 1;
-pub const __USE_POSIX199506: u32 = 1;
-pub const __USE_XOPEN2K: u32 = 1;
-pub const __USE_XOPEN2K8: u32 = 1;
-pub const _ATFILE_SOURCE: u32 = 1;
-pub const __USE_MISC: u32 = 1;
-pub const __USE_ATFILE: u32 = 1;
-pub const __USE_FORTIFY_LEVEL: u32 = 0;
-pub const _STDC_PREDEF_H: u32 = 1;
-pub const __STDC_IEC_559__: u32 = 1;
-pub const __STDC_IEC_559_COMPLEX__: u32 = 1;
-pub const __STDC_ISO_10646__: u32 = 201505;
-pub const __STDC_NO_THREADS__: u32 = 1;
-pub const __GNU_LIBRARY__: u32 = 6;
-pub const __GLIBC__: u32 = 2;
-pub const __GLIBC_MINOR__: u32 = 23;
-pub const _SYS_CDEFS_H: u32 = 1;
-pub const __WORDSIZE: u32 = 64;
-pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1;
-pub const __SYSCALL_WORDSIZE: u32 = 64;
-pub const _BITS_WCHAR_H: u32 = 1;
-pub const INT8_MIN: i32 = -128;
-pub const INT16_MIN: i32 = -32768;
-pub const INT32_MIN: i32 = -2147483648;
-pub const INT8_MAX: u32 = 127;
-pub const INT16_MAX: u32 = 32767;
-pub const INT32_MAX: u32 = 2147483647;
-pub const UINT8_MAX: u32 = 255;
-pub const UINT16_MAX: u32 = 65535;
-pub const UINT32_MAX: u32 = 4294967295;
-pub const INT_LEAST8_MIN: i32 = -128;
-pub const INT_LEAST16_MIN: i32 = -32768;
-pub const INT_LEAST32_MIN: i32 = -2147483648;
-pub const INT_LEAST8_MAX: u32 = 127;
-pub const INT_LEAST16_MAX: u32 = 32767;
-pub const INT_LEAST32_MAX: u32 = 2147483647;
-pub const UINT_LEAST8_MAX: u32 = 255;
-pub const UINT_LEAST16_MAX: u32 = 65535;
-pub const UINT_LEAST32_MAX: u32 = 4294967295;
-pub const INT_FAST8_MIN: i32 = -128;
-pub const INT_FAST16_MIN: i64 = -9223372036854775808;
-pub const INT_FAST32_MIN: i64 = -9223372036854775808;
-pub const INT_FAST8_MAX: u32 = 127;
-pub const INT_FAST16_MAX: u64 = 9223372036854775807;
-pub const INT_FAST32_MAX: u64 = 9223372036854775807;
-pub const UINT_FAST8_MAX: u32 = 255;
-pub const UINT_FAST16_MAX: i32 = -1;
-pub const UINT_FAST32_MAX: i32 = -1;
-pub const INTPTR_MIN: i64 = -9223372036854775808;
-pub const INTPTR_MAX: u64 = 9223372036854775807;
-pub const UINTPTR_MAX: i32 = -1;
-pub const PTRDIFF_MIN: i64 = -9223372036854775808;
-pub const PTRDIFF_MAX: u64 = 9223372036854775807;
-pub const SIG_ATOMIC_MIN: i32 = -2147483648;
-pub const SIG_ATOMIC_MAX: u32 = 2147483647;
-pub const SIZE_MAX: i32 = -1;
-pub const WINT_MIN: u32 = 0;
-pub const WINT_MAX: u32 = 4294967295;
-pub type int_least8_t = ::std::os::raw::c_schar;
-pub type int_least16_t = ::std::os::raw::c_short;
-pub type int_least32_t = ::std::os::raw::c_int;
-pub type int_least64_t = ::std::os::raw::c_long;
-pub type uint_least8_t = ::std::os::raw::c_uchar;
-pub type uint_least16_t = ::std::os::raw::c_ushort;
-pub type uint_least32_t = ::std::os::raw::c_uint;
-pub type uint_least64_t = ::std::os::raw::c_ulong;
-pub type int_fast8_t = ::std::os::raw::c_schar;
-pub type int_fast16_t = ::std::os::raw::c_long;
-pub type int_fast32_t = ::std::os::raw::c_long;
-pub type int_fast64_t = ::std::os::raw::c_long;
-pub type uint_fast8_t = ::std::os::raw::c_uchar;
-pub type uint_fast16_t = ::std::os::raw::c_ulong;
-pub type uint_fast32_t = ::std::os::raw::c_ulong;
-pub type uint_fast64_t = ::std::os::raw::c_ulong;
-pub type intmax_t = ::std::os::raw::c_long;
-pub type uintmax_t = ::std::os::raw::c_ulong;
-pub type wchar_t = ::std::os::raw::c_int;
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct max_align_t {
- pub __clang_max_align_nonce1: ::std::os::raw::c_longlong,
- pub __bindgen_padding_0: u64,
- pub __clang_max_align_nonce2: f64,
-}
-pub const DLDeviceType_kDLCPU: DLDeviceType = 1;
-pub const DLDeviceType_kDLGPU: DLDeviceType = 2;
-pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3;
-pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4;
-pub const DLDeviceType_kDLMetal: DLDeviceType = 8;
-pub const DLDeviceType_kDLVPI: DLDeviceType = 9;
-pub const DLDeviceType_kDLROCM: DLDeviceType = 10;
-/// \brief The device type in DLContext.
-pub type DLDeviceType = u32;
-/// \brief A Device context for Tensor and operator.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLContext {
- /// \brief The device type used in the device.
- pub device_type: DLDeviceType,
- /// \brief The device index
- pub device_id: ::std::os::raw::c_int,
-}
-pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0;
-pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1;
-pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2;
-/// \brief The type code options DLDataType.
-pub type DLDataTypeCode = u32;
-/// \brief The data type the tensor can hold.
-///
-/// Examples
-/// - float: type_code = 2, bits = 32, lanes=1
-/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
-/// - int8: type_code = 0, bits = 8, lanes=1
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLDataType {
- /// \brief Type code of base types.
- /// We keep it uint8_t instead of DLDataTypeCode for minimal memory
- /// footprint, but the value should be one of DLDataTypeCode enum values.
- ///
- pub code: u8,
- /// \brief Number of bits, common choices are 8, 16, 32.
- pub bits: u8,
- /// \brief Number of lanes in the type, used for vector types.
- pub lanes: u16,
-}
-/// \brief Plain C Tensor object, does not manage memory.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLTensor {
- /// \brief The opaque data pointer points to the allocated data.
- /// This will be CUDA device pointer or cl_mem handle in OpenCL.
- /// This pointer is always aligns to 256 bytes as in CUDA.
- pub data: *mut ::std::os::raw::c_void,
- /// \brief The device context of the tensor
- pub ctx: DLContext,
- /// \brief Number of dimensions
- pub ndim: ::std::os::raw::c_int,
- /// \brief The data type of the pointer
- pub dtype: DLDataType,
- /// \brief The shape of the tensor
- pub shape: *mut i64,
- /// \brief strides of the tensor,
- /// can be NULL, indicating tensor is compact.
- pub strides: *mut i64,
- /// \brief The offset in bytes to the beginning pointer to data
- pub byte_offset: u64,
-}
-/// \brief C Tensor object, manage memory of DLTensor. This data structure is
-/// intended to faciliate the borrowing of DLTensor by another framework. It is
-/// not meant to transfer the tensor. When the borrowing framework doesn't need
-/// the tensor, it should call the deleter to notify the host that the resource
-/// is no longer needed.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct DLManagedTensor {
- /// \brief DLTensor which is being memory managed
- pub dl_tensor: DLTensor,
- /// \brief the context of the original host framework of DLManagedTensor in
- /// which DLManagedTensor is used in the framework. It can also be NULL.
- pub manager_ctx: *mut ::std::os::raw::c_void,
- /// \brief Destructor signature void (*)(void*) - this should be called
- /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
- /// if there is no way for the caller to provide a reasonable destructor.
- pub deleter: ::std::option::Option<unsafe extern "C" fn(self_: *mut DLManagedTensor)>,
-}
-/// \brief type of array index.
-pub type tvm_index_t = i64;
-pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5;
-pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6;
-pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7;
-pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11;
-pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12;
-/// \brief Extension device types in TVM
-pub type TVMDeviceExtType = u32;
-pub const TVMTypeCode_kHandle: TVMTypeCode = 3;
-pub const TVMTypeCode_kNull: TVMTypeCode = 4;
-pub const TVMTypeCode_kTVMType: TVMTypeCode = 5;
-pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6;
-pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7;
-pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8;
-pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9;
-pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10;
-pub const TVMTypeCode_kStr: TVMTypeCode = 11;
-pub const TVMTypeCode_kBytes: TVMTypeCode = 12;
-pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13;
-pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15;
-pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16;
-pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20;
-pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64;
-pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128;
-/// \brief The type code in TVMType
-/// \note TVMType is used in two places.
-pub type TVMTypeCode = u32;
-/// \brief The data type used in TVM Runtime.
-///
-/// Examples
-/// - float: type_code = 2, bits = 32, lanes=1
-/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
-/// - int8: type_code = 0, bits = 8, lanes=1
-///
-/// \note Arguments TVM API function always takes bits=64 and lanes=1
-pub type TVMType = DLDataType;
-/// \brief The Device information, abstract away common device types.
-pub type TVMContext = DLContext;
-/// \brief The tensor array stucture to TVM API.
-pub type TVMArray = DLTensor;
-/// \brief the array handle
-pub type TVMArrayHandle = *mut TVMArray;
-/// \brief Union type of values
-/// being passed through API and function calls.
-#[repr(C)]
-#[derive(Copy, Clone)]
-pub union TVMValue {
- pub v_int64: i64,
- pub v_float64: f64,
- pub v_handle: *mut ::std::os::raw::c_void,
- pub v_str: *const ::std::os::raw::c_char,
- pub v_type: TVMType,
- pub v_ctx: TVMContext,
- _bindgen_union_align: u64,
-}
-/// \brief Byte array type used to pass in byte array
-/// When kBytes is used as data type.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct TVMByteArray {
- pub data: *const ::std::os::raw::c_char,
- pub size: usize,
-}
-/// \brief Handle to TVM runtime modules.
-pub type TVMModuleHandle = *mut ::std::os::raw::c_void;
-/// \brief Handle to packed function handle.
-pub type TVMFunctionHandle = *mut ::std::os::raw::c_void;
-/// \brief Handle to hold return value.
-pub type TVMRetValueHandle = *mut ::std::os::raw::c_void;
-/// \brief The stream that is specific to device
-/// can be NULL, which indicates the default one.
-pub type TVMStreamHandle = *mut ::std::os::raw::c_void;
-extern "C" {
- /// \brief Used for implementing C API function.
- /// Set last error message before return.
- /// \param msg The error message to be set.
- pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char);
-}
-extern "C" {
- /// \brief return str message of the last error
- /// all function in this file will return 0 when success
- /// and -1 when an error occured,
- /// TVMGetLastError can be called to retrieve the error
- ///
- /// this function is threadsafe and can be called by different thread
- /// \return error info
- pub fn TVMGetLastError() -> *const ::std::os::raw::c_char;
-}
-extern "C" {
- /// \brief Load module from file.
- /// \param file_name The file name to load the module from.
- /// \param format The format of the module.
- /// \param out The result module
- ///
- /// \return 0 when success, -1 when failure happens
- /// \note The resulting module do not contain import relation.
- /// It can be reconstructed by TVMModImport.
- pub fn TVMModLoadFromFile(
- file_name: *const ::std::os::raw::c_char,
- format: *const ::std::os::raw::c_char,
- out: *mut TVMModuleHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Add dep to mod's dependency.
- /// This allows functions in this module to use modules.
- ///
- /// \param mod The module handle.
- /// \param dep The dependent module to be imported.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Get function from the module.
- /// \param mod The module handle.
- /// \param func_name The name of the function.
- /// \param query_imports Whether to query imported modules
- /// \param out The result function, can be NULL if it is not available.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMModGetFunction(
- mod_: TVMModuleHandle,
- func_name: *const ::std::os::raw::c_char,
- query_imports: ::std::os::raw::c_int,
- out: *mut TVMFunctionHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free front-end extension type resource.
- /// \param handle The extension handle.
- /// \param type_code The type of of the extension type.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMExtTypeFree(
- handle: *mut ::std::os::raw::c_void,
- type_code: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free the Module
- /// \param mod The module to be freed.
- ///
- /// \note This may not free up the module's resources.
- /// If there is active TVMFunctionHandle uses the module
- /// Or if this module is imported by another active module.
- ///
- /// The all functions remains valid until TVMFuncFree is called.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free the function when it is no longer needed.
- /// \param func The function handle
- /// \return 0 when success, -1 when failure happens
- pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Call a Packed TVM Function.
- ///
- /// \param func node handle of the function.
- /// \param arg_values The arguments
- /// \param type_codes The type codes of the arguments
- /// \param num_args Number of arguments.
- ///
- /// \param ret_val The return value.
- /// \param ret_type_code the type code of return value.
- ///
- /// \return 0 when success, -1 when failure happens
- /// \note TVM calls always exchanges with type bits=64, lanes=1
- ///
- /// \note API calls always exchanges with type bits=64, lanes=1
- /// If API call returns container handles (e.g. FunctionHandle)
- /// these handles should be managed by the front-end.
- /// The front-end need to call free function (e.g. TVMFuncFree)
- /// to free these handles.
- pub fn TVMFuncCall(
- func: TVMFunctionHandle,
- arg_values: *mut TVMValue,
- type_codes: *mut ::std::os::raw::c_int,
- num_args: ::std::os::raw::c_int,
- ret_val: *mut TVMValue,
- ret_type_code: *mut ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Set the return value of TVMPackedCFunc.
- ///
- /// This function is called by TVMPackedCFunc to set the return value.
- /// When this function is not called, the function returns null by default.
- ///
- /// \param ret The return value handle, pass by ret in TVMPackedCFunc
- /// \param value The value to be returned.
- /// \param type_code The type of the value to be returned.
- /// \param num_ret Number of return values, for now only 1 is supported.
- pub fn TVMCFuncSetReturn(
- ret: TVMRetValueHandle,
- value: *mut TVMValue,
- type_code: *mut ::std::os::raw::c_int,
- num_ret: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Inplace translate callback argument value to return value.
- /// This is only needed for non-POD arguments.
- ///
- /// \param value The value to be translated.
- /// \param code The type code to be translated.
- /// \note This function will do a shallow copy when necessary.
- ///
- /// \return 0 when success, -1 when failure happens.
- pub fn TVMCbArgToReturn(
- value: *mut TVMValue,
- code: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-/// \brief C type of packed function.
-///
-/// \param args The arguments
-/// \param type_codes The type codes of the arguments
-/// \param num_args Number of arguments.
-/// \param ret The return value handle.
-/// \param resource_handle The handle additional resouce handle from fron-end.
-/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
-/// \sa TVMCFuncSetReturn
-pub type TVMPackedCFunc = ::std::option::Option<
- unsafe extern "C" fn(
- args: *mut TVMValue,
- type_codes: *mut ::std::os::raw::c_int,
- num_args: ::std::os::raw::c_int,
- ret: TVMRetValueHandle,
- resource_handle: *mut ::std::os::raw::c_void,
- ) -> ::std::os::raw::c_int,
->;
-/// \brief C callback to free the resource handle in C packed function.
-/// \param resource_handle The handle additional resouce handle from fron-end.
-pub type TVMPackedCFuncFinalizer =
- ::std::option::Option<unsafe extern "C" fn(resource_handle: *mut ::std::os::raw::c_void)>;
-/// \brief Signature for extension function declarer.
-///
-/// TVM call this function to get the extension functions
-/// The declarer will call register_func to register function and their name.
-///
-/// \param register_func_handle The register function
-/// \return 0 if success, -1 if failure happens
-pub type TVMExtensionFuncDeclarer = ::std::option::Option<
- unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int,
->;
-extern "C" {
- /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
- ///
- /// The resource_handle will be managed by TVM API, until the function is no longer used.
- ///
- /// \param func The packed C function.
- /// \param resource_handle The resource handle from front-end, can be NULL.
- /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
- /// \param out the result function handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMFuncCreateFromCFunc(
- func: TVMPackedCFunc,
- resource_handle: *mut ::std::os::raw::c_void,
- fin: TVMPackedCFuncFinalizer,
- out: *mut TVMFunctionHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Register the function to runtime's global table.
- ///
- /// The registered function then can be pulled by the backend by the name.
- ///
- /// \param name The name of the function.
- /// \param f The function to be registered.
- /// \param override Whether allow override already registered function.
- pub fn TVMFuncRegisterGlobal(
- name: *const ::std::os::raw::c_char,
- f: TVMFunctionHandle,
- override_: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Get a global function.
- ///
- /// \param name The name of the function.
- /// \param out the result function pointer, NULL if it does not exist.
- ///
- /// \note The function handle of global function is managed by TVM runtime,
- /// So TVMFuncFree is should not be called when it get deleted.
- pub fn TVMFuncGetGlobal(
- name: *const ::std::os::raw::c_char,
- out: *mut TVMFunctionHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief List all the globally registered function name
- /// \param out_size The number of functions
- /// \param out_array The array of function names.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMFuncListGlobalNames(
- out_size: *mut ::std::os::raw::c_int,
- out_array: *mut *mut *const ::std::os::raw::c_char,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Allocate a nd-array's memory,
- /// including space of shape, of given spec.
- ///
- /// \param shape The shape of the array, the data content will be copied to out
- /// \param ndim The number of dimension of the array.
- /// \param dtype_code The type code of the dtype
- /// \param dtype_bits The number of bits of dtype
- /// \param dtype_lanes The number of lanes in the dtype.
- /// \param device_type The device type of context
- /// \param device_id The device id of context.
- /// \param out The output handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayAlloc(
- shape: *const tvm_index_t,
- ndim: ::std::os::raw::c_int,
- dtype_code: ::std::os::raw::c_int,
- dtype_bits: ::std::os::raw::c_int,
- dtype_lanes: ::std::os::raw::c_int,
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- out: *mut TVMArrayHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free the TVM Array.
- /// \param handle The array handle to be freed.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Copy array data from CPU byte array.
- /// \param handle The array handle.
- /// \param data the data pointer
- /// \param nbytes The number of bytes to copy.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayCopyFromBytes(
- handle: TVMArrayHandle,
- data: *mut ::std::os::raw::c_void,
- nbytes: usize,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Copy array data to CPU byte array.
- /// \param handle The array handle.
- /// \param data the data pointer
- /// \param nbytes The number of bytes to copy.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayCopyToBytes(
- handle: TVMArrayHandle,
- data: *mut ::std::os::raw::c_void,
- nbytes: usize,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Copy the array, both from and to must be valid during the copy.
- /// \param from The array to be copied from.
- /// \param to The target space.
- /// \param stream The stream where the copy happens, can be NULL.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayCopyFromTo(
- from: TVMArrayHandle,
- to: TVMArrayHandle,
- stream: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Produce an array from the DLManagedTensor that shares data memory
- /// with the DLManagedTensor.
- /// \param from The source DLManagedTensor.
- /// \param out The output array handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayFromDLPack(
- from: *mut DLManagedTensor,
- out: *mut TVMArrayHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Produce a DLMangedTensor from the array that shares data memory with
- /// the array.
- /// \param from The source array.
- /// \param out The DLManagedTensor handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMArrayToDLPack(
- from: TVMArrayHandle,
- out: *mut *mut DLManagedTensor,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Delete (free) a DLManagedTensor's data.
- /// \param dltensor Pointer to the DLManagedTensor.
- pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor);
-}
-extern "C" {
- /// \brief Create a new runtime stream.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context
- /// \param out The new stream handle
- /// \return 0 when success, -1 when failure happens
- pub fn TVMStreamCreate(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- out: *mut TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Free a created stream handle.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context
- /// \param stream The stream to be freed
- /// \return 0 when success, -1 when failure happens
- pub fn TVMStreamFree(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- stream: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Set the runtime stream of current thread to be stream.
- /// The subsequent calls to the same device_type
- /// will use the setted stream handle.
- /// The specific type of stream is runtime device dependent.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context.
- /// \param handle The stream handle.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMSetStream(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- handle: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Wait until all computations on stream completes.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context.
- /// \param stream The stream to be synchronized.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMSynchronize(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- stream: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Synchronize two streams of execution.
- ///
- /// \param device_type The device type of context
- /// \param device_id The device id of context
- /// \param src The source stream to synchronize.
- /// \param dst The destination stream to synchronize.
- /// \return 0 when success, -1 when failure happens
- pub fn TVMStreamStreamSynchronize(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- src: TVMStreamHandle,
- dst: TVMStreamHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Backend function for modules to get function
- /// from its environment mod_node (its imports and global function).
- /// The user do should not call TVMFuncFree on func.
- ///
- /// \param mod_node The module handle.
- /// \param func_name The name of the function.
- /// \param out The result function.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendGetFuncFromEnv(
- mod_node: *mut ::std::os::raw::c_void,
- func_name: *const ::std::os::raw::c_char,
- out: *mut TVMFunctionHandle,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Backend function to register system-wide library symbol.
- ///
- /// \param name The name of the symbol
- /// \param ptr The symbol address.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendRegisterSystemLibSymbol(
- name: *const ::std::os::raw::c_char,
- ptr: *mut ::std::os::raw::c_void,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Backend function to allocate temporal workspace.
- ///
- /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
- ///
- /// \param nbytes The size of the space requested.
- /// \param device_type The device type which the space will be allocated.
- /// \param device_id The device id which the space will be allocated.
- /// \param dtype_code_hint The type code of the array elements. Only used in
- /// certain backends such as OpenGL.
- /// \param dtype_bits_hint The type bits of the array elements. Only used in
- /// certain backends such as OpenGL.
- /// \return nullptr when error is thrown, a valid ptr if success
- pub fn TVMBackendAllocWorkspace(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- nbytes: u64,
- dtype_code_hint: ::std::os::raw::c_int,
- dtype_bits_hint: ::std::os::raw::c_int,
- ) -> *mut ::std::os::raw::c_void;
-}
-extern "C" {
- /// \brief Backend function to free temporal workspace.
- ///
- /// \param ptr The result allocated space pointer.
- /// \param device_type The device type which the space will be allocated.
- /// \param device_id The device id which the space will be allocated.
- /// \return 0 when no error is thrown, -1 when failure happens
- ///
- /// \sa TVMBackendAllocWorkspace
- pub fn TVMBackendFreeWorkspace(
- device_type: ::std::os::raw::c_int,
- device_id: ::std::os::raw::c_int,
- ptr: *mut ::std::os::raw::c_void,
- ) -> ::std::os::raw::c_int;
-}
-/// \brief Environment for TVM parallel task.
-#[repr(C)]
-#[derive(Debug, Copy, Clone)]
-pub struct TVMParallelGroupEnv {
- /// \brief Auxiliary used for synchronization
- pub sync_handle: *mut ::std::os::raw::c_void,
- /// \brief total amount of task
- pub num_task: i32,
-}
-/// \brief The callback function to execute a parallel lambda
-/// \param task_id the task id of the function.
-/// \param penv The parallel environment backs the execution.
-/// \param cdata The supporting closure data.
-pub type FTVMParallelLambda = ::std::option::Option<
- unsafe extern "C" fn(
- task_id: ::std::os::raw::c_int,
- penv: *mut TVMParallelGroupEnv,
- cdata: *mut ::std::os::raw::c_void,
- ) -> ::std::os::raw::c_int,
->;
-extern "C" {
- /// \brief Backend function for running parallel jobs.
- ///
- /// \param flambda The parallel function to be launched.
- /// \param cdata The closure data.
- /// \param num_task Number of tasks to launch, can be 0, means launch
- /// with all available threads.
- ///
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendParallelLaunch(
- flambda: FTVMParallelLambda,
- cdata: *mut ::std::os::raw::c_void,
- num_task: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief BSP barrrier between parallel threads
- /// \param task_id the task id of the function.
- /// \param penv The parallel environment backs the execution.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendParallelBarrier(
- task_id: ::std::os::raw::c_int,
- penv: *mut TVMParallelGroupEnv,
- ) -> ::std::os::raw::c_int;
-}
-extern "C" {
- /// \brief Simple static initialization function.
- /// Run f once and set handle to be not null.
- /// This function is mainly used for test purpose.
- ///
- /// \param handle An global address to indicate f
- /// \param f The function to be ran
- /// \param cdata The closure data to pass to the function.
- /// \param nbytes Number of bytes in the closure data.
- /// \return 0 when no error is thrown, -1 when failure happens
- pub fn TVMBackendRunOnce(
- handle: *mut *mut ::std::os::raw::c_void,
- f: ::std::option::Option<
- unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int,
- >,
- cdata: *mut ::std::os::raw::c_void,
- nbytes: ::std::os::raw::c_int,
- ) -> ::std::os::raw::c_int;
-}
-//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
+use std::fmt;
-error_chain! {
- errors {
- TryFromTVMArgValueError(expected: String, actual: String) {
- description("mismatched types while converting from TVMArgValue")
- display("expected `{}` but given `{}`", expected, actual)
+static TYPE_CODE_STRS: [&str; 15] = [
+ "int",
+ "uint",
+ "float",
+ "handle",
+ "null",
+ "TVMType",
+ "TVMContext",
+ "ArrayHandle",
+ "NodeHandle",
+ "ModuleHandle",
+ "FuncHandle",
+ "str",
+ "bytes",
+ "NDArrayContainer",
+ "ExtBegin",
+];
+
+#[derive(Debug, Fail)]
+pub struct ValueDowncastError {
+ actual_type_code: i64,
+ expected_type_code: i64,
+}
+
+impl ValueDowncastError {
+ pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self {
+ Self {
+ actual_type_code,
+ expected_type_code,
}
+ }
+}
- TryFromTVMRetValueError(expected: String, actual: String) {
- description("mismatched types while downcasting TVMRetValue")
- display("invalid downcast: expected `{}` but given `{}`", expected, actual)
+impl fmt::Display for ValueDowncastError {
+ fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ formatter,
+ "Could not downcast TVMValue: expected `{}` but was {}",
+ TYPE_CODE_STRS[self.actual_type_code as usize],
+ TYPE_CODE_STRS[self.expected_type_code as usize]
+ )
+ }
+}
+
+#[derive(Debug, Fail)]
+#[fail(display = "Function call `{}` returned error: {}", context, message)]
+pub struct FuncCallError {
+ context: String,
+ message: String,
+}
+
+impl FuncCallError {
+ pub fn get_with_context(context: String) -> Self {
+ Self {
+ context,
+ message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
+ .to_str()
+ .expect("double fault")
+ .to_owned(),
}
}
}
+
+// error_chain! {
+// errors {
+// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
+// description("mismatched types while downcasting TVMRetValue")
+// display("invalid downcast: expected `{}` but was `{}`",
+// expected_type, type_code_to_string(actual_type_code))
+// }
+// }
+// foreign_links {
+// IntoString(std::ffi::IntoStringError);
+// ParseInt(std::num::ParseIntError);
+// Utf8(std::str::Utf8Error);
+// }
+// }
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
-#![crate_name = "tvm_common"]
-#![recursion_limit = "1024"]
-#![allow(non_camel_case_types, unused_imports)]
-#![feature(box_syntax, try_from)]
+#![feature(box_syntax, trait_alias)]
#[macro_use]
-extern crate error_chain;
+extern crate failure;
/// Unified ffi module for both runtime and frontend crates.
pub mod ffi {
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
- #[cfg(feature = "frontend")]
- pub extern crate tvm_sys as ts;
+ use std::os::raw::{c_char, c_int, c_void};
- #[cfg(feature = "runtime")]
- pub mod runtime {
- use std::os::raw::{c_char, c_int, c_void};
+ include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
- include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
-
- pub type BackendPackedCFunc = extern "C" fn(
- args: *const TVMValue,
- type_codes: *const c_int,
- num_args: c_int,
- ) -> c_int;
- }
+ pub type BackendPackedCFunc =
+ extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
}
+pub mod array;
pub mod errors;
-pub mod ty;
+#[macro_use]
+pub mod packed_func;
pub mod value;
pub use errors::*;
-pub use ty::TVMTypeCode;
-pub use value::{TVMArgValue, TVMRetValue, TVMValue};
+pub use ffi::{TVMContext, TVMType};
+pub use packed_func::{TVMArgValue, TVMRetValue};
--- /dev/null
+use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
+
+use failure::Error;
+
+pub use crate::ffi::TVMValue;
+use crate::ffi::*;
+
+pub trait PackedFunc =
+ Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync;
+
+/// Calls a packed function and returns a `TVMRetValue`.
+///
+/// # Example
+///
+/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
+#[macro_export]
+macro_rules! call_packed {
+ ($fn:expr, $($args:expr),+) => {
+ $fn(&[$($args.into(),)+])
+ };
+ ($fn:expr) => {
+ $fn(&Vec::new())
+ };
+}
+
+/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
+/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
+#[derive(Clone, Copy)]
+pub struct TVMArgValue<'a> {
+ pub _lifetime: PhantomData<&'a ()>,
+ pub value: TVMValue,
+ pub type_code: i64,
+}
+
+impl<'a> TVMArgValue<'a> {
+ pub fn new(value: TVMValue, type_code: i64) -> Self {
+ TVMArgValue {
+ _lifetime: PhantomData,
+ value: value,
+ type_code: type_code,
+ }
+ }
+}
+
+#[macro_export]
+macro_rules! ensure_type {
+ ($val:ident, $expected_type_code:expr) => {
+ ensure!(
+ $val.type_code == $expected_type_code as i64,
+ $crate::errors::ValueDowncastError::new(
+ $val.type_code as i64,
+ $expected_type_code as i64
+ )
+ );
+ };
+}
+
+/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
+macro_rules! impl_prim_tvm_arg {
+ ($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => {
+ $(
+ impl From<$type> for TVMArgValue<'static> {
+ fn from(val: $type) -> Self {
+ TVMArgValue {
+ value: TVMValue { $field: val as $field_type },
+ type_code: $type_code as i64,
+ _lifetime: PhantomData,
+ }
+ }
+ }
+ impl<'a> From<&'a $type> for TVMArgValue<'a> {
+ fn from(val: &'a $type) -> Self {
+ TVMArgValue {
+ value: TVMValue {
+ $field: val.to_owned() as $field_type,
+ },
+ type_code: $type_code as i64,
+ _lifetime: PhantomData,
+ }
+ }
+ }
+ impl<'a> TryFrom<TVMArgValue<'a>> for $type {
+ type Error = Error;
+ fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
+ ensure_type!(val, $type_code);
+ Ok(unsafe { val.value.$field as $type })
+ }
+ }
+
+ impl<'a> TryFrom<&TVMArgValue<'a>> for $type {
+ type Error = Error;
+ fn try_from(val: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
+ ensure_type!(val, $type_code);
+ Ok(unsafe { val.value.$field as $type })
+ }
+ }
+ )+
+ };
+}
+
+impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]);
+impl_prim_tvm_arg!(
+ DLDataTypeCode_kDLInt,
+ v_int64,
+ i64,
+ [i8, i16, i32, i64, isize]
+);
+impl_prim_tvm_arg!(
+ DLDataTypeCode_kDLUInt,
+ v_int64,
+ i64,
+ [u8, u16, u32, u64, usize]
+);
+
+#[cfg(feature = "bindings")]
+// only allow this in bindings because pure-rust can't take ownership of leaked CString
+impl<'a> From<&String> for TVMArgValue<'a> {
+ fn from(string: &String) -> Self {
+ TVMArgValue {
+ value: TVMValue {
+ v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(),
+ },
+ type_code: TVMTypeCode_kStr as i64,
+ _lifetime: PhantomData,
+ }
+ }
+}
+
+impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> {
+ fn from(string: &std::ffi::CString) -> Self {
+ TVMArgValue {
+ value: TVMValue {
+ v_str: string.as_ptr(),
+ },
+ type_code: TVMTypeCode_kStr as i64,
+ _lifetime: PhantomData,
+ }
+ }
+}
+
+impl<'a> TryFrom<TVMArgValue<'a>> for &str {
+ type Error = Error;
+ fn try_from(arg: TVMArgValue<'a>) -> Result<Self, Self::Error> {
+ ensure_type!(arg, TVMTypeCode_kStr);
+ Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
+ }
+}
+
+impl<'a> TryFrom<&TVMArgValue<'a>> for &str {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'a>) -> Result<Self, Self::Error> {
+ ensure_type!(arg, TVMTypeCode_kStr);
+ Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?)
+ }
+}
+
+/// Creates a conversion to a `TVMArgValue` for an object handle.
+impl<'a, T> From<*const T> for TVMArgValue<'a> {
+ fn from(ptr: *const T) -> Self {
+ TVMArgValue {
+ value: TVMValue {
+ v_handle: ptr as *mut T as *mut c_void,
+ },
+ type_code: TVMTypeCode_kArrayHandle as i64,
+ _lifetime: PhantomData,
+ }
+ }
+}
+
+/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
+impl<'a, T> From<*mut T> for TVMArgValue<'a> {
+ fn from(ptr: *mut T) -> Self {
+ TVMArgValue {
+ value: TVMValue {
+ v_handle: ptr as *mut c_void,
+ },
+ type_code: TVMTypeCode_kHandle as i64,
+ _lifetime: PhantomData,
+ }
+ }
+}
+
+impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
+ fn from(arr: &'a mut DLTensor) -> Self {
+ TVMArgValue {
+ value: TVMValue {
+ v_handle: arr as *mut _ as *mut c_void,
+ },
+ type_code: TVMTypeCode_kArrayHandle as i64,
+ _lifetime: PhantomData,
+ }
+ }
+}
+
+impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
+ fn from(arr: &'a DLTensor) -> Self {
+ TVMArgValue {
+ value: TVMValue {
+ v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
+ },
+ type_code: TVMTypeCode_kArrayHandle as i64,
+ _lifetime: PhantomData,
+ }
+ }
+}
+
+impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType {
+ type Error = Error;
+ fn try_from(arg: &'a TVMArgValue<'v>) -> Result<Self, Self::Error> {
+ ensure_type!(arg, TVMTypeCode_kTVMType);
+ Ok(unsafe { arg.value.v_type.into() })
+ }
+}
+
+/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
+/// Can be downcasted using `try_from` if it contains the desired type.
+///
+/// # Example
+///
+/// ```
+/// let a = 42u32;
+/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
+///
+/// let s = "hello, world!";
+/// let t: TVMRetValue = s.into();
+/// assert_eq!(String::try_from(t).unwrap(), s);
+/// ```
+pub struct TVMRetValue {
+ pub value: TVMValue,
+ pub box_value: Box<Any>,
+ pub type_code: i64,
+}
+
+impl TVMRetValue {
+ pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
+ Self {
+ value,
+ type_code,
+ box_value: box (),
+ }
+ }
+
+ pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
+ (self.value, self.type_code as TVMTypeCode)
+ }
+}
+
+impl Default for TVMRetValue {
+ fn default() -> Self {
+ TVMRetValue {
+ value: TVMValue { v_int64: 0 as i64 },
+ type_code: 0,
+ box_value: box (),
+ }
+ }
+}
+
+macro_rules! impl_pod_ret_value {
+ ($code:expr, [ $( $ty:ty ),+ ] ) => {
+ $(
+ impl From<$ty> for TVMRetValue {
+ fn from(val: $ty) -> Self {
+ Self {
+ value: val.into(),
+ type_code: $code as i64,
+ box_value: box (),
+ }
+ }
+ }
+
+ impl TryFrom<TVMRetValue> for $ty {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
+ ensure_type!(ret, $code);
+ Ok(ret.value.into())
+ }
+ }
+ )+
+ };
+}
+
+impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]);
+impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]);
+impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]);
+impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]);
+impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]);
+
+impl TryFrom<TVMRetValue> for String {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<String, Self::Error> {
+ ensure_type!(ret, TVMTypeCode_kStr);
+ let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) };
+ let ret_str = cs.clone().into_string();
+ if cfg!(feature = "bindings") {
+ std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall)
+ }
+ Ok(ret_str?)
+ }
+}
+
+impl From<String> for TVMRetValue {
+ fn from(s: String) -> Self {
+ let cs = std::ffi::CString::new(s).unwrap();
+ Self {
+ value: TVMValue {
+ v_str: cs.into_raw() as *mut i8,
+ },
+ box_value: box (),
+ type_code: TVMTypeCode_kStr as i64,
+ }
+ }
+}
-//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
-//!
-//! # Example
-//!
-//! ```
-//! let dtype = TVMType::from("float");
-//! println!("dtype is: {}", dtype);
-//! ```
-
-use std::{
- ffi::{CStr, CString},
- fmt::{self, Display, Formatter},
-};
-
-/// TVM type codes.
-#[repr(u32)]
-#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
-pub enum TVMTypeCode {
- kDLInt = 0,
- kDLUInt = 1,
- kDLFloat = 2,
- kHandle = 3,
- kNull = 4,
- kTVMType = 5,
- kTVMContext = 6,
- kArrayHandle = 7,
- kNodeHandle = 8,
- kModuleHandle = 9,
- kFuncHandle = 10,
- kStr = 11,
- kBytes = 12,
- kNDArrayContainer = 13,
-}
-
-impl Default for TVMTypeCode {
- fn default() -> Self {
- TVMTypeCode::kDLInt
- }
-}
-
-impl From<TVMTypeCode> for i64 {
- fn from(arg: TVMTypeCode) -> i64 {
- match arg {
- TVMTypeCode::kDLInt => 0,
- TVMTypeCode::kDLUInt => 1,
- TVMTypeCode::kDLFloat => 2,
- TVMTypeCode::kHandle => 3,
- TVMTypeCode::kNull => 4,
- TVMTypeCode::kTVMType => 5,
- TVMTypeCode::kTVMContext => 6,
- TVMTypeCode::kArrayHandle => 7,
- TVMTypeCode::kNodeHandle => 8,
- TVMTypeCode::kModuleHandle => 9,
- TVMTypeCode::kFuncHandle => 10,
- TVMTypeCode::kStr => 11,
- TVMTypeCode::kBytes => 12,
- TVMTypeCode::kNDArrayContainer => 13,
- }
- }
-}
-
-impl Into<TVMTypeCode> for i64 {
- fn into(self) -> TVMTypeCode {
- match self {
- 0 => TVMTypeCode::kDLInt,
- 1 => TVMTypeCode::kDLUInt,
- 2 => TVMTypeCode::kDLFloat,
- 3 => TVMTypeCode::kHandle,
- 4 => TVMTypeCode::kNull,
- 5 => TVMTypeCode::kTVMType,
- 6 => TVMTypeCode::kTVMContext,
- 7 => TVMTypeCode::kArrayHandle,
- 8 => TVMTypeCode::kNodeHandle,
- 9 => TVMTypeCode::kModuleHandle,
- 10 => TVMTypeCode::kFuncHandle,
- 11 => TVMTypeCode::kStr,
- 12 => TVMTypeCode::kBytes,
- 13 => TVMTypeCode::kNDArrayContainer,
- _ => unreachable!(),
- }
- }
-}
-
-impl Display for TVMTypeCode {
- fn fmt(&self, f: &mut Formatter) -> fmt::Result {
- write!(
- f,
- "{}",
- match self {
- TVMTypeCode::kDLInt => "int",
- TVMTypeCode::kDLUInt => "uint",
- TVMTypeCode::kDLFloat => "float",
- TVMTypeCode::kHandle => "handle",
- TVMTypeCode::kNull => "null",
- TVMTypeCode::kTVMType => "TVM type",
- TVMTypeCode::kTVMContext => "TVM context",
- TVMTypeCode::kArrayHandle => "Array handle",
- TVMTypeCode::kNodeHandle => "Node handle",
- TVMTypeCode::kModuleHandle => "Module handle",
- TVMTypeCode::kFuncHandle => "Function handle",
- TVMTypeCode::kStr => "string",
- TVMTypeCode::kBytes => "bytes",
- TVMTypeCode::kNDArrayContainer => "ndarray container",
- }
- )
- }
-}
-
-macro_rules! impl_prim_type {
- ($type:ty, $variant:ident) => {
- impl<'a> From<&'a $type> for TVMTypeCode {
- fn from(_arg: &$type) -> Self {
- TVMTypeCode::$variant
- }
- }
-
- impl<'a> From<&'a mut $type> for TVMTypeCode {
- fn from(_arg: &mut $type) -> Self {
- TVMTypeCode::$variant
- }
- }
- };
-}
-
-impl_prim_type!(usize, kDLInt);
-impl_prim_type!(i64, kDLInt);
-impl_prim_type!(i32, kDLInt);
-impl_prim_type!(i16, kDLInt);
-impl_prim_type!(i8, kDLInt);
-
-impl_prim_type!(u64, kDLUInt);
-impl_prim_type!(u32, kDLUInt);
-impl_prim_type!(u16, kDLUInt);
-impl_prim_type!(u8, kDLUInt);
-
-impl_prim_type!(f64, kDLFloat);
-impl_prim_type!(f32, kDLFloat);
-
-impl_prim_type!(str, kStr);
-impl_prim_type!(CStr, kStr);
-impl_prim_type!(String, kStr);
-impl_prim_type!(CString, kStr);
-
-impl_prim_type!([u8], kBytes);
-//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue`
-//! required for using TVM functions.
+use std::str::FromStr;
-use std::{
- any::Any,
- convert::TryFrom,
- ffi::{CStr, CString},
- fmt::{self, Debug, Formatter},
- marker::PhantomData,
- mem,
- ops::Deref,
- os::raw::{c_char, c_void},
-};
+use failure::Error;
-#[cfg(feature = "runtime")]
-use ffi::runtime::TVMValue as _TVMValue;
+use crate::ffi::*;
-#[cfg(feature = "frontend")]
-use ffi::ts::TVMValue as _TVMValue;
-
-use errors::*;
-
-use ty::TVMTypeCode;
-
-/// Wrapped TVMValue type.
-#[derive(Clone, Copy)]
-pub struct TVMValue {
- pub inner: _TVMValue,
-}
-
-impl TVMValue {
- /// Creates TVMValue from the raw part.
- pub fn new(inner: _TVMValue) -> Self {
- TVMValue { inner }
- }
-
- pub(crate) fn into_raw(self) -> _TVMValue {
- self.inner
- }
-}
-
-impl Debug for TVMValue {
- fn fmt(&self, f: &mut Formatter) -> fmt::Result {
- unsafe {
- write!(
- f,
- "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\
- [v_str: {:?}]",
- self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str
- )
+impl TVMType {
+ fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
+ Self {
+ code: type_code,
+ bits,
+ lanes,
}
}
}
-impl Deref for TVMValue {
- type Target = _TVMValue;
- fn deref(&self) -> &Self::Target {
- &self.inner
- }
-}
-
-macro_rules! impl_prim_val {
- ($type:ty, $field:ident, $cast:ty) => {
- impl From<$type> for TVMValue {
- fn from(arg: $type) -> Self {
- let inner = _TVMValue {
- $field: arg as $cast,
- };
- Self::new(inner)
- }
+/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
+/// such as "int32", "float32" or with lane "float32x1".
+impl FromStr for TVMType {
+ type Err = Error;
+ fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+ if type_str == "bool" {
+ return Ok(TVMType::new(1, 1, 1));
}
- impl<'a> From<&'a $type> for TVMValue {
- fn from(arg: &$type) -> Self {
- let inner = _TVMValue {
- $field: *arg as $cast,
- };
- Self::new(inner)
+ let mut type_lanes = type_str.split("x");
+ let typ = type_lanes.next().expect("Missing dtype");
+ let lanes = type_lanes
+ .next()
+ .map(|l| <u16>::from_str_radix(l, 10))
+ .unwrap_or(Ok(1))?;
+ let (type_name, bits) = match typ.find(char::is_numeric) {
+ Some(idx) => {
+ let (name, bits_str) = typ.split_at(idx);
+ (name, u8::from_str_radix(bits_str, 10)?)
}
- }
-
- impl<'a> From<&'a mut $type> for TVMValue {
- fn from(arg: &mut $type) -> Self {
- let inner = _TVMValue {
- $field: *arg as $cast,
- };
- Self::new(inner)
- }
- }
-
- impl TryFrom<TVMValue> for $type {
- type Error = Error;
- fn try_from(val: TVMValue) -> Result<Self> {
- Ok(unsafe { val.inner.$field as $type })
- }
- }
-
- impl<'a> TryFrom<&'a TVMValue> for $type {
- type Error = Error;
- fn try_from(val: &TVMValue) -> Result<Self> {
- Ok(unsafe { val.into_raw().$field as $type })
- }
- }
-
- impl<'a> TryFrom<&'a mut TVMValue> for $type {
- type Error = Error;
- fn try_from(val: &mut TVMValue) -> Result<Self> {
- Ok(unsafe { val.into_raw().$field as $type })
- }
- }
- };
-}
-
-impl_prim_val!(isize, v_int64, i64);
-impl_prim_val!(i64, v_int64, i64);
-impl_prim_val!(i32, v_int64, i64);
-impl_prim_val!(i16, v_int64, i64);
-impl_prim_val!(i8, v_int64, i64);
-impl_prim_val!(usize, v_int64, i64);
-impl_prim_val!(u64, v_int64, i64);
-impl_prim_val!(u32, v_int64, i64);
-impl_prim_val!(u16, v_int64, i64);
-impl_prim_val!(u8, v_int64, i64);
-
-impl_prim_val!(f64, v_float64, f64);
-impl_prim_val!(f32, v_float64, f64);
-
-impl<'a> From<&'a str> for TVMValue {
- fn from(arg: &str) -> TVMValue {
- let arg = CString::new(arg).unwrap();
- let inner = _TVMValue {
- v_str: arg.as_ptr() as *const c_char,
+ None => (typ, 32),
};
- mem::forget(arg);
- Self::new(inner)
- }
-}
-
-impl<'a> From<&'a String> for TVMValue {
- fn from(arg: &String) -> TVMValue {
- let arg = CString::new(arg.as_bytes()).unwrap();
- let inner = _TVMValue {
- v_str: arg.as_ptr() as *const c_char,
- };
- mem::forget(arg);
- Self::new(inner)
- }
-}
-
-impl<'a> From<&'a CString> for TVMValue {
- fn from(arg: &CString) -> TVMValue {
- let arg = arg.to_owned();
- let inner = _TVMValue {
- v_str: arg.as_ptr() as *const c_char,
- };
- mem::forget(arg);
- Self::new(inner)
- }
-}
-impl<'a> From<&'a [u8]> for TVMValue {
- fn from(arg: &[u8]) -> TVMValue {
- let arg = arg.to_owned();
- let inner = _TVMValue {
- v_handle: &arg as *const _ as *mut c_void,
+ let type_code = match type_name {
+ "int" => 0,
+ "uint" => 1,
+ "float" => 2,
+ "handle" => 3,
+ _ => return Err(format_err!("Unknown type {}", type_name)),
};
- mem::forget(arg);
- Self::new(inner)
- }
-}
-
-/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function.
-/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`.
-/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions.
-///
-/// ## Example
-///
-/// ```
-/// let s = "hello".to_string();
-/// let arg = TVMArgValue::from(&s);
-/// let tvm: String = arg.try_into().unwrap();
-/// assert_eq!(arg, s);
-/// ```
-#[derive(Debug, Clone, Copy)]
-pub struct TVMArgValue<'a> {
- /// The wrapped TVMValue
- pub value: TVMValue,
- /// The matching type code.
- pub type_code: TVMTypeCode,
- /// This is only exposed to runtime and frontend crates and is not meant to be used directly.
- pub lifetime: PhantomData<&'a ()>,
-}
-
-impl<'a> TVMArgValue<'a> {
- pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self {
- TVMArgValue {
- value: value,
- type_code: type_code,
- lifetime: PhantomData,
- }
- }
-}
-
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 {
- type Error = Error;
- fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
- if (arg.type_code == TVMTypeCode::kDLInt)
- | (arg.type_code == TVMTypeCode::kDLUInt)
- | (arg.type_code == TVMTypeCode::kNull)
- {
- Ok(unsafe { arg.value.inner.v_int64 })
- } else {
- bail!(ErrorKind::TryFromTVMArgValueError(
- stringify!(i64).to_string(),
- arg.type_code.to_string()
- ))
- }
- }
-}
-
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 {
- type Error = Error;
- fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
- if arg.type_code == TVMTypeCode::kDLFloat {
- Ok(unsafe { arg.value.inner.v_float64 })
- } else {
- bail!(ErrorKind::TryFromTVMArgValueError(
- stringify!(f64).to_string(),
- arg.type_code.to_string()
- ))
- }
- }
-}
-
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String {
- type Error = Error;
- fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
- if arg.type_code == TVMTypeCode::kStr {
- let ret_str = unsafe {
- match CStr::from_ptr(arg.value.inner.v_str).to_str() {
- Ok(s) => s,
- Err(_) => "Invalid UTF-8 message",
- }
- };
- Ok(ret_str.to_string())
- } else {
- bail!(ErrorKind::TryFromTVMArgValueError(
- stringify!(String).to_string(),
- arg.type_code.to_string()
- ))
- }
- }
-}
-
-/// Main way to create a TVMArgValue from suported Rust values.
-impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a>
-where
- TVMValue: From<&'b T>,
- TVMTypeCode: From<&'b T>,
-{
- fn from(arg: &'b T) -> Self {
- TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg))
- }
-}
-
-/// Creates a conversion to a `TVMArgValue` for an object handle.
-impl<'a, T> From<*const T> for TVMArgValue<'a> {
- fn from(ptr: *const T) -> Self {
- let value = TVMValue::new(_TVMValue {
- v_handle: ptr as *mut T as *mut c_void,
- });
- TVMArgValue::new(value, TVMTypeCode::kArrayHandle)
+ Ok(TVMType::new(type_code, bits, lanes))
}
}
-/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
-impl<'a, T> From<*mut T> for TVMArgValue<'a> {
- fn from(ptr: *mut T) -> Self {
- let value = TVMValue::new(_TVMValue {
- v_handle: ptr as *mut c_void,
- });
-
- TVMArgValue::new(value, TVMTypeCode::kHandle)
- }
-}
-
-/// An owned version of TVMPODValue. It can be converted from varieties of
-/// primitive and object types.
-/// It can be downcasted using `try_from` if it contains the desired type.
-///
-/// # Example
-///
-/// ```
-/// let a = 42u32;
-/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
-///
-/// let s = "hello, world!";
-/// let t: TVMRetValue = s.into();
-/// assert_eq!(String::try_from(t).unwrap(), s);
-/// ```
-pub struct TVMRetValue {
- /// A primitive return value, if any.
- pub prim_value: usize,
- /// An object return value, if any.
- pub box_value: Box<Any>,
- pub type_code: TVMTypeCode,
-}
-
-impl TVMRetValue {
- fn new(prim_value: usize, box_value: Box<Any>, type_code: TVMTypeCode) -> Self {
- Self {
- prim_value,
- box_value,
- type_code,
+impl std::fmt::Display for TVMType {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ if self.bits == 1 && self.lanes == 1 {
+ return write!(f, "bool");
}
- }
-
- /// unsafe function to create `TVMRetValue` from `TVMValue` and
- /// its matching `TVMTypeCode`.
- pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self {
- let value = value.into_raw();
- match type_code {
- TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => {
- Self::new(value.v_int64 as usize, box (), type_code)
- }
- TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code),
- TVMTypeCode::kHandle
- | TVMTypeCode::kArrayHandle
- | TVMTypeCode::kNodeHandle
- | TVMTypeCode::kModuleHandle
- | TVMTypeCode::kFuncHandle => {
- Self::new(value.v_handle as usize, box value.v_handle, type_code)
- }
- TVMTypeCode::kStr | TVMTypeCode::kBytes => {
- Self::new(value.v_str as usize, box (value.v_str), type_code)
- }
- _ => Self::new(0usize, box (), type_code),
+ let mut type_str = match self.code {
+ 0 => "int",
+ 1 => "uint",
+ 2 => "float",
+ 4 => "handle",
+ _ => "unknown",
}
- }
+ .to_string();
- /// Returns the underlying `TVMValue` and `TVMTypeCode`.
- pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
- let val = match self.type_code {
- TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue {
- v_int64: self.prim_value as i64,
- }),
- TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue {
- v_float64: self.prim_value as f64,
- }),
- TVMTypeCode::kHandle
- | TVMTypeCode::kArrayHandle
- | TVMTypeCode::kNodeHandle
- | TVMTypeCode::kModuleHandle
- | TVMTypeCode::kFuncHandle
- | TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue {
- v_handle: self.prim_value as *const c_void as *mut c_void,
- }),
- TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue {
- v_str: self.prim_value as *const c_char,
- }),
- _ => unreachable!(),
- };
- (val, self.type_code)
- }
-}
-
-impl Default for TVMRetValue {
- fn default() -> Self {
- TVMRetValue {
- prim_value: 0usize,
- box_value: box (),
- type_code: TVMTypeCode::default(),
- }
- }
-}
-
-impl Clone for TVMRetValue {
- fn clone(&self) -> Self {
- match self.type_code {
- TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => {
- Self::new(self.prim_value.clone(), box (), self.type_code.clone())
- }
- TVMTypeCode::kHandle
- | TVMTypeCode::kArrayHandle
- | TVMTypeCode::kNodeHandle
- | TVMTypeCode::kModuleHandle
- | TVMTypeCode::kFuncHandle
- | TVMTypeCode::kNDArrayContainer => Self::new(
- self.prim_value.clone(),
- box (self.prim_value.clone() as *const c_void as *mut c_void),
- self.type_code.clone(),
- ),
- TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new(
- self.prim_value.clone(),
- box (self.prim_value.clone() as *const c_char),
- self.type_code.clone(),
- ),
- _ => unreachable!(),
+ type_str += &self.bits.to_string();
+ if self.lanes > 1 {
+ type_str += &format!("x{}", self.lanes);
}
+ f.write_str(&type_str)
}
}
-impl Debug for TVMRetValue {
- fn fmt(&self, f: &mut Formatter) -> fmt::Result {
- write!(
- f,
- "prim_value: {:?}, box_value: {:?}, type_code: {:?}",
- self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code
- )
- }
-}
-
-macro_rules! impl_prim_ret_value {
- ($type:ty, $code:expr) => {
- impl From<$type> for TVMRetValue {
- fn from(val: $type) -> Self {
- TVMRetValue {
- prim_value: val as usize,
- box_value: box (),
- type_code: $code,
- }
- }
- }
-
- impl<'a> From<&'a $type> for TVMRetValue {
- fn from(val: &$type) -> Self {
- TVMRetValue {
- prim_value: *val as usize,
- box_value: box (),
- type_code: $code,
+macro_rules! impl_pod_tvm_value {
+ ($field:ident, $field_ty:ty, $( $ty:ty ),+) => {
+ $(
+ impl From<$ty> for TVMValue {
+ fn from(val: $ty) -> Self {
+ TVMValue { $field: val as $field_ty }
}
}
- }
- impl<'a> From<&'a mut $type> for TVMRetValue {
- fn from(val: &mut $type) -> Self {
- TVMRetValue {
- prim_value: *val as usize,
- box_value: box (),
- type_code: $code,
+ impl From<TVMValue> for $ty {
+ fn from(val: TVMValue) -> Self {
+ unsafe { val.$field as $ty }
}
}
- }
-
- impl TryFrom<TVMRetValue> for $type {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<$type> {
- if ret.type_code == $code {
- Ok(ret.prim_value as $type)
- } else {
- bail!(ErrorKind::TryFromTVMRetValueError(
- stringify!($type).to_string(),
- ret.type_code.to_string(),
- ))
- }
- }
- }
+ )+
};
+ ($field:ident, $ty:ty) => {
+ impl_pod_tvm_value!($field, $ty, $ty);
+ }
}
-impl_prim_ret_value!(i8, TVMTypeCode::kDLInt);
-impl_prim_ret_value!(i16, TVMTypeCode::kDLInt);
-impl_prim_ret_value!(i32, TVMTypeCode::kDLInt);
-impl_prim_ret_value!(i64, TVMTypeCode::kDLInt);
-impl_prim_ret_value!(isize, TVMTypeCode::kDLInt);
-
-impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt);
-impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt);
-impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt);
-impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt);
-impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt);
-
-impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat);
-impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat);
+impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize);
+impl_pod_tvm_value!(v_float64, f64, f32, f64);
+impl_pod_tvm_value!(v_type, TVMType);
+impl_pod_tvm_value!(v_ctx, TVMContext);
-macro_rules! impl_ptr_ret_value {
- ($type:ty) => {
- impl From<$type> for TVMRetValue {
- fn from(ptr: $type) -> Self {
- TVMRetValue {
- prim_value: ptr as usize,
- box_value: box (),
- type_code: TVMTypeCode::kHandle,
- }
- }
- }
-
- impl TryFrom<TVMRetValue> for $type {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<$type> {
- if ret.type_code == TVMTypeCode::kHandle {
- Ok(ret.prim_value as $type)
- } else {
- bail!(ErrorKind::TryFromTVMRetValueError(
- stringify!($type).to_string(),
- ret.type_code.to_string(),
- ))
- }
+macro_rules! impl_tvm_context {
+ ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
+ /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev")
+ impl FromStr for TVMContext {
+ type Err = Error;
+ fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+ Ok(Self {
+ device_type: match type_str {
+ $( $( stringify!($dev_name) )|+ => $dev_type ),+,
+ _ => return Err(format_err!("device {} not supported", type_str).into()),
+ },
+ device_id: 0,
+ })
}
}
- };
-}
-
-impl_ptr_ret_value!(*const c_void);
-impl_ptr_ret_value!(*mut c_void);
-impl From<String> for TVMRetValue {
- fn from(val: String) -> Self {
- let pval = val.as_ptr() as *const c_char as usize;
- let bval = box (val.as_ptr() as *const c_char);
- mem::forget(val);
- TVMRetValue::new(pval, bval, TVMTypeCode::kStr)
- }
-}
-
-impl TryFrom<TVMRetValue> for String {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<String> {
- // Note: simple downcast doesn't work for function call return values
- let ret_str = unsafe {
- match CStr::from_ptr(ret.prim_value as *const c_char).to_str() {
- Ok(s) => s,
- Err(_) => "Invalid UTF-8 message",
- }
- };
-
- Ok(ret_str.to_string())
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::convert::TryInto;
-
- #[test]
- fn numeric() {
- macro_rules! arg_ret_tests {
- ($v:expr; $($ty:ty),+) => {{
+ impl TVMContext {
+ $(
$(
- let v = $v as $ty;
- let b = TVMRetValue::from(&v);
- let b: $ty = b.try_into().unwrap();
- assert_eq!(b, v);
+ pub fn $dev_name(device_id: usize) -> Self {
+ Self {
+ device_type: $dev_type,
+ device_id: device_id as i32,
+ }
+ }
)+
- }};
+ )+
}
-
- arg_ret_tests!(42; i8, i16, i32, i64, f32, f64);
- }
-
- #[test]
- fn string() {
- let s = "hello".to_string();
- let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap();
- assert_eq!(tvm_arg, s);
- }
+ };
}
+
+impl_tvm_context!(
+ DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
+ DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+ DLDeviceType_kDLOpenCL: [cl],
+ DLDeviceType_kDLMetal: [metal],
+ DLDeviceType_kDLVPI: [vpi],
+ DLDeviceType_kDLROCM: [rocm],
+ DLDeviceType_kDLExtDev: [ext_dev]
+);
+++ /dev/null
-[package]
-name = "tvm-sys"
-version = "0.1.0"
-authors = ["TVM Contributors"]
-license = "Apache-2.0"
-description = "Raw C API"
-
-[build-dependencies]
-bindgen = "0.37.4"
+++ /dev/null
-extern crate bindgen;
-
-use std::path::PathBuf;
-
-fn main() {
- println!("cargo:rerun-if-env-changed=TVM_HOME");
- println!("cargo:rustc-link-lib=dylib=tvm_runtime");
- println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
- let bindings = bindgen::Builder::default()
- .header(format!(
- "{}/include/tvm/runtime/c_runtime_api.h",
- env!("TVM_HOME")
- ))
- .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
- .blacklist_type("max_align_t") // @see rust-bindgen#550
- .layout_tests(false)
- .derive_partialeq(true)
- .derive_eq(true)
- .generate()
- .expect("unable to generate bindings");
-
- bindings
- .write_to_file(PathBuf::from("src/bindgen.rs"))
- .expect("can not write the bindings!");
-}
+++ /dev/null
-#![allow(
- non_camel_case_types,
- non_snake_case,
- non_upper_case_globals,
- dead_code,
- improper_ctypes
-)]
-
-include!("bindgen.rs");
crate-type = ["dylib"]
[dependencies]
-error-chain = "0.12.0"
+failure = "0.1.5"
lazy_static = "1.1.0"
ndarray = "0.12.1"
num-traits = "0.2"
-tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] }
+tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] }
[features]
blas = ["ndarray/blas"]
-#![feature(try_from)]
-
extern crate csv;
extern crate image;
extern crate ndarray;
convert::TryInto,
fs::{self, File},
path::Path,
+ str::FromStr,
};
use image::{FilterType, GenericImageView};
// make arr shape as [1, 3, 224, 224] acceptable to resnet
let arr = arr.insert_axis(Axis(0));
// create input tensor from rust's ndarray
- let input =
- NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+ let input = NDArray::from_rust_ndarray(
+ &arr,
+ TVMContext::cpu(0),
+ TVMType::from_str("float32").unwrap(),
+ )
+ .unwrap();
println!(
"input size is {:?}",
input.shape().expect("cannot get the input shape")
)))
.unwrap();
// get the global TVM graph runtime function
- let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
+ let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
let runtime_create_fn_ret = call_packed!(
runtime_create_fn,
&graph,
.get_function("set_input", false)
.unwrap();
- call_packed!(set_input_fn, "data", &input).unwrap();
+ let data_str = "data".to_string();
+ call_packed!(set_input_fn, &data_str, &input).unwrap();
// get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument
call_packed!(run_fn,).unwrap();
// prepare to get the output
let output_shape = &mut [1, 1000];
- let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
+ let output = NDArray::empty(
+ output_shape,
+ TVMContext::cpu(0),
+ TVMType::from_str("float32").unwrap(),
+ );
// get the `get_output` function from runtime module
let ref get_output_fn = graph_runtime_module
.get_function("get_output", false)
//!
//! For more detail, please see the example `resnet` in `examples` repository.
-use std::os::raw::c_char;
+use std::os::raw::{c_char, c_void};
-use crate::ts;
+use tvm_common::{ffi, TVMArgValue};
/// A struct holding TVM byte-array.
///
/// ```
#[derive(Debug, Clone)]
pub struct TVMByteArray {
- pub(crate) inner: ts::TVMByteArray,
+ pub(crate) inner: ffi::TVMByteArray,
}
impl TVMByteArray {
- pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray {
+ pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray {
TVMByteArray { inner: barr }
}
impl<'a> From<&'a Vec<u8>> for TVMByteArray {
fn from(arg: &Vec<u8>) -> Self {
- let barr = ts::TVMByteArray {
+ let barr = ffi::TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
};
}
}
+impl<'a> From<&TVMByteArray> for TVMArgValue<'a> {
+ fn from(arr: &TVMByteArray) -> Self {
+ Self {
+ value: ffi::TVMValue {
+ v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void,
+ },
+ type_code: ffi::TVMTypeCode_kBytes as i64,
+ _lifetime: std::marker::PhantomData,
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
//! ```
use std::{
+ convert::TryInto,
fmt::{self, Display, Formatter},
os::raw::c_void,
ptr,
};
-use crate::{function, ts, Result};
+use failure::Error;
+
+use tvm_common::{
+ ffi::{self, TVMValue},
+ TVMArgValue,
+};
+
+use crate::function;
/// Device type can be from a supported device name. See the supported devices
/// in [TVM](https://github.com/dmlc/tvm).
}
}
-impl From<TVMDeviceType> for ts::DLDeviceType {
+impl From<TVMDeviceType> for ffi::DLDeviceType {
fn from(device_type: TVMDeviceType) -> Self {
match device_type.0 {
- 1 => ts::DLDeviceType_kDLCPU,
- 2 => ts::DLDeviceType_kDLGPU,
- 3 => ts::DLDeviceType_kDLCPUPinned,
- 4 => ts::DLDeviceType_kDLOpenCL,
- 7 => ts::DLDeviceType_kDLVulkan,
- 8 => ts::DLDeviceType_kDLMetal,
- 9 => ts::DLDeviceType_kDLVPI,
- 10 => ts::DLDeviceType_kDLROCM,
- 12 => ts::DLDeviceType_kDLExtDev,
+ 1 => ffi::DLDeviceType_kDLCPU,
+ 2 => ffi::DLDeviceType_kDLGPU,
+ 3 => ffi::DLDeviceType_kDLCPUPinned,
+ 4 => ffi::DLDeviceType_kDLOpenCL,
+ 7 => ffi::DLDeviceType_kDLVulkan,
+ 8 => ffi::DLDeviceType_kDLMetal,
+ 9 => ffi::DLDeviceType_kDLVPI,
+ 10 => ffi::DLDeviceType_kDLROCM,
+ 12 => ffi::DLDeviceType_kDLExtDev,
_ => panic!("device type not found!"),
}
}
}
-impl From<ts::DLDeviceType> for TVMDeviceType {
- fn from(device_type: ts::DLDeviceType) -> Self {
+impl From<ffi::DLDeviceType> for TVMDeviceType {
+ fn from(device_type: ffi::DLDeviceType) -> Self {
match device_type {
- ts::DLDeviceType_kDLCPU => TVMDeviceType(1),
- ts::DLDeviceType_kDLGPU => TVMDeviceType(2),
- ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
- ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
- ts::DLDeviceType_kDLVulkan => TVMDeviceType(7),
- ts::DLDeviceType_kDLMetal => TVMDeviceType(8),
- ts::DLDeviceType_kDLVPI => TVMDeviceType(9),
- ts::DLDeviceType_kDLROCM => TVMDeviceType(10),
- ts::DLDeviceType_kDLExtDev => TVMDeviceType(12),
+ ffi::DLDeviceType_kDLCPU => TVMDeviceType(1),
+ ffi::DLDeviceType_kDLGPU => TVMDeviceType(2),
+ ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
+ ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
+ ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7),
+ ffi::DLDeviceType_kDLMetal => TVMDeviceType(8),
+ ffi::DLDeviceType_kDLVPI => TVMDeviceType(9),
+ ffi::DLDeviceType_kDLROCM => TVMDeviceType(10),
+ ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12),
_ => panic!("device type not found!"),
}
}
}
}
+impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> {
+ fn from(dev_type: &'a TVMDeviceType) -> Self {
+ Self {
+ value: TVMValue {
+ v_int64: dev_type.0 as i64,
+ },
+ type_code: ffi::DLDataTypeCode_kDLInt as i64,
+ _lifetime: std::marker::PhantomData,
+ }
+ }
+}
+
/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
/// Supported device types
pub device_type: TVMDeviceType,
/// Device id
- pub device_id: usize,
+ pub device_id: i32,
}
impl TVMContext {
/// Creates context from device type and id.
- pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
+ pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self {
TVMContext {
- device_type: device_type,
- device_id: device_id,
+ device_type,
+ device_id,
}
}
}
($(($ctx:ident, $dldevt:expr));+) => {
$(
impl TVMContext {
- pub fn $ctx(device_id: usize) -> Self {
+ pub fn $ctx(device_id: i32) -> Self {
Self::new(TVMDeviceType($dldevt), device_id)
}
}
impl TVMContext {
/// Checks whether the context exists or not.
pub fn exist(&self) -> bool {
- let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
- .expect("API function always exists");
+ let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
- let ret = call_packed!(func, &dt, &self.device_id, &0)
+ let ret: u64 = call_packed!(func, &dt, &self.device_id, &0)
.unwrap()
- .prim_value;
+ .try_into()
+ .unwrap();
ret != 0
}
/// Synchronize the context stream.
- pub fn sync(&self) -> Result<()> {
- check_call!(ts::TVMSynchronize(
+ pub fn sync(&self) -> Result<(), Error> {
+ check_call!(ffi::TVMSynchronize(
self.device_type.0 as i32,
self.device_id as i32,
ptr::null_mut() as *mut c_void
$(
impl TVMContext {
pub fn $attr_name(&self) -> usize {
- let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
+ let func = function::Function::get("_GetDeviceAttr")
.expect("API function always exists");
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
- let ret = function::Builder::from(func)
- .args(&[dt, self.device_id, $attr_kind])
+ function::Builder::from(func)
+ .args(&[dt, self.device_id as usize, $attr_kind])
.invoke()
- .unwrap();
- ret.prim_value as usize
+ .unwrap()
+ .try_into()
+ .unwrap()
}
}
)+
(multi_processor_count, 7);
(max_thread_dimensions, 8));
-impl From<ts::DLContext> for TVMContext {
- fn from(ctx: ts::DLContext) -> Self {
+impl From<ffi::DLContext> for TVMContext {
+ fn from(ctx: ffi::DLContext) -> Self {
TVMContext {
device_type: TVMDeviceType::from(ctx.device_type),
- device_id: ctx.device_id as usize,
+ device_id: ctx.device_id,
}
}
}
-impl From<TVMContext> for ts::DLContext {
+impl From<TVMContext> for ffi::DLContext {
fn from(ctx: TVMContext) -> Self {
- ts::DLContext {
+ ffi::DLContext {
device_type: ctx.device_type.into(),
device_id: ctx.device_id as i32,
}
-//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types.
+pub use failure::Error;
-use std::{ffi, option};
+#[derive(Debug, Fail)]
+#[fail(display = "Cannot convert from an empty array.")]
+pub struct EmptyArrayError;
-use crate::{common_errors, rust_ndarray};
-
-error_chain! {
- errors {
- EmptyArray {
- description("cannot convert from an empty array")
- }
-
- NullHandle(name: String) {
- description("null handle")
- display("requested `{}` handle is null", name)
- }
-
- FunctionNotFound {
- description("function not found")
- display("function was not set in `function::Builder`")
- }
-
- TypeMismatch(expected: String, found: String) {
- description("type mismatch!")
- display("expected type `{}`, but found `{}`", expected, found)
- }
-
- MissingShapeError {
- description("ndarray `shape()` returns `None`")
- display("called `Option::unwrap()` on a `None` value")
- }
-
- AtMostOneReturn {
- description("TVM functions accept at most one return value")
- }
+#[derive(Debug, Fail)]
+#[fail(display = "Handle `{}` is null.", name)]
+pub struct NullHandleError {
+ pub name: String,
+}
- }
+#[derive(Debug, Fail)]
+#[fail(display = "Function was not set in `function::Builder`")]
+pub struct FunctionNotFoundError;
- foreign_links {
- ShapeError(rust_ndarray::ShapeError);
- NulError(ffi::NulError);
- IntoStringError(ffi::IntoStringError);
- CommonError(common_errors::Error);
- }
+#[derive(Debug, Fail)]
+#[fail(display = "Expected type `{}` but found `{}`", expected, actual)]
+pub struct TypeMismatchError {
+ pub expected: String,
+ pub actual: String,
}
-impl From<option::NoneError> for Error {
- fn from(_err: option::NoneError) -> Self {
- ErrorKind::MissingShapeError.into()
- }
-}
+#[derive(Debug, Fail)]
+#[fail(display = "Missing NDArray shape.")]
+pub struct MissingShapeError;
sync::Mutex,
};
-use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue};
+use failure::Error;
+
+use crate::{
+ errors,
+ ffi::{self, TVMValue},
+ Module, TVMArgValue, TVMRetValue,
+};
lazy_static! {
static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
let mut out_size = 0 as c_int;
let name = ptr::null_mut() as *mut c_char;
let mut out_array = name as *mut _;
- check_call!(ts::TVMFuncListGlobalNames(
+ check_call!(ffi::TVMFuncListGlobalNames(
&mut out_size as *mut _,
&mut out_array
));
}
/// Wrapper around TVM function handle which includes `is_global`
-/// indicating whether the function is global or not, `is_released`
-/// to hint dropping the function handle and `is_cloned` showing
+/// indicating whether the function is global or not, and `is_cloned` showing
/// not to drop a cloned function from Rust side.
/// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)]
pub struct Function {
- pub(crate) handle: ts::TVMFunctionHandle,
+ pub(crate) handle: ffi::TVMFunctionHandle,
// whether the registered function is global or not.
is_global: bool,
- // whether the function has been dropped from frontend or not.
- is_released: bool,
// whether the function has been cloned from frontend or not.
is_cloned: bool,
}
unsafe impl Sync for Function {}
impl Function {
- pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self {
+ pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
Function {
handle: handle,
- is_global: is_global,
- is_released: is_released,
+ is_global: false,
is_cloned: false,
}
}
/// For a given function, it returns a function by name.
- pub fn get<S: AsRef<str>>(name: S, is_global: bool) -> Option<&'static Function> {
+ pub fn get<S: AsRef<str>>(name: S) -> Option<&'static Function> {
let mut globals = GLOBAL_FUNCTIONS.lock().unwrap();
globals.get_mut(name.as_ref()).and_then(|maybe_func| {
if maybe_func.is_none() {
let name = CString::new(name.as_ref()).unwrap();
- let mut handle = ptr::null_mut() as ts::TVMFunctionHandle;
- check_call!(ts::TVMFuncGetGlobal(
+ let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
+ check_call!(ffi::TVMFuncGetGlobal(
name.as_ptr() as *const c_char,
&mut handle as *mut _
));
- maybe_func.replace(Function::new(
- handle, is_global, false, /* is_released */
- ));
+ maybe_func.replace(Function {
+ handle: handle,
+ is_global: true,
+ is_cloned: false,
+ });
}
unsafe {
std::mem::transmute::<Option<&Function>, Option<&'static Function>>(
}
/// Returns the underlying TVM function handle.
- pub fn handle(&self) -> ts::TVMFunctionHandle {
+ pub fn handle(&self) -> ffi::TVMFunctionHandle {
self.handle
}
self.is_global
}
- /// Returns `true` if the underlying TVM function has been released
- /// from the frontend and `false` otherwise.
- pub fn is_released(&self) -> bool {
- self.is_released
- }
-
/// Returns `true` if the underlying TVM function has been cloned
/// from the frontend and `false` otherwise.
pub fn is_cloned(&self) -> bool {
impl Clone for Function {
fn clone(&self) -> Function {
- if !self.is_released && !self.is_cloned {
- Self {
- handle: self.handle,
- is_global: self.is_global,
- is_released: self.is_released,
- is_cloned: true,
- }
- } else {
- Function::new(self.handle, self.is_global, self.is_released)
+ Self {
+ handle: self.handle,
+ is_global: self.is_global,
+ is_cloned: true,
}
}
}
impl Drop for Function {
fn drop(&mut self) {
- if !self.is_released && !self.is_global && !self.is_cloned {
- check_call!(ts::TVMFuncFree(self.handle));
- self.is_released = true;
+ if !self.is_global && !self.is_cloned {
+ check_call!(ffi::TVMFuncFree(self.handle));
}
}
}
/// Function builder in order to create and call functions.
///
/// *Note:* Currently TVM functions accept *at most* one return value.
-#[derive(Debug, Clone, Default)]
+#[derive(Default)]
pub struct Builder<'a, 'm> {
pub func: Option<&'m Function>,
- pub arg_buf: Option<Box<[TVMArgValue<'a>]>>,
+ pub arg_buf: Vec<TVMArgValue<'a>>,
pub ret_buf: Option<TVMRetValue>,
}
impl<'a, 'm> Builder<'a, 'm> {
pub fn new(
func: Option<&'m Function>,
- arg_buf: Option<Box<[TVMArgValue<'a>]>>,
+ arg_buf: Vec<TVMArgValue<'a>>,
ret_buf: Option<TVMRetValue>,
) -> Self {
Self {
}
}
- pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self {
- self.func = Function::get(name, is_global);
+ pub fn get_function(&mut self, name: &'m str) -> &mut Self {
+ self.func = Function::get(name);
self
}
/// Pushes a [`TVMArgValue`] into the function argument buffer.
- pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self
+ pub fn arg<T: 'a>(&mut self, arg: &'a T) -> &mut Self
where
- TVMValue: From<&'b T>,
- TVMTypeCode: From<&'b T>,
+ TVMArgValue<'a>: From<&'a T>,
{
- let tvm_arg = TVMArgValue::from(arg);
- if self.arg_buf.is_none() {
- self.arg_buf = Some(Box::new([tvm_arg]));
- } else {
- let new_arg_buf = self.arg_buf.take().map(|bbuf| {
- let mut new_arg_buf = Vec::from(bbuf);
- new_arg_buf.push(tvm_arg);
- let new_len = new_arg_buf.len();
- new_arg_buf.truncate(new_len);
- new_arg_buf.into_boxed_slice()
- });
- self.arg_buf = new_arg_buf;
- }
+ self.arg_buf.push(arg.into());
self
}
/// Pushes multiple [`TVMArgValue`]s into the function argument buffer.
- pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self
+ pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self
where
- I: IntoIterator<Item = &'b T>,
- TVMValue: From<&'b T>,
- TVMTypeCode: From<&'b T>,
+ I: IntoIterator<Item = &'a T>,
+ TVMArgValue<'a>: From<&'a T>,
{
- for arg in args {
+ args.into_iter().for_each(|arg| {
self.arg(&arg);
- }
+ });
self
}
/// Sets an output for a function that requirs a mutable output to be provided.
/// See the `basics` in tests for an example.
- pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self>
+ pub fn set_output<T>(&mut self, ret: T) -> &mut Self
where
- TVMValue: From<&'b T>,
- TVMTypeCode: From<&'b T>,
+ TVMRetValue: From<T>,
{
- if self.ret_buf.is_none() {
- let tvm_ret =
- unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) };
- self.ret_buf = Some(tvm_ret);
- } else {
- bail!(ErrorKind::AtMostOneReturn)
- }
- Ok(self)
+ self.ret_buf = Some(ret.into());
+ self
}
/// Calls the function that created from `Builder`.
- pub fn invoke(&mut self) -> Result<TVMRetValue> {
- self.clone()(())
- }
-}
-
-impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> {
- type Output = Result<TVMRetValue>;
- extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output {
- if self.func.is_none() {
- bail!("{}", ErrorKind::FunctionNotFound);
- }
-
- let mut ret_val = unsafe { mem::uninitialized::<ts::TVMValue>() };
- let mut ret_type_code = 0 as c_int;
- if self.arg_buf.is_some() {
- let arg_buf = self.arg_buf?;
- let mut num_args = arg_buf.len();
- let mut values = arg_buf
- .iter()
- .map(|tav| tav.value.inner)
- .collect::<Vec<ts::TVMValue>>();
- let mut tcodes = arg_buf
- .iter()
- .map(|tav| tav.type_code as c_int)
- .collect::<Vec<_>>();
-
- if self.ret_buf.is_some() {
- num_args = num_args + 1;
- let ret_buf = self.ret_buf?;
- let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf);
- values.append(&mut vec![ret_val.inner]);
- tcodes.append(&mut vec![ret_type_code as c_int]);
- }
-
- values.truncate(num_args);
- tcodes.truncate(num_args);
- check_call!(ts::TVMFuncCall(
- self.func?.handle,
- values.as_mut_ptr(),
- tcodes.as_mut_ptr(),
- num_args as c_int,
- &mut ret_val as *mut _,
- &mut ret_type_code as *mut _
- ));
- } else {
- check_call!(ts::TVMFuncCall(
- self.func?.handle,
- ptr::null_mut(),
- ptr::null_mut(),
- 0 as c_int,
- &mut ret_val as *mut _,
- &mut ret_type_code as *mut _
- ));
- }
+ pub fn invoke(&mut self) -> Result<TVMRetValue, Error> {
+ #![allow(unused_unsafe)]
+ ensure!(self.func.is_some(), errors::FunctionNotFoundError);
+
+ let num_args = self.arg_buf.len();
+ let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = self
+ .arg_buf
+ .iter()
+ .map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode))
+ .unzip();
+
+ let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() };
+ let mut ret_type_code = 0;
+ check_call!(ffi::TVMFuncCall(
+ self.func.ok_or(errors::FunctionNotFoundError)?.handle,
+ values.as_mut_ptr(),
+ type_codes.as_mut_ptr() as *mut i32,
+ num_args as c_int,
+ &mut ret_val as *mut _,
+ &mut ret_type_code as *mut _
+ ));
- let ret = unsafe {
- TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into())
- };
- Ok(ret)
+ Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64) })
}
}
/// TVM functions.
impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> {
fn from(func: &'m Function) -> Self {
- Builder::new(Some(func), None, None)
+ Builder::new(Some(func), Vec::new(), None)
}
}
/// Converts a mutable reference of a [`Module`] to [`Builder`].
impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
fn from(module: &'m mut Module) -> Self {
- Builder::new(module.entry(), None, None)
+ Builder::new(module.entry(), Vec::new(), None)
}
}
unsafe extern "C" fn tvm_callback(
- args: *mut ts::TVMValue,
+ args: *mut ffi::TVMValue,
type_codes: *mut c_int,
num_args: c_int,
- ret: ts::TVMRetValueHandle,
+ ret: ffi::TVMRetValueHandle,
fhandle: *mut c_void,
) -> c_int {
// turning off the incorrect linter complaints
- #![allow(unused_assignments)]
+ #![allow(unused_assignments, unused_unsafe)]
let len = num_args as usize;
let args_list = slice::from_raw_parts_mut(args, len);
let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
let mut local_args: Vec<TVMArgValue> = Vec::new();
- let mut value = mem::uninitialized::<ts::TVMValue>();
+ let mut value = mem::uninitialized::<ffi::TVMValue>();
let mut tcode = mem::uninitialized::<c_int>();
- let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
+ let rust_fn =
+ mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
for i in 0..len {
value = args_list[i];
tcode = type_codes_list[i];
- if tcode == ts::TVMTypeCode_kNodeHandle as c_int
- || tcode == ts::TVMTypeCode_kFuncHandle as c_int
- || tcode == ts::TVMTypeCode_kModuleHandle as c_int
+ if tcode == ffi::TVMTypeCode_kNodeHandle as c_int
+ || tcode == ffi::TVMTypeCode_kFuncHandle as c_int
+ || tcode == ffi::TVMTypeCode_kModuleHandle as c_int
{
- check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode));
+ check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode));
}
- local_args.push(TVMArgValue::new(
- TVMValue::new(value),
- (tcode as i64).into(),
- ));
+ local_args.push(TVMArgValue::new(value.into(), (tcode as i64).into()));
}
let rv = match rust_fn(local_args.as_slice()) {
}
};
- let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv);
- let mut ret_val = ret_val.inner;
+ let (mut ret_val, ret_tcode) = rv.into_tvm_value();
let mut ret_type_code = ret_tcode as c_int;
- check_call!(ts::TVMCFuncSetReturn(
+ check_call!(ffi::TVMCFuncSetReturn(
ret,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _,
}
unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) {
- let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue>>(fhandle);
+ let rust_fn =
+ mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>>(fhandle);
mem::drop(rust_fn);
}
-fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue>) -> Function {
- let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
- let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue>;
- check_call!(ts::TVMFuncCreateFromCFunc(
+fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> Function {
+ let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
+ let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>;
+ check_call!(ffi::TVMFuncCreateFromCFunc(
Some(tvm_callback),
resource_handle as *mut c_void,
Some(tvm_callback_finalizer),
&mut fhandle as *mut _
));
- Function::new(fhandle, false, false)
+ Function::new(fhandle)
}
/// Registers a Rust function with signature
-/// `fn(&[TVMArgValue]) -> Result<TVMRetValue>`
+/// `fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>`
/// as a **global TVM packed function** from frontend to TVM backend.
///
/// Use [`register_global_func`] if overriding an existing global TVM function
/// ```
/// use std::convert::TryInto;
///
-/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// let mut ret = 0i64;
/// for arg in args.iter() {
/// let arg: i64 = arg.try_into()?;
/// assert_eq!(ret, 60);
/// ```
pub fn register<S: AsRef<str>>(
- f: fn(&[TVMArgValue]) -> Result<TVMRetValue>,
+ f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>,
name: S,
override_: bool,
-) -> Result<()> {
+) -> Result<(), Error> {
let func = convert_to_tvm_func(f);
let name = CString::new(name.as_ref())?;
- check_call!(ts::TVMFuncRegisterGlobal(
- name.as_ref().as_ptr() as *const c_char,
+ check_call!(ffi::TVMFuncRegisterGlobal(
+ name.into_raw(),
func.handle(),
override_ as c_int
));
- mem::forget(name);
Ok(())
}
/// use std::convert::TryInto;
///
/// register_global_func! {
-/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// let mut ret = 0f64;
/// for arg in args.iter() {
/// let arg: f64 = arg.try_into()?;
macro_rules! register_global_func {
{
$(#[$m:meta])*
- fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue> {
+ fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result<TVMRetValue, Error> {
$($code:tt)*
}
} => {{
$(#[$m])*
- fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ fn $fn_name($args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
$($code)*
}
#[test]
fn get_fn() {
- assert!(Function::get(CANARY, true).is_some());
- assert!(Function::get("does not exists!", false).is_none());
+ assert!(Function::get(CANARY).is_some());
+ assert!(Function::get("does not exists!").is_none());
}
#[test]
fn provide_args() {
+ let str_arg = CString::new("test").unwrap();
let mut func = Builder::default();
- func.get_function("tvm.graph_runtime.remote_create", true)
+ func.get_function("tvm.graph_runtime.remote_create")
.args(&[10, 20])
- .arg(&"test".to_owned());
- assert!(func.arg_buf.is_some());
- assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3));
+ .arg(&str_arg);
+ assert_eq!(func.arg_buf.len(), 3);
}
}
//!
//! Checkout the `examples` repository for more details.
-#![crate_name = "tvm_frontend"]
-#![recursion_limit = "1024"]
-#![allow(non_camel_case_types, unused_unsafe)]
-#![feature(
- try_from,
- try_trait,
- fn_traits,
- unboxed_closures,
- box_syntax,
- option_replace
-)]
+#![feature(box_syntax)]
#[macro_use]
-extern crate error_chain;
-extern crate tvm_common as common;
+extern crate failure;
#[macro_use]
extern crate lazy_static;
extern crate ndarray as rust_ndarray;
extern crate num_traits;
+extern crate tvm_common;
use std::{
ffi::{CStr, CString},
str,
};
-use crate::common::ffi::ts;
+use failure::Error;
+
+pub use crate::{
+ bytearray::TVMByteArray,
+ context::{TVMContext, TVMDeviceType},
+ errors::*,
+ function::Function,
+ module::Module,
+ ndarray::NDArray,
+ tvm_common::{
+ errors as common_errors,
+ ffi::{self, TVMType},
+ packed_func::{TVMArgValue, TVMRetValue},
+ },
+};
// Macro to check the return call to TVM runtime shared library.
macro_rules! check_call {
/// Gets the last error message.
pub fn get_last_error() -> &'static str {
unsafe {
- match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
+ match CStr::from_ptr(ffi::TVMGetLastError()).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
}
pub(crate) fn set_last_error(err: &Error) {
let c_string = CString::new(err.to_string()).unwrap();
unsafe {
- ts::TVMAPISetLastError(c_string.as_ptr());
+ ffi::TVMAPISetLastError(c_string.as_ptr());
}
}
pub mod errors;
pub mod module;
pub mod ndarray;
-pub mod ty;
pub mod value;
-pub use crate::{
- bytearray::TVMByteArray,
- common::{
- errors as common_errors,
- ty::TVMTypeCode,
- value::{TVMArgValue, TVMRetValue, TVMValue},
- },
- context::{TVMContext, TVMDeviceType},
- errors::*,
- function::Function,
- module::Module,
- ndarray::NDArray,
- ty::TVMType,
-};
-
/// Outputs the current TVM version.
pub fn version() -> &'static str {
- match str::from_utf8(ts::TVM_VERSION) {
+ match str::from_utf8(ffi::TVM_VERSION) {
Ok(s) => s,
Err(_) => "Invalid UTF-8 string",
}
#[test]
fn set_error() {
- let err = ErrorKind::EmptyArray;
+ let err = errors::EmptyArrayError;
set_last_error(&err.into());
- assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
+ assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string());
}
}
ptr,
};
-use crate::ts;
+use failure::Error;
+use tvm_common::ffi;
-use crate::{function::Function, ErrorKind, Result};
+use crate::{errors, function::Function};
const ENTRY_FUNC: &'static str = "__tvm_main__";
/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
-/// Also [`is_released`] shows whether the module is dropped or not.
///
/// [`entry_func`]:struct.Module.html#method.entry_func
-/// [`is_released`]:struct.Module.html#method.is_released
#[derive(Debug, Clone)]
pub struct Module {
- pub(crate) handle: ts::TVMModuleHandle,
- is_released: bool,
+ pub(crate) handle: ffi::TVMModuleHandle,
entry_func: Option<Function>,
}
impl Module {
- pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self {
+ pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
Self {
handle,
- is_released,
entry_func: None,
}
}
}
/// Gets a function by name from a registered module.
- pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
+ pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> {
let name = CString::new(name)?;
- let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
- check_call!(ts::TVMModGetFunction(
+ let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
+ check_call!(ffi::TVMModGetFunction(
self.handle,
name.as_ptr() as *const c_char,
query_import as c_int,
&mut fhandle as *mut _
));
- if fhandle.is_null() {
- bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
- } else {
- Ok(Function::new(fhandle, false, false))
- }
+ ensure!(
+ !fhandle.is_null(),
+ errors::NullHandleError {
+ name: format!("{}", name.into_string()?)
+ }
+ );
+ Ok(Function::new(fhandle))
}
/// Imports a dependent module such as `.ptx` for gpu.
pub fn import_module(&self, dependent_module: Module) {
- check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
+ check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
}
/// Loads a module shared library from path.
- pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> {
- let ext = path.as_ref().extension()?.to_str()?;
- let func = Function::get("module._LoadFromFile", true /* is_global */)
- .expect("API function always exists");
- let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?;
+ pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> {
+ let ext = CString::new(
+ path.as_ref()
+ .extension()
+ .unwrap_or(std::ffi::OsStr::new(""))
+ .to_str()
+ .ok_or_else(|| {
+ format_err!("Bad module load path: `{}`.", path.as_ref().display())
+ })?,
+ )?;
+ let func = Function::get("module._LoadFromFile").expect("API function always exists");
+ let cpath =
+ CString::new(path.as_ref().to_str().ok_or_else(|| {
+ format_err!("Bad module load path: `{}`.", path.as_ref().display())
+ })?)?;
+ let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?;
Ok(ret)
}
/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
- let func = Function::get("module._Enabled", true /* is_global */)
- .expect("API function always exists");
+ let func = Function::get("module._Enabled").expect("API function always exists");
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
- let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap();
+ let tgt = CString::new(target).unwrap();
+ let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap();
ret != 0
}
/// Returns the underlying module handle.
- pub fn handle(&self) -> ts::TVMModuleHandle {
+ pub fn handle(&self) -> ffi::TVMModuleHandle {
self.handle
}
-
- /// Returns true if the underlying module has been dropped and false otherwise.
- pub fn is_released(&self) -> bool {
- self.is_released
- }
}
impl Drop for Module {
fn drop(&mut self) {
- if !self.is_released {
- check_call!(ts::TVMModFree(self.handle));
- self.is_released = true;
- }
+ check_call!(ffi::TVMModFree(self.handle));
}
}
//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
-use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice};
+use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
-use crate::rust_ndarray::{Array, ArrayD};
+use failure::Error;
use num_traits::Num;
+use rust_ndarray::{Array, ArrayD};
+use tvm_common::{ffi, TVMType};
-use crate::ts;
-
-use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType};
+use crate::{errors, TVMByteArray, TVMContext};
/// See the [`module-level documentation`](../ndarray/index.html) for more details.
///
/// Wrapper around TVM array handle.
#[derive(Debug)]
pub struct NDArray {
- pub(crate) handle: ts::TVMArrayHandle,
+ pub(crate) handle: ffi::TVMArrayHandle,
is_view: bool,
}
impl NDArray {
- pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self {
+ pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
NDArray {
handle: handle,
- is_view: is_view,
+ is_view: true,
}
}
/// Returns the underlying array handle.
- pub fn handle(&self) -> ts::TVMArrayHandle {
+ pub fn handle(&self) -> ffi::TVMArrayHandle {
self.handle
}
}
/// Shows whether the underlying ndarray is contiguous in memory or not.
- pub fn is_contiguous(&self) -> Result<bool> {
+ pub fn is_contiguous(&self) -> Result<bool, Error> {
Ok(match self.strides() {
None => true,
Some(strides) => {
- // MissingShapeError in case shape is not determined
- self.shape()?
+ // errors::MissingShapeError in case shape is not determined
+ self.shape()
+ .ok_or(errors::MissingShapeError)?
.iter()
.zip(strides)
.rfold(
/// assert_eq!(ndarray.shape(), Some(shape));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ```
- pub fn to_vec<T>(&self) -> Result<Vec<T>> {
- if self.shape().is_none() {
- bail!("{}", ErrorKind::EmptyArray);
- }
- let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype());
+ pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> {
+ ensure!(self.shape().is_some(), errors::EmptyArrayError);
+ let earr = NDArray::empty(
+ self.shape().ok_or(errors::MissingShapeError)?,
+ TVMContext::cpu(0),
+ self.dtype(),
+ );
let target = self.copy_to_ndarray(earr)?;
let arr = unsafe { *(target.handle) };
- let sz = self.size()? as usize;
+ let sz = self.size().ok_or(errors::MissingShapeError)?;
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
unsafe {
v.as_mut_ptr()
}
/// Converts the NDArray to [`TVMByteArray`].
- pub fn to_bytearray(&self) -> Result<TVMByteArray> {
+ pub fn to_bytearray(&self) -> Result<TVMByteArray, Error> {
let v = self.to_vec::<u8>()?;
Ok(TVMByteArray::from(&v))
}
/// *Note*: if something goes wrong during the copy, it will panic
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
- check_call!(ts::TVMArrayCopyFromBytes(
+ check_call!(ffi::TVMArrayCopyFromBytes(
self.handle,
data.as_ptr() as *mut _,
data.len() * mem::size_of::<T>()
}
/// Copies the NDArray to another target NDArray.
- pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
+ pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, Error> {
if self.dtype() != target.dtype() {
bail!(
"{}",
- ErrorKind::TypeMismatch(
- format!("{}", self.dtype().to_string()),
- format!("{}", target.dtype().to_string()),
- )
+ errors::TypeMismatchError {
+ expected: format!("{}", self.dtype().to_string()),
+ actual: format!("{}", target.dtype().to_string()),
+ }
);
}
- check_call!(ts::TVMArrayCopyFromTo(
+ check_call!(ffi::TVMArrayCopyFromTo(
self.handle,
target.handle,
- ptr::null_mut() as ts::TVMStreamHandle
+ ptr::null_mut() as ffi::TVMStreamHandle
));
Ok(target)
}
/// Copies the NDArray to a target context.
- pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray> {
- let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype());
+ pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray, Error> {
+ let tmp = NDArray::empty(
+ self.shape().ok_or(errors::MissingShapeError)?,
+ target.clone(),
+ self.dtype(),
+ );
let copy = self.copy_to_ndarray(tmp)?;
Ok(copy)
}
rnd: &ArrayD<T>,
ctx: TVMContext,
dtype: TVMType,
- ) -> Result<Self> {
+ ) -> Result<Self, Error> {
let mut shape = rnd.shape().to_vec();
let mut nd = NDArray::empty(&mut shape, ctx, dtype);
let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
- nd.copy_from_buffer(buf.as_slice_mut()?);
+ nd.copy_from_buffer(
+ buf.as_slice_mut()
+ .expect("Array from iter must be contiguous."),
+ );
Ok(nd)
}
/// Allocates and creates an empty NDArray given the shape, context and dtype.
pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray {
- let mut handle = ptr::null_mut() as ts::TVMArrayHandle;
- check_call!(ts::TVMArrayAlloc(
+ let mut handle = ptr::null_mut() as ffi::TVMArrayHandle;
+ check_call!(ffi::TVMArrayAlloc(
shape.as_ptr() as *const i64,
shape.len() as c_int,
- dtype.inner.code as c_int,
- dtype.inner.bits as c_int,
- dtype.inner.lanes as c_int,
+ dtype.code as c_int,
+ dtype.bits as c_int,
+ dtype.lanes as c_int,
ctx.device_type.0 as c_int,
ctx.device_id as c_int,
&mut handle as *mut _,
));
- NDArray::new(handle, false)
+ NDArray {
+ handle,
+ is_view: false,
+ }
}
}
($type:ty, $type_name:tt) => {
impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
type Error = Error;
- fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
- if nd.shape().is_none() {
- bail!("{}", ErrorKind::EmptyArray);
- }
- assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
- Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
+ fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> {
+ ensure!(nd.shape().is_some(), errors::MissingShapeError);
+ assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch");
+ Ok(Array::from_shape_vec(
+ &*nd.shape().ok_or(errors::MissingShapeError)?,
+ nd.to_vec::<$type>()?,
+ )?)
}
}
impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
type Error = Error;
- fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
- if nd.shape().is_none() {
- bail!("{}", ErrorKind::EmptyArray);
- }
- assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
- Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
+ fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> {
+ ensure!(nd.shape().is_some(), errors::MissingShapeError);
+ assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch");
+ Ok(Array::from_shape_vec(
+ &*nd.shape().ok_or(errors::MissingShapeError)?,
+ nd.to_vec::<$type>()?,
+ )?)
}
}
};
impl Drop for NDArray {
fn drop(&mut self) {
if !self.is_view {
- check_call!(ts::TVMArrayFree(self.handle));
+ check_call!(ffi::TVMArrayFree(self.handle));
}
}
}
fn basics() {
let shape = &mut [1, 2, 3];
let ctx = TVMContext::cpu(0);
- let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
+ let ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap());
assert_eq!(ndarray.shape().unwrap(), shape);
assert_eq!(
ndarray.size().unwrap(),
let shape = &mut [4];
let mut data = vec![1i32, 2, 3, 4];
let ctx = TVMContext::cpu(0);
- let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32"));
+ let mut ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap());
assert!(ndarray.to_vec::<i32>().is_ok());
ndarray.copy_from_buffer(&mut data);
assert_eq!(ndarray.shape().unwrap(), shape);
assert!(ndarray.is_contiguous().is_ok());
assert_eq!(ndarray.byte_offset(), 0);
let mut shape = vec![4];
- let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32"));
+ let e = NDArray::empty(
+ &mut shape,
+ TVMContext::cpu(0),
+ TVMType::from_str("int32").unwrap(),
+ );
let nd = ndarray.copy_to_ndarray(e);
assert!(nd.is_ok());
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
let mut shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.];
let ctx = TVMContext::cpu(0);
- let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32"));
+ let mut nd_float = NDArray::empty(
+ &mut shape,
+ ctx.clone(),
+ TVMType::from_str("float32").unwrap(),
+ );
nd_float.copy_from_buffer(&mut data);
- let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32"));
+ let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap());
nd_float.copy_to_ndarray(empty_int).unwrap();
}
let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
.unwrap()
.into_dyn();
- let nd =
- NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
+ let nd = NDArray::from_rust_ndarray(
+ &a,
+ TVMContext::cpu(0),
+ TVMType::from_str("float32").unwrap(),
+ )
+ .unwrap();
assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
assert!(rnd.all_close(&a, 1e-8f32));
+++ /dev/null
-//! This module implements the required conversions from Rust types to TVM types.
-//!
-//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
-//! and 64-bits pointers are supported.
-
-use std::{
- fmt::{self, Display, Formatter},
- ops::{Deref, DerefMut},
-};
-
-use crate::ts;
-
-use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
-
-macro_rules! impl_prim_type {
- ($type:ty, $variant:ident) => {
- impl From<$type> for TVMTypeCode {
- fn from(_arg: $type) -> Self {
- TVMTypeCode::$variant
- }
- }
-
- impl<'a> From<&'a $type> for TVMTypeCode {
- fn from(_arg: &$type) -> Self {
- TVMTypeCode::$variant
- }
- }
-
- impl<'a> From<&'a mut $type> for TVMTypeCode {
- fn from(_arg: &mut $type) -> Self {
- TVMTypeCode::$variant
- }
- }
- };
-}
-
-impl_prim_type!(TVMDeviceType, kDLInt);
-impl_prim_type!(TVMContext, kTVMContext);
-impl_prim_type!(TVMType, kTVMType);
-impl_prim_type!(Function, kFuncHandle);
-impl_prim_type!(Module, kModuleHandle);
-impl_prim_type!(NDArray, kArrayHandle);
-impl_prim_type!(TVMByteArray, kBytes);
-
-/// See the [module-level documentation](../ty/index.html) for more details.
-///
-/// Wrapper around underlying TVMType
-#[derive(Debug, Copy, Clone, PartialEq, Eq)]
-pub struct TVMType {
- // inner fields are (code: u8, bits: u8, lanes: u16)
- pub inner: ts::TVMType,
-}
-
-impl TVMType {
- pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
- TVMType {
- inner: ts::TVMType {
- code: type_code,
- bits: bits,
- lanes: lanes,
- },
- }
- }
-}
-
-/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
-/// such as "int32", "float32" or with lane "float32x1".
-impl<'a> From<&'a str> for TVMType {
- fn from(type_str: &'a str) -> Self {
- if type_str == "bool" {
- return TVMType::new(1, 1, 1);
- }
-
- let mut type_lanes = type_str.split("x");
- let typ = type_lanes.next().expect("Missing dtype");
- let lanes = type_lanes
- .next()
- .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
- .unwrap_or(1);
- let (type_name, bits) = match typ.find(char::is_numeric) {
- Some(idx) => {
- let (name, bits_str) = typ.split_at(idx);
- (
- name,
- u8::from_str_radix(bits_str, 10)
- .expect(&format!("Bad dtype bits: {}", bits_str)),
- )
- }
- None => (typ, 32),
- };
-
- let type_code = match type_name {
- "int" => 0,
- "uint" => 1,
- "float" => 2,
- "handle" => 3,
- _ => unimplemented!(),
- };
-
- TVMType::new(type_code, bits, lanes)
- }
-}
-
-impl Display for TVMType {
- fn fmt(&self, f: &mut Formatter) -> fmt::Result {
- let ts::TVMType { code, bits, lanes } = self.inner;
- if bits == 1 && lanes == 1 {
- return write!(f, "bool");
- }
- let mut tcode_str = match code {
- 0 => "int",
- 1 => "uint",
- 2 => "float",
- 4 => "handle",
- _ => "Unknown",
- }
- .to_string();
-
- tcode_str += &bits.to_string();
- if lanes > 1 {
- tcode_str += &format!("x{}", lanes.to_string());
- }
- f.write_str(&tcode_str)
- }
-}
-
-impl From<TVMType> for ts::DLDataType {
- fn from(dtype: TVMType) -> Self {
- dtype.inner
- }
-}
-
-impl From<ts::DLDataType> for TVMType {
- fn from(dtype: ts::DLDataType) -> Self {
- Self::new(dtype.code, dtype.bits, dtype.lanes)
- }
-}
-
-impl Deref for TVMType {
- type Target = ts::TVMType;
- fn deref(&self) -> &Self::Target {
- &self.inner
- }
-}
-
-impl DerefMut for TVMType {
- fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.inner
- }
-}
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
-use std::{convert::TryFrom, mem, os::raw::c_void};
+use std::{convert::TryFrom, os::raw::c_void};
+
+use failure::Error;
+use tvm_common::{
+ ensure_type,
+ ffi::{self, TVMValue},
+};
use crate::{
- common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext,
- TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue,
+ common_errors::*, context::TVMContext, Function, Module, NDArray, TVMArgValue, TVMByteArray,
+ TVMRetValue,
};
macro_rules! impl_tvm_val_from_handle {
- ($($ty:ty),+) => {
- $(
- impl<'a> From<&'a $ty> for TVMValue {
- fn from(arg: &$ty) -> Self {
- let inner = ts::TVMValue {
+ ($ty:ident, $type_code:expr, $handle:ty) => {
+ impl<'a> From<&'a $ty> for TVMArgValue<'a> {
+ fn from(arg: &$ty) -> Self {
+ TVMArgValue {
+ value: TVMValue {
v_handle: arg.handle as *mut _ as *mut c_void,
- };
- Self::new(inner)
+ },
+ type_code: $type_code as i64,
+ _lifetime: std::marker::PhantomData,
}
}
- )+
- }
-}
-
-impl_tvm_val_from_handle!(Module, Function, NDArray);
-
-impl<'a> From<&'a TVMType> for TVMValue {
- fn from(ty: &TVMType) -> Self {
- let inner = ts::TVMValue { v_type: ty.inner };
- Self::new(inner)
- }
-}
-
-impl<'a> From<&'a TVMContext> for TVMValue {
- fn from(ctx: &TVMContext) -> Self {
- let inner = ts::TVMValue {
- v_ctx: ctx.clone().into(),
- };
- Self::new(inner)
- }
-}
-
-impl<'a> From<&'a TVMDeviceType> for TVMValue {
- fn from(dev: &TVMDeviceType) -> Self {
- let inner = ts::TVMValue {
- v_int64: dev.0 as i64,
- };
- Self::new(inner)
- }
-}
-
-impl<'a> From<&'a TVMByteArray> for TVMValue {
- fn from(barr: &TVMByteArray) -> Self {
- let inner = ts::TVMValue {
- v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void,
- };
- Self::new(inner)
- }
-}
+ }
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray {
- type Error = Error;
- fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
- if arg.type_code == TVMTypeCode::kArrayHandle {
- let handle = unsafe { arg.value.inner.v_handle };
- let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) };
- Ok(Self::new(arr_handle, true))
- } else {
- bail!(ErrorKind::TryFromTVMArgValueError(
- stringify!(NDArray).to_string(),
- arg.type_code.to_string()
- ))
+ impl<'a> From<&'a mut $ty> for TVMArgValue<'a> {
+ fn from(arg: &mut $ty) -> Self {
+ TVMArgValue {
+ value: TVMValue {
+ v_handle: arg.handle as *mut _ as *mut c_void,
+ },
+ type_code: $type_code as i64,
+ _lifetime: std::marker::PhantomData,
+ }
+ }
}
- }
-}
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module {
- type Error = Error;
- fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
- if arg.type_code == TVMTypeCode::kModuleHandle {
- let handle = unsafe { arg.value.inner.v_handle };
- Ok(Self::new(handle, false))
- } else {
- bail!(ErrorKind::TryFromTVMArgValueError(
- stringify!(Module).to_string(),
- arg.type_code.to_string()
- ))
+ impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty {
+ type Error = Error;
+ fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty, Self::Error> {
+ ensure_type!(arg, $type_code);
+ Ok($ty::new(unsafe { arg.value.v_handle as $handle }))
+ }
}
- }
-}
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray {
- type Error = Error;
- fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
- if arg.type_code == TVMTypeCode::kBytes {
- unsafe {
- let barr_ptr =
- mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle);
- Ok(Self::new(*barr_ptr))
+ impl From<$ty> for TVMRetValue {
+ fn from(val: $ty) -> TVMRetValue {
+ TVMRetValue {
+ value: TVMValue {
+ v_handle: val.handle() as *mut c_void,
+ },
+ box_value: box val,
+ type_code: $type_code as i64,
+ }
}
- } else {
- bail!(ErrorKind::TryFromTVMArgValueError(
- stringify!(TVMByteArray).to_string(),
- arg.type_code.to_string()
- ))
}
- }
-}
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType {
- type Error = Error;
- fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
- if arg.type_code == TVMTypeCode::kTVMType {
- let ty = unsafe { arg.value.inner.v_type };
- Ok(TVMType::from(ty))
- } else {
- bail!(ErrorKind::TryFromTVMArgValueError(
- stringify!(TVMType).to_string(),
- arg.type_code.to_string()
- ))
+ impl TryFrom<TVMRetValue> for $ty {
+ type Error = Error;
+ fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> {
+ ensure_type!(ret, $type_code);
+ Ok($ty::new(unsafe { ret.value.v_handle as $handle }))
+ }
}
- }
+ };
}
-impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext {
- type Error = Error;
- fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
- if arg.type_code == TVMTypeCode::kTVMContext {
- let ty = unsafe { arg.value.inner.v_ctx };
- Ok(TVMContext::from(ty))
- } else {
- bail!(ErrorKind::TryFromTVMArgValueError(
- stringify!(TVMContext).to_string(),
- arg.type_code.to_string()
- ))
+impl_tvm_val_from_handle!(
+ Function,
+ ffi::TVMTypeCode_kFuncHandle,
+ ffi::TVMFunctionHandle
+);
+impl_tvm_val_from_handle!(Module, ffi::TVMTypeCode_kModuleHandle, ffi::TVMModuleHandle);
+impl_tvm_val_from_handle!(NDArray, ffi::TVMTypeCode_kArrayHandle, ffi::TVMArrayHandle);
+
+impl<'a> From<&'a TVMByteArray> for TVMValue {
+ fn from(barr: &TVMByteArray) -> Self {
+ TVMValue {
+ v_handle: &barr.inner as *const ffi::TVMByteArray as *mut c_void,
}
}
}
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
- prim_value: 0,
+ value: TVMValue { v_int64: 0 },
box_value: box val,
- type_code: $code,
+ type_code: $code as i64,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<$type> {
+ fn try_from(ret: TVMRetValue) -> Result<$type, Self::Error> {
if let Ok(val) = ret.box_value.downcast::<$type>() {
Ok(*val)
} else {
- bail!(ErrorKind::TryFromTVMRetValueError(
- stringify!($type).to_string(),
- ret.type_code.to_string()
- ))
+ bail!(ValueDowncastError::new($code as i64, ret.type_code as i64))
}
}
}
};
}
-impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType);
-impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext);
-impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes);
-
-impl TryFrom<TVMRetValue> for Module {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<Module> {
- if let Ok(handle) = ret.box_value.downcast::<ts::TVMModuleHandle>() {
- Ok(Module::new(*handle, false))
- } else {
- bail!(ErrorKind::TryFromTVMRetValueError(
- stringify!(TVMTypeCode::kModuleHandle).to_string(),
- ret.type_code.to_string()
- ))
- }
- }
-}
-
-impl TryFrom<TVMRetValue> for Function {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<Function> {
- if let Ok(handle) = ret.box_value.downcast::<ts::TVMFunctionHandle>() {
- Ok(Function::new(*handle, false, false))
- } else {
- bail!(ErrorKind::TryFromTVMRetValueError(
- stringify!(TVMTypeCode::kFuncHandle).to_string(),
- ret.type_code.to_string()
- ))
- }
- }
-}
+impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext);
+impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes);
-impl TryFrom<TVMRetValue> for NDArray {
+impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray {
type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<NDArray> {
- if let Ok(handle) = ret.box_value.downcast::<ts::TVMArrayHandle>() {
- Ok(NDArray::new(*handle, false))
- } else {
- bail!(ErrorKind::TryFromTVMRetValueError(
- stringify!(TVMTypeCode::kArrayHandle).to_string(),
- ret.type_code.to_string()
- ))
- }
+ fn try_from(arg: &TVMArgValue<'v>) -> Result<Self, Self::Error> {
+ ensure_type!(arg, ffi::TVMTypeCode_kBytes);
+ Ok(TVMByteArray::new(unsafe {
+ *(arg.value.v_handle as *mut ffi::TVMByteArray)
+ }))
}
}
#[cfg(test)]
mod tests {
use super::*;
- use std::convert::TryInto;
+ use std::{convert::TryInto, str::FromStr};
+ use tvm_common::ffi::TVMType;
#[test]
fn bytearray() {
#[test]
fn ty() {
- let t = TVMType::from("int32");
+ let t = TVMType::from_str("int32").unwrap();
let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
assert_eq!(tvm, t);
}
extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm;
+use std::str::FromStr;
+
use tvm::*;
fn main() {
} else {
(TVMContext::gpu(0), "gpu")
};
- let dtype = TVMType::from("float32");
+ let dtype = TVMType::from_str("float32").unwrap();
let mut arr = NDArray::empty(shape, ctx, dtype);
arr.copy_from_buffer(data.as_mut_slice());
let mut ret = NDArray::empty(shape, ctx, dtype);
function::Builder::from(&mut fadd)
.arg(&arr)
.arg(&arr)
- .set_output(&mut ret)
- .unwrap()
+ .arg(&mut ret)
.invoke()
.unwrap();
-#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm;
use rust_ndarray::ArrayD;
-use std::convert::{TryFrom, TryInto};
+use std::{
+ convert::{TryFrom, TryInto},
+ str::FromStr,
+};
-use tvm::*;
+use tvm::{errors::Error, *};
fn main() {
register_global_func! {
- fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0f32;
let shape = &mut [2];
for arg in args.iter() {
- let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+ let e = NDArray::empty(
+ shape, TVMContext::cpu(0),
+ TVMType::from_str("float32").unwrap()
+ );
let arg: NDArray = arg.try_into()?;
let arr = arg.copy_to_ndarray(e)?;
let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
let shape = &mut [2];
let mut data = vec![3f32, 4.0];
- let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+ let mut arr = NDArray::empty(
+ shape,
+ TVMContext::cpu(0),
+ TVMType::from_str("float32").unwrap(),
+ );
arr.copy_from_buffer(data.as_mut_slice());
let mut registered = function::Builder::default();
let ret: f32 = registered
- .get_function("sum", true)
+ .get_function("sum")
.arg(&arr)
.arg(&arr)
.invoke()
-#![feature(extern_crate_item_prelude, panic_info_message)]
+#![feature(panic_info_message)]
#![allow(unused_imports)]
use std::panic;
#[macro_use]
extern crate tvm_frontend as tvm;
-use tvm::*;
+use tvm::{errors::Error, *};
fn main() {
register_global_func! {
- fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> {
- Err(ErrorKind::TypeMismatch(
- format!("{}", "i64".to_string()),
- format!("{}", "f64".to_string()),
- ).into())
+ fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
+ Err(errors::TypeMismatchError{
+ expected: "i64".to_string(),
+ actual: "f64".to_string(),
+ }.into())
}
}
let mut registered = function::Builder::default();
- registered.get_function("error", true);
+ registered.get_function("error");
assert!(registered.func.is_some());
registered.args(&[10, 20]);
-#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
-use tvm::*;
+use tvm::{errors::Error, *};
fn main() {
register_global_func! {
- fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0.0;
- for arg in args.iter() {
+ for arg in args.into_iter() {
let val: f64 = arg.try_into()?;
ret += val;
}
- Ok(TVMRetValue::from(&ret))
+ Ok(TVMRetValue::from(ret))
}
}
let mut registered = function::Builder::default();
- registered.get_function("sum", true);
+ registered.get_function("sum");
assert!(registered.func.is_some());
let ret: f64 = registered
.args(&[10.0f64, 20.0, 30.0])
-#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
-use tvm::*;
+use tvm::{errors::Error, *};
fn main() {
- fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = 0i64;
for arg in args.iter() {
let val: i64 = arg.try_into()?;
ret += val;
}
- Ok(TVMRetValue::from(&ret))
+ Ok(TVMRetValue::from(ret))
}
tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
let mut registered = function::Builder::default();
- registered.get_function("mysum", true);
+ registered.get_function("mysum");
assert!(registered.func.is_some());
let ret: i64 = registered
.args(&[10, 20, 30])
-#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
-use tvm::*;
+use tvm::{errors::Error, *};
// FIXME
fn main() {
register_global_func! {
- fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
let mut ret = "".to_string();
for arg in args.iter() {
- let val: String = arg.try_into()?;
- ret += val.as_str();
+ let val: &str = arg.try_into()?;
+ ret += val;
}
Ok(TVMRetValue::from(ret))
}
}
+ let a = std::ffi::CString::new("a").unwrap();
+ let b = std::ffi::CString::new("b").unwrap();
+ let c = std::ffi::CString::new("c").unwrap();
let mut registered = function::Builder::default();
- registered.get_function("concate_str", true);
+ registered.get_function("concate_str");
assert!(registered.func.is_some());
- let a = "a".to_string();
- let b = "b".to_string();
- let c = "c".to_string();
let ret: String = registered
- .args(&[a, b, c])
+ .arg(&a)
+ .arg(&b)
+ .arg(&c)
.invoke()
.unwrap()
.try_into()
+++ /dev/null
-Cargo.lock
-target/
-**/*.rs.bk
[dependencies]
bounded-spsc-queue = "0.4.0"
-error-chain = { version = "0.12.0", default-features = false }
+failure = "0.1.5"
itertools = "0.7.8"
lazy_static = "1.1.0"
-ndarray = "0.11.2"
+ndarray="0.12.1"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
-tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] }
+tvm-common = { version = "0.1.0", path = "../common/" }
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
#[cfg(target_env = "sgx")]
-use alloc::alloc::{self, Layout};
+use alloc::alloc::{self, Layout, LayoutErr};
#[cfg(not(target_env = "sgx"))]
-use std::alloc::{self, Layout};
-
-use crate::errors::*;
+use std::alloc::{self, Layout, LayoutErr};
const DEFAULT_ALIGN_BYTES: usize = 4;
impl Allocation {
/// Allocates a chunk of memory of `size` bytes with optional alignment.
- pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
+ pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) };
-use std::{
- any::TypeId,
- convert::TryFrom,
- mem,
- ops::{Deref, DerefMut},
- os::raw::{c_int, c_void},
- ptr, slice,
-};
+use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
+use failure::Error;
use ndarray;
-
-use crate::{
- allocator::Allocation,
- errors::*,
- ffi::runtime::{
+use tvm_common::{
+ array::{DataType, TVMContext},
+ ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt,
- DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor,
+ DLDataTypeCode_kDLUInt, DLTensor,
},
};
+use crate::allocator::Allocation;
+
/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
pub enum Storage<'a> {
}
impl<'a> Storage<'a> {
- pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
+ pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}
byte_offset: 0,
}
}
+
+ pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
+ assert!(!flatten || self.is_contiguous());
+ DLTensor {
+ data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
+ ctx: DLContext::from(&self.ctx),
+ ndim: if flatten { 1 } else { self.shape.len() } as i32,
+ dtype: DLDataType::from(&self.dtype),
+ shape: if flatten {
+ &self.size as *const _ as *mut i64
+ } else {
+ self.shape.as_ptr()
+ } as *mut i64,
+ strides: if flatten || self.is_contiguous() {
+ ptr::null_mut()
+ } else {
+ self.strides.as_ref().unwrap().as_ptr()
+ } as *mut i64,
+ byte_offset: 0,
+ }
+ }
}
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
- fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
+ fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
};
}
-impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
-impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
-impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
-impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
-
-pub struct DLTensor {
- pub(crate) inner: _DLTensor,
-}
-
-impl Deref for DLTensor {
- type Target = _DLTensor;
- fn deref(&self) -> &Self::Target {
- &self.inner
- }
-}
-
-impl DerefMut for DLTensor {
- fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.inner
- }
-}
-
-impl DLTensor {
- pub(crate) fn new(raw: _DLTensor) -> Self {
- Self { inner: raw }
- }
-
- pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
- assert!(!flatten || tensor.is_contiguous());
- Self {
- inner: _DLTensor {
- data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void,
- ctx: DLContext::from(&tensor.ctx),
- ndim: if flatten { 1 } else { tensor.shape.len() } as i32,
- dtype: DLDataType::from(&tensor.dtype),
- shape: if flatten {
- &tensor.size as *const _ as *mut i64
- } else {
- tensor.shape.as_ptr()
- } as *mut i64,
- strides: if flatten || tensor.is_contiguous() {
- ptr::null_mut()
- } else {
- tensor.strides.as_ref().unwrap().as_ptr()
- } as *mut i64,
- byte_offset: 0,
- },
- }
- }
-}
-
-impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
- fn from(tensor: &'a Tensor<'t>) -> Self {
- DLTensor::from_tensor(tensor, false /* flatten */)
- }
-}
-
-impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
- fn from(tensor: &'a mut Tensor<'t>) -> Self {
- DLTensor::from_tensor(tensor, false /* flatten */)
- }
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq)]
-pub struct DataType {
- pub(crate) code: usize,
- pub(crate) bits: usize,
- pub(crate) lanes: usize,
-}
-
-impl DataType {
- /// Returns the number of bytes occupied by an element of this `DataType`.
- pub fn itemsize(&self) -> usize {
- (self.bits * self.lanes) >> 3
- }
-
- /// Returns whether this `DataType` represents primitive type `T`.
- pub fn is_type<T: 'static>(&self) -> bool {
- if self.lanes != 1 {
- return false;
- }
- let typ = TypeId::of::<T>();
- (typ == TypeId::of::<i32>() && self.code == 0 && self.bits == 32)
- || (typ == TypeId::of::<i64>() && self.code == 0 && self.bits == 64)
- || (typ == TypeId::of::<u32>() && self.code == 1 && self.bits == 32)
- || (typ == TypeId::of::<u64>() && self.code == 1 && self.bits == 64)
- || (typ == TypeId::of::<f32>() && self.code == 2 && self.bits == 32)
- || (typ == TypeId::of::<f64>() && self.code == 2 && self.bits == 64)
- }
-}
-
-impl<'a> From<&'a DataType> for DLDataType {
- fn from(dtype: &'a DataType) -> Self {
- Self {
- code: dtype.code as u8,
- bits: dtype.bits as u8,
- lanes: dtype.lanes as u16,
- }
- }
-}
-
-impl From<DLDataType> for DataType {
- fn from(dtype: DLDataType) -> Self {
- Self {
- code: dtype.code as usize,
- bits: dtype.bits as usize,
- lanes: dtype.lanes as usize,
- }
- }
-}
-
macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
- const $name: DataType = DataType {
+ pub const $name: DataType = DataType {
code: $code as usize,
bits: $bits,
lanes: $lanes,
// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
+impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
+impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
+impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
+impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
-#[derive(Debug, Clone, Copy, PartialEq)]
-pub struct TVMContext {
- pub(crate) device_type: usize,
- pub(crate) device_id: usize,
-}
-
-impl<'a> From<&'a TVMContext> for DLContext {
- fn from(ctx: &'a TVMContext) -> Self {
- Self {
- device_type: ctx.device_type as u32,
- device_id: ctx.device_id as i32,
- }
+impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
+ fn from(tensor: &'a Tensor<'t>) -> Self {
+ Tensor::as_dltensor(tensor, false /* flatten */)
}
}
-impl Default for TVMContext {
- fn default() -> Self {
- Self {
- device_type: DLDeviceType_kDLCPU as usize,
- device_id: 0,
- }
+impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
+ fn from(tensor: &'a mut Tensor<'t>) -> Self {
+ Tensor::as_dltensor(tensor, false /* flatten */)
}
}
};
}
-/// `From` conversions to `DLTensor` for `ndarray::Array`.
-/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
-macro_rules! impl_dltensor_from_ndarray {
- ($type:ty, $typecode:expr) => {
- impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
- fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
- DLTensor {
- inner: _DLTensor {
- data: arr.as_mut_ptr() as *mut c_void,
- ctx: DLContext {
- device_type: DLDeviceType_kDLCPU,
- device_id: 0,
- },
- ndim: arr.ndim() as c_int,
- dtype: DLDataType {
- code: $typecode as u8,
- bits: 8 * mem::size_of::<$type>() as u8,
- lanes: 1,
- },
- shape: arr.shape().as_ptr() as *const i64 as *mut i64,
- strides: arr.strides().as_ptr() as *const isize as *mut i64,
- byte_offset: 0,
- },
- }
- }
- }
- };
-}
-
-impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
-impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
-impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
-impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
-impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
-impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
-
impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
-#[cfg(target_env = "sgx")]
-use alloc::alloc;
-#[cfg(not(target_env = "sgx"))]
-use std::alloc;
-use std::num;
-
-use crate::common::errors as common_errors;
-use ndarray;
-use serde_json;
-
-error_chain! {
- errors {
- GraphFormatError(msg: String) {
- description("unable to load graph")
- display("could not load graph json: {}", msg)
- }
-
- LoadGraphParamsError(msg: String) {
- description("unable to load graph params")
- display("could not load graph params: {}", msg)
- }
- }
- foreign_links {
- Alloc(alloc::AllocErr);
- GraphDeserialize(serde_json::Error);
- ParseInt(num::ParseIntError);
- ShapeError(ndarray::ShapeError);
- CommonError(common_errors::Error);
- }
+#[derive(Debug, Fail)]
+pub enum GraphFormatError {
+ #[fail(display = "Could not parse graph json")]
+ Parse(#[fail(cause)] failure::Error),
+ #[fail(display = "Could not parse graph params")]
+ Params,
+ #[fail(display = "{} is missing attr: {}", 0, 1)]
+ MissingAttr(String, String),
+ #[fail(display = "Missing field: {}", 0)]
+ MissingField(&'static str),
+ #[fail(display = "Invalid DLType: {}", 0)]
+ InvalidDLType(String),
}
-impl From<alloc::LayoutErr> for Error {
- fn from(_err: alloc::LayoutErr) -> Error {
- Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
- }
+#[derive(Debug, Fail)]
+#[fail(display = "SGX error: 0x{:x}", code)]
+pub struct SgxError {
+ pub code: u32,
}
use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
+use failure::Error;
use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
use serde;
use serde_json;
-
-use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor};
-use crate::{
- common::value::TVMArgValue,
- errors::{Error, ErrorKind, Result},
- ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt},
+use tvm_common::{
+ array::{DataType, TVMContext},
+ ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor},
+ TVMArgValue,
};
+use crate::{errors::GraphFormatError, Module, Storage, Tensor};
+
// @see `kTVMNDArrayMagic` in `ndarray.h`
const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
}
impl Graph {
- fn entry_index(&self, entry: &Entry) -> Result<usize> {
+ fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
self.node_row_ptr
.as_ref()
.map(|nrp| nrp[entry.id] + entry.index)
- .ok_or("Missing node_row_ptr.".into())
+ .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
}
/// Attempt to deserialize a JSON attribute to a type `T`.
- fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
+ fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
Ok(serde_json::from_value::<T>(
self.attrs
.as_ref()
- .ok_or(ErrorKind::GraphFormatError(
- "Missing graph attrs".to_string(),
- ))?
+ .ok_or(GraphFormatError::MissingField("attrs"))?
.get(attr)
- .ok_or(ErrorKind::GraphFormatError(format!(
- "Missing {} attr",
- attr
- )))?
+ .ok_or_else(|| {
+ GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
+ })?
.to_owned(),
- )?)
+ )
+ .map_err(|err| GraphFormatError::Parse(err.into()))?)
}
}
flatten_data: bool,
}
+macro_rules! get_node_attr {
+ ($node:expr, $attrs:ident, $attr:literal) => {
+ $attrs
+ .get($attr)
+ .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
+ };
+}
+
impl Node {
- fn parse_attrs(&self) -> Result<NodeAttrs> {
+ fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
let attrs = self
.attrs
.as_ref()
- .ok_or(format!("Missing node.attrs for `{}`", self.name))?;
- let func_name = attrs
- .get("func_name")
- .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
- .to_string();
- let num_outputs = attrs
- .get("num_outputs")
- .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
- .parse::<usize>()?;
- let flatten_data = attrs
- .get("flatten_data")
- .ok_or(format!(
- "Node `{}` is missing attrs.flatten_data",
- self.name
- ))?
- .parse::<u8>()?
- == 1;
+ .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
Ok(NodeAttrs {
- func_name,
- num_outputs,
- flatten_data,
+ func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
+ num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
+ flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
})
}
}
impl<'a> TryFrom<&'a String> for Graph {
type Error = Error;
- fn try_from(graph_json: &String) -> Result<Self> {
+ fn try_from(graph_json: &String) -> Result<Self, self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
impl<'a> TryFrom<&'a str> for Graph {
type Error = Error;
- fn try_from(graph_json: &'a str) -> Result<Self> {
+ fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
impl<'m, 't> GraphExecutor<'m, 't> {
- pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
+ pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
let tensors = Self::setup_storages(&graph)?;
Ok(GraphExecutor {
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
}
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
- fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
+ fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
let dtypes = graph
if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
Ok(dtype)
} else {
- Err(ErrorKind::GraphFormatError(
- format!("Invalid dltype: {}", dltype).to_string(),
- )
- .into())
+ Err(GraphFormatError::InvalidDLType(dltype.to_string()))
}
})
- .collect::<Result<Vec<DataType>>>()?;
+ .collect::<Result<Vec<DataType>, GraphFormatError>>()?;
- let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
+ let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max();
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
for (i, &storage_id) in storage_ids.iter().enumerate() {
- let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
+ let dtype_size = dtypes[i].bits() * dtypes[i].lanes() >> 3;
let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
}
let mut storages: Vec<Storage> = storage_num_bytes
.into_iter()
.map(|nbytes| Storage::new(nbytes, align))
- .collect::<Result<Vec<Storage>>>()?;
+ .collect::<Result<Vec<Storage>, Error>>()?;
let tensors = izip!(storage_ids, shapes, dtypes)
.map(|(storage_id, shape, dtype)| {
graph: &Graph,
lib: &'m M,
tensors: &Vec<Tensor<'t>>,
- ) -> Result<Vec<Box<Fn() + 'm>>> {
+ ) -> Result<Vec<Box<Fn() + 'm>>, Error> {
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
continue;
}
- let func = lib
- .get_function(&attrs.func_name)
- .ok_or(format!("Missing function {}", attrs.func_name))?;
+ let func = lib.get_function(&attrs.func_name).ok_or(format_err!(
+ "Library is missing function {}",
+ attrs.func_name
+ ))?;
let arg_indices = node
.inputs
.iter()
.map(|idx| {
let tensor = &tensors[idx?];
Ok(if attrs.flatten_data {
- DLTensor::from_tensor(tensor, true /* flatten */)
+ Tensor::as_dltensor(tensor, true /* flatten */)
} else {
DLTensor::from(tensor)
})
})
- .collect::<Result<Vec<DLTensor>>>()
+ .collect::<Result<Vec<DLTensor>, Error>>()
.unwrap();
let op: Box<Fn()> = box move || {
let args = dl_tensors
.iter()
.map(|t| t.into())
.collect::<Vec<TVMArgValue>>();
- func(args.as_slice());
+ func(args.as_slice()).unwrap();
};
op_execs.push(op);
}
}
}
-/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
+// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
named!(
tvm_str_to_type<CompleteStr, DataType>,
do_parse!(
)
);
-/// Converts a bytes to String.
+// Converts a bytes to String.
named!(
name<String>,
map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
))
);
-/// Parses a TVMContext
+// Parses a TVMContext
named!(
tvm_ctx<&[u8], TVMContext>,
do_parse!(
)
);
-/// Parses a DataType
+// Parses a DataType
named!(
data_type<&[u8], DataType>,
do_parse!(
)
);
-/// Parses a Tensor from a TVM array file.
+// Parses a Tensor from a TVM array file.
named!(
tensor<Tensor>,
do_parse!(
)
);
-/// Parses a graph params dict from a params binary file.
+// Parses a graph params dict from a params binary file.
named!(
parse_param_dict<HashMap<String, Tensor>>,
do_parse!(
);
/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
-pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
+pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
- if remaining_bytes.len() > 0 {
- bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
- } else {
+ if remaining_bytes.len() == 0 {
Ok(param_dict)
+ } else {
+ Err(GraphFormatError::Params)
}
} else {
- bail!(ErrorKind::LoadGraphParamsError(
- "invalid parameters file".to_string()
- ))
+ Err(GraphFormatError::Params)
}
}
allocator_api,
box_syntax,
fn_traits,
- try_from,
unboxed_closures,
vec_remove_item
)]
#[cfg(target_env = "sgx")]
extern crate core;
#[macro_use]
-extern crate error_chain;
+extern crate failure;
#[macro_use]
extern crate itertools;
#[macro_use]
#[macro_use]
extern crate serde_derive;
extern crate serde_json;
-extern crate tvm_common as common;
+extern crate tvm_common;
mod allocator;
mod array;
pub mod errors;
-mod module;
-#[macro_use]
-mod packed_func;
mod graph;
+mod module;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading;
mod workspace;
-pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue};
-
-pub use self::{
- array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*,
+pub use tvm_common::{
+ call_packed,
+ errors::*,
+ ffi::{self, DLTensor},
+ packed_func::{self, *},
+ TVMArgValue, TVMRetValue,
};
-#[cfg(target_env = "sgx")]
-use self::sgx::ocall_packed_func;
+pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*};
+
+lazy_static! {
+ static ref LAST_ERROR: std::sync::RwLock<Option<&'static std::ffi::CStr>> =
+ std::sync::RwLock::new(None);
+}
#[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
- #[cfg(not(target_env = "sgx"))]
- unsafe {
- panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
- }
+ *LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) });
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
}
+
+#[no_mangle]
+pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char {
+ match *LAST_ERROR.read().unwrap() {
+ Some(err) => err.as_ptr(),
+ None => std::ptr::null(),
+ }
+}
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
};
-use crate::{
- ffi::runtime::BackendPackedCFunc,
- packed_func::{wrap_backend_packed_func, PackedFunc},
+use tvm_common::{
+ ffi::BackendPackedCFunc,
+ packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
};
pub trait Module {
- fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
+ fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
}
pub struct SystemLibModule;
lazy_static! {
- static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
+ static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
Mutex::new(HashMap::new());
}
impl Module for SystemLibModule {
- fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
+ fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.get(name.as_ref())
- .map(|func| wrap_backend_packed_func(func.to_owned()))
+ .map(|f| *f)
}
}
}
}
+// @see `WrapPackedFunc` in `llvm_module.cc`.
+pub(super) fn wrap_backend_packed_func(
+ func_name: String,
+ func: BackendPackedCFunc,
+) -> Box<dyn PackedFunc> {
+ box move |args: &[TVMArgValue]| {
+ let exit_code = func(
+ args.iter()
+ .map(|ref arg| arg.value)
+ .collect::<Vec<TVMValue>>()
+ .as_ptr(),
+ args.iter()
+ .map(|ref arg| arg.type_code as i32)
+ .collect::<Vec<i32>>()
+ .as_ptr() as *const i32,
+ args.len() as i32,
+ );
+ if exit_code == 0 {
+ Ok(TVMRetValue::default())
+ } else {
+ Err(tvm_common::errors::FuncCallError::get_with_context(
+ func_name.clone(),
+ ))
+ }
+ }
+}
+
#[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char,
func: BackendPackedCFunc,
) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
- SYSTEM_LIB_FUNCTIONS
- .lock()
- .unwrap()
- .insert(name.to_string(), func);
+ SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
+ name.to_string(),
+ &*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
+ );
return 0;
}
+++ /dev/null
-use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
-
-use super::Tensor;
-use crate::ffi::runtime::{
- BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
- TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
-};
-
-use super::DLTensor;
-use crate::{
- common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
- errors::*,
-};
-
-pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
-
-/// Calls a packed function and returns a `TVMRetValue`.
-///
-/// # Example
-///
-/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
-#[macro_export]
-macro_rules! call_packed {
- ($fn:expr, $($args:expr),+) => {
- $fn(&[$($args.into(),)+])
- };
- ($fn:expr) => {
- $fn(&Vec::new())
- };
-}
-
-impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
- fn from(arr: &'a DLTensor) -> Self {
- let raw = _TVMValue {
- v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
- };
- TVMArgValue {
- value: TVMValue::new(raw),
- type_code: TVMTypeCode::kArrayHandle,
- lifetime: PhantomData,
- }
- }
-}
-
-impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
- fn from(arr: &'a mut DLTensor) -> Self {
- let raw = _TVMValue {
- v_handle: arr as *mut _ as *mut c_void,
- };
- TVMArgValue {
- value: TVMValue::new(raw),
- type_code: TVMTypeCode::kArrayHandle,
- lifetime: PhantomData,
- }
- }
-}
-
-impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
- type Error = Error;
- fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
- ensure!(
- val.type_code == TVMTypeCode::kArrayHandle
- || val.type_code == TVMTypeCode::kNDArrayContainer,
- "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
- TVMTypeCode::kArrayHandle,
- TVMTypeCode::kNDArrayContainer,
- val.type_code,
- );
-
- let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
- Ok(DLTensor::new(dlt).into())
- }
-}
-
-impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
- fn from(val: &'t Tensor<'a>) -> Self {
- TVMRetValue {
- prim_value: 0,
- box_value: box DLTensor::from(val),
- type_code: TVMTypeCode::kNDArrayContainer,
- }
- }
-}
-
-impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
- type Error = Error;
- fn try_from(ret: TVMRetValue) -> Result<Self> {
- ensure!(
- ret.type_code == TVMTypeCode::kArrayHandle
- || ret.type_code == TVMTypeCode::kNDArrayContainer,
- "Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
- TVMTypeCode_kArrayHandle,
- TVMTypeCode_kNDArrayContainer,
- ret.type_code,
- );
-
- let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
- Ok(DLTensor::new(dlt).into())
- }
-}
-
-// @see `WrapPackedFunc` in `llvm_module.cc`.
-pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
- box move |args: &[TVMArgValue]| {
- func(
- args.iter()
- .map(|ref arg| arg.value.inner)
- .collect::<Vec<_TVMValue>>()
- .as_ptr(),
- args.iter()
- .map(|ref arg| arg.type_code as i32)
- .collect::<Vec<i32>>()
- .as_ptr() as *const i32,
- args.len() as i32,
- );
- TVMRetValue::default()
- }
-}
os::raw::{c_char, c_int},
};
-use errors::Result;
-use ffi::runtime::TVMValue;
-use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
-
-pub use runtime::threading::tvm_run_worker as run_worker;
+pub use crate::threading::tvm_run_worker as run_worker;
+use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
+use errors::SgxError;
+use ffi::TVMValue;
#[macro_export]
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
- err => Err(format!("SGX error: {}", err)),
+ code => Err(SgxError { code }),
}
};
}
) -> SgxStatus;
}
-pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
+pub fn ocall_packed_func<S: AsRef<str>>(
+ fn_name: S,
+ args: &[TVMArgValue],
+) -> Result<TVMRetValue, SgxError> {
let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64;
unsafe {
#[macro_export]
macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => {
- ocall_packed_func($fn_name, &[$($args.into(),)+])
+ $crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`"))
};
($fn_name:expr) => {
- ocall_packed_func($fn_name, &Vec::new())
+ $crate::sgx::ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`"))
}
}
use std::{
os::raw::{c_int, c_void},
sync::{
- atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
+ atomic::{AtomicUsize, Ordering},
Arc, Barrier,
},
};
use std::{collections::VecDeque, ptr, sync::Mutex};
use bounded_spsc_queue::{self, Producer};
-
-use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv};
+use tvm_common::ffi::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
-use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue};
+use super::{TVMArgValue, TVMRetValue};
type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
}
/// Waits for all tasks in this `Job` to be completed.
- fn wait(&self) -> Result<()> {
+ fn wait(&self) {
while self.pending.load(Ordering::Acquire) > 0 {
#[cfg(not(target_env = "sgx"))]
thread::yield_now();
}
- Ok(())
}
}
}
tasks.pop().unwrap()();
- job.wait().unwrap();
+ job.wait();
}
fn run_worker(queue: Consumer<Task>) {
cb: cb,
cdata: cdata,
req_num_tasks: num_task,
- pending: Arc::new(ATOMIC_USIZE_INIT),
+ pending: Arc::new(AtomicUsize::new(0)),
});
});
}
cb: poison_pill,
cdata: ptr::null(),
req_num_tasks: 0,
- pending: Arc::new(ATOMIC_USIZE_INIT),
+ pending: Arc::new(AtomicUsize::new(0)),
});
});
ocall_packed!("__sgx_thread_group_join__", 0);
#[test]
fn test_parallel_launch() {
TVMBackendParallelLaunch(flambda, ptr::null(), 6);
- let counter = ATOMIC_USIZE_INIT;
- let task_ids_sum = ATOMIC_USIZE_INIT;
+ let counter = AtomicUsize::new(0);
+ let task_ids_sum = AtomicUsize::new(0);
let cdata = (counter, task_ids_sum);
let num_tasks = 3;
TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
ptr,
};
-use super::allocator::Allocation;
-use crate::errors::*;
+use failure::Error;
+
+use crate::allocator::Allocation;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
}
}
- fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
+ fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}
- fn alloc(&mut self, size: usize) -> Result<*mut u8> {
+ fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
if self.free.len() == 0 {
return self.alloc_new(size);
}
}
}
- fn free(&mut self, ptr: *mut u8) -> Result<()> {
+ fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
let mut ws_idx = None;
for i in 0..self.in_use.len() {
let idx = self.in_use[i];
}
Ok(self
.free
- .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?))
+ .push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?))
}
}
authors = ["TVM Contributors"]
[dependencies]
-ndarray = "0.11.2"
+ndarray="0.12.1"
serde = "1.0.59"
serde_json = "1.0.17"
tvm-runtime = { path = "../../" }
authors = ["TVM Contributors"]
[dependencies]
-ndarray = "0.11.2"
+ndarray="0.12.1"
tvm-runtime = { path = "../../" }
[build-dependencies]
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();
- call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
+ call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32));
}
# test common
cd $RUST_DIR/common
-cargo build --features runtime
-cargo test --features runtime --tests
+cargo build
+cargo test --tests
-cargo build --features frontend
-cargo test --features frontend --tests
+cargo build --features bindings
+cargo test --features bindings --tests
# test runtime
cd $RUST_DIR/runtime