[Rust] Some rust cleanups (#6116)
authorJason Knight <binarybana@gmail.com>
Thu, 23 Jul 2020 21:04:30 +0000 (14:04 -0700)
committerGitHub <noreply@github.com>
Thu, 23 Jul 2020 21:04:30 +0000 (14:04 -0700)
* Some rust cleanups

* Turn off default features for bindgen
* Upgrade some deps for smaller total dep tree
* Switch (/complete switch) to thiserror
* Remove unnecessary transmutes

* Fix null pointer assert

* Update wasm32 test

rust/tvm-graph-rt/Cargo.toml
rust/tvm-graph-rt/src/array.rs
rust/tvm-graph-rt/src/errors.rs
rust/tvm-graph-rt/src/graph.rs
rust/tvm-graph-rt/src/module/dso.rs
rust/tvm-graph-rt/src/threading.rs
rust/tvm-graph-rt/src/workspace.rs
rust/tvm-graph-rt/tests/test_wasm32/src/main.rs
rust/tvm-macros/Cargo.toml
rust/tvm-rt/src/object/object_ptr.rs
rust/tvm-sys/Cargo.toml

index 0cf2ac1..d8dfcdb 100644 (file)
@@ -28,8 +28,9 @@ authors = ["TVM Contributors"]
 edition = "2018"
 
 [dependencies]
-crossbeam = "0.7.3"
-failure = "0.1"
+crossbeam-channel = "0.4"
+thiserror = "1"
+
 itertools = "0.8"
 lazy_static = "1.4"
 ndarray="0.12"
index 1ed0f3c..b911aa8 100644 (file)
 
 use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
 
-use failure::{ensure, Error};
 use ndarray;
 use tvm_sys::{ffi::DLTensor, Context, DataType};
 
 use crate::allocator::Allocation;
+use crate::errors::ArrayError;
+use std::alloc::LayoutErr;
 
 /// A `Storage` is a container which holds `Tensor` data.
 #[derive(PartialEq)]
@@ -36,7 +37,7 @@ pub enum Storage<'a> {
 }
 
 impl<'a> Storage<'a> {
-    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, LayoutErr> {
         Ok(Storage::Owned(Allocation::new(size, align)?))
     }
 
@@ -297,13 +298,11 @@ impl<'a> Tensor<'a> {
 macro_rules! impl_ndarray_try_from_tensor {
     ($type:ty, $dtype:expr) => {
         impl<'t> TryFrom<Tensor<'t>> for ndarray::ArrayD<$type> {
-            type Error = Error;
-            fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
-                ensure!(
-                    tensor.dtype == $dtype,
-                    "Cannot convert Tensor with dtype {:?} to ndarray",
-                    tensor.dtype
-                );
+            type Error = ArrayError;
+            fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Self::Error> {
+                if tensor.dtype != $dtype {
+                    return Err(ArrayError::IncompatibleDataType(tensor.dtype));
+                }
                 Ok(ndarray::Array::from_shape_vec(
                     tensor
                         .shape
@@ -311,7 +310,8 @@ macro_rules! impl_ndarray_try_from_tensor {
                         .map(|s| *s as usize)
                         .collect::<Vec<usize>>(),
                     tensor.to_vec::<$type>(),
-                )?)
+                )
+                .map_err(|_| ArrayError::ShapeError(tensor.shape.clone()))?)
             }
         }
     };
index d82da15..2ca97bd 100644 (file)
  * under the License.
  */
 
-use failure::Fail;
+use thiserror::Error;
+use tvm_sys::DataType;
 
-#[derive(Debug, Fail)]
+#[derive(Debug, Error)]
 pub enum GraphFormatError {
-    #[fail(display = "Could not parse graph json")]
-    Parse(#[fail(cause)] failure::Error),
-    #[fail(display = "Could not parse graph params")]
+    #[error("Could not parse graph json")]
+    Parse(#[from] serde_json::Error),
+    #[error("Could not parse graph params")]
     Params,
-    #[fail(display = "{} is missing attr: {}", 0, 1)]
+    #[error("{0} is missing attr: {1}")]
     MissingAttr(String, String),
-    #[fail(display = "Missing field: {}", 0)]
+    #[error("Graph has invalid attr that can't be parsed: {0}")]
+    InvalidAttr(#[from] std::num::ParseIntError),
+    #[error("Missing field: {0}")]
     MissingField(&'static str),
-    #[fail(display = "Invalid DLType: {}", 0)]
+    #[error("Invalid DLType: {0}")]
     InvalidDLType(String),
+    #[error("Unsupported Op: {0}")]
+    UnsupportedOp(String),
+}
+
+#[derive(Debug, Error)]
+#[error("Function {0} not found")]
+pub struct FunctionNotFound(pub String);
+
+#[derive(Debug, Error)]
+#[error("Pointer {0:?} invalid when freeing")]
+pub struct InvalidPointer(pub *mut u8);
+
+#[derive(Debug, Error)]
+pub enum ArrayError {
+    #[error("Cannot convert Tensor with dtype {0} to ndarray")]
+    IncompatibleDataType(DataType),
+    #[error("Shape error when casting ndarray to TVM Array with shape {0:?}")]
+    ShapeError(Vec<i64>),
 }
index 895739d..91021dd 100644 (file)
  * under the License.
  */
 
-use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
+use std::{
+    cmp, collections::HashMap, convert::TryFrom, error::Error, iter::FromIterator, mem, str,
+};
 
-use failure::{ensure, format_err, Error};
 use itertools::izip;
 use nom::{
     character::complete::{alpha1, digit1},
@@ -27,7 +28,6 @@ use nom::{
     number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
     opt, tag, take, tuple,
 };
-
 use serde::{Deserialize, Serialize};
 use serde_json;
 
@@ -35,7 +35,7 @@ use tvm_sys::ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCod
 
 use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType};
 
-use crate::{errors::GraphFormatError, Module, Storage, Tensor};
+use crate::{errors::*, Module, Storage, Tensor};
 
 // @see `kTVMNDArrayMagic` in `ndarray.h`
 const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F;
@@ -114,7 +114,7 @@ macro_rules! get_node_attr {
 }
 
 impl Node {
-    fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
+    fn parse_attrs(&self) -> Result<NodeAttrs, GraphFormatError> {
         let attrs = self
             .attrs
             .as_ref()
@@ -128,15 +128,15 @@ impl Node {
 }
 
 impl<'a> TryFrom<&'a String> for Graph {
-    type Error = Error;
-    fn try_from(graph_json: &String) -> Result<Self, self::Error> {
+    type Error = GraphFormatError;
+    fn try_from(graph_json: &String) -> Result<Self, GraphFormatError> {
         let graph = serde_json::from_str(graph_json)?;
         Ok(graph)
     }
 }
 
 impl<'a> TryFrom<&'a str> for Graph {
-    type Error = Error;
+    type Error = GraphFormatError;
     fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
         let graph = serde_json::from_str(graph_json)?;
         Ok(graph)
@@ -177,7 +177,7 @@ pub struct GraphExecutor<'m, 't> {
 unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
 
 impl<'m, 't> GraphExecutor<'m, 't> {
-    pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
+    pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Box<dyn Error>> {
         let tensors = Self::setup_storages(&graph)?;
         Ok(GraphExecutor {
             op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
@@ -194,7 +194,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
     }
 
     /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
-    fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
+    fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Box<dyn Error>> {
         let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
         let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
         let dtypes = graph
@@ -221,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
         let mut storages: Vec<Storage> = storage_num_bytes
             .into_iter()
             .map(|nbytes| Storage::new(nbytes, align))
-            .collect::<Result<Vec<Storage>, Error>>()?;
+            .collect::<Result<Vec<Storage>, std::alloc::LayoutErr>>()?;
 
         let tensors = izip!(storage_ids, shapes, dtypes)
             .map(|(storage_id, shape, dtype)| {
@@ -246,8 +246,10 @@ impl<'m, 't> GraphExecutor<'m, 't> {
         graph: &Graph,
         lib: &'m M,
         tensors: &[Tensor<'t>],
-    ) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {
-        ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
+    ) -> Result<Vec<Box<dyn Fn() + 'm>>, Box<dyn Error + 'static>> {
+        if !graph.node_row_ptr.is_some() {
+            return Err(GraphFormatError::MissingField("node_row_ptr").into());
+        }
         let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
 
         let mut op_execs = Vec::new();
@@ -255,10 +257,14 @@ impl<'m, 't> GraphExecutor<'m, 't> {
             if node.op == "null" {
                 continue;
             }
-            ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
-            ensure!(node.attrs.is_some(), "Missing node attrs.");
+            if node.op != "tvm_op" {
+                return Err(GraphFormatError::UnsupportedOp(node.op.to_owned()).into());
+            }
+            if !node.attrs.is_some() {
+                return Err(GraphFormatError::MissingAttr(node.op.clone(), "".to_string()).into());
+            }
 
-            let attrs = node.parse_attrs()?;
+            let attrs: NodeAttrs = node.parse_attrs()?.into();
 
             if attrs.func_name == "__nop" {
                 continue;
@@ -266,14 +272,14 @@ impl<'m, 't> GraphExecutor<'m, 't> {
 
             let func = lib
                 .get_function(&attrs.func_name)
-                .ok_or_else(|| format_err!("Library is missing function {}", attrs.func_name))?;
+                .ok_or_else(|| FunctionNotFound(attrs.func_name.clone()))?;
             let arg_indices = node
                 .inputs
                 .iter()
                 .map(|entry| graph.entry_index(entry))
                 .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i] + oi)));
 
-            let dl_tensors = arg_indices
+            let dl_tensors: Vec<DLTensor> = arg_indices
                 .map(|idx| {
                     let tensor = &tensors[idx?];
                     Ok(if attrs.flatten_data {
@@ -282,14 +288,15 @@ impl<'m, 't> GraphExecutor<'m, 't> {
                         DLTensor::from(tensor)
                     })
                 })
-                .collect::<Result<Vec<DLTensor>, Error>>()
-                .unwrap();
+                .collect::<Result<Vec<DLTensor>, GraphFormatError>>()?
+                .into();
             let op: Box<dyn Fn()> = Box::new(move || {
-                let args = dl_tensors
+                let args: Vec<ArgValue> = dl_tensors
                     .iter()
                     .map(|t| t.into())
                     .collect::<Vec<ArgValue>>();
-                func(&args).unwrap();
+                let err_str = format!("Function {} failed to execute", attrs.func_name);
+                func(&args).expect(&err_str);
             });
             op_execs.push(op);
         }
index 51645d5..f1145da 100644 (file)
@@ -59,7 +59,7 @@ macro_rules! init_context_func {
 }
 
 impl<'a> DsoModule<'a> {
-    pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
+    pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, std::io::Error> {
         let lib = libloading::Library::new(filename)?;
 
         init_context_func!(
index 9b83ff3..cbb3bf1 100644 (file)
@@ -29,7 +29,7 @@ use std::{
 #[cfg(not(target_arch = "wasm32"))]
 use std::env;
 
-use crossbeam::channel::{bounded, Receiver, Sender};
+use crossbeam_channel::{bounded, Receiver, Sender};
 use tvm_sys::ffi::TVMParallelGroupEnv;
 
 pub(crate) type FTVMParallelLambda =
index 35cfe91..cf26497 100644 (file)
 
 use std::{
     cell::RefCell,
+    error::Error,
     os::raw::{c_int, c_void},
     ptr,
 };
 
-use failure::{format_err, Error};
-
 use crate::allocator::Allocation;
+use crate::errors::InvalidPointer;
+use std::alloc::LayoutErr;
 
 const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
 
@@ -49,13 +50,13 @@ impl WorkspacePool {
         }
     }
 
-    fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
+    fn alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
         self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
         self.in_use.push(self.workspaces.len() - 1);
         Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
     }
 
-    fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
+    fn alloc(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
         if self.free.is_empty() {
             return self.alloc_new(size);
         }
@@ -82,7 +83,7 @@ impl WorkspacePool {
         }
     }
 
-    fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
+    fn free(&mut self, ptr: *mut u8) -> Result<(), Box<dyn Error>> {
         let mut ws_idx = None;
         for i in 0..self.in_use.len() {
             let idx = self.in_use[i];
@@ -92,7 +93,7 @@ impl WorkspacePool {
                 break;
             }
         }
-        let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?;
+        let ws_idx = ws_idx.ok_or_else(|| InvalidPointer(ptr))?;
         self.free.push(ws_idx);
         Ok(())
     }
index a46cfa9..67ef217 100644 (file)
@@ -30,10 +30,10 @@ unsafe fn __get_tvm_module_ctx() -> i32 {
 
 extern crate ndarray;
 #[macro_use]
-extern crate tvm_runtime;
+extern crate tvm_graph_rt;
 
 use ndarray::Array;
-use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
+use tvm_graph_rt::{DLTensor, Module as _, SystemLibModule};
 
 fn main() {
     // try static
index 7abc9ae..a9ac09e 100644 (file)
@@ -30,7 +30,7 @@ edition = "2018"
 proc-macro = true
 
 [dependencies]
-goblin = "0.0.24"
+goblin = "^0.2"
 proc-macro2 = "^1.0"
 quote = "^1.0"
 syn = { version = "1.0.17", features = ["full", "extra-traits"] }
index 7d133fa..6880824 100644 (file)
@@ -38,7 +38,7 @@ type Deleter = unsafe extern "C" fn(object: *mut Object) -> ();
 #[derive(Debug)]
 #[repr(C)]
 pub struct Object {
-    /// The index into into TVM's runtime type information table.
+    /// The index into TVM's runtime type information table.
     pub(self) type_index: u32,
     // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure.
     // NB: in general we should not touch this in Rust.
@@ -57,10 +57,10 @@ pub struct Object {
 /// trait magic here to get a monomorphized deleter for each object
 /// "subtype".
 ///
-/// This function just transmutes the pointer to the correct type
+/// This function just converts the pointer to the correct type
 /// and invokes the underlying typed delete function.
 unsafe extern "C" fn delete<T: IsObject>(object: *mut Object) {
-    let typed_object: *mut T = std::mem::transmute(object);
+    let typed_object: *mut T = object as *mut T;
     T::typed_delete(typed_object);
 }
 
@@ -104,8 +104,7 @@ impl Object {
         } else {
             let mut index = 0;
             unsafe {
-                let index_ptr = std::mem::transmute(&mut index);
-                if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 {
+                if TVMObjectTypeKey2Index(cstring.as_ptr(), &mut index) != 0 {
                     panic!(crate::get_last_error())
                 }
             }
@@ -130,16 +129,16 @@ impl Object {
 
     /// Increases the object's reference count by one.
     pub(self) fn inc_ref(&self) {
+        let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void;
         unsafe {
-            let raw_ptr = std::mem::transmute(self);
             assert_eq!(TVMObjectRetain(raw_ptr), 0);
         }
     }
 
     /// Decreases the object's reference count by one.
     pub(self) fn dec_ref(&self) {
+        let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void;
         unsafe {
-            let raw_ptr = std::mem::transmute(self);
             assert_eq!(TVMObjectFree(raw_ptr), 0);
         }
     }
@@ -277,10 +276,9 @@ impl<T: IsObject> std::ops::Deref for ObjectPtr<T> {
 
 impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue {
     fn from(object_ptr: ObjectPtr<T>) -> RetValue {
-        let raw_object_ptr = ObjectPtr::leak(object_ptr);
-        let void_ptr: *mut std::ffi::c_void = unsafe { std::mem::transmute(raw_object_ptr) };
-        assert!(!void_ptr.is_null());
-        RetValue::ObjectHandle(void_ptr)
+        let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void;
+        assert!(!raw_object_ptr.is_null());
+        RetValue::ObjectHandle(raw_object_ptr)
     }
 }
 
@@ -290,8 +288,7 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
     fn try_from(ret_value: RetValue) -> Result<ObjectPtr<T>, Self::Error> {
         match ret_value {
             RetValue::ObjectHandle(handle) => {
-                let handle: *mut Object = unsafe { std::mem::transmute(handle) };
-                let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+                let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
                 debug_assert!(optr.count() >= 1);
                 println!("back to type {}", optr.count());
                 optr.downcast()
@@ -304,10 +301,9 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
 impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
     fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> {
         debug_assert!(object_ptr.count() >= 1);
-        let raw_object_ptr = ObjectPtr::leak(object_ptr);
-        let void_ptr: *mut std::ffi::c_void = unsafe { std::mem::transmute(raw_object_ptr) };
-        assert!(!void_ptr.is_null());
-        ArgValue::ObjectHandle(void_ptr)
+        let raw_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void;
+        assert!(!raw_ptr.is_null());
+        ArgValue::ObjectHandle(raw_ptr)
     }
 }
 
@@ -317,8 +313,7 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
     fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> {
         match arg_value {
             ArgValue::ObjectHandle(handle) => {
-                let handle = unsafe { std::mem::transmute(handle) };
-                let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+                let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
                 debug_assert!(optr.count() >= 1);
                 println!("count: {}", optr.count());
                 optr.downcast()
index fe4d0bf..faddce4 100644 (file)
@@ -32,4 +32,4 @@ ndarray = "0.12"
 enumn = "^0.1"
 
 [build-dependencies]
-bindgen = "0.51"
+bindgen = { version="0.51", default-features=false }