edition = "2018"
[dependencies]
-crossbeam = "0.7.3"
-failure = "0.1"
+crossbeam-channel = "0.4"
+thiserror = "1"
+
itertools = "0.8"
lazy_static = "1.4"
ndarray="0.12"
use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
-use failure::{ensure, Error};
use ndarray;
use tvm_sys::{ffi::DLTensor, Context, DataType};
use crate::allocator::Allocation;
+use crate::errors::ArrayError;
+use std::alloc::LayoutErr;
/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
}
impl<'a> Storage<'a> {
- pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
+ pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, LayoutErr> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}
macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'t> TryFrom<Tensor<'t>> for ndarray::ArrayD<$type> {
- type Error = Error;
- fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
- ensure!(
- tensor.dtype == $dtype,
- "Cannot convert Tensor with dtype {:?} to ndarray",
- tensor.dtype
- );
+ type Error = ArrayError;
+ fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Self::Error> {
+ if tensor.dtype != $dtype {
+ return Err(ArrayError::IncompatibleDataType(tensor.dtype));
+ }
Ok(ndarray::Array::from_shape_vec(
tensor
.shape
.map(|s| *s as usize)
.collect::<Vec<usize>>(),
tensor.to_vec::<$type>(),
- )?)
+ )
+ .map_err(|_| ArrayError::ShapeError(tensor.shape.clone()))?)
}
}
};
* under the License.
*/
-use failure::Fail;
+use thiserror::Error;
+use tvm_sys::DataType;
-#[derive(Debug, Fail)]
+#[derive(Debug, Error)]
pub enum GraphFormatError {
- #[fail(display = "Could not parse graph json")]
- Parse(#[fail(cause)] failure::Error),
- #[fail(display = "Could not parse graph params")]
+ #[error("Could not parse graph json")]
+ Parse(#[from] serde_json::Error),
+ #[error("Could not parse graph params")]
Params,
- #[fail(display = "{} is missing attr: {}", 0, 1)]
+ #[error("{0} is missing attr: {1}")]
MissingAttr(String, String),
- #[fail(display = "Missing field: {}", 0)]
+ #[error("Graph has invalid attr that can't be parsed: {0}")]
+ InvalidAttr(#[from] std::num::ParseIntError),
+ #[error("Missing field: {0}")]
MissingField(&'static str),
- #[fail(display = "Invalid DLType: {}", 0)]
+ #[error("Invalid DLType: {0}")]
InvalidDLType(String),
+ #[error("Unsupported Op: {0}")]
+ UnsupportedOp(String),
+}
+
+#[derive(Debug, Error)]
+#[error("Function {0} not found")]
+pub struct FunctionNotFound(pub String);
+
+#[derive(Debug, Error)]
+#[error("Pointer {0:?} invalid when freeing")]
+pub struct InvalidPointer(pub *mut u8);
+
+#[derive(Debug, Error)]
+pub enum ArrayError {
+ #[error("Cannot convert Tensor with dtype {0} to ndarray")]
+ IncompatibleDataType(DataType),
+ #[error("Shape error when casting ndarray to TVM Array with shape {0:?}")]
+ ShapeError(Vec<i64>),
}
* under the License.
*/
-use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
+use std::{
+ cmp, collections::HashMap, convert::TryFrom, error::Error, iter::FromIterator, mem, str,
+};
-use failure::{ensure, format_err, Error};
use itertools::izip;
use nom::{
character::complete::{alpha1, digit1},
number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
opt, tag, take, tuple,
};
-
use serde::{Deserialize, Serialize};
use serde_json;
use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType};
-use crate::{errors::GraphFormatError, Module, Storage, Tensor};
+use crate::{errors::*, Module, Storage, Tensor};
// @see `kTVMNDArrayMagic` in `ndarray.h`
const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F;
}
impl Node {
- fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
+ fn parse_attrs(&self) -> Result<NodeAttrs, GraphFormatError> {
let attrs = self
.attrs
.as_ref()
}
impl<'a> TryFrom<&'a String> for Graph {
- type Error = Error;
- fn try_from(graph_json: &String) -> Result<Self, self::Error> {
+ type Error = GraphFormatError;
+ fn try_from(graph_json: &String) -> Result<Self, GraphFormatError> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
}
impl<'a> TryFrom<&'a str> for Graph {
- type Error = Error;
+ type Error = GraphFormatError;
fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
impl<'m, 't> GraphExecutor<'m, 't> {
- pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
+ pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Box<dyn Error>> {
let tensors = Self::setup_storages(&graph)?;
Ok(GraphExecutor {
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
}
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
- fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
+ fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Box<dyn Error>> {
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
let dtypes = graph
let mut storages: Vec<Storage> = storage_num_bytes
.into_iter()
.map(|nbytes| Storage::new(nbytes, align))
- .collect::<Result<Vec<Storage>, Error>>()?;
+ .collect::<Result<Vec<Storage>, std::alloc::LayoutErr>>()?;
let tensors = izip!(storage_ids, shapes, dtypes)
.map(|(storage_id, shape, dtype)| {
graph: &Graph,
lib: &'m M,
tensors: &[Tensor<'t>],
- ) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {
- ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
+ ) -> Result<Vec<Box<dyn Fn() + 'm>>, Box<dyn Error + 'static>> {
+ if !graph.node_row_ptr.is_some() {
+ return Err(GraphFormatError::MissingField("node_row_ptr").into());
+ }
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
let mut op_execs = Vec::new();
if node.op == "null" {
continue;
}
- ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
- ensure!(node.attrs.is_some(), "Missing node attrs.");
+ if node.op != "tvm_op" {
+ return Err(GraphFormatError::UnsupportedOp(node.op.to_owned()).into());
+ }
+ if !node.attrs.is_some() {
+ return Err(GraphFormatError::MissingAttr(node.op.clone(), "".to_string()).into());
+ }
- let attrs = node.parse_attrs()?;
+ let attrs: NodeAttrs = node.parse_attrs()?.into();
if attrs.func_name == "__nop" {
continue;
let func = lib
.get_function(&attrs.func_name)
- .ok_or_else(|| format_err!("Library is missing function {}", attrs.func_name))?;
+ .ok_or_else(|| FunctionNotFound(attrs.func_name.clone()))?;
let arg_indices = node
.inputs
.iter()
.map(|entry| graph.entry_index(entry))
.chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i] + oi)));
- let dl_tensors = arg_indices
+ let dl_tensors: Vec<DLTensor> = arg_indices
.map(|idx| {
let tensor = &tensors[idx?];
Ok(if attrs.flatten_data {
DLTensor::from(tensor)
})
})
- .collect::<Result<Vec<DLTensor>, Error>>()
- .unwrap();
+ .collect::<Result<Vec<DLTensor>, GraphFormatError>>()?
+ .into();
let op: Box<dyn Fn()> = Box::new(move || {
- let args = dl_tensors
+ let args: Vec<ArgValue> = dl_tensors
.iter()
.map(|t| t.into())
.collect::<Vec<ArgValue>>();
- func(&args).unwrap();
+ let err_str = format!("Function {} failed to execute", attrs.func_name);
+ func(&args).expect(&err_str);
});
op_execs.push(op);
}
}
impl<'a> DsoModule<'a> {
- pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
+ pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, std::io::Error> {
let lib = libloading::Library::new(filename)?;
init_context_func!(
#[cfg(not(target_arch = "wasm32"))]
use std::env;
-use crossbeam::channel::{bounded, Receiver, Sender};
+use crossbeam_channel::{bounded, Receiver, Sender};
use tvm_sys::ffi::TVMParallelGroupEnv;
pub(crate) type FTVMParallelLambda =
use std::{
cell::RefCell,
+ error::Error,
os::raw::{c_int, c_void},
ptr,
};
-use failure::{format_err, Error};
-
use crate::allocator::Allocation;
+use crate::errors::InvalidPointer;
+use std::alloc::LayoutErr;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
}
}
- fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
+ fn alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}
- fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
+ fn alloc(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
if self.free.is_empty() {
return self.alloc_new(size);
}
}
}
- fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
+ fn free(&mut self, ptr: *mut u8) -> Result<(), Box<dyn Error>> {
let mut ws_idx = None;
for i in 0..self.in_use.len() {
let idx = self.in_use[i];
break;
}
}
- let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?;
+ let ws_idx = ws_idx.ok_or_else(|| InvalidPointer(ptr))?;
self.free.push(ws_idx);
Ok(())
}
extern crate ndarray;
#[macro_use]
-extern crate tvm_runtime;
+extern crate tvm_graph_rt;
use ndarray::Array;
-use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
+use tvm_graph_rt::{DLTensor, Module as _, SystemLibModule};
fn main() {
// try static
proc-macro = true
[dependencies]
-goblin = "0.0.24"
+goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
#[derive(Debug)]
#[repr(C)]
pub struct Object {
- /// The index into into TVM's runtime type information table.
+ /// The index into TVM's runtime type information table.
pub(self) type_index: u32,
// TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure.
// NB: in general we should not touch this in Rust.
/// trait magic here to get a monomorphized deleter for each object
/// "subtype".
///
-/// This function just transmutes the pointer to the correct type
+/// This function just converts the pointer to the correct type
/// and invokes the underlying typed delete function.
unsafe extern "C" fn delete<T: IsObject>(object: *mut Object) {
- let typed_object: *mut T = std::mem::transmute(object);
+ let typed_object: *mut T = object as *mut T;
T::typed_delete(typed_object);
}
} else {
let mut index = 0;
unsafe {
- let index_ptr = std::mem::transmute(&mut index);
- if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 {
+ if TVMObjectTypeKey2Index(cstring.as_ptr(), &mut index) != 0 {
panic!(crate::get_last_error())
}
}
/// Increases the object's reference count by one.
pub(self) fn inc_ref(&self) {
+ let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void;
unsafe {
- let raw_ptr = std::mem::transmute(self);
assert_eq!(TVMObjectRetain(raw_ptr), 0);
}
}
/// Decreases the object's reference count by one.
pub(self) fn dec_ref(&self) {
+ let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void;
unsafe {
- let raw_ptr = std::mem::transmute(self);
assert_eq!(TVMObjectFree(raw_ptr), 0);
}
}
impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue {
fn from(object_ptr: ObjectPtr<T>) -> RetValue {
- let raw_object_ptr = ObjectPtr::leak(object_ptr);
- let void_ptr: *mut std::ffi::c_void = unsafe { std::mem::transmute(raw_object_ptr) };
- assert!(!void_ptr.is_null());
- RetValue::ObjectHandle(void_ptr)
+ let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void;
+ assert!(!raw_object_ptr.is_null());
+ RetValue::ObjectHandle(raw_object_ptr)
}
}
fn try_from(ret_value: RetValue) -> Result<ObjectPtr<T>, Self::Error> {
match ret_value {
RetValue::ObjectHandle(handle) => {
- let handle: *mut Object = unsafe { std::mem::transmute(handle) };
- let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+ let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
println!("back to type {}", optr.count());
optr.downcast()
impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> {
debug_assert!(object_ptr.count() >= 1);
- let raw_object_ptr = ObjectPtr::leak(object_ptr);
- let void_ptr: *mut std::ffi::c_void = unsafe { std::mem::transmute(raw_object_ptr) };
- assert!(!void_ptr.is_null());
- ArgValue::ObjectHandle(void_ptr)
+ let raw_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void;
+ assert!(!raw_ptr.is_null());
+ ArgValue::ObjectHandle(raw_ptr)
}
}
fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> {
match arg_value {
ArgValue::ObjectHandle(handle) => {
- let handle = unsafe { std::mem::transmute(handle) };
- let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+ let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
println!("count: {}", optr.count());
optr.downcast()
enumn = "^0.1"
[build-dependencies]
-bindgen = "0.51"
+bindgen = { version="0.51", default-features=false }