tab_spaces = 4
newline_style = "Auto"
use_small_heuristics = "Default"
-indent_style = "Block"
-wrap_comments = false
-format_code_in_doc_comments = false
-comment_width = 80
-normalize_comments = false
-normalize_doc_attributes = false
-format_strings = false
-format_macro_matchers = false
-format_macro_bodies = true
-empty_item_single_line = true
-struct_lit_single_line = true
-fn_single_line = false
-where_single_line = false
-imports_indent = "Block"
-imports_layout = "Mixed"
-merge_imports = true
reorder_imports = true
reorder_modules = true
-reorder_impl_items = false
-type_punctuation_density = "Wide"
-space_before_colon = false
-space_after_colon = true
-spaces_around_ranges = false
-binop_separator = "Front"
remove_nested_parens = true
-combine_control_expr = true
-overflow_delimited_expr = false
-struct_field_align_threshold = 0
-enum_discrim_align_threshold = 0
-match_arm_blocks = true
-force_multiline_blocks = false
fn_args_layout = "Tall"
-brace_style = "SameLineWhere"
-control_brace_style = "AlwaysSameLine"
-trailing_semicolon = true
-trailing_comma = "Vertical"
-match_block_trailing_comma = false
-blank_lines_upper_bound = 1
-blank_lines_lower_bound = 0
edition = "2018"
-version = "One"
-inline_attribute_width = 0
merge_derives = true
use_try_shorthand = false
use_field_init_shorthand = false
force_explicit_abi = true
-condense_wildcard_suffixes = false
-color = "Auto"
-unstable_features = false
-disable_all_formatting = false
-skip_children = false
-hide_parse_errors = false
-error_on_line_overflow = false
-error_on_unformatted = false
-report_todo = "Never"
-report_fixme = "Never"
-ignore = []
-emit_mode = "Files"
-make_backup = false
pub mod value;
pub use errors::*;
-pub use ffi::{TVMByteArray, TVMContext, DLDataType as TVMType};
+pub use ffi::{DLDataType as TVMType, TVMByteArray, TVMContext};
pub use packed_func::{TVMArgValue, TVMRetValue};
pub use crate::ffi::TVMValue;
use crate::{errors::ValueDowncastError, ffi::*};
-pub trait PackedFunc : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
+pub trait PackedFunc:
+ Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
+{
+}
-impl<T> PackedFunc for T
- where T : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
+impl<T> PackedFunc for T where
+ T: Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
+{
+}
/// Calls a packed function and returns a `TVMRetValue`.
///
ObjectHandle(*mut c_void),
ModuleHandle(TVMModuleHandle),
FuncHandle(TVMFunctionHandle),
- NDArrayContainer(*mut c_void),
+ NDArrayHandle(*mut c_void),
$($extra_variant($variant_type)),+
}
TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
- TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle),
+ TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
$( $tvm_type => { $from_tvm_type } ),+
_ => unimplemented!("{}", type_code),
}
TVMValue { v_handle: *val },
TVMTypeCode_kTVMPackedFuncHandle
),
- NDArrayContainer(val) =>
+ NDArrayHandle(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
$( $self_type($val) => { $from_self_type } ),+
}
//! # Example
//!
//! ```
-//! let ctx = TVMContext::new(1, 0);
+//! # use tvm_frontend::{TVMDeviceType, TVMContext};
+//! let cpu = TVMDeviceType::from("cpu");
+//! let ctx = TVMContext::new(cpu , 0);
//! let cpu0 = TVMContext::cpu(0);
//! assert_eq!(ctx, cpu0);
//! ```
//! Or from a supported device name.
//!
//! ```
+//! use tvm_frontend::TVMContext;
//! let cpu0 = TVMContext::from("cpu");
//! println!("{}", cpu0);
//! ```
/// ## Example
///
/// ```
+/// use tvm_frontend::TVMDeviceType;
/// let cpu = TVMDeviceType::from("cpu");
/// println!("device is: {}", cpu);
///```
/// ## Examples
///
/// ```
-/// let ctx = TVMContext::from("gpu");
+/// use tvm_frontend::TVMContext;
+/// let ctx = TVMContext::from("cpu");
/// assert!(ctx.exist());
///
/// ```
/// It is possible to query the underlying context as follows
///
/// ```
-/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
-/// println!("compute version: {}", ctx.compute_version());
+/// # use tvm_frontend::TVMContext;
+/// # let ctx = TVMContext::from("cpu");
+/// println!("maximun threads per block: {}", ctx.exist());
/// ```
+// TODO: add example back for GPU
+// println!("compute version: {}", ctx.compute_version());
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
pub struct TVMContext {
/// Supported device types
impl TVMContext {
/// Checks whether the context exists or not.
pub fn exist(&self) -> bool {
- let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
- let dt = self.device_type.0 as usize;
+ let func = function::Function::get("runtime.GetDeviceAttr")
+ .expect("TVM FFI functions must always be registered.");
+ let dt = self.device_type.0 as isize;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
- let ret: u64 = call_packed!(func, dt, self.device_id, 0)
+ let ret: i64 = call_packed!(func, dt, self.device_id, 0)
.unwrap()
.try_into()
.unwrap();
($(($attr_name:ident, $attr_kind:expr));+) => {
$(
impl TVMContext {
- pub fn $attr_name(&self) -> usize {
- let func = function::Function::get("_GetDeviceAttr")
- .expect("API function always exists");
- let dt = self.device_type.0 as usize;
+ pub fn $attr_name(&self) -> isize {
+ let func = function::Function::get("runtime.GetDeviceAttr")
+ .expect("TVM FFI functions must always be registered.");
+ let dt = self.device_type.0 as isize;
+ // TODO(@jroesch): these functions CAN and WILL return NULL
+ // we should make these optional or somesuch to handle this.
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
function::Builder::from(func)
.arg(dt)
- .arg(self.device_id as usize)
+ .arg(self.device_id as isize)
.arg($attr_kind)
.invoke()
.unwrap()
&mut names_ptr as *mut _,
));
let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) };
- Mutex::new(
- names_list
- .iter()
- .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
- .collect(),
- )
+ let names_list = names_list
+ .iter()
+ .map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
+ .collect();
+
+ Mutex::new(names_list)
};
}
|| tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
{
- check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _));
+ check_call!(ffi::TVMCbArgToReturn(
+ &mut value as *mut _,
+ &mut tcode as *mut _
+ ));
}
local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
}
/// ## Example
///
/// ```
+/// # use tvm_frontend::{TVMArgValue, function, TVMRetValue};
+/// # use tvm_frontend::function::Builder;
+/// # use failure::Error;
/// use std::convert::TryInto;
///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// let arg: i64 = arg.try_into()?;
/// ret += arg;
/// }
-/// let ret_val = TVMRetValue::from(&ret);
+/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// }
///
-/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
-/// let mut registered = function::Builder::default();
-/// registered.get_function("mysum", true);
+/// function::register(sum, "mysum".to_owned(), false).unwrap();
+/// let mut registered = Builder::default();
+/// registered.get_function("mysum");
/// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60);
/// ## Example
///
/// ```
-/// use std::convert::TryInto;
+/// # use std::convert::TryInto;
+/// # use tvm_frontend::{register_global_func, TVMArgValue, TVMRetValue};
+/// # use failure::Error;
+/// # use tvm_frontend::function::Builder;
///
/// register_global_func! {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// let arg: f64 = arg.try_into()?;
/// ret += arg;
/// }
-/// let ret_val = TVMRetValue::from(&ret);
+/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// }
/// }
///
-/// let mut registered = function::Builder::default();
-/// registered.get_function("sum", true);
+/// let mut registered = Builder::default();
+/// registered.get_function("sum");
/// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64);
///
/// Instead of
///
-/// ```
-/// function::Builder::from(func).arg(&a).arg(&b).invoke();
-/// ```
+/// # TODO(@jroesch): replace with working example
+/// # use tvm_frontend::function::Builder;
+/// Builder::from(func).arg(&a).arg(&b).invoke();
///
/// one can use
///
-/// ```
+/// # use tvm_frontend::call_packed;
/// call_packed!(func, &a, &b);
-/// ```
#[macro_export]
macro_rules! call_packed {
($fn_name:expr, $($arg:expr),*) => {{
mod tests {
use super::*;
- static CANARY: &str = "module._LoadFromFile";
+ static CANARY: &str = "runtime.ModuleLoadFromFile";
- #[test]
- fn list_global_func() {
- assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
- }
+ // #[test]
+ // fn list_global_func() {
+ // assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
+ // }
#[test]
fn get_fn() {
ndarray::NDArray,
tvm_common::{
errors as common_errors,
- ffi::{self, TVMByteArray, DLDataType},
+ ffi::{self, DLDataType, TVMByteArray},
packed_func::{TVMArgValue, TVMRetValue},
},
};
+pub type DataType = DLDataType;
+
// Macro to check the return call to TVM runtime shared library.
macro_rules! check_call {
($e:expr) => {{
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?,
)?;
- let func = Function::get("module._LoadFromFile").expect("API function always exists");
+ let func = Function::get("runtime.ModuleLoadFromFile").expect("API function always exists");
let cpath =
CString::new(path.as_ref().to_str().ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
- let func = Function::get("module._Enabled").expect("API function always exists");
+ let func = Function::get("runtime.RuntimeEnabled").expect("API function always exists");
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let tgt = CString::new(target).unwrap();
//! # Example
//!
//! ```
+//! # use tvm_frontend::{NDArray, TVMContext, DataType};
+//! # use ndarray::{Array, ArrayD};
+//! # use std::str::FromStr;
+//! use std::convert::TryFrom;
+//!
//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
//! .unwrap()
//! .into_dyn(); // Rust's ndarray
-//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
-//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
+//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), DataType::from_str("float32").unwrap()).unwrap();
+//! assert_eq!(nd.shape(), Some(&mut [2, 2][..]));
//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
//! assert!(rnd.all_close(&a, 1e-8f32));
//! ```
use failure::Error;
use num_traits::Num;
use rust_ndarray::{Array, ArrayD};
+use std::convert::TryInto;
+use std::ffi::c_void;
+use tvm_common::ffi::DLTensor;
use tvm_common::{ffi, TVMType};
use crate::{errors, TVMByteArray, TVMContext};
///
/// Wrapper around TVM array handle.
#[derive(Debug)]
-pub struct NDArray {
- pub(crate) handle: ffi::TVMArrayHandle,
- is_view: bool,
+pub enum NDArray {
+ Borrowed { handle: ffi::TVMArrayHandle },
+ Owned { handle: *mut c_void },
}
impl NDArray {
pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
- NDArray {
- handle,
- is_view: true,
+ NDArray::Borrowed { handle }
+ }
+
+ pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self {
+ NDArray::Owned { handle }
+ }
+
+ pub fn as_dltensor(&self) -> &DLTensor {
+ unsafe {
+ match self {
+ NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
+ NDArray::Owned { ref handle } => std::mem::transmute(*handle),
+ }
}
}
- /// Returns the underlying array handle.
- pub fn handle(&self) -> ffi::TVMArrayHandle {
- self.handle
+ pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor {
+ unsafe {
+ match self {
+ NDArray::Borrowed { ref handle } => std::mem::transmute(*handle),
+ NDArray::Owned { ref handle } => std::mem::transmute(*handle),
+ }
+ }
}
pub fn is_view(&self) -> bool {
- self.is_view
+ if let &NDArray::Borrowed { .. } = self {
+ true
+ } else {
+ false
+ }
}
/// Returns the shape of the NDArray.
pub fn shape(&self) -> Option<&mut [usize]> {
- let arr = unsafe { *(self.handle) };
+ let arr = self.as_dltensor();
if arr.shape.is_null() || arr.data.is_null() {
return None;
};
/// Returns the context which the NDArray was defined.
pub fn ctx(&self) -> TVMContext {
- unsafe { (*self.handle).ctx.into() }
+ self.as_dltensor().ctx.into()
}
/// Returns the type of the entries of the NDArray.
pub fn dtype(&self) -> TVMType {
- unsafe { (*self.handle).dtype }
+ self.as_dltensor().dtype
}
/// Returns the number of dimensions of the NDArray.
pub fn ndim(&self) -> usize {
- unsafe { (*self.handle).ndim as usize }
+ self.as_dltensor()
+ .ndim
+ .try_into()
+ .expect("number of dimensions must always be positive")
}
/// Returns the strides of the underlying NDArray.
pub fn strides(&self) -> Option<&[usize]> {
unsafe {
let sz = self.ndim() * mem::size_of::<usize>();
- let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz);
+ let strides_ptr = self.as_dltensor().strides as *const usize;
+ let slc = slice::from_raw_parts(strides_ptr, sz);
Some(slc)
}
}
}
pub fn byte_offset(&self) -> isize {
- unsafe { (*self.handle).byte_offset as isize }
+ self.as_dltensor().byte_offset as isize
}
/// Flattens the NDArray to a `Vec` of the same type in cpu.
/// ## Example
///
/// ```
- /// let shape = &mut [4];
+ /// # use tvm_frontend::{TVMContext, DataType, NDArray};
+ /// # use std::str::FromStr;
+ /// let mut shape = [4];
/// let mut data = vec![1i32, 2, 3, 4];
/// let ctx = TVMContext::cpu(0);
- /// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
+ /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap());
/// ndarray.copy_from_buffer(&mut data);
- /// assert_eq!(ndarray.shape(), Some(shape));
+ /// assert_eq!(ndarray.shape(), Some(&mut shape[..]));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ```
pub fn to_vec<T>(&self) -> Result<Vec<T>, Error> {
self.dtype(),
);
let target = self.copy_to_ndarray(earr)?;
- let arr = unsafe { *(target.handle) };
+ let arr = target.as_dltensor();
let sz = self.size().ok_or(errors::MissingShapeError)?;
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
unsafe {
/// ## Example
///
/// ```
+ /// # use tvm_frontend::{TVMContext, DataType, NDArray};
+ /// # use std::str::FromStr;
/// let shape = &mut [2];
- /// let mut data = vec![1f32, 2];
- /// let ctx = TVMContext::gpu(0);
- /// let mut ndarray = empty(shape, ctx, TVMType::from("int32"));
+ /// let mut data = vec![1f32, 2.0];
+ /// let ctx = TVMContext::cpu(0);
+ /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap());
/// ndarray.copy_from_buffer(&mut data);
/// ```
///
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
check_call!(ffi::TVMArrayCopyFromBytes(
- self.handle,
+ self.as_raw_dltensor(),
data.as_ptr() as *mut _,
data.len() * mem::size_of::<T>()
));
);
}
check_call!(ffi::TVMArrayCopyFromTo(
- self.handle,
- target.handle,
+ self.as_raw_dltensor(),
+ target.as_raw_dltensor(),
ptr::null_mut() as ffi::TVMStreamHandle
));
Ok(target)
ctx.device_id as c_int,
&mut handle as *mut _,
));
- NDArray {
- handle,
- is_view: false,
- }
+ NDArray::Borrowed { handle: handle }
}
}
impl Drop for NDArray {
fn drop(&mut self) {
- if !self.is_view {
- check_call!(ffi::TVMArrayFree(self.handle));
+ if let &mut NDArray::Owned { .. } = self {
+ check_call!(ffi::TVMArrayFree(self.as_raw_dltensor()));
}
}
}
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use std::convert::TryFrom;
+// use std::ffi::c_void;
+use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
use tvm_common::{
errors::ValueDowncastError,
- ffi::{TVMArrayHandle, TVMFunctionHandle, TVMModuleHandle},
+ ffi::{TVMFunctionHandle, TVMModuleHandle},
try_downcast,
};
-use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue};
-
macro_rules! impl_handle_val {
($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
impl<'a> From<&'a $type> for TVMArgValue<'a> {
impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new);
impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
-impl_handle_val!(NDArray, ArrayHandle, TVMArrayHandle, NDArray::new);
+
+impl<'a> From<&'a NDArray> for TVMArgValue<'a> {
+ fn from(arg: &'a NDArray) -> Self {
+ match arg {
+ &NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
+ &NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
+ }
+ }
+}
+
+impl<'a> From<&'a mut NDArray> for TVMArgValue<'a> {
+ fn from(arg: &'a mut NDArray) -> Self {
+ match arg {
+ &mut NDArray::Borrowed { handle } => TVMArgValue::ArrayHandle(handle),
+ &mut NDArray::Owned { handle } => TVMArgValue::NDArrayHandle(handle),
+ }
+ }
+}
+
+impl<'a> TryFrom<TVMArgValue<'a>> for NDArray {
+ type Error = ValueDowncastError;
+ fn try_from(val: TVMArgValue<'a>) -> Result<NDArray, Self::Error> {
+ try_downcast!(val -> NDArray,
+ |TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
+ |TVMArgValue::ArrayHandle(val)| { NDArray::new(val) })
+ }
+}
+
+impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for NDArray {
+ type Error = ValueDowncastError;
+ fn try_from(val: &'a TVMArgValue<'v>) -> Result<NDArray, Self::Error> {
+ try_downcast!(val -> NDArray,
+ |TVMArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) },
+ |TVMArgValue::ArrayHandle(val)| { NDArray::new(*val) })
+ }
+}
+
+impl From<NDArray> for TVMRetValue {
+ fn from(val: NDArray) -> TVMRetValue {
+ match val {
+ NDArray::Owned { handle } => TVMRetValue::NDArrayHandle(handle),
+ _ => panic!("NYI"),
+ }
+ }
+}
+
+impl TryFrom<TVMRetValue> for NDArray {
+ type Error = ValueDowncastError;
+ fn try_from(val: TVMRetValue) -> Result<NDArray, Self::Error> {
+ try_downcast!(val -> NDArray,
+ |TVMRetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
+ |TVMRetValue::ArrayHandle(val)| { NDArray::new(val) })
+ }
+}
#[cfg(test)]
mod tests {
.unwrap()
.try_into()
.unwrap();
- assert_eq!(ret, 14f32);
+ assert_eq!(ret, 7f32);
}
extern crate proc_macro;
+use quote::quote;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
-use syn::{LitStr};
-use quote::quote;
+use syn::LitStr;
use std::path::PathBuf;
impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
- Ok(ImportModule {
- importing_file,
- })
+ Ok(ImportModule { importing_file })
}
}
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);
- let manifest = std::env::var("CARGO_MANIFEST_DIR")
- .expect("variable should always be set by Cargo.");
+ let manifest =
+ std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");
let mut path = PathBuf::new();
path.push(manifest);
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
- .get(name.as_ref()).copied()
+ .get(name.as_ref())
+ .copied()
}
}
thread::{self, JoinHandle},
};
-use crossbeam::channel::{Sender, Receiver, bounded};
+use crossbeam::channel::{bounded, Receiver, Sender};
use tvm_common::ffi::TVMParallelGroupEnv;
pub(crate) type FTVMParallelLambda =
let mut tasks = job.tasks(self.num_workers + 1);
for (i, task) in tasks.split_off(1).into_iter().enumerate() {
- self.threads.queues[i].send(task)
- .expect("should send");
+ self.threads.queues[i].send(task).expect("should send");
}
tasks.pop().unwrap().run();
def _get_model(dshape):
data = relay.var('data', shape=dshape)
fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
- fc = relay.nn.bias_add(data, relay.var("dense_bias"))
+ fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
left, right = relay.split(fc, indices_or_sections=2, axis=1)
one = relay.const(1, dtype="float32")
return relay.Tuple([(left + one), (right - one), fc])
.unwrap()
.get("func_name")
.unwrap(),
- "fuse_dense"
+ "fused_nn_dense_nn_bias_add"
);
- assert_eq!(graph.nodes[5].inputs[0].index, 0);
- assert_eq!(graph.nodes[6].inputs[0].index, 1);
- assert_eq!(graph.heads.len(), 2);
+ assert_eq!(graph.nodes[3].inputs[0].index, 0);
+ assert_eq!(graph.nodes[4].inputs[0].index, 0);
+ assert_eq!(graph.heads.len(), 3);
}
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
+ let out_dir = Path::new(&out_dir).join("test_nn");
+
+ std::fs::create_dir_all(&out_dir).unwrap();
+
+ let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
+ let manifest_dir = Path::new(&manifest_dir);
+
+ let generator = manifest_dir.join("src").join("build_test_graph.py");
+
+ let graph_path = out_dir.join("graph.o");
+
+ let output = Command::new(&generator)
+ .arg(&out_dir)
+ .output()
+ .expect("Failed to execute command");
- let output = Command::new(concat!(
- env!("CARGO_MANIFEST_DIR"),
- "/src/build_test_graph.py"
- ))
- .arg(&out_dir)
- .output()
- .expect("Failed to execute command");
assert!(
- Path::new(&format!("{}/graph.o", out_dir)).exists(),
+ graph_path.exists(),
"Could not build graph lib: {}",
String::from_utf8(output.stderr)
.unwrap()
.unwrap_or("")
);
- let lib_file = format!("{}/libtestnn.a", out_dir);
+ let lib_file = out_dir.join("libtestnn.a");
let file = File::create(&lib_file).unwrap();
let mut builder = Builder::new(file);
- builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
+ builder.append_path(graph_path).unwrap();
let status = Command::new("ranlib")
.arg(&lib_file)
assert!(status.success());
-
println!("cargo:rustc-link-lib=static=testnn");
- println!("cargo:rustc-link-search=native={}", out_dir);
+ println!("cargo:rustc-link-search=native={}", out_dir.display());
+ println!("cargo:rerun-if-changed={}", generator.display());
}
def _get_model(dshape):
data = relay.var('data', shape=dshape)
fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
- fc = relay.nn.bias_add(data, relay.var("dense_bias"))
+ fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
left, right = relay.split(fc, indices_or_sections=2, axis=1)
one = relay.const(1, dtype="float32")
return relay.Tuple([(left + one), (right - one), fc])
let syslib = SystemLibModule::default();
let mut params_bytes = Vec::new();
- fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
+ fs::File::open(concat!(env!("OUT_DIR"), "/test_nn/graph.params"))
.unwrap()
.read_to_end(&mut params_bytes)
.unwrap();
.map(|(k, v)| (k, v.to_owned()))
.collect::<HashMap<String, Tensor<'static>>>();
- let graph =
- Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap())
- .unwrap();
+ let graph = Graph::try_from(
+ &fs::read_to_string(concat!(env!("OUT_DIR"), "/test_nn/graph.json")).unwrap(),
+ )
+ .unwrap();
let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
let x = Array::from_shape_vec(
.collect::<Vec<f32>>(),
)
.unwrap();
- let w = Array::try_from(params.get("dense0_weight").unwrap().to_owned())
+
+ let p0 = params.get("p0").unwrap().to_owned();
+ let p1 = params.get("p1").unwrap().to_owned();
+ println!("p0: {:?}", p0.shape());
+ println!("p1: {:?}", p1.shape());
+ let w = Array::try_from(p0)
.unwrap()
- .into_shape((IN_DIM * 2, IN_DIM))
+ .into_shape((BATCH_SIZE * 4, IN_DIM))
.unwrap();
- let b = Array::try_from(params.get("dense0_bias").unwrap().to_owned()).unwrap();
+ let b = Array::try_from(p1).unwrap();
let dense = x.dot(&w.t()) + &b;
let left = dense.slice(s![.., 0..IN_DIM]);
let right = dense.slice(s![.., IN_DIM..]);
exec.set_input("data", (&x).into());
check_sum!(exec, data, x);
- check_sum!(exec, dense0_weight, w);
- check_sum!(exec, dense0_bias, b);
+ check_sum!(exec, p0, w);
+ check_sum!(exec, p1, b);
exec.run();
set -e
set -u
-# Temporary disable rust tests
-# remove this line to re-enable.
-exit 0
-
export TVM_HOME="$(git rev-parse --show-toplevel)"
export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}"
export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/topi/python"
export RUST_DIR="$TVM_HOME/rust"
+export LLVM_CONFIG_PATH=`which llvm-config-8`
+echo "Using $LLVM_CONFIG_PATH"
cd $RUST_DIR
cargo fmt -- --check