From aa84ee2c010830a4460d67d7cdb10305053f3112 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 23 Jun 2020 11:39:48 -0700 Subject: [PATCH] Rust Refactor Stage 4: Rewrite Rust graph runtime to use new APIs (#5830) * Port graph-runtime to new API * --amend * Fix file lint * Remove old travis file * Add @kazum's patch * Update rust/tvm-sys/src/datatype.rs Co-authored-by: Andrew Co-authored-by: Andrew --- rust/Cargo.toml | 5 + rust/runtime/src/graph.rs | 1 + rust/tvm-graph-rt/Cargo.toml | 44 ++ rust/tvm-graph-rt/src/allocator.rs | 73 +++ rust/tvm-graph-rt/src/array.rs | 401 +++++++++++++++++ rust/tvm-graph-rt/src/errors.rs | 34 ++ rust/tvm-graph-rt/src/graph.rs | 495 +++++++++++++++++++++ rust/tvm-graph-rt/src/lib.rs | 68 +++ rust/tvm-graph-rt/src/module/dso.rs | 148 ++++++ rust/tvm-graph-rt/src/module/mod.rs | 64 +++ rust/tvm-graph-rt/src/module/syslib.rs | 73 +++ rust/tvm-graph-rt/src/threading.rs | 263 +++++++++++ rust/tvm-graph-rt/src/workspace.rs | 138 ++++++ rust/tvm-graph-rt/tests/.gitignore | 3 + rust/tvm-graph-rt/tests/build_model.py | 53 +++ rust/tvm-graph-rt/tests/test_graph_serde.rs | 83 ++++ rust/tvm-graph-rt/tests/test_nn/Cargo.toml | 31 ++ rust/tvm-graph-rt/tests/test_nn/build.rs | 70 +++ .../tests/test_nn/src/build_test_graph.py | 55 +++ rust/tvm-graph-rt/tests/test_nn/src/main.rs | 105 +++++ rust/tvm-graph-rt/tests/test_tvm_basic/Cargo.toml | 29 ++ rust/tvm-graph-rt/tests/test_tvm_basic/build.rs | 69 +++ .../tests/test_tvm_basic/src/build_test_lib.py | 38 ++ rust/tvm-graph-rt/tests/test_tvm_basic/src/main.rs | 50 +++ rust/tvm-graph-rt/tests/test_tvm_dso/Cargo.toml | 26 ++ rust/tvm-graph-rt/tests/test_tvm_dso/build.rs | 42 ++ .../tests/test_tvm_dso/src/build_test_lib.py | 41 ++ rust/tvm-graph-rt/tests/test_tvm_dso/src/main.rs | 42 ++ rust/tvm-graph-rt/tests/test_wasm32/.cargo/config | 2 + rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml | 30 ++ rust/tvm-graph-rt/tests/test_wasm32/build.rs | 77 ++++ .../tests/test_wasm32/src/build_test_lib.py | 38 ++ rust/tvm-graph-rt/tests/test_wasm32/src/main.rs | 54 +++ rust/tvm-sys/build.rs | 1 + rust/tvm-sys/src/array.rs | 1 + rust/tvm-sys/src/datatype.rs | 14 +- tests/lint/check_file_type.py | 1 + 37 files changed, 2761 insertions(+), 1 deletion(-) create mode 100644 rust/tvm-graph-rt/Cargo.toml create mode 100644 rust/tvm-graph-rt/src/allocator.rs create mode 100644 rust/tvm-graph-rt/src/array.rs create mode 100644 rust/tvm-graph-rt/src/errors.rs create mode 100644 rust/tvm-graph-rt/src/graph.rs create mode 100644 rust/tvm-graph-rt/src/lib.rs create mode 100644 rust/tvm-graph-rt/src/module/dso.rs create mode 100644 rust/tvm-graph-rt/src/module/mod.rs create mode 100644 rust/tvm-graph-rt/src/module/syslib.rs create mode 100644 rust/tvm-graph-rt/src/threading.rs create mode 100644 rust/tvm-graph-rt/src/workspace.rs create mode 100644 rust/tvm-graph-rt/tests/.gitignore create mode 100755 rust/tvm-graph-rt/tests/build_model.py create mode 100644 rust/tvm-graph-rt/tests/test_graph_serde.rs create mode 100644 rust/tvm-graph-rt/tests/test_nn/Cargo.toml create mode 100644 rust/tvm-graph-rt/tests/test_nn/build.rs create mode 100755 rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py create mode 100644 rust/tvm-graph-rt/tests/test_nn/src/main.rs create mode 100644 rust/tvm-graph-rt/tests/test_tvm_basic/Cargo.toml create mode 100644 rust/tvm-graph-rt/tests/test_tvm_basic/build.rs create mode 100755 rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py create mode 100644 rust/tvm-graph-rt/tests/test_tvm_basic/src/main.rs create mode 100644 rust/tvm-graph-rt/tests/test_tvm_dso/Cargo.toml create mode 100644 rust/tvm-graph-rt/tests/test_tvm_dso/build.rs create mode 100755 rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py create mode 100644 rust/tvm-graph-rt/tests/test_tvm_dso/src/main.rs create mode 100644 rust/tvm-graph-rt/tests/test_wasm32/.cargo/config create mode 100644 rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml create mode 100644 rust/tvm-graph-rt/tests/test_wasm32/build.rs create mode 100755 rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py create mode 100644 rust/tvm-graph-rt/tests/test_wasm32/src/main.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index d9bb3ab..9542178 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -32,4 +32,9 @@ members = [ "tvm-macros", "tvm-rt", "tvm", + "tvm-graph-rt", + "tvm-graph-rt/tests/test_tvm_basic", + "tvm-graph-rt/tests/test_tvm_dso", + "tvm-graph-rt/tests/test_wasm32", + "tvm-graph-rt/tests/test_nn", ] diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 71541ba..c1f44ef 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -24,6 +24,7 @@ use nom::{ character::complete::{alpha1, digit1}, number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8}, }; + use serde; use serde_json; use tvm_common::{ diff --git a/rust/tvm-graph-rt/Cargo.toml b/rust/tvm-graph-rt/Cargo.toml new file mode 100644 index 0000000..0cf2ac1 --- /dev/null +++ b/rust/tvm-graph-rt/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-graph-rt" +version = "0.1.0" +license = "Apache-2.0" +description = "A static graph runtime for TVM." +repository = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["tvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +crossbeam = "0.7.3" +failure = "0.1" +itertools = "0.8" +lazy_static = "1.4" +ndarray="0.12" +nom = "5.0" +num_cpus = "1.10" +serde = { version = "^1.0", features = ["derive"] } +serde_json = "^1.0" +tvm-sys = { version = "0.1", path = "../tvm-sys" } +tvm-macros = { version = "0.1", path = "../tvm-macros" } + +[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies] +libloading = "0.5" diff --git a/rust/tvm-graph-rt/src/allocator.rs b/rust/tvm-graph-rt/src/allocator.rs new file mode 100644 index 0000000..81499af --- /dev/null +++ b/rust/tvm-graph-rt/src/allocator.rs @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::alloc::{self, Layout, LayoutErr}; + +const DEFAULT_ALIGN_BYTES: usize = 4; + +#[derive(PartialEq, Eq)] +pub struct Allocation { + layout: Layout, + ptr: *mut u8, +} + +impl Allocation { + /// Allocates a chunk of memory of `size` bytes with optional alignment. + pub fn new(size: usize, align: Option) -> Result { + let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); + let layout = Layout::from_size_align(size, alignment)?; + let ptr = unsafe { alloc::alloc(layout) }; + if ptr.is_null() { + alloc::handle_alloc_error(layout); + } + Ok(Self { ptr, layout }) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } + + /// Returns the size of the Allocation in bytes. + pub fn size(&self) -> usize { + self.layout.size() + } + + /// Returns the byte alignment of the Allocation. + pub fn align(&self) -> usize { + self.layout.align() + } + + /// Returns a view of the Allocation. + pub fn as_slice(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.as_mut_ptr(), self.size()) } + } + + /// Returns a mutable view of the Allocation. + pub fn as_mut_slice(&mut self) -> &mut [u8] { + unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) } + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + unsafe { + alloc::dealloc(self.ptr, self.layout); + } + } +} diff --git a/rust/tvm-graph-rt/src/array.rs b/rust/tvm-graph-rt/src/array.rs new file mode 100644 index 0000000..8209b59 --- /dev/null +++ b/rust/tvm-graph-rt/src/array.rs @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice}; + +use failure::{ensure, Error}; +use ndarray; +use tvm_sys::{ffi::DLTensor, Context, DataType}; + +use crate::allocator::Allocation; + +/// A `Storage` is a container which holds `Tensor` data. +#[derive(PartialEq)] +pub enum Storage<'a> { + /// A `Storage` which owns its contained bytes. + Owned(Allocation), + + /// A view of an existing `Storage`. + View(&'a mut [u8], usize), // ptr, align +} + +impl<'a> Storage<'a> { + pub fn new(size: usize, align: Option) -> Result, Error> { + Ok(Storage::Owned(Allocation::new(size, align)?)) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + match self { + Storage::Owned(alloc) => alloc.as_mut_ptr(), + Storage::View(slice, _) => slice.as_ptr() as *mut u8, + } + } + + pub fn size(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.size(), + Storage::View(slice, _) => slice.len(), + } + } + + pub fn align(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.align(), + Storage::View(_, align) => *align, + } + } + + pub fn as_ptr(&self) -> *const u8 { + self.as_mut_ptr() as *const _ + } + + /// Returns a `Storage::View` which points to an owned `Storage::Owned`. + pub fn view(&self) -> Storage<'a> { + match self { + Storage::Owned(alloc) => Storage::View( + unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) }, + self.align(), + ), + Storage::View(slice, _) => Storage::View( + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) }, + self.align(), + ), + } + } + + pub fn is_owned(&self) -> bool { + match self { + Storage::Owned(_) => true, + _ => false, + } + } + + /// Returns an owned version of this storage via cloning. + pub fn to_owned(&self) -> Storage<'static> { + let s = Storage::new(self.size(), Some(self.align())).unwrap(); + unsafe { + s.as_mut_ptr() + .copy_from_nonoverlapping(self.as_ptr(), self.size()); + } + s + } + + /// Returns a view of the stored data. + pub fn as_slice(&self) -> &[u8] { + match self { + Storage::Owned(alloc) => alloc.as_slice(), + Storage::View(slice, _) => &*slice, + } + } + + /// Returns a mutable view of the stored data. + pub fn as_mut_slice(&mut self) -> &mut [u8] { + match self { + Storage::Owned(alloc) => alloc.as_mut_slice(), + Storage::View(slice, _) => slice, + } + } +} + +impl<'d, 's, T> From<&'d [T]> for Storage<'s> { + fn from(data: &'d [T]) -> Self { + let data = unsafe { + slice::from_raw_parts_mut( + data.as_ptr() as *const u8 as *mut u8, + data.len() * mem::size_of::() as usize, + ) + }; + Storage::View(data, mem::align_of::()) + } +} + +/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`. +/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or +/// converted to `ndarray::Array` for non-TVM processing. +/// +/// # Examples +/// +/// ``` +/// extern crate ndarray; +/// use std::convert::TryInto; +/// use tvm_runtime::{call_packed, DLTensor, ArgValue, RetValue, Tensor}; +/// +/// let mut a_nd: ndarray::Array1 = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); +/// let mut a: Tensor = a_nd.into(); +/// let mut a_dl: DLTensor = (&mut a).into(); +/// +/// let tvm_fn = |args: &[ArgValue]| -> Result { Ok(RetValue::default()) }; +/// call_packed!(tvm_fn, &mut a_dl); +/// +/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. +/// let mut a_nd: ndarray::ArrayD = a.try_into().unwrap(); +/// ``` +#[derive(PartialEq)] +pub struct Tensor<'a> { + /// The bytes which contain the data this `Tensor` represents. + pub(crate) data: Storage<'a>, + pub(crate) ctx: Context, + pub(crate) dtype: DataType, + pub(crate) shape: Vec, + // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h + /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. + pub(crate) strides: Option>, + pub(crate) byte_offset: isize, + /// The number of elements in the `Tensor`. + pub(crate) size: usize, +} + +unsafe impl<'a> Send for Tensor<'a> {} + +impl<'a> Tensor<'a> { + pub fn shape(&self) -> Vec { + self.shape.clone() + } + + pub fn data(&self) -> &Storage { + &self.data + } + + pub fn data_mut(&mut self) -> &'a mut Storage { + &mut self.data + } + + /// Returns the data of this `Tensor` as a `Vec`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. + pub fn to_vec(&self) -> Vec { + assert!(self.is_contiguous()); + assert!(self.dtype.is_type::()); + unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() } + } + + /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory. + pub fn is_contiguous(&self) -> bool { + match self.strides { + None => true, + Some(ref strides) => { + // check that stride for each dimension is the + // product of all trailing dimensons' shapes + self.shape + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + } + } + + /// Returns a clone of this `Tensor`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. + pub fn copy(&mut self, other: &Tensor) { + assert!( + self.dtype == other.dtype && self.size == other.size, + "Tensor shape/dtype mismatch." + ); + assert!( + self.is_contiguous() && other.is_contiguous(), + "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`", + self.strides, + other.strides + ); + unsafe { + self.data + .as_mut_ptr() + .offset(self.byte_offset as isize) + .copy_from_nonoverlapping( + other.data.as_mut_ptr().offset(other.byte_offset), + other.size * other.dtype.itemsize(), + ); + } + } + + /// Returns an owned version of this `Tensor` via cloning. + pub fn to_owned(&self) -> Tensor<'static> { + let t = Tensor { + data: self.data.to_owned(), + ctx: self.ctx, + dtype: self.dtype, + size: self.size, + shape: self.shape.clone(), + strides: None, + byte_offset: 0, + }; + unsafe { mem::transmute::, Tensor<'static>>(t) } + } + + fn from_array_storage<'s, T, D: ndarray::Dimension>( + arr: &ndarray::Array, + storage: Storage<'s>, + dtype_fn: fn(u8, u16) -> DataType, + ) -> Tensor<'s> { + let type_width = mem::size_of::() as u8; + + Tensor { + data: storage, + ctx: Context::default(), + dtype: dtype_fn(8 * type_width, 1), + size: arr.len(), + shape: arr.shape().iter().map(|&v| v as i64).collect(), + strides: Some(arr.strides().iter().map(|&v| v as usize).collect()), + byte_offset: 0, + } + } + + pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor { + assert!(!flatten || self.is_contiguous()); + DLTensor { + data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void, + ctx: self.ctx.into(), + ndim: if flatten { 1 } else { self.shape.len() } as i32, + dtype: self.dtype.into(), + shape: if flatten { + &self.size as *const _ as *mut i64 + } else { + self.shape.as_ptr() + } as *mut i64, + strides: if flatten || self.is_contiguous() { + ptr::null_mut() + } else { + self.strides.as_ref().unwrap().as_ptr() + } as *mut i64, + byte_offset: 0, + ..Default::default() + } + } +} + +/// Conversions to `ndarray::Array` from `Tensor`, if the types match. +macro_rules! impl_ndarray_try_from_tensor { + ($type:ty, $dtype:expr) => { + impl<'t> TryFrom> for ndarray::ArrayD<$type> { + type Error = Error; + fn try_from(tensor: Tensor) -> Result, Error> { + ensure!( + tensor.dtype == $dtype, + "Cannot convert Tensor with dtype {:?} to ndarray", + tensor.dtype + ); + Ok(ndarray::Array::from_shape_vec( + tensor + .shape + .iter() + .map(|s| *s as usize) + .collect::>(), + tensor.to_vec::<$type>(), + )?) + } + } + }; +} + +macro_rules! make_dtype_const { + ($name: ident, $cnst:expr) => { + pub const $name: DataType = $cnst; + }; +} + +make_dtype_const!(DTYPE_INT32, DataType::int(32, 1)); +make_dtype_const!(DTYPE_UINT32, DataType::uint(32, 1)); +make_dtype_const!(DTYPE_FLOAT32, DataType::float(32, 1)); +make_dtype_const!(DTYPE_FLOAT64, DataType::float(64, 1)); +impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); +impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); +impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); +impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); + +impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { + fn from(tensor: &'a Tensor<'t>) -> Self { + Tensor::as_dltensor(tensor, false /* flatten */) + } +} + +impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { + fn from(tensor: &'a mut Tensor<'t>) -> Self { + Tensor::as_dltensor(tensor, false /* flatten */) + } +} + +impl<'a> From for Tensor<'a> { + fn from(dlt: DLTensor) -> Self { + unsafe { + let dtype = DataType::from(dlt.dtype); + let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec(); + let size = shape.iter().map(|v| *v as usize).product::() as usize; + let storage = Storage::from(slice::from_raw_parts( + dlt.data as *const u8, + dtype.itemsize() * size, + )); + Self { + data: storage, + ctx: Context::default(), + dtype, + size, + shape, + strides: if dlt.strides.is_null() { + None + } else { + Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec()) + }, + byte_offset: dlt.byte_offset as isize, + } + } + } +} + +/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. +/// +/// # Panics +/// +/// Panics if the ndarray is not contiguous. +macro_rules! impl_tensor_from_ndarray { + ($type:ty, $dtype_fn:expr) => { + impl From> for Tensor<'static> { + fn from(arr: ndarray::Array<$type, D>) -> Self { + let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous")); + Tensor::from_array_storage(&arr, storage.to_owned(), $dtype_fn) + } + } + impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { + fn from(arr: &'a ndarray::Array<$type, D>) -> Self { + let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous")); + Tensor::from_array_storage(arr, storage, $dtype_fn) + } + } + }; +} + +impl_tensor_from_ndarray!(f32, DataType::float); +impl_tensor_from_ndarray!(f64, DataType::float); +impl_tensor_from_ndarray!(i32, DataType::int); +impl_tensor_from_ndarray!(i64, DataType::int); +impl_tensor_from_ndarray!(u32, DataType::uint); +impl_tensor_from_ndarray!(u64, DataType::uint); diff --git a/rust/tvm-graph-rt/src/errors.rs b/rust/tvm-graph-rt/src/errors.rs new file mode 100644 index 0000000..d82da15 --- /dev/null +++ b/rust/tvm-graph-rt/src/errors.rs @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use failure::Fail; + +#[derive(Debug, Fail)] +pub enum GraphFormatError { + #[fail(display = "Could not parse graph json")] + Parse(#[fail(cause)] failure::Error), + #[fail(display = "Could not parse graph params")] + Params, + #[fail(display = "{} is missing attr: {}", 0, 1)] + MissingAttr(String, String), + #[fail(display = "Missing field: {}", 0)] + MissingField(&'static str), + #[fail(display = "Invalid DLType: {}", 0)] + InvalidDLType(String), +} diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs new file mode 100644 index 0000000..895739d --- /dev/null +++ b/rust/tvm-graph-rt/src/graph.rs @@ -0,0 +1,495 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; + +use failure::{ensure, format_err, Error}; +use itertools::izip; +use nom::{ + character::complete::{alpha1, digit1}, + complete, count, do_parse, length_count, map, named, + number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8}, + opt, tag, take, tuple, +}; + +use serde::{Deserialize, Serialize}; +use serde_json; + +use tvm_sys::ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt}; + +use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType}; + +use crate::{errors::GraphFormatError, Module, Storage, Tensor}; + +// @see `kTVMNDArrayMagic` in `ndarray.h` +const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F; +// @see `kTVMNDArrayListMagic` in `graph_runtime.h` +const _NDARRAY_LIST_MAGIC: u64 = 0xF7E5_8D4F_0504_9CB7; + +/// A TVM computation graph. +/// +/// # Examples +/// +/// ```norun +/// let graph_json = fs::read_to_string("graph.json").unwrap(); +/// let graph = Graph::try_from(&graph_json).unwrap(); +/// ``` +#[derive(Serialize, Deserialize, Debug)] +pub struct Graph { + pub nodes: Vec, + pub arg_nodes: Vec, + pub heads: Vec, + pub node_row_ptr: Option>, + pub attrs: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Entry { + pub id: usize, + pub index: usize, + pub version: usize, +} + +impl Graph { + fn entry_index(&self, entry: &Entry) -> Result { + self.node_row_ptr + .as_ref() + .map(|nrp| nrp[entry.id] + entry.index) + .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr")) + } + + /// Attempt to deserialize a JSON attribute to a type `T`. + fn get_attr(&self, attr: &str) -> Result { + Ok(serde_json::from_value::( + self.attrs + .as_ref() + .ok_or(GraphFormatError::MissingField("attrs"))? + .get(attr) + .ok_or_else(|| { + GraphFormatError::MissingAttr("graph".to_string(), attr.to_string()) + })? + .to_owned(), + ) + .map_err(|err| GraphFormatError::Parse(err.into()))?) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Node { + pub op: String, + pub name: String, + pub inputs: Vec, + pub attrs: Option>, + pub control_deps: Option>, +} + +struct NodeAttrs { + func_name: String, + num_outputs: usize, + flatten_data: bool, +} + +macro_rules! get_node_attr { + ($node:expr, $attrs:ident, $attr:literal) => { + $attrs + .get($attr) + .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned())) + }; +} + +impl Node { + fn parse_attrs(&self) -> Result { + let attrs = self + .attrs + .as_ref() + .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?; + Ok(NodeAttrs { + func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(), + num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::()?, + flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::()? == 1, + }) + } +} + +impl<'a> TryFrom<&'a String> for Graph { + type Error = Error; + fn try_from(graph_json: &String) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +impl<'a> TryFrom<&'a str> for Graph { + type Error = Error; + fn try_from(graph_json: &'a str) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +/// A executor for a TVM computation graph. +/// +/// # Examples +/// +/// ```norun +/// use ndarray::Array; +/// +/// let syslib = SystemLibModule::default(); // a provider of TVM functions +/// +/// let mut params_bytes = Vec::new(); +/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap(); +/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap(); +/// +/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap(); +/// +/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); +/// exec.load_params(params); +/// +/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]); +/// exec.set_input("data", x.into()); +/// exec.run(); +/// let output = exec.get_output(0).unwrap(); +/// +/// println!("{:#?}", Array::try_from(output).unwrap()); +/// ``` +pub struct GraphExecutor<'m, 't> { + graph: Graph, + op_execs: Vec>, + tensors: Vec>, +} + +unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} + +impl<'m, 't> GraphExecutor<'m, 't> { + pub fn new(graph: Graph, lib: &'m M) -> Result { + let tensors = Self::setup_storages(&graph)?; + Ok(GraphExecutor { + op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, + tensors, + graph, + }) + } + + /// Runs the computation graph. + pub fn run(&mut self) { + self.op_execs.iter().for_each(|op_exec| { + op_exec(); + }); + } + + /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. + fn setup_storages<'a>(graph: &'a Graph) -> Result>, Error> { + let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; + let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; + let dtypes = graph + .get_attr::<(String, Vec)>("dltype")? + .1 + .iter() + .map(|dltype| { + if let Ok((_, dtype)) = tvm_str_to_type(dltype) { + Ok(dtype) + } else { + Err(GraphFormatError::InvalidDLType(dltype.to_string())) + } + }) + .collect::, GraphFormatError>>()?; + + let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max(); + let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; + for (i, &storage_id) in storage_ids.iter().enumerate() { + let dtype_size = (dtypes[i].bits() * dtypes[i].lanes()) >> 3; + let nbytes = dtype_size * shapes[i].iter().product::() as usize; + storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); + } + + let mut storages: Vec = storage_num_bytes + .into_iter() + .map(|nbytes| Storage::new(nbytes, align)) + .collect::, Error>>()?; + + let tensors = izip!(storage_ids, shapes, dtypes) + .map(|(storage_id, shape, dtype)| { + let storage = storages[storage_id].view(); + Tensor { + data: mem::replace(&mut storages[storage_id], storage), + ctx: Context::default(), + dtype, + size: shape.iter().product::() as usize, + shape, + strides: None, + byte_offset: 0, + } + }) + .collect(); + + Ok(tensors) + } + + /// Creates closures which represent the computation performed by this graph. + fn setup_op_execs( + graph: &Graph, + lib: &'m M, + tensors: &[Tensor<'t>], + ) -> Result>, Error> { + ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); + let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); + + let mut op_execs = Vec::new(); + for (i, node) in graph.nodes.iter().enumerate() { + if node.op == "null" { + continue; + } + ensure!(node.op == "tvm_op", "Only TVM ops are supported."); + ensure!(node.attrs.is_some(), "Missing node attrs."); + + let attrs = node.parse_attrs()?; + + if attrs.func_name == "__nop" { + continue; + } + + let func = lib + .get_function(&attrs.func_name) + .ok_or_else(|| format_err!("Library is missing function {}", attrs.func_name))?; + let arg_indices = node + .inputs + .iter() + .map(|entry| graph.entry_index(entry)) + .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i] + oi))); + + let dl_tensors = arg_indices + .map(|idx| { + let tensor = &tensors[idx?]; + Ok(if attrs.flatten_data { + Tensor::as_dltensor(tensor, true /* flatten */) + } else { + DLTensor::from(tensor) + }) + }) + .collect::, Error>>() + .unwrap(); + let op: Box = Box::new(move || { + let args = dl_tensors + .iter() + .map(|t| t.into()) + .collect::>(); + func(&args).unwrap(); + }); + op_execs.push(op); + } + Ok(op_execs) + } + + pub fn load_params(&mut self, params: HashMap) { + params.into_iter().for_each(|(name, param)| { + self.set_input(name, param); + }) + } + + #[allow(clippy::if_same_then_else)] + pub fn set_input>(&mut self, name: S, value: Tensor) { + if let Some(idx) = self.get_input_index(name.as_ref()) { + // TODO: consider `new_with_params` to avoid ever allocating + let ptr = self.tensors[idx].data.as_ptr(); + let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr); + let owner = to_replace.nth(0).unwrap(); + if value.data.is_owned() { + // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr + // mem::replace(&mut (*owner), value); + // to_replace.for_each(|t| { + // panic!("replacing"); + // t.data = owner.data.view(); + // }); + owner.copy(&value); + } else { + owner.copy(&value); + } + } else { + println!("Unexpected input `{}`", name.as_ref()); + } + } + + /// Returns the graph input with name `name`, if it exists. + pub fn get_input>(&mut self, name: S) -> Option<&Tensor> { + self.get_input_index(name.as_ref()) + .map(move |idx| &self.tensors[idx]) + } + + /// Returns the graph output with index `index`, if it exists. + pub fn get_output(&self, idx: usize) -> Option<&Tensor> { + let graph = &self.graph; + graph.heads.get(idx).and_then(|entry| { + graph + .entry_index(entry) + .map(|idx| self.tensors.get(idx)) + .unwrap_or(None) + }) + } + + /// Returns the index for graph input with name `name`, if it exists. + pub fn get_input_index>(&self, name: S) -> Option { + let graph = &self.graph; + (0..graph.nodes.len()) + .skip_while(|&i| graph.nodes[i].name != name.as_ref()) + .nth(0) + .and_then(|i| { + if graph.arg_nodes.iter().any(|&id| id == i) { + graph.node_row_ptr.as_ref().map(|nrp| nrp[i]) + } else { + None + } + }) + } +} + +// Converts a string to TVM DLDataTypeCode. @see `String2DLDataType` in packed_func.h +named! { + tvm_str_to_type<&str, DataType>, + do_parse!( + type_name: alpha1 >> + bits: digit1 >> + lanes: opt!(complete!(tuple!(tag!("x"), digit1))) >> + ( + DataType::new( + match type_name { + "int" => DLDataTypeCode_kDLInt, + "uint" => DLDataTypeCode_kDLUInt, + "float" => DLDataTypeCode_kDLFloat, + _ => DLDataTypeCode_kDLFloat, + } as u8, + bits.parse::().unwrap() as u8, + lanes + .map(|(_, lanes)| lanes.parse::().unwrap() as u16) + .unwrap_or(1), + ) + ) + ) +} + +// Converts a bytes to String. +named! { + name, + do_parse!( + len_l: le_u32 >> + len_h: le_u32 >> + data: take!(len_l) >> + ( + if len_h == 0 { + String::from_utf8(data.to_vec()).unwrap() + } else { + panic!("Too long string") + } + ) + ) +} + +// Parses a Context +named! { + tvm_ctx<&[u8], Context>, + do_parse!( + device_type: le_u32 >> + device_id: le_i32 >> + ( + Context { + device_type: DeviceType::from(device_type), + device_id: device_id as usize, + } + ) + ) +} + +// Parses a DataType +named! { + data_type<&[u8], DataType>, + do_parse!( + code: le_u8 >> + bits: le_u8 >> + lanes: le_u16 >> + (DataType::new(code, bits, lanes))) +} + +// Parses a Tensor from a TVM array file. +named! { + tensor, + do_parse!( + take!(8) >> + le_u64 >> + ctx: tvm_ctx >> + ndim: le_u32 >> + dtype: data_type >> + shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize) >> + length: le_i64 >> + data: take!(length) >> + ( + Tensor { + data: Storage::from(data), + ctx: ctx, + dtype: dtype, + size: shape.iter().product::() as usize, + shape: shape, + strides: None, + byte_offset: 0, + } + ) + ) +} + +// Parses a graph params dict from a params binary file. +named! { + parse_param_dict>, + do_parse!( + take!(8) >> + le_u64 >> + names: length_count!(le_u64, name) >> + tensors: length_count!(le_u64, tensor) >> + ( + HashMap::from_iter(names.into_iter().zip(tensors.into_iter())) + ) + ) +} + +/// Loads a param dict saved using `relay.save_param_dict`. +pub fn load_param_dict(bytes: &[u8]) -> Result, GraphFormatError> { + if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { + if remaining_bytes.is_empty() { + Ok(param_dict) + } else { + Err(GraphFormatError::Params) + } + } else { + Err(GraphFormatError::Params) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_str_to_type() { + assert_eq!( + tvm_str_to_type("float24").unwrap().1, + DataType::float(24, 1) + ); + assert_eq!( + tvm_str_to_type("uint111x44").unwrap().1, + DataType::uint(111, 44) + ); + } +} diff --git a/rust/tvm-graph-rt/src/lib.rs b/rust/tvm-graph-rt/src/lib.rs new file mode 100644 index 0000000..0e3db52 --- /dev/null +++ b/rust/tvm-graph-rt/src/lib.rs @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`. +//! It's mainly useful for compiling to WebAssembly and SGX, +//! but also native if you prefer Rust to C++. +//! +//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`. +//! Single-function modules are used via the `packed_func!` macro after obtaining +//! the function from `runtime::SystemLibModule` +//! +//! The main entrypoints to this crate are `GraphExecutor` +//! For examples of use, please refer to the multi-file tests in the `tests` directory. + +use lazy_static::lazy_static; + +mod allocator; +mod array; +pub mod errors; +mod graph; +mod module; +mod threading; +mod workspace; + +pub use tvm_macros::import_module; +pub use tvm_sys::{ + call_packed, + errors::*, + ffi::{self, DLTensor}, + packed_func::{self, *}, + ArgValue, RetValue, +}; + +pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*}; + +lazy_static! { + static ref LAST_ERROR: std::sync::RwLock> = + std::sync::RwLock::new(None); +} + +#[no_mangle] +pub unsafe extern "C" fn TVMAPISetLastError(cmsg: *const i8) { + *LAST_ERROR.write().unwrap() = Some(std::ffi::CStr::from_ptr(cmsg)); +} + +#[no_mangle] +pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char { + match *LAST_ERROR.read().unwrap() { + Some(err) => err.as_ptr(), + None => std::ptr::null(), + } +} diff --git a/rust/tvm-graph-rt/src/module/dso.rs b/rust/tvm-graph-rt/src/module/dso.rs new file mode 100644 index 0000000..51645d5 --- /dev/null +++ b/rust/tvm-graph-rt/src/module/dso.rs @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + cell::RefCell, + collections::HashMap, + ffi::CStr, + os::raw::{c_char, c_int, c_void}, + pin::Pin, +}; + +use tvm_sys::{ffi::BackendPackedCFunc, packed_func::PackedFunc}; + +use crate::{ + threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch}, + workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace}, + TVMAPISetLastError, +}; + +use super::Module; + +const TVM_MAIN: &[u8] = b"__tvm_main__"; +const TVM_MODULE_CTX: &[u8] = b"__tvm_module_ctx"; + +/// A module backed by a Dynamic Shared Object (dylib). +pub struct DsoModule<'a> { + lib: libloading::Library, + packed_funcs: RefCell>, + _pin: std::marker::PhantomPinned, +} + +macro_rules! init_context_func { + ($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => { + unsafe { + $( + let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes()); + if let Ok(fn_ptr) = fn_ptr { + **fn_ptr = $fn; + } + )+ + } + }; +} + +impl<'a> DsoModule<'a> { + pub fn new>(filename: P) -> Result>, failure::Error> { + let lib = libloading::Library::new(filename)?; + + init_context_func!( + lib, + (TVMAPISetLastError, unsafe extern "C" fn(*const i8)), + ( + TVMBackendAllocWorkspace, + unsafe extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void + ), + ( + TVMBackendFreeWorkspace, + unsafe extern "C" fn(c_int, c_int, *mut c_void) -> c_int + ), + ( + TVMBackendParallelLaunch, + unsafe extern "C" fn( + crate::threading::FTVMParallelLambda, + *const c_void, + usize, + ) -> c_int + ), + ( + TVMBackendParallelBarrier, + unsafe extern "C" fn(usize, *const tvm_sys::ffi::TVMParallelGroupEnv) + ), + ); + + // Pin the module in memory so that `ctx` pointer (below) is stable. + let dso_mod = Box::pin(Self { + lib, + packed_funcs: RefCell::new(HashMap::new()), + _pin: std::marker::PhantomPinned, + }); + + unsafe { + if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) { + **ctx = &dso_mod as *const _ as *const c_void; + } + } + + Ok(dso_mod) + } +} + +impl<'a> Module for DsoModule<'a> { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)> { + let name = name.as_ref(); + let func = match unsafe { + self.lib + .get::(if name.as_bytes() == TVM_MAIN { + // If __tvm_main__ is present, it contains the name of the + // actual main function. + match self + .lib + .get::<*const c_char>(TVM_MAIN) + .map(|p| CStr::from_ptr(*p)) + { + Ok(m) => m.to_bytes(), + _ => return None, + } + } else { + name.as_bytes() + }) + } { + Ok(func) => unsafe { func.into_raw() }, + Err(_) => return None, + }; + + self.packed_funcs.borrow_mut().insert( + name.to_string(), + &*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)), + ); + + self.packed_funcs.borrow().get(name).copied() + } +} + +impl<'a> Drop for DsoModule<'a> { + fn drop(&mut self) { + self.packed_funcs + .replace(HashMap::new()) + .into_iter() + .map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) }) + .for_each(std::mem::drop); + } +} diff --git a/rust/tvm-graph-rt/src/module/mod.rs b/rust/tvm-graph-rt/src/module/mod.rs new file mode 100644 index 0000000..511ba4b --- /dev/null +++ b/rust/tvm-graph-rt/src/module/mod.rs @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))] +mod dso; +mod syslib; + +use tvm_sys::{ + ffi::BackendPackedCFunc, + packed_func::{ArgValue, PackedFunc, RetValue, TVMValue}, +}; + +#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))] +pub use dso::DsoModule; +pub use syslib::SystemLibModule; + +pub trait Module { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; +} + +// @see `WrapPackedFunc` in `llvm_module.cc`. +fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box { + Box::new(move |args: &[ArgValue]| { + let (values, type_codes): (Vec, Vec) = args + .iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let ret: RetValue = RetValue::default(); + let (mut ret_val, mut ret_type_code) = ret.to_tvm_value(); + let exit_code = func( + values.as_ptr(), + type_codes.as_ptr(), + values.len() as i32, + &mut ret_val, + &mut ret_type_code, + ); + if exit_code == 0 { + Ok(RetValue::from_tvm_value(ret_val, ret_type_code)) + } else { + Err(tvm_sys::errors::FuncCallError::get_with_context( + func_name.clone(), + )) + } + }) +} diff --git a/rust/tvm-graph-rt/src/module/syslib.rs b/rust/tvm-graph-rt/src/module/syslib.rs new file mode 100644 index 0000000..0279e31 --- /dev/null +++ b/rust/tvm-graph-rt/src/module/syslib.rs @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, +}; + +use lazy_static::lazy_static; + +use tvm_sys::{ffi::BackendPackedCFunc, packed_func::PackedFunc}; + +use super::Module; + +pub struct SystemLibModule; + +#[cfg(target_env = "sgx")] +extern "C" { + fn __tvm_module_startup(); +} + +lazy_static! { + static ref SYSTEM_LIB_FUNCTIONS: Mutex> = + Mutex::new(HashMap::new()); +} + +impl Module for SystemLibModule { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)> { + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .get(name.as_ref()) + .copied() + } +} + +impl Default for SystemLibModule { + fn default() -> Self { + #[cfg(target_env = "sgx")] + unsafe { + __tvm_module_startup(); + } + SystemLibModule {} + } +} + +#[no_mangle] +pub extern "C" fn TVMBackendRegisterSystemLibSymbol( + cname: *const c_char, + func: BackendPackedCFunc, +) -> i32 { + let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; + SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert( + name.to_string(), + &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)), + ); + 0 +} diff --git a/rust/tvm-graph-rt/src/threading.rs b/rust/tvm-graph-rt/src/threading.rs new file mode 100644 index 0000000..bda53a8 --- /dev/null +++ b/rust/tvm-graph-rt/src/threading.rs @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + os::raw::{c_int, c_void}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Barrier, + }, + thread::{self, JoinHandle}, +}; + +#[cfg(not(target_arch = "wasm32"))] +use std::env; + +use crossbeam::channel::{bounded, Receiver, Sender}; +use tvm_sys::ffi::TVMParallelGroupEnv; + +pub(crate) type FTVMParallelLambda = + extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; + +/// Holds a parallel job request made by a TVM library function. +struct Job { + cb: FTVMParallelLambda, + cdata: *const c_void, + req_num_tasks: usize, + pending: Arc, +} + +impl Job { + /// Splits this job into a number of `Task`s which can be scheduled. + fn tasks(&self, num_workers: usize) -> Vec { + let num_tasks = if self.req_num_tasks == 0 { + num_workers + } else { + self.req_num_tasks.min(num_workers) + }; + self.pending.store(num_tasks, Ordering::SeqCst); + + let barrier = Arc::new(Barrier::new(num_tasks)); + + (0..num_tasks) + .map(move |i| Task { + id: i, + flambda: self.cb, + penv: TVMParallelGroupEnv { + sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void, + num_task: num_tasks as i32, + }, + cdata: self.cdata, + pending: Arc::clone(&self.pending), + }) + .collect() + } + + /// Waits for all tasks in this `Job` to be completed. + fn wait(&self) { + while self.pending.load(Ordering::Acquire) > 0 { + thread::yield_now(); + } + } +} + +/// A chunk of work requested by a TVM function. +struct Task { + id: usize, + flambda: FTVMParallelLambda, + penv: TVMParallelGroupEnv, + cdata: *const c_void, + pending: Arc, +} +unsafe impl Send for Task {} +unsafe impl Sync for Task {} + +impl Task { + fn run(self) -> i32 { + let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata); + self.pending.fetch_sub(1, Ordering::AcqRel); + status + } +} + +#[derive(Default)] +struct Threads { + #[allow(unused)] + handles: Vec>, + queues: Vec>, +} + +impl<'a> Threads { + fn launch) + 'static + Copy>( + num_threads: usize, + cb: F, + ) -> Self { + let (handles, queues) = (0..num_threads) + .map(|_| { + let (p, c) = bounded(2); + let handle = thread::spawn(move || cb(c.into())); + (handle, p) + }) + .unzip(); + Threads { handles, queues } + } +} + +struct ThreadPool { + num_workers: usize, + #[allow(unused)] + threads: Threads, +} + +thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new()); + +impl ThreadPool { + fn new() -> Self { + let num_workers = max_concurrency(); + ThreadPool { + num_workers, + threads: Threads::launch(num_workers, ThreadPool::run_worker), + } + } + + fn launch(&self, job: Job) { + let mut tasks = job.tasks(self.num_workers + 1); + + for (i, task) in tasks.split_off(1).into_iter().enumerate() { + self.threads.queues[i].send(task).expect("should send"); + } + + tasks.pop().unwrap().run(); + job.wait(); + } + + fn run_worker(queue: Receiver) { + loop { + let task = match queue.recv() { + Ok(v) => v, + Err(_) => break, + }; + let result = task.run(); + if result == ::min_value() { + break; + } else if result != 0 { + panic!("Error running task."); + } + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +fn max_concurrency() -> usize { + if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or_else(|_| env::var("OMP_NUM_THREADS")) { + if let Ok(threads) = usize::from_str_radix(&threads_str, 10) { + return threads; + } + } + num_cpus::get() +} + +#[cfg(target_arch = "wasm32")] +fn max_concurrency() -> usize { + 0 // wasm doesn't support threads yet +} + +#[no_mangle] +pub extern "C" fn TVMBackendParallelLaunch( + cb: FTVMParallelLambda, + cdata: *const c_void, + num_task: usize, +) -> c_int { + if max_concurrency() < 2 { + let penv = TVMParallelGroupEnv { + sync_handle: std::ptr::null_mut(), + num_task: 1, + }; + cb(0, &penv as *const _, cdata); + } else { + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb, + cdata, + req_num_tasks: num_task, + pending: Arc::new(AtomicUsize::new(0)), + }); + }); + } + 0 +} + +// @see issue 988 for information on why this function is used. +#[no_mangle] +pub unsafe extern "C" fn TVMBackendParallelBarrier( + _task_id: usize, + penv: *const TVMParallelGroupEnv, +) { + let barrier: &Arc = &*((*penv).sync_handle as *const Arc); + barrier.wait(); +} + +#[cfg(test)] +mod tests { + use std::{ptr, thread, time::Duration}; + + use super::*; + + #[test] + fn test_max_concurrency() { + env::set_var("TVM_NUM_THREADS", "42"); + env::set_var("OMP_NUM_THREADS", "24"); + assert_eq!(max_concurrency(), 42); + env::remove_var("TVM_NUM_THREADS"); + assert_eq!(max_concurrency(), 24); + } + + extern "C" fn flambda( + task_id: usize, + penv: *const TVMParallelGroupEnv, + cdata: *const c_void, + ) -> i32 { + if cdata.is_null() { + return 0; + } + unsafe { + let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize)); + thread::sleep(Duration::from_millis(50 * task_id as u64)); + counter.fetch_add(1, Ordering::SeqCst); + task_ids_sum.fetch_add(task_id, Ordering::SeqCst); + assert_eq!((*penv).num_task, 3); + } + 0 + } + + #[test] + fn test_parallel_launch() { + TVMBackendParallelLaunch(flambda, ptr::null(), 6); + let counter = AtomicUsize::new(0); + let task_ids_sum = AtomicUsize::new(0); + let cdata = (counter, task_ids_sum); + let num_tasks = 3; + TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); + assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks); + assert_eq!( + cdata.1.load(Ordering::SeqCst), + (0..num_tasks).sum::() + ); + } +} diff --git a/rust/tvm-graph-rt/src/workspace.rs b/rust/tvm-graph-rt/src/workspace.rs new file mode 100644 index 0000000..35cfe91 --- /dev/null +++ b/rust/tvm-graph-rt/src/workspace.rs @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + cell::RefCell, + os::raw::{c_int, c_void}, + ptr, +}; + +use failure::{format_err, Error}; + +use crate::allocator::Allocation; + +const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` + +pub fn remove_item(vec: &mut Vec, item: &T) -> Option { + let pos = vec.iter().position(|x| *x == *item)?; + Some(vec.remove(pos)) +} + +struct WorkspacePool { + workspaces: Vec, + free: Vec, + in_use: Vec, +} + +impl WorkspacePool { + fn new() -> Self { + WorkspacePool { + workspaces: Vec::new(), + free: Vec::new(), + in_use: Vec::new(), + } + } + + fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> { + self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); + self.in_use.push(self.workspaces.len() - 1); + Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) + } + + fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> { + if self.free.is_empty() { + return self.alloc_new(size); + } + let idx = self + .free + .iter() + .fold(None, |cur_ws_idx: Option, &idx| { + let ws_size = self.workspaces[idx].size(); + if ws_size < size { + return cur_ws_idx; + } + cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { + let cur_size = self.workspaces[cur_idx].size(); + Some(if ws_size <= cur_size { idx } else { cur_idx }) + }) + }); + match idx { + Some(idx) => { + remove_item(&mut self.free, &idx).unwrap(); + self.in_use.push(idx); + Ok(self.workspaces[idx].as_mut_ptr()) + } + None => self.alloc_new(size), + } + } + + fn free(&mut self, ptr: *mut u8) -> Result<(), Error> { + let mut ws_idx = None; + for i in 0..self.in_use.len() { + let idx = self.in_use[i]; + if self.workspaces[idx].as_mut_ptr() == ptr { + self.in_use.remove(i); + ws_idx = Some(idx); + break; + } + } + let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?; + self.free.push(ws_idx); + Ok(()) + } +} + +thread_local!(static WORKSPACE_POOL: RefCell = RefCell::new(WorkspacePool::new())); + +const WORKSPACE_PAGE_SIZE: usize = 4 << 10; + +#[no_mangle] +pub extern "C" fn TVMBackendAllocWorkspace( + _device_type: c_int, + _device_id: c_int, + size: u64, + _dtype_code_hint: c_int, + _dtype_bits_hint: c_int, +) -> *mut c_void { + let nbytes = if size == 0 { + WORKSPACE_PAGE_SIZE + } else { + size as usize + }; + WORKSPACE_POOL.with(|pool_cell| { + pool_cell + .borrow_mut() + .alloc(nbytes as usize) + .unwrap_or(ptr::null_mut()) as *mut c_void + }) +} + +#[no_mangle] +pub extern "C" fn TVMBackendFreeWorkspace( + _device_type: c_int, + _device_id: c_int, + ptr: *mut c_void, +) -> c_int { + WORKSPACE_POOL.with(|pool_cell| { + (match pool_cell.borrow_mut().free(ptr as *mut u8) { + Ok(()) => 0, + Err(_) => -1, + }) as c_int + }) +} diff --git a/rust/tvm-graph-rt/tests/.gitignore b/rust/tvm-graph-rt/tests/.gitignore new file mode 100644 index 0000000..8110767 --- /dev/null +++ b/rust/tvm-graph-rt/tests/.gitignore @@ -0,0 +1,3 @@ +*.json +*.params +*.o diff --git a/rust/tvm-graph-rt/tests/build_model.py b/rust/tvm-graph-rt/tests/build_model.py new file mode 100755 index 0000000..ddfa03b --- /dev/null +++ b/rust/tvm-graph-rt/tests/build_model.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Builds a simple graph for testing.""" + +from os import path as osp + +import numpy as np +import tvm +from tvm import te +from tvm import relay +from tvm.relay import testing + +CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) + +def _get_model(dshape): + data = relay.var('data', shape=dshape) + fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2) + fc = relay.nn.bias_add(fc, relay.var("dense_bias")) + left, right = relay.split(fc, indices_or_sections=2, axis=1) + one = relay.const(1, dtype="float32") + return relay.Tuple([(left + one), (right - one), fc]) + + +def main(): + dshape = (32, 16) + net = _get_model(dshape) + mod, params = testing.create_workload(net) + graph, lib, params = relay.build( + mod, 'llvm', params=params) + + with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph) + with open(osp.join(CWD, 'graph.params'), 'wb') as f_params: + f_params.write(relay.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/tvm-graph-rt/tests/test_graph_serde.rs b/rust/tvm-graph-rt/tests/test_graph_serde.rs new file mode 100644 index 0000000..6cea4ad --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_graph_serde.rs @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate serde; +extern crate serde_json; + +extern crate tvm_runtime; + +use std::{convert::TryFrom, fs, io::Read}; + +use tvm_runtime::Graph; + +macro_rules! mf_dir { + ($p:literal) => { + concat!(env!("CARGO_MANIFEST_DIR"), $p) + }; +} + +static PARAMS_FIXTURE_PATH: &str = mf_dir!("/tests/graph.params"); + +#[test] +fn test_load_graph() { + let output = std::process::Command::new(mf_dir!("/tests/build_model.py")) + .env( + "PYTHONPATH", + concat!( + mf_dir!("/../../python"), + ":", + mf_dir!("/../../nnvm/python"), + ":", + mf_dir!("/../../topi/python") + ), + ) + .output() + .expect("Failed to build test model"); + assert!( + std::path::Path::new(PARAMS_FIXTURE_PATH).exists(), + "Could not build test graph fixture: STDOUT:\n\n{}\nSTDERR: {}\n\n", + String::from_utf8(output.stdout).unwrap(), + String::from_utf8(output.stderr).unwrap() + ); + let mut params_bytes = Vec::new(); + fs::File::open(PARAMS_FIXTURE_PATH) + .unwrap() + .read_to_end(&mut params_bytes) + .unwrap(); + let _params = tvm_runtime::load_param_dict(¶ms_bytes); + + let graph = Graph::try_from( + &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), + ) + .unwrap(); + + assert_eq!(graph.nodes[3].op, "tvm_op"); + assert_eq!( + graph.nodes[3] + .attrs + .as_ref() + .unwrap() + .get("func_name") + .unwrap(), + "fused_nn_dense_nn_bias_add" + ); + assert_eq!(graph.nodes[3].inputs[0].index, 0); + assert_eq!(graph.nodes[4].inputs[0].index, 0); + assert_eq!(graph.heads.len(), 3); +} diff --git a/rust/tvm-graph-rt/tests/test_nn/Cargo.toml b/rust/tvm-graph-rt/tests/test_nn/Cargo.toml new file mode 100644 index 0000000..158f9e2 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_nn/Cargo.toml @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-rt-nn" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +serde = "1.0" +serde_json = "1.0" +tvm-graph-rt = { path = "../../" } + +[build-dependencies] +ar = "0.6" diff --git a/rust/tvm-graph-rt/tests/test_nn/build.rs b/rust/tvm-graph-rt/tests/test_nn/build.rs new file mode 100644 index 0000000..8ae1131 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_nn/build.rs @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ar; + +use std::{env, fs::File, path::Path, process::Command}; + +use ar::Builder; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + let out_dir = Path::new(&out_dir).join("test_nn"); + + std::fs::create_dir_all(&out_dir).unwrap(); + + let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let manifest_dir = Path::new(&manifest_dir); + + let generator = manifest_dir.join("src").join("build_test_graph.py"); + + let graph_path = out_dir.join("graph.o"); + + let output = Command::new(&generator) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + + assert!( + graph_path.exists(), + "Could not build graph lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + let lib_file = out_dir.join("libtestnn.a"); + let file = File::create(&lib_file).unwrap(); + let mut builder = Builder::new(file); + builder.append_path(graph_path).unwrap(); + + let status = Command::new("ranlib") + .arg(&lib_file) + .status() + .expect("fdjlksafjdsa"); + + assert!(status.success()); + + println!("cargo:rustc-link-lib=static=testnn"); + println!("cargo:rustc-link-search=native={}", out_dir.display()); + println!("cargo:rerun-if-changed={}", generator.display()); +} diff --git a/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py new file mode 100755 index 0000000..cb7c4f7 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Builds a simple graph for testing.""" + +from os import path as osp +import sys + +import numpy as np +import tvm +from tvm import te +from tvm import relay +from tvm.relay import testing + + +def _get_model(dshape): + data = relay.var('data', shape=dshape) + fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2) + fc = relay.nn.bias_add(fc, relay.var("dense_bias")) + left, right = relay.split(fc, indices_or_sections=2, axis=1) + one = relay.const(1, dtype="float32") + return relay.Tuple([(left + one), (right - one), fc]) + +def main(): + dshape = (4, 8) + net = _get_model(dshape) + mod, params = testing.create_workload(net) + graph, lib, params = relay.build( + mod, 'llvm --system-lib', params=params) + + out_dir = sys.argv[1] + lib.save(osp.join(sys.argv[1], 'graph.o')) + with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph) + + with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: + f_params.write(relay.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/tvm-graph-rt/tests/test_nn/src/main.rs b/rust/tvm-graph-rt/tests/test_nn/src/main.rs new file mode 100644 index 0000000..505c544 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_nn/src/main.rs @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#[macro_use] +extern crate ndarray; +extern crate serde; +extern crate serde_json; + +extern crate tvm_runtime; +use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; + +use ndarray::Array; +use tvm_runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; + +const BATCH_SIZE: usize = 4; +const IN_DIM: usize = 8; + +macro_rules! check_sum { + ($e:expr, $a:ident, $b:ident) => { + let a = Array::try_from($e.get_input(stringify!($a)).unwrap().to_owned()).unwrap(); + check_sum!(a, $b); + }; + ($e:expr, $a:expr, $b:ident) => { + let a = Array::try_from($e.get_output($a).unwrap().to_owned()).unwrap(); + check_sum!(a, $b); + }; + ($a:ident, $b:ident) => { + let a_sum: f32 = $a.scalar_sum(); + let b_sum: f32 = $b.scalar_sum(); + assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); + }; +} + +fn main() { + let syslib = SystemLibModule::default(); + + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("OUT_DIR"), "/test_nn/graph.params")) + .unwrap() + .read_to_end(&mut params_bytes) + .unwrap(); + let params = tvm_runtime::load_param_dict(¶ms_bytes) + .unwrap() + .into_iter() + .map(|(k, v)| (k, v.to_owned())) + .collect::>>(); + + let graph = Graph::try_from( + &fs::read_to_string(concat!(env!("OUT_DIR"), "/test_nn/graph.json")).unwrap(), + ) + .unwrap(); + let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); + + let x = Array::from_shape_vec( + (BATCH_SIZE, IN_DIM), + (0..BATCH_SIZE * IN_DIM) + .map(|x| x as f32) + .collect::>(), + ) + .unwrap(); + + let p0 = params.get("p0").unwrap().to_owned(); + let p1 = params.get("p1").unwrap().to_owned(); + println!("p0: {:?}", p0.shape()); + println!("p1: {:?}", p1.shape()); + let w = Array::try_from(p0) + .unwrap() + .into_shape((BATCH_SIZE * 4, IN_DIM)) + .unwrap(); + let b = Array::try_from(p1).unwrap(); + let dense = x.dot(&w.t()) + &b; + let left = dense.slice(s![.., 0..IN_DIM]); + let right = dense.slice(s![.., IN_DIM..]); + let expected_o0 = &left + 1f32; + let expected_o1 = &right - 1f32; + + exec.load_params(params); + exec.set_input("data", (&x).into()); + + check_sum!(exec, data, x); + check_sum!(exec, p0, w); + check_sum!(exec, p1, b); + + exec.run(); + + check_sum!(exec, 0, expected_o0); + check_sum!(exec, 1, expected_o1); + check_sum!(exec, 2, dense); +} diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/Cargo.toml b/rust/tvm-graph-rt/tests/test_tvm_basic/Cargo.toml new file mode 100644 index 0000000..c1e87ef --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_tvm_basic/Cargo.toml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-rt-tvm-basic" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +tvm-graph-rt = { path = "../../" } + +[build-dependencies] +ar = "0.6" diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/build.rs b/rust/tvm-graph-rt/tests/test_tvm_basic/build.rs new file mode 100644 index 0000000..ade9e02 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_tvm_basic/build.rs @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ar; + +use std::{path::PathBuf, process::Command}; + +use ar::Builder; +use std::fs::File; + +fn main() { + let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + out_dir.push("lib"); + + if !out_dir.is_dir() { + std::fs::create_dir(&out_dir).unwrap(); + } + + let obj_file = out_dir.join("test.o"); + let lib_file = out_dir.join("libtest_basic.a"); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + obj_file.exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + let mut builder = Builder::new(File::create(&lib_file).unwrap()); + builder.append_path(&obj_file).unwrap(); + drop(builder); + + let status = Command::new("ranlib") + .arg(&lib_file) + .status() + .expect("fdjlksafjdsa"); + + assert!(status.success()); + + println!("cargo:rustc-link-lib=static=test_basic"); + println!("cargo:rustc-link-search=native={}", out_dir.display()); +} diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py new file mode 100755 index 0000000..bf7e60a --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm +from tvm import te + +def main(): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.te.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o')) + +if __name__ == '__main__': + main() diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/src/main.rs b/rust/tvm-graph-rt/tests/test_tvm_basic/src/main.rs new file mode 100644 index 0000000..653cb43 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_tvm_basic/src/main.rs @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; + +mod tvm_mod { + import_module!("lib/test.o"); +} + +fn main() { + // try static + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(tvm_mod::default_function, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); + + // try runtime + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/Cargo.toml b/rust/tvm-graph-rt/tests/test_tvm_dso/Cargo.toml new file mode 100644 index 0000000..1909268 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_tvm_dso/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-rt-tvm-dso" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +tvm-graph-rt = { path = "../../" } diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/build.rs b/rust/tvm-graph-rt/tests/test_tvm_dso/build.rs new file mode 100644 index 0000000..f1d9822 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_tvm_dso/build.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{env, path::Path, process::Command}; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/test.so", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); +} diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py new file mode 100755 index 0000000..cb7353f --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm +from tvm import te +from tvm.contrib import cc + +def main(): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.te.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + obj_file = osp.join(sys.argv[1], 'test.o') + tvm.build(s, [A, B, C], 'llvm').save(obj_file) + cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file]) + +if __name__ == '__main__': + main() diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/src/main.rs b/rust/tvm-graph-rt/tests/test_tvm_dso/src/main.rs new file mode 100644 index 0000000..953676c --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_tvm_dso/src/main.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, DsoModule, Module}; + +fn main() { + tvm_runtime::TVMGetLastError(); + let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap(); + let add = module + .get_function("__tvm_main__") + .expect("main function not found"); + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/rust/tvm-graph-rt/tests/test_wasm32/.cargo/config b/rust/tvm-graph-rt/tests/test_wasm32/.cargo/config new file mode 100644 index 0000000..6b77899 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_wasm32/.cargo/config @@ -0,0 +1,2 @@ +[build] +target = "wasm32-wasi" diff --git a/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml b/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml new file mode 100644 index 0000000..aed467f --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-rt-wasm32" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +ndarray="0.12" +tvm-graph-rt = { path = "../../" } + +[build-dependencies] +anyhow = "^1.0" diff --git a/rust/tvm-graph-rt/tests/test_wasm32/build.rs b/rust/tvm-graph-rt/tests/test_wasm32/build.rs new file mode 100644 index 0000000..5c816c3 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_wasm32/build.rs @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{path::PathBuf, process::Command}; + +use anyhow::{Context, Result}; + +fn main() -> Result<()> { + let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + out_dir.push("lib"); + + if !out_dir.is_dir() { + std::fs::create_dir(&out_dir).context("failed to create directory for WASM outputs")?; + } + + let obj_file = out_dir.join("test.o"); + let lib_file = out_dir.join("libtest_wasm32.a"); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .context("failed to execute Python script for generating TVM library")?; + + assert!( + obj_file.exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8"); + + let output = Command::new(ar) + .arg("rcs") + .arg(&lib_file) + .arg(&obj_file) + .output() + .context("failed to run LLVM_AR command")?; + + assert!( + lib_file.exists(), + "Could not create archive: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + println!("cargo:rustc-link-lib=static=test_wasm32"); + println!("cargo:rustc-link-search=native={}", out_dir.display()); + Ok(()) +} diff --git a/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py new file mode 100755 index 0000000..6016c60 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm +from tvm import te + +def main(): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.te.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm -target=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o')) + +if __name__ == '__main__': + main() diff --git a/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs b/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs new file mode 100644 index 0000000..a46cfa9 --- /dev/null +++ b/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern "C" { + static __tvm_module_ctx: i32; +} + +#[no_mangle] +unsafe fn __get_tvm_module_ctx() -> i32 { + // Refer a symbol in the libtest_wasm32.a to make sure that the link of the + // library is not optimized out. + __tvm_module_ctx +} + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; + +fn main() { + // try static + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 85e16be..01d2934 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -54,6 +54,7 @@ fn main() { .layout_tests(false) .derive_partialeq(true) .derive_eq(true) + .derive_default(true) .generate() .expect("unable to generate bindings") .write_to_file(PathBuf::from("src/c_runtime_api.rs")) diff --git a/rust/tvm-sys/src/array.rs b/rust/tvm-sys/src/array.rs index 1627e9e..5d09d86 100644 --- a/rust/tvm-sys/src/array.rs +++ b/rust/tvm-sys/src/array.rs @@ -48,6 +48,7 @@ macro_rules! impl_dltensor_from_ndarray { shape: arr.shape().as_ptr() as *const i64 as *mut i64, strides: arr.strides().as_ptr() as *const i64 as *mut i64, byte_offset: 0, + ..Default::default() } } } diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs index ccdee3f..c98d374 100644 --- a/rust/tvm-sys/src/datatype.rs +++ b/rust/tvm-sys/src/datatype.rs @@ -39,7 +39,7 @@ pub struct DataType { } impl DataType { - pub fn new(code: u8, bits: u8, lanes: u16) -> DataType { + pub const fn new(code: u8, bits: u8, lanes: u16) -> DataType { DataType { code, bits, lanes } } @@ -73,6 +73,18 @@ impl DataType { pub fn lanes(&self) -> usize { self.lanes as usize } + + pub const fn int(bits: u8, lanes: u16) -> DataType { + DataType::new(DL_INT_CODE, bits, lanes) + } + + pub const fn float(bits: u8, lanes: u16) -> DataType { + DataType::new(DL_FLOAT_CODE, bits, lanes) + } + + pub const fn uint(bits: u8, lanes: u16) -> DataType { + DataType::new(DL_UINT_CODE, bits, lanes) + } } impl<'a> From<&'a DataType> for DLDataType { diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index da3a456..e3e74ad 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -107,6 +107,7 @@ ALLOW_SPECIFIC_FILE = { "Jenkinsfile", # cargo config "rust/runtime/tests/test_wasm32/.cargo/config", + "rust/tvm-graph-rt/tests/test_wasm32/.cargo/config", "apps/sgx/.cargo/config", # html for demo purposes "web/apps/browser/rpc_server.html", -- 2.7.4