Rust Refactor Stage 4: Rewrite Rust graph runtime to use new APIs (#5830)
authorJared Roesch <roeschinc@gmail.com>
Tue, 23 Jun 2020 18:39:48 +0000 (11:39 -0700)
committerGitHub <noreply@github.com>
Tue, 23 Jun 2020 18:39:48 +0000 (11:39 -0700)
* Port graph-runtime to new API

* --amend

* Fix file lint

* Remove old travis file

* Add @kazum's patch

* Update rust/tvm-sys/src/datatype.rs

Co-authored-by: Andrew <amcharg@gmail.com>
Co-authored-by: Andrew <amcharg@gmail.com>
37 files changed:
rust/Cargo.toml
rust/runtime/src/graph.rs
rust/tvm-graph-rt/Cargo.toml [new file with mode: 0644]
rust/tvm-graph-rt/src/allocator.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/array.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/errors.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/graph.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/lib.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/module/dso.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/module/mod.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/module/syslib.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/threading.rs [new file with mode: 0644]
rust/tvm-graph-rt/src/workspace.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/.gitignore [new file with mode: 0644]
rust/tvm-graph-rt/tests/build_model.py [new file with mode: 0755]
rust/tvm-graph-rt/tests/test_graph_serde.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_nn/Cargo.toml [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_nn/build.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py [new file with mode: 0755]
rust/tvm-graph-rt/tests/test_nn/src/main.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_tvm_basic/Cargo.toml [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_tvm_basic/build.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py [new file with mode: 0755]
rust/tvm-graph-rt/tests/test_tvm_basic/src/main.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_tvm_dso/Cargo.toml [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_tvm_dso/build.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py [new file with mode: 0755]
rust/tvm-graph-rt/tests/test_tvm_dso/src/main.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_wasm32/.cargo/config [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_wasm32/build.rs [new file with mode: 0644]
rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py [new file with mode: 0755]
rust/tvm-graph-rt/tests/test_wasm32/src/main.rs [new file with mode: 0644]
rust/tvm-sys/build.rs
rust/tvm-sys/src/array.rs
rust/tvm-sys/src/datatype.rs
tests/lint/check_file_type.py

index d9bb3ab..9542178 100644 (file)
@@ -32,4 +32,9 @@ members = [
        "tvm-macros",
        "tvm-rt",
        "tvm",
+       "tvm-graph-rt",
+       "tvm-graph-rt/tests/test_tvm_basic",
+       "tvm-graph-rt/tests/test_tvm_dso",
+       "tvm-graph-rt/tests/test_wasm32",
+       "tvm-graph-rt/tests/test_nn",
 ]
index 71541ba..c1f44ef 100644 (file)
@@ -24,6 +24,7 @@ use nom::{
     character::complete::{alpha1, digit1},
     number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
 };
+
 use serde;
 use serde_json;
 use tvm_common::{
diff --git a/rust/tvm-graph-rt/Cargo.toml b/rust/tvm-graph-rt/Cargo.toml
new file mode 100644 (file)
index 0000000..0cf2ac1
--- /dev/null
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "tvm-graph-rt"
+version = "0.1.0"
+license = "Apache-2.0"
+description = "A static graph runtime for TVM."
+repository = "https://github.com/apache/incubator-tvm"
+readme = "README.md"
+keywords = ["tvm"]
+categories = ["api-bindings", "science"]
+authors = ["TVM Contributors"]
+edition = "2018"
+
+[dependencies]
+crossbeam = "0.7.3"
+failure = "0.1"
+itertools = "0.8"
+lazy_static = "1.4"
+ndarray="0.12"
+nom = "5.0"
+num_cpus = "1.10"
+serde = { version = "^1.0", features = ["derive"] }
+serde_json = "^1.0"
+tvm-sys = { version = "0.1", path = "../tvm-sys" }
+tvm-macros = { version = "0.1", path = "../tvm-macros" }
+
+[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
+libloading = "0.5"
diff --git a/rust/tvm-graph-rt/src/allocator.rs b/rust/tvm-graph-rt/src/allocator.rs
new file mode 100644 (file)
index 0000000..81499af
--- /dev/null
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::alloc::{self, Layout, LayoutErr};
+
+const DEFAULT_ALIGN_BYTES: usize = 4;
+
+#[derive(PartialEq, Eq)]
+pub struct Allocation {
+    layout: Layout,
+    ptr: *mut u8,
+}
+
+impl Allocation {
+    /// Allocates a chunk of memory of `size` bytes with optional alignment.
+    pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> {
+        let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
+        let layout = Layout::from_size_align(size, alignment)?;
+        let ptr = unsafe { alloc::alloc(layout) };
+        if ptr.is_null() {
+            alloc::handle_alloc_error(layout);
+        }
+        Ok(Self { ptr, layout })
+    }
+
+    pub fn as_mut_ptr(&self) -> *mut u8 {
+        self.ptr
+    }
+
+    /// Returns the size of the Allocation in bytes.
+    pub fn size(&self) -> usize {
+        self.layout.size()
+    }
+
+    /// Returns the byte alignment of the Allocation.
+    pub fn align(&self) -> usize {
+        self.layout.align()
+    }
+
+    /// Returns a view of the Allocation.
+    pub fn as_slice(&self) -> &[u8] {
+        unsafe { std::slice::from_raw_parts(self.as_mut_ptr(), self.size()) }
+    }
+
+    /// Returns a mutable view of the Allocation.
+    pub fn as_mut_slice(&mut self) -> &mut [u8] {
+        unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
+    }
+}
+
+impl Drop for Allocation {
+    fn drop(&mut self) {
+        unsafe {
+            alloc::dealloc(self.ptr, self.layout);
+        }
+    }
+}
diff --git a/rust/tvm-graph-rt/src/array.rs b/rust/tvm-graph-rt/src/array.rs
new file mode 100644 (file)
index 0000000..8209b59
--- /dev/null
@@ -0,0 +1,401 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
+
+use failure::{ensure, Error};
+use ndarray;
+use tvm_sys::{ffi::DLTensor, Context, DataType};
+
+use crate::allocator::Allocation;
+
+/// A `Storage` is a container which holds `Tensor` data.
+#[derive(PartialEq)]
+pub enum Storage<'a> {
+    /// A `Storage` which owns its contained bytes.
+    Owned(Allocation),
+
+    /// A view of an existing `Storage`.
+    View(&'a mut [u8], usize), // ptr, align
+}
+
+impl<'a> Storage<'a> {
+    pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
+        Ok(Storage::Owned(Allocation::new(size, align)?))
+    }
+
+    pub fn as_mut_ptr(&self) -> *mut u8 {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_ptr(),
+            Storage::View(slice, _) => slice.as_ptr() as *mut u8,
+        }
+    }
+
+    pub fn size(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.size(),
+            Storage::View(slice, _) => slice.len(),
+        }
+    }
+
+    pub fn align(&self) -> usize {
+        match self {
+            Storage::Owned(alloc) => alloc.align(),
+            Storage::View(_, align) => *align,
+        }
+    }
+
+    pub fn as_ptr(&self) -> *const u8 {
+        self.as_mut_ptr() as *const _
+    }
+
+    /// Returns a `Storage::View` which points to an owned `Storage::Owned`.
+    pub fn view(&self) -> Storage<'a> {
+        match self {
+            Storage::Owned(alloc) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) },
+                self.align(),
+            ),
+            Storage::View(slice, _) => Storage::View(
+                unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) },
+                self.align(),
+            ),
+        }
+    }
+
+    pub fn is_owned(&self) -> bool {
+        match self {
+            Storage::Owned(_) => true,
+            _ => false,
+        }
+    }
+
+    /// Returns an owned version of this storage via cloning.
+    pub fn to_owned(&self) -> Storage<'static> {
+        let s = Storage::new(self.size(), Some(self.align())).unwrap();
+        unsafe {
+            s.as_mut_ptr()
+                .copy_from_nonoverlapping(self.as_ptr(), self.size());
+        }
+        s
+    }
+
+    /// Returns a view of the stored data.
+    pub fn as_slice(&self) -> &[u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_slice(),
+            Storage::View(slice, _) => &*slice,
+        }
+    }
+
+    /// Returns a mutable view of the stored data.
+    pub fn as_mut_slice(&mut self) -> &mut [u8] {
+        match self {
+            Storage::Owned(alloc) => alloc.as_mut_slice(),
+            Storage::View(slice, _) => slice,
+        }
+    }
+}
+
+impl<'d, 's, T> From<&'d [T]> for Storage<'s> {
+    fn from(data: &'d [T]) -> Self {
+        let data = unsafe {
+            slice::from_raw_parts_mut(
+                data.as_ptr() as *const u8 as *mut u8,
+                data.len() * mem::size_of::<T>() as usize,
+            )
+        };
+        Storage::View(data, mem::align_of::<T>())
+    }
+}
+
+/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`.
+/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or
+/// converted to `ndarray::Array` for non-TVM processing.
+///
+/// # Examples
+///
+/// ```
+/// extern crate ndarray;
+/// use std::convert::TryInto;
+/// use tvm_runtime::{call_packed, DLTensor, ArgValue, RetValue, Tensor};
+///
+/// let mut a_nd: ndarray::Array1<f32> = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// let mut a: Tensor = a_nd.into();
+/// let mut a_dl: DLTensor = (&mut a).into();
+///
+/// let tvm_fn = |args: &[ArgValue]| -> Result<RetValue, ()> { Ok(RetValue::default()) };
+/// call_packed!(tvm_fn, &mut a_dl);
+///
+/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs.
+/// let mut a_nd: ndarray::ArrayD<f32> = a.try_into().unwrap();
+/// ```
+#[derive(PartialEq)]
+pub struct Tensor<'a> {
+    /// The bytes which contain the data this `Tensor` represents.
+    pub(crate) data: Storage<'a>,
+    pub(crate) ctx: Context,
+    pub(crate) dtype: DataType,
+    pub(crate) shape: Vec<i64>,
+    // ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h
+    /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous.
+    pub(crate) strides: Option<Vec<usize>>,
+    pub(crate) byte_offset: isize,
+    /// The number of elements in the `Tensor`.
+    pub(crate) size: usize,
+}
+
+unsafe impl<'a> Send for Tensor<'a> {}
+
+impl<'a> Tensor<'a> {
+    pub fn shape(&self) -> Vec<i64> {
+        self.shape.clone()
+    }
+
+    pub fn data(&self) -> &Storage {
+        &self.data
+    }
+
+    pub fn data_mut(&mut self) -> &'a mut Storage {
+        &mut self.data
+    }
+
+    /// Returns the data of this `Tensor` as a `Vec`.
+    ///
+    /// # Panics
+    ///
+    /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
+    pub fn to_vec<T: 'static + std::fmt::Debug + Clone>(&self) -> Vec<T> {
+        assert!(self.is_contiguous());
+        assert!(self.dtype.is_type::<T>());
+        unsafe { slice::from_raw_parts(self.data.as_ptr() as *const T, self.size).to_vec() }
+    }
+
+    /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory.
+    pub fn is_contiguous(&self) -> bool {
+        match self.strides {
+            None => true,
+            Some(ref strides) => {
+                // check that stride for each dimension is the
+                // product of all trailing dimensons' shapes
+                self.shape
+                    .iter()
+                    .zip(strides)
+                    .rfold(
+                        (true, 1),
+                        |(is_contig, expected_stride), (shape, stride)| {
+                            (
+                                is_contig && *stride == expected_stride,
+                                expected_stride * (*shape as usize),
+                            )
+                        },
+                    )
+                    .0
+            }
+        }
+    }
+
+    /// Returns a clone of this `Tensor`.
+    ///
+    /// # Panics
+    ///
+    /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`.
+    pub fn copy(&mut self, other: &Tensor) {
+        assert!(
+            self.dtype == other.dtype && self.size == other.size,
+            "Tensor shape/dtype mismatch."
+        );
+        assert!(
+      self.is_contiguous() && other.is_contiguous(),
+      "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`",
+      self.strides,
+      other.strides
+    );
+        unsafe {
+            self.data
+                .as_mut_ptr()
+                .offset(self.byte_offset as isize)
+                .copy_from_nonoverlapping(
+                    other.data.as_mut_ptr().offset(other.byte_offset),
+                    other.size * other.dtype.itemsize(),
+                );
+        }
+    }
+
+    /// Returns an owned version of this `Tensor` via cloning.
+    pub fn to_owned(&self) -> Tensor<'static> {
+        let t = Tensor {
+            data: self.data.to_owned(),
+            ctx: self.ctx,
+            dtype: self.dtype,
+            size: self.size,
+            shape: self.shape.clone(),
+            strides: None,
+            byte_offset: 0,
+        };
+        unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
+    }
+
+    fn from_array_storage<'s, T, D: ndarray::Dimension>(
+        arr: &ndarray::Array<T, D>,
+        storage: Storage<'s>,
+        dtype_fn: fn(u8, u16) -> DataType,
+    ) -> Tensor<'s> {
+        let type_width = mem::size_of::<T>() as u8;
+
+        Tensor {
+            data: storage,
+            ctx: Context::default(),
+            dtype: dtype_fn(8 * type_width, 1),
+            size: arr.len(),
+            shape: arr.shape().iter().map(|&v| v as i64).collect(),
+            strides: Some(arr.strides().iter().map(|&v| v as usize).collect()),
+            byte_offset: 0,
+        }
+    }
+
+    pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
+        assert!(!flatten || self.is_contiguous());
+        DLTensor {
+            data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
+            ctx: self.ctx.into(),
+            ndim: if flatten { 1 } else { self.shape.len() } as i32,
+            dtype: self.dtype.into(),
+            shape: if flatten {
+                &self.size as *const _ as *mut i64
+            } else {
+                self.shape.as_ptr()
+            } as *mut i64,
+            strides: if flatten || self.is_contiguous() {
+                ptr::null_mut()
+            } else {
+                self.strides.as_ref().unwrap().as_ptr()
+            } as *mut i64,
+            byte_offset: 0,
+            ..Default::default()
+        }
+    }
+}
+
+/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
+macro_rules! impl_ndarray_try_from_tensor {
+    ($type:ty, $dtype:expr) => {
+        impl<'t> TryFrom<Tensor<'t>> for ndarray::ArrayD<$type> {
+            type Error = Error;
+            fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
+                ensure!(
+                    tensor.dtype == $dtype,
+                    "Cannot convert Tensor with dtype {:?} to ndarray",
+                    tensor.dtype
+                );
+                Ok(ndarray::Array::from_shape_vec(
+                    tensor
+                        .shape
+                        .iter()
+                        .map(|s| *s as usize)
+                        .collect::<Vec<usize>>(),
+                    tensor.to_vec::<$type>(),
+                )?)
+            }
+        }
+    };
+}
+
+macro_rules! make_dtype_const {
+    ($name: ident, $cnst:expr) => {
+        pub const $name: DataType = $cnst;
+    };
+}
+
+make_dtype_const!(DTYPE_INT32, DataType::int(32, 1));
+make_dtype_const!(DTYPE_UINT32, DataType::uint(32, 1));
+make_dtype_const!(DTYPE_FLOAT32, DataType::float(32, 1));
+make_dtype_const!(DTYPE_FLOAT64, DataType::float(64, 1));
+impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
+impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
+impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
+impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
+
+impl<'a, 't> From<&'a Tensor<'t>> for DLTensor {
+    fn from(tensor: &'a Tensor<'t>) -> Self {
+        Tensor::as_dltensor(tensor, false /* flatten */)
+    }
+}
+
+impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor {
+    fn from(tensor: &'a mut Tensor<'t>) -> Self {
+        Tensor::as_dltensor(tensor, false /* flatten */)
+    }
+}
+
+impl<'a> From<DLTensor> for Tensor<'a> {
+    fn from(dlt: DLTensor) -> Self {
+        unsafe {
+            let dtype = DataType::from(dlt.dtype);
+            let shape = slice::from_raw_parts(dlt.shape, dlt.ndim as usize).to_vec();
+            let size = shape.iter().map(|v| *v as usize).product::<usize>() as usize;
+            let storage = Storage::from(slice::from_raw_parts(
+                dlt.data as *const u8,
+                dtype.itemsize() * size,
+            ));
+            Self {
+                data: storage,
+                ctx: Context::default(),
+                dtype,
+                size,
+                shape,
+                strides: if dlt.strides.is_null() {
+                    None
+                } else {
+                    Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec())
+                },
+                byte_offset: dlt.byte_offset as isize,
+            }
+        }
+    }
+}
+
+/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
+///
+/// # Panics
+///
+/// Panics if the ndarray is not contiguous.
+macro_rules! impl_tensor_from_ndarray {
+    ($type:ty, $dtype_fn:expr) => {
+        impl<D: ndarray::Dimension> From<ndarray::Array<$type, D>> for Tensor<'static> {
+            fn from(arr: ndarray::Array<$type, D>) -> Self {
+                let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
+                Tensor::from_array_storage(&arr, storage.to_owned(), $dtype_fn)
+            }
+        }
+        impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
+            fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
+                let storage = Storage::from(arr.as_slice().expect("NDArray must be contiguous"));
+                Tensor::from_array_storage(arr, storage, $dtype_fn)
+            }
+        }
+    };
+}
+
+impl_tensor_from_ndarray!(f32, DataType::float);
+impl_tensor_from_ndarray!(f64, DataType::float);
+impl_tensor_from_ndarray!(i32, DataType::int);
+impl_tensor_from_ndarray!(i64, DataType::int);
+impl_tensor_from_ndarray!(u32, DataType::uint);
+impl_tensor_from_ndarray!(u64, DataType::uint);
diff --git a/rust/tvm-graph-rt/src/errors.rs b/rust/tvm-graph-rt/src/errors.rs
new file mode 100644 (file)
index 0000000..d82da15
--- /dev/null
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use failure::Fail;
+
+#[derive(Debug, Fail)]
+pub enum GraphFormatError {
+    #[fail(display = "Could not parse graph json")]
+    Parse(#[fail(cause)] failure::Error),
+    #[fail(display = "Could not parse graph params")]
+    Params,
+    #[fail(display = "{} is missing attr: {}", 0, 1)]
+    MissingAttr(String, String),
+    #[fail(display = "Missing field: {}", 0)]
+    MissingField(&'static str),
+    #[fail(display = "Invalid DLType: {}", 0)]
+    InvalidDLType(String),
+}
diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs
new file mode 100644 (file)
index 0000000..895739d
--- /dev/null
@@ -0,0 +1,495 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
+
+use failure::{ensure, format_err, Error};
+use itertools::izip;
+use nom::{
+    character::complete::{alpha1, digit1},
+    complete, count, do_parse, length_count, map, named,
+    number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
+    opt, tag, take, tuple,
+};
+
+use serde::{Deserialize, Serialize};
+use serde_json;
+
+use tvm_sys::ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt};
+
+use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType};
+
+use crate::{errors::GraphFormatError, Module, Storage, Tensor};
+
+// @see `kTVMNDArrayMagic` in `ndarray.h`
+const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F;
+// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
+const _NDARRAY_LIST_MAGIC: u64 = 0xF7E5_8D4F_0504_9CB7;
+
+/// A TVM computation graph.
+///
+/// # Examples
+///
+/// ```norun
+/// let graph_json = fs::read_to_string("graph.json").unwrap();
+/// let graph = Graph::try_from(&graph_json).unwrap();
+/// ```
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Graph {
+    pub nodes: Vec<Node>,
+    pub arg_nodes: Vec<usize>,
+    pub heads: Vec<Entry>,
+    pub node_row_ptr: Option<Vec<usize>>,
+    pub attrs: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Entry {
+    pub id: usize,
+    pub index: usize,
+    pub version: usize,
+}
+
+impl Graph {
+    fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
+        self.node_row_ptr
+            .as_ref()
+            .map(|nrp| nrp[entry.id] + entry.index)
+            .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
+    }
+
+    /// Attempt to deserialize a JSON attribute to a type `T`.
+    fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
+        Ok(serde_json::from_value::<T>(
+            self.attrs
+                .as_ref()
+                .ok_or(GraphFormatError::MissingField("attrs"))?
+                .get(attr)
+                .ok_or_else(|| {
+                    GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
+                })?
+                .to_owned(),
+        )
+        .map_err(|err| GraphFormatError::Parse(err.into()))?)
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Node {
+    pub op: String,
+    pub name: String,
+    pub inputs: Vec<Entry>,
+    pub attrs: Option<HashMap<String, String>>,
+    pub control_deps: Option<Vec<Entry>>,
+}
+
+struct NodeAttrs {
+    func_name: String,
+    num_outputs: usize,
+    flatten_data: bool,
+}
+
+macro_rules! get_node_attr {
+    ($node:expr, $attrs:ident, $attr:literal) => {
+        $attrs
+            .get($attr)
+            .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
+    };
+}
+
+impl Node {
+    fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
+        let attrs = self
+            .attrs
+            .as_ref()
+            .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
+        Ok(NodeAttrs {
+            func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
+            num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
+            flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
+        })
+    }
+}
+
+impl<'a> TryFrom<&'a String> for Graph {
+    type Error = Error;
+    fn try_from(graph_json: &String) -> Result<Self, self::Error> {
+        let graph = serde_json::from_str(graph_json)?;
+        Ok(graph)
+    }
+}
+
+impl<'a> TryFrom<&'a str> for Graph {
+    type Error = Error;
+    fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
+        let graph = serde_json::from_str(graph_json)?;
+        Ok(graph)
+    }
+}
+
+/// A executor for a TVM computation graph.
+///
+/// # Examples
+///
+/// ```norun
+/// use ndarray::Array;
+///
+/// let syslib = SystemLibModule::default(); // a provider of TVM functions
+///
+/// let mut params_bytes = Vec::new();
+/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
+/// let params = tvm::runtime::load_param_dict(&params_bytes).unwrap();
+///
+/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
+///
+/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
+/// exec.load_params(params);
+///
+/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
+/// exec.set_input("data", x.into());
+/// exec.run();
+/// let output = exec.get_output(0).unwrap();
+///
+/// println!("{:#?}", Array::try_from(output).unwrap());
+/// ```
+pub struct GraphExecutor<'m, 't> {
+    graph: Graph,
+    op_execs: Vec<Box<dyn Fn() + 'm>>,
+    tensors: Vec<Tensor<'t>>,
+}
+
+unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
+
+impl<'m, 't> GraphExecutor<'m, 't> {
+    pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
+        let tensors = Self::setup_storages(&graph)?;
+        Ok(GraphExecutor {
+            op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
+            tensors,
+            graph,
+        })
+    }
+
+    /// Runs the computation graph.
+    pub fn run(&mut self) {
+        self.op_execs.iter().for_each(|op_exec| {
+            op_exec();
+        });
+    }
+
+    /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
+    fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
+        let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
+        let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
+        let dtypes = graph
+            .get_attr::<(String, Vec<String>)>("dltype")?
+            .1
+            .iter()
+            .map(|dltype| {
+                if let Ok((_, dtype)) = tvm_str_to_type(dltype) {
+                    Ok(dtype)
+                } else {
+                    Err(GraphFormatError::InvalidDLType(dltype.to_string()))
+                }
+            })
+            .collect::<Result<Vec<DataType>, GraphFormatError>>()?;
+
+        let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max();
+        let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
+        for (i, &storage_id) in storage_ids.iter().enumerate() {
+            let dtype_size = (dtypes[i].bits() * dtypes[i].lanes()) >> 3;
+            let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
+            storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
+        }
+
+        let mut storages: Vec<Storage> = storage_num_bytes
+            .into_iter()
+            .map(|nbytes| Storage::new(nbytes, align))
+            .collect::<Result<Vec<Storage>, Error>>()?;
+
+        let tensors = izip!(storage_ids, shapes, dtypes)
+            .map(|(storage_id, shape, dtype)| {
+                let storage = storages[storage_id].view();
+                Tensor {
+                    data: mem::replace(&mut storages[storage_id], storage),
+                    ctx: Context::default(),
+                    dtype,
+                    size: shape.iter().product::<i64>() as usize,
+                    shape,
+                    strides: None,
+                    byte_offset: 0,
+                }
+            })
+            .collect();
+
+        Ok(tensors)
+    }
+
+    /// Creates closures which represent the computation performed by this graph.
+    fn setup_op_execs<M: 'm + Module>(
+        graph: &Graph,
+        lib: &'m M,
+        tensors: &[Tensor<'t>],
+    ) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {
+        ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
+        let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
+
+        let mut op_execs = Vec::new();
+        for (i, node) in graph.nodes.iter().enumerate() {
+            if node.op == "null" {
+                continue;
+            }
+            ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
+            ensure!(node.attrs.is_some(), "Missing node attrs.");
+
+            let attrs = node.parse_attrs()?;
+
+            if attrs.func_name == "__nop" {
+                continue;
+            }
+
+            let func = lib
+                .get_function(&attrs.func_name)
+                .ok_or_else(|| format_err!("Library is missing function {}", attrs.func_name))?;
+            let arg_indices = node
+                .inputs
+                .iter()
+                .map(|entry| graph.entry_index(entry))
+                .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i] + oi)));
+
+            let dl_tensors = arg_indices
+                .map(|idx| {
+                    let tensor = &tensors[idx?];
+                    Ok(if attrs.flatten_data {
+                        Tensor::as_dltensor(tensor, true /* flatten */)
+                    } else {
+                        DLTensor::from(tensor)
+                    })
+                })
+                .collect::<Result<Vec<DLTensor>, Error>>()
+                .unwrap();
+            let op: Box<dyn Fn()> = Box::new(move || {
+                let args = dl_tensors
+                    .iter()
+                    .map(|t| t.into())
+                    .collect::<Vec<ArgValue>>();
+                func(&args).unwrap();
+            });
+            op_execs.push(op);
+        }
+        Ok(op_execs)
+    }
+
+    pub fn load_params(&mut self, params: HashMap<String, Tensor>) {
+        params.into_iter().for_each(|(name, param)| {
+            self.set_input(name, param);
+        })
+    }
+
+    #[allow(clippy::if_same_then_else)]
+    pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor) {
+        if let Some(idx) = self.get_input_index(name.as_ref()) {
+            // TODO: consider `new_with_params` to avoid ever allocating
+            let ptr = self.tensors[idx].data.as_ptr();
+            let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
+            let owner = to_replace.nth(0).unwrap();
+            if value.data.is_owned() {
+                // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
+                // mem::replace(&mut (*owner), value);
+                // to_replace.for_each(|t| {
+                //   panic!("replacing");
+                //   t.data = owner.data.view();
+                // });
+                owner.copy(&value);
+            } else {
+                owner.copy(&value);
+            }
+        } else {
+            println!("Unexpected input `{}`", name.as_ref());
+        }
+    }
+
+    /// Returns the graph input with name `name`, if it exists.
+    pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
+        self.get_input_index(name.as_ref())
+            .map(move |idx| &self.tensors[idx])
+    }
+
+    /// Returns the graph output with index `index`, if it exists.
+    pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
+        let graph = &self.graph;
+        graph.heads.get(idx).and_then(|entry| {
+            graph
+                .entry_index(entry)
+                .map(|idx| self.tensors.get(idx))
+                .unwrap_or(None)
+        })
+    }
+
+    /// Returns the index for graph input with name `name`, if it exists.
+    pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
+        let graph = &self.graph;
+        (0..graph.nodes.len())
+            .skip_while(|&i| graph.nodes[i].name != name.as_ref())
+            .nth(0)
+            .and_then(|i| {
+                if graph.arg_nodes.iter().any(|&id| id == i) {
+                    graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
+                } else {
+                    None
+                }
+            })
+    }
+}
+
+// Converts a string to TVM DLDataTypeCode. @see `String2DLDataType` in packed_func.h
+named! {
+  tvm_str_to_type<&str, DataType>,
+  do_parse!(
+    type_name: alpha1 >>
+    bits:      digit1 >>
+    lanes:     opt!(complete!(tuple!(tag!("x"), digit1))) >>
+    (
+        DataType::new(
+            match type_name {
+                "int" => DLDataTypeCode_kDLInt,
+                "uint" => DLDataTypeCode_kDLUInt,
+                "float" => DLDataTypeCode_kDLFloat,
+                _ => DLDataTypeCode_kDLFloat,
+            } as u8,
+            bits.parse::<u8>().unwrap() as u8,
+            lanes
+                .map(|(_, lanes)| lanes.parse::<u16>().unwrap() as u16)
+                .unwrap_or(1),
+        )
+    )
+  )
+}
+
+// Converts a bytes to String.
+named! {
+    name<String>,
+    do_parse!(
+        len_l: le_u32 >>
+        len_h: le_u32 >>
+        data: take!(len_l) >>
+        (
+            if len_h == 0 {
+                String::from_utf8(data.to_vec()).unwrap()
+            } else {
+                panic!("Too long string")
+            }
+        )
+    )
+}
+
+// Parses a Context
+named! {
+  tvm_ctx<&[u8], Context>,
+  do_parse!(
+    device_type: le_u32 >>
+    device_id:   le_i32 >>
+    (
+        Context {
+            device_type: DeviceType::from(device_type),
+            device_id: device_id as usize,
+        }
+    )
+  )
+}
+
+// Parses a DataType
+named! {
+  data_type<&[u8], DataType>,
+  do_parse!(
+    code:  le_u8  >>
+    bits:  le_u8  >>
+    lanes: le_u16 >>
+    (DataType::new(code, bits, lanes)))
+}
+
+// Parses a Tensor from a TVM array file.
+named! {
+    tensor<Tensor>,
+    do_parse!(
+                take!(8)      >>
+                le_u64        >>
+        ctx:    tvm_ctx       >>
+        ndim:   le_u32        >>
+        dtype:  data_type     >>
+        shape:  count!(map!(le_i64, |sz| sz as i64), ndim as usize) >>
+        length: le_i64        >>
+        data:   take!(length) >>
+        (
+            Tensor {
+                data: Storage::from(data),
+                ctx: ctx,
+                dtype: dtype,
+                size: shape.iter().product::<i64>() as usize,
+                shape: shape,
+                strides: None,
+                byte_offset: 0,
+            }
+        )
+    )
+}
+
+// Parses a graph params dict from a params binary file.
+named! {
+    parse_param_dict<HashMap<String, Tensor>>,
+    do_parse!(
+                 take!(8)                      >>
+                 le_u64                        >>
+        names:   length_count!(le_u64, name)   >>
+        tensors: length_count!(le_u64, tensor) >>
+        (
+            HashMap::from_iter(names.into_iter().zip(tensors.into_iter()))
+        )
+    )
+}
+
+/// Loads a param dict saved using `relay.save_param_dict`.
+pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
+    if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
+        if remaining_bytes.is_empty() {
+            Ok(param_dict)
+        } else {
+            Err(GraphFormatError::Params)
+        }
+    } else {
+        Err(GraphFormatError::Params)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_str_to_type() {
+        assert_eq!(
+            tvm_str_to_type("float24").unwrap().1,
+            DataType::float(24, 1)
+        );
+        assert_eq!(
+            tvm_str_to_type("uint111x44").unwrap().1,
+            DataType::uint(111, 44)
+        );
+    }
+}
diff --git a/rust/tvm-graph-rt/src/lib.rs b/rust/tvm-graph-rt/src/lib.rs
new file mode 100644 (file)
index 0000000..0e3db52
--- /dev/null
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`.
+//! It's mainly useful for compiling to WebAssembly and SGX,
+//! but also native if you prefer Rust to C++.
+//!
+//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`.
+//! Single-function modules are used via the `packed_func!` macro after obtaining
+//! the function from `runtime::SystemLibModule`
+//!
+//! The main entrypoints to this crate are `GraphExecutor`
+//! For examples of use, please refer to the multi-file tests in the `tests` directory.
+
+use lazy_static::lazy_static;
+
+mod allocator;
+mod array;
+pub mod errors;
+mod graph;
+mod module;
+mod threading;
+mod workspace;
+
+pub use tvm_macros::import_module;
+pub use tvm_sys::{
+    call_packed,
+    errors::*,
+    ffi::{self, DLTensor},
+    packed_func::{self, *},
+    ArgValue, RetValue,
+};
+
+pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*};
+
+lazy_static! {
+    static ref LAST_ERROR: std::sync::RwLock<Option<&'static std::ffi::CStr>> =
+        std::sync::RwLock::new(None);
+}
+
+#[no_mangle]
+pub unsafe extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
+    *LAST_ERROR.write().unwrap() = Some(std::ffi::CStr::from_ptr(cmsg));
+}
+
+#[no_mangle]
+pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char {
+    match *LAST_ERROR.read().unwrap() {
+        Some(err) => err.as_ptr(),
+        None => std::ptr::null(),
+    }
+}
diff --git a/rust/tvm-graph-rt/src/module/dso.rs b/rust/tvm-graph-rt/src/module/dso.rs
new file mode 100644 (file)
index 0000000..51645d5
--- /dev/null
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    cell::RefCell,
+    collections::HashMap,
+    ffi::CStr,
+    os::raw::{c_char, c_int, c_void},
+    pin::Pin,
+};
+
+use tvm_sys::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
+
+use crate::{
+    threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch},
+    workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace},
+    TVMAPISetLastError,
+};
+
+use super::Module;
+
+const TVM_MAIN: &[u8] = b"__tvm_main__";
+const TVM_MODULE_CTX: &[u8] = b"__tvm_module_ctx";
+
+/// A module backed by a Dynamic Shared Object (dylib).
+pub struct DsoModule<'a> {
+    lib: libloading::Library,
+    packed_funcs: RefCell<HashMap<String, &'a (dyn PackedFunc)>>,
+    _pin: std::marker::PhantomPinned,
+}
+
+macro_rules! init_context_func {
+    ($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => {
+        unsafe {
+            $(
+                let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes());
+                if let Ok(fn_ptr) = fn_ptr {
+                    **fn_ptr = $fn;
+                }
+            )+
+        }
+    };
+}
+
+impl<'a> DsoModule<'a> {
+    pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
+        let lib = libloading::Library::new(filename)?;
+
+        init_context_func!(
+            lib,
+            (TVMAPISetLastError, unsafe extern "C" fn(*const i8)),
+            (
+                TVMBackendAllocWorkspace,
+                unsafe extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
+            ),
+            (
+                TVMBackendFreeWorkspace,
+                unsafe extern "C" fn(c_int, c_int, *mut c_void) -> c_int
+            ),
+            (
+                TVMBackendParallelLaunch,
+                unsafe extern "C" fn(
+                    crate::threading::FTVMParallelLambda,
+                    *const c_void,
+                    usize,
+                ) -> c_int
+            ),
+            (
+                TVMBackendParallelBarrier,
+                unsafe extern "C" fn(usize, *const tvm_sys::ffi::TVMParallelGroupEnv)
+            ),
+        );
+
+        // Pin the module in memory so that `ctx` pointer (below) is stable.
+        let dso_mod = Box::pin(Self {
+            lib,
+            packed_funcs: RefCell::new(HashMap::new()),
+            _pin: std::marker::PhantomPinned,
+        });
+
+        unsafe {
+            if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) {
+                **ctx = &dso_mod as *const _ as *const c_void;
+            }
+        }
+
+        Ok(dso_mod)
+    }
+}
+
+impl<'a> Module for DsoModule<'a> {
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
+        let name = name.as_ref();
+        let func = match unsafe {
+            self.lib
+                .get::<BackendPackedCFunc>(if name.as_bytes() == TVM_MAIN {
+                    // If __tvm_main__ is present, it contains the name of the
+                    // actual main function.
+                    match self
+                        .lib
+                        .get::<*const c_char>(TVM_MAIN)
+                        .map(|p| CStr::from_ptr(*p))
+                    {
+                        Ok(m) => m.to_bytes(),
+                        _ => return None,
+                    }
+                } else {
+                    name.as_bytes()
+                })
+        } {
+            Ok(func) => unsafe { func.into_raw() },
+            Err(_) => return None,
+        };
+
+        self.packed_funcs.borrow_mut().insert(
+            name.to_string(),
+            &*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)),
+        );
+
+        self.packed_funcs.borrow().get(name).copied()
+    }
+}
+
+impl<'a> Drop for DsoModule<'a> {
+    fn drop(&mut self) {
+        self.packed_funcs
+            .replace(HashMap::new())
+            .into_iter()
+            .map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) })
+            .for_each(std::mem::drop);
+    }
+}
diff --git a/rust/tvm-graph-rt/src/module/mod.rs b/rust/tvm-graph-rt/src/module/mod.rs
new file mode 100644 (file)
index 0000000..511ba4b
--- /dev/null
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
+mod dso;
+mod syslib;
+
+use tvm_sys::{
+    ffi::BackendPackedCFunc,
+    packed_func::{ArgValue, PackedFunc, RetValue, TVMValue},
+};
+
+#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
+pub use dso::DsoModule;
+pub use syslib::SystemLibModule;
+
+pub trait Module {
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
+}
+
+// @see `WrapPackedFunc` in `llvm_module.cc`.
+fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> {
+    Box::new(move |args: &[ArgValue]| {
+        let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
+            .iter()
+            .map(|arg| {
+                let (val, code) = arg.to_tvm_value();
+                (val, code as i32)
+            })
+            .unzip();
+        let ret: RetValue = RetValue::default();
+        let (mut ret_val, mut ret_type_code) = ret.to_tvm_value();
+        let exit_code = func(
+            values.as_ptr(),
+            type_codes.as_ptr(),
+            values.len() as i32,
+            &mut ret_val,
+            &mut ret_type_code,
+        );
+        if exit_code == 0 {
+            Ok(RetValue::from_tvm_value(ret_val, ret_type_code))
+        } else {
+            Err(tvm_sys::errors::FuncCallError::get_with_context(
+                func_name.clone(),
+            ))
+        }
+    })
+}
diff --git a/rust/tvm-graph-rt/src/module/syslib.rs b/rust/tvm-graph-rt/src/module/syslib.rs
new file mode 100644 (file)
index 0000000..0279e31
--- /dev/null
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
+};
+
+use lazy_static::lazy_static;
+
+use tvm_sys::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
+
+use super::Module;
+
+pub struct SystemLibModule;
+
+#[cfg(target_env = "sgx")]
+extern "C" {
+    fn __tvm_module_startup();
+}
+
+lazy_static! {
+    static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
+        Mutex::new(HashMap::new());
+}
+
+impl Module for SystemLibModule {
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
+        SYSTEM_LIB_FUNCTIONS
+            .lock()
+            .unwrap()
+            .get(name.as_ref())
+            .copied()
+    }
+}
+
+impl Default for SystemLibModule {
+    fn default() -> Self {
+        #[cfg(target_env = "sgx")]
+        unsafe {
+            __tvm_module_startup();
+        }
+        SystemLibModule {}
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
+    cname: *const c_char,
+    func: BackendPackedCFunc,
+) -> i32 {
+    let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
+    SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
+        name.to_string(),
+        &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
+    );
+    0
+}
diff --git a/rust/tvm-graph-rt/src/threading.rs b/rust/tvm-graph-rt/src/threading.rs
new file mode 100644 (file)
index 0000000..bda53a8
--- /dev/null
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    os::raw::{c_int, c_void},
+    sync::{
+        atomic::{AtomicUsize, Ordering},
+        Arc, Barrier,
+    },
+    thread::{self, JoinHandle},
+};
+
+#[cfg(not(target_arch = "wasm32"))]
+use std::env;
+
+use crossbeam::channel::{bounded, Receiver, Sender};
+use tvm_sys::ffi::TVMParallelGroupEnv;
+
+pub(crate) type FTVMParallelLambda =
+    extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
+
+/// Holds a parallel job request made by a TVM library function.
+struct Job {
+    cb: FTVMParallelLambda,
+    cdata: *const c_void,
+    req_num_tasks: usize,
+    pending: Arc<AtomicUsize>,
+}
+
+impl Job {
+    /// Splits this job into a number of `Task`s which can be scheduled.
+    fn tasks(&self, num_workers: usize) -> Vec<Task> {
+        let num_tasks = if self.req_num_tasks == 0 {
+            num_workers
+        } else {
+            self.req_num_tasks.min(num_workers)
+        };
+        self.pending.store(num_tasks, Ordering::SeqCst);
+
+        let barrier = Arc::new(Barrier::new(num_tasks));
+
+        (0..num_tasks)
+            .map(move |i| Task {
+                id: i,
+                flambda: self.cb,
+                penv: TVMParallelGroupEnv {
+                    sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
+                    num_task: num_tasks as i32,
+                },
+                cdata: self.cdata,
+                pending: Arc::clone(&self.pending),
+            })
+            .collect()
+    }
+
+    /// Waits for all tasks in this `Job` to be completed.
+    fn wait(&self) {
+        while self.pending.load(Ordering::Acquire) > 0 {
+            thread::yield_now();
+        }
+    }
+}
+
+/// A chunk of work requested by a TVM function.
+struct Task {
+    id: usize,
+    flambda: FTVMParallelLambda,
+    penv: TVMParallelGroupEnv,
+    cdata: *const c_void,
+    pending: Arc<AtomicUsize>,
+}
+unsafe impl Send for Task {}
+unsafe impl Sync for Task {}
+
+impl Task {
+    fn run(self) -> i32 {
+        let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
+        self.pending.fetch_sub(1, Ordering::AcqRel);
+        status
+    }
+}
+
+#[derive(Default)]
+struct Threads {
+    #[allow(unused)]
+    handles: Vec<JoinHandle<()>>,
+    queues: Vec<Sender<Task>>,
+}
+
+impl<'a> Threads {
+    fn launch<F: Sync + Send + FnOnce(Receiver<Task>) + 'static + Copy>(
+        num_threads: usize,
+        cb: F,
+    ) -> Self {
+        let (handles, queues) = (0..num_threads)
+            .map(|_| {
+                let (p, c) = bounded(2);
+                let handle = thread::spawn(move || cb(c.into()));
+                (handle, p)
+            })
+            .unzip();
+        Threads { handles, queues }
+    }
+}
+
+struct ThreadPool {
+    num_workers: usize,
+    #[allow(unused)]
+    threads: Threads,
+}
+
+thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
+
+impl ThreadPool {
+    fn new() -> Self {
+        let num_workers = max_concurrency();
+        ThreadPool {
+            num_workers,
+            threads: Threads::launch(num_workers, ThreadPool::run_worker),
+        }
+    }
+
+    fn launch(&self, job: Job) {
+        let mut tasks = job.tasks(self.num_workers + 1);
+
+        for (i, task) in tasks.split_off(1).into_iter().enumerate() {
+            self.threads.queues[i].send(task).expect("should send");
+        }
+
+        tasks.pop().unwrap().run();
+        job.wait();
+    }
+
+    fn run_worker(queue: Receiver<Task>) {
+        loop {
+            let task = match queue.recv() {
+                Ok(v) => v,
+                Err(_) => break,
+            };
+            let result = task.run();
+            if result == <i32>::min_value() {
+                break;
+            } else if result != 0 {
+                panic!("Error running task.");
+            }
+        }
+    }
+}
+
+#[cfg(not(target_arch = "wasm32"))]
+fn max_concurrency() -> usize {
+    if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or_else(|_| env::var("OMP_NUM_THREADS")) {
+        if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
+            return threads;
+        }
+    }
+    num_cpus::get()
+}
+
+#[cfg(target_arch = "wasm32")]
+fn max_concurrency() -> usize {
+    0 // wasm doesn't support threads yet
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendParallelLaunch(
+    cb: FTVMParallelLambda,
+    cdata: *const c_void,
+    num_task: usize,
+) -> c_int {
+    if max_concurrency() < 2 {
+        let penv = TVMParallelGroupEnv {
+            sync_handle: std::ptr::null_mut(),
+            num_task: 1,
+        };
+        cb(0, &penv as *const _, cdata);
+    } else {
+        THREAD_POOL.with(|pool| {
+            pool.launch(Job {
+                cb,
+                cdata,
+                req_num_tasks: num_task,
+                pending: Arc::new(AtomicUsize::new(0)),
+            });
+        });
+    }
+    0
+}
+
+// @see issue 988 for information on why this function is used.
+#[no_mangle]
+pub unsafe extern "C" fn TVMBackendParallelBarrier(
+    _task_id: usize,
+    penv: *const TVMParallelGroupEnv,
+) {
+    let barrier: &Arc<Barrier> = &*((*penv).sync_handle as *const Arc<Barrier>);
+    barrier.wait();
+}
+
+#[cfg(test)]
+mod tests {
+    use std::{ptr, thread, time::Duration};
+
+    use super::*;
+
+    #[test]
+    fn test_max_concurrency() {
+        env::set_var("TVM_NUM_THREADS", "42");
+        env::set_var("OMP_NUM_THREADS", "24");
+        assert_eq!(max_concurrency(), 42);
+        env::remove_var("TVM_NUM_THREADS");
+        assert_eq!(max_concurrency(), 24);
+    }
+
+    extern "C" fn flambda(
+        task_id: usize,
+        penv: *const TVMParallelGroupEnv,
+        cdata: *const c_void,
+    ) -> i32 {
+        if cdata.is_null() {
+            return 0;
+        }
+        unsafe {
+            let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
+            thread::sleep(Duration::from_millis(50 * task_id as u64));
+            counter.fetch_add(1, Ordering::SeqCst);
+            task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
+            assert_eq!((*penv).num_task, 3);
+        }
+        0
+    }
+
+    #[test]
+    fn test_parallel_launch() {
+        TVMBackendParallelLaunch(flambda, ptr::null(), 6);
+        let counter = AtomicUsize::new(0);
+        let task_ids_sum = AtomicUsize::new(0);
+        let cdata = (counter, task_ids_sum);
+        let num_tasks = 3;
+        TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
+        assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
+        assert_eq!(
+            cdata.1.load(Ordering::SeqCst),
+            (0..num_tasks).sum::<usize>()
+        );
+    }
+}
diff --git a/rust/tvm-graph-rt/src/workspace.rs b/rust/tvm-graph-rt/src/workspace.rs
new file mode 100644 (file)
index 0000000..35cfe91
--- /dev/null
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    cell::RefCell,
+    os::raw::{c_int, c_void},
+    ptr,
+};
+
+use failure::{format_err, Error};
+
+use crate::allocator::Allocation;
+
+const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
+
+pub fn remove_item<T: PartialEq>(vec: &mut Vec<T>, item: &T) -> Option<T> {
+    let pos = vec.iter().position(|x| *x == *item)?;
+    Some(vec.remove(pos))
+}
+
+struct WorkspacePool {
+    workspaces: Vec<Allocation>,
+    free: Vec<usize>,
+    in_use: Vec<usize>,
+}
+
+impl WorkspacePool {
+    fn new() -> Self {
+        WorkspacePool {
+            workspaces: Vec::new(),
+            free: Vec::new(),
+            in_use: Vec::new(),
+        }
+    }
+
+    fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
+        self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
+        self.in_use.push(self.workspaces.len() - 1);
+        Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
+    }
+
+    fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
+        if self.free.is_empty() {
+            return self.alloc_new(size);
+        }
+        let idx = self
+            .free
+            .iter()
+            .fold(None, |cur_ws_idx: Option<usize>, &idx| {
+                let ws_size = self.workspaces[idx].size();
+                if ws_size < size {
+                    return cur_ws_idx;
+                }
+                cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
+                    let cur_size = self.workspaces[cur_idx].size();
+                    Some(if ws_size <= cur_size { idx } else { cur_idx })
+                })
+            });
+        match idx {
+            Some(idx) => {
+                remove_item(&mut self.free, &idx).unwrap();
+                self.in_use.push(idx);
+                Ok(self.workspaces[idx].as_mut_ptr())
+            }
+            None => self.alloc_new(size),
+        }
+    }
+
+    fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
+        let mut ws_idx = None;
+        for i in 0..self.in_use.len() {
+            let idx = self.in_use[i];
+            if self.workspaces[idx].as_mut_ptr() == ptr {
+                self.in_use.remove(i);
+                ws_idx = Some(idx);
+                break;
+            }
+        }
+        let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?;
+        self.free.push(ws_idx);
+        Ok(())
+    }
+}
+
+thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
+
+const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
+
+#[no_mangle]
+pub extern "C" fn TVMBackendAllocWorkspace(
+    _device_type: c_int,
+    _device_id: c_int,
+    size: u64,
+    _dtype_code_hint: c_int,
+    _dtype_bits_hint: c_int,
+) -> *mut c_void {
+    let nbytes = if size == 0 {
+        WORKSPACE_PAGE_SIZE
+    } else {
+        size as usize
+    };
+    WORKSPACE_POOL.with(|pool_cell| {
+        pool_cell
+            .borrow_mut()
+            .alloc(nbytes as usize)
+            .unwrap_or(ptr::null_mut()) as *mut c_void
+    })
+}
+
+#[no_mangle]
+pub extern "C" fn TVMBackendFreeWorkspace(
+    _device_type: c_int,
+    _device_id: c_int,
+    ptr: *mut c_void,
+) -> c_int {
+    WORKSPACE_POOL.with(|pool_cell| {
+        (match pool_cell.borrow_mut().free(ptr as *mut u8) {
+            Ok(()) => 0,
+            Err(_) => -1,
+        }) as c_int
+    })
+}
diff --git a/rust/tvm-graph-rt/tests/.gitignore b/rust/tvm-graph-rt/tests/.gitignore
new file mode 100644 (file)
index 0000000..8110767
--- /dev/null
@@ -0,0 +1,3 @@
+*.json
+*.params
+*.o
diff --git a/rust/tvm-graph-rt/tests/build_model.py b/rust/tvm-graph-rt/tests/build_model.py
new file mode 100755 (executable)
index 0000000..ddfa03b
--- /dev/null
@@ -0,0 +1,53 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Builds a simple graph for testing."""
+
+from os import path as osp
+
+import numpy as np
+import tvm
+from tvm import te
+from tvm import relay
+from tvm.relay import testing
+
+CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+
+def _get_model(dshape):
+    data = relay.var('data', shape=dshape)
+    fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
+    fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
+    left, right = relay.split(fc, indices_or_sections=2, axis=1)
+    one = relay.const(1, dtype="float32")
+    return relay.Tuple([(left + one), (right - one), fc])
+
+
+def main():
+    dshape = (32, 16)
+    net = _get_model(dshape)
+    mod, params = testing.create_workload(net)
+    graph, lib, params = relay.build(
+        mod, 'llvm', params=params)
+
+    with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet:
+        f_resnet.write(graph)
+    with open(osp.join(CWD, 'graph.params'), 'wb') as f_params:
+        f_params.write(relay.save_param_dict(params))
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/tvm-graph-rt/tests/test_graph_serde.rs b/rust/tvm-graph-rt/tests/test_graph_serde.rs
new file mode 100644 (file)
index 0000000..6cea4ad
--- /dev/null
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate serde;
+extern crate serde_json;
+
+extern crate tvm_runtime;
+
+use std::{convert::TryFrom, fs, io::Read};
+
+use tvm_runtime::Graph;
+
+macro_rules! mf_dir {
+    ($p:literal) => {
+        concat!(env!("CARGO_MANIFEST_DIR"), $p)
+    };
+}
+
+static PARAMS_FIXTURE_PATH: &str = mf_dir!("/tests/graph.params");
+
+#[test]
+fn test_load_graph() {
+    let output = std::process::Command::new(mf_dir!("/tests/build_model.py"))
+        .env(
+            "PYTHONPATH",
+            concat!(
+                mf_dir!("/../../python"),
+                ":",
+                mf_dir!("/../../nnvm/python"),
+                ":",
+                mf_dir!("/../../topi/python")
+            ),
+        )
+        .output()
+        .expect("Failed to build test model");
+    assert!(
+        std::path::Path::new(PARAMS_FIXTURE_PATH).exists(),
+        "Could not build test graph fixture: STDOUT:\n\n{}\nSTDERR: {}\n\n",
+        String::from_utf8(output.stdout).unwrap(),
+        String::from_utf8(output.stderr).unwrap()
+    );
+    let mut params_bytes = Vec::new();
+    fs::File::open(PARAMS_FIXTURE_PATH)
+        .unwrap()
+        .read_to_end(&mut params_bytes)
+        .unwrap();
+    let _params = tvm_runtime::load_param_dict(&params_bytes);
+
+    let graph = Graph::try_from(
+        &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
+    )
+    .unwrap();
+
+    assert_eq!(graph.nodes[3].op, "tvm_op");
+    assert_eq!(
+        graph.nodes[3]
+            .attrs
+            .as_ref()
+            .unwrap()
+            .get("func_name")
+            .unwrap(),
+        "fused_nn_dense_nn_bias_add"
+    );
+    assert_eq!(graph.nodes[3].inputs[0].index, 0);
+    assert_eq!(graph.nodes[4].inputs[0].index, 0);
+    assert_eq!(graph.heads.len(), 3);
+}
diff --git a/rust/tvm-graph-rt/tests/test_nn/Cargo.toml b/rust/tvm-graph-rt/tests/test_nn/Cargo.toml
new file mode 100644 (file)
index 0000000..158f9e2
--- /dev/null
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "test-rt-nn"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray="0.12"
+serde = "1.0"
+serde_json = "1.0"
+tvm-graph-rt = { path = "../../" }
+
+[build-dependencies]
+ar = "0.6"
diff --git a/rust/tvm-graph-rt/tests/test_nn/build.rs b/rust/tvm-graph-rt/tests/test_nn/build.rs
new file mode 100644 (file)
index 0000000..8ae1131
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate ar;
+
+use std::{env, fs::File, path::Path, process::Command};
+
+use ar::Builder;
+
+fn main() {
+    let out_dir = env::var("OUT_DIR").unwrap();
+    let out_dir = Path::new(&out_dir).join("test_nn");
+
+    std::fs::create_dir_all(&out_dir).unwrap();
+
+    let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
+    let manifest_dir = Path::new(&manifest_dir);
+
+    let generator = manifest_dir.join("src").join("build_test_graph.py");
+
+    let graph_path = out_dir.join("graph.o");
+
+    let output = Command::new(&generator)
+        .arg(&out_dir)
+        .output()
+        .expect("Failed to execute command");
+
+    assert!(
+        graph_path.exists(),
+        "Could not build graph lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    let lib_file = out_dir.join("libtestnn.a");
+    let file = File::create(&lib_file).unwrap();
+    let mut builder = Builder::new(file);
+    builder.append_path(graph_path).unwrap();
+
+    let status = Command::new("ranlib")
+        .arg(&lib_file)
+        .status()
+        .expect("fdjlksafjdsa");
+
+    assert!(status.success());
+
+    println!("cargo:rustc-link-lib=static=testnn");
+    println!("cargo:rustc-link-search=native={}", out_dir.display());
+    println!("cargo:rerun-if-changed={}", generator.display());
+}
diff --git a/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py
new file mode 100755 (executable)
index 0000000..cb7c4f7
--- /dev/null
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Builds a simple graph for testing."""
+
+from os import path as osp
+import sys
+
+import numpy as np
+import tvm
+from tvm import te
+from tvm import relay
+from tvm.relay import testing
+
+
+def _get_model(dshape):
+    data = relay.var('data', shape=dshape)
+    fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2)
+    fc = relay.nn.bias_add(fc, relay.var("dense_bias"))
+    left, right = relay.split(fc, indices_or_sections=2, axis=1)
+    one = relay.const(1, dtype="float32")
+    return relay.Tuple([(left + one), (right - one), fc])
+
+def main():
+    dshape = (4, 8)
+    net = _get_model(dshape)
+    mod, params = testing.create_workload(net)
+    graph, lib, params = relay.build(
+        mod, 'llvm --system-lib', params=params)
+
+    out_dir = sys.argv[1]
+    lib.save(osp.join(sys.argv[1], 'graph.o'))
+    with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
+        f_resnet.write(graph)
+
+    with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
+        f_params.write(relay.save_param_dict(params))
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/tvm-graph-rt/tests/test_nn/src/main.rs b/rust/tvm-graph-rt/tests/test_nn/src/main.rs
new file mode 100644 (file)
index 0000000..505c544
--- /dev/null
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#[macro_use]
+extern crate ndarray;
+extern crate serde;
+extern crate serde_json;
+
+extern crate tvm_runtime;
+use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
+
+use ndarray::Array;
+use tvm_runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
+
+const BATCH_SIZE: usize = 4;
+const IN_DIM: usize = 8;
+
+macro_rules! check_sum {
+    ($e:expr, $a:ident, $b:ident) => {
+        let a = Array::try_from($e.get_input(stringify!($a)).unwrap().to_owned()).unwrap();
+        check_sum!(a, $b);
+    };
+    ($e:expr, $a:expr, $b:ident) => {
+        let a = Array::try_from($e.get_output($a).unwrap().to_owned()).unwrap();
+        check_sum!(a, $b);
+    };
+    ($a:ident, $b:ident) => {
+        let a_sum: f32 = $a.scalar_sum();
+        let b_sum: f32 = $b.scalar_sum();
+        assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
+    };
+}
+
+fn main() {
+    let syslib = SystemLibModule::default();
+
+    let mut params_bytes = Vec::new();
+    fs::File::open(concat!(env!("OUT_DIR"), "/test_nn/graph.params"))
+        .unwrap()
+        .read_to_end(&mut params_bytes)
+        .unwrap();
+    let params = tvm_runtime::load_param_dict(&params_bytes)
+        .unwrap()
+        .into_iter()
+        .map(|(k, v)| (k, v.to_owned()))
+        .collect::<HashMap<String, Tensor<'static>>>();
+
+    let graph = Graph::try_from(
+        &fs::read_to_string(concat!(env!("OUT_DIR"), "/test_nn/graph.json")).unwrap(),
+    )
+    .unwrap();
+    let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
+
+    let x = Array::from_shape_vec(
+        (BATCH_SIZE, IN_DIM),
+        (0..BATCH_SIZE * IN_DIM)
+            .map(|x| x as f32)
+            .collect::<Vec<f32>>(),
+    )
+    .unwrap();
+
+    let p0 = params.get("p0").unwrap().to_owned();
+    let p1 = params.get("p1").unwrap().to_owned();
+    println!("p0: {:?}", p0.shape());
+    println!("p1: {:?}", p1.shape());
+    let w = Array::try_from(p0)
+        .unwrap()
+        .into_shape((BATCH_SIZE * 4, IN_DIM))
+        .unwrap();
+    let b = Array::try_from(p1).unwrap();
+    let dense = x.dot(&w.t()) + &b;
+    let left = dense.slice(s![.., 0..IN_DIM]);
+    let right = dense.slice(s![.., IN_DIM..]);
+    let expected_o0 = &left + 1f32;
+    let expected_o1 = &right - 1f32;
+
+    exec.load_params(params);
+    exec.set_input("data", (&x).into());
+
+    check_sum!(exec, data, x);
+    check_sum!(exec, p0, w);
+    check_sum!(exec, p1, b);
+
+    exec.run();
+
+    check_sum!(exec, 0, expected_o0);
+    check_sum!(exec, 1, expected_o1);
+    check_sum!(exec, 2, dense);
+}
diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/Cargo.toml b/rust/tvm-graph-rt/tests/test_tvm_basic/Cargo.toml
new file mode 100644 (file)
index 0000000..c1e87ef
--- /dev/null
@@ -0,0 +1,29 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "test-rt-tvm-basic"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray="0.12"
+tvm-graph-rt = { path = "../../" }
+
+[build-dependencies]
+ar = "0.6"
diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/build.rs b/rust/tvm-graph-rt/tests/test_tvm_basic/build.rs
new file mode 100644 (file)
index 0000000..ade9e02
--- /dev/null
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate ar;
+
+use std::{path::PathBuf, process::Command};
+
+use ar::Builder;
+use std::fs::File;
+
+fn main() {
+    let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
+    out_dir.push("lib");
+
+    if !out_dir.is_dir() {
+        std::fs::create_dir(&out_dir).unwrap();
+    }
+
+    let obj_file = out_dir.join("test.o");
+    let lib_file = out_dir.join("libtest_basic.a");
+
+    let output = Command::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/src/build_test_lib.py"
+    ))
+    .arg(&out_dir)
+    .output()
+    .expect("Failed to execute command");
+    assert!(
+        obj_file.exists(),
+        "Could not build tvm lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    let mut builder = Builder::new(File::create(&lib_file).unwrap());
+    builder.append_path(&obj_file).unwrap();
+    drop(builder);
+
+    let status = Command::new("ranlib")
+        .arg(&lib_file)
+        .status()
+        .expect("fdjlksafjdsa");
+
+    assert!(status.success());
+
+    println!("cargo:rustc-link-lib=static=test_basic");
+    println!("cargo:rustc-link-search=native={}", out_dir.display());
+}
diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py
new file mode 100755 (executable)
index 0000000..bf7e60a
--- /dev/null
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Prepares a simple TVM library for testing."""
+
+from os import path as osp
+import sys
+
+import tvm
+from tvm import te
+
+def main():
+    n = te.var('n')
+    A = te.placeholder((n,), name='A')
+    B = te.placeholder((n,), name='B')
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    s = tvm.te.create_schedule(C.op)
+    s[C].parallel(s[C].op.axis[0])
+    print(tvm.lower(s, [A, B, C], simple_mode=True))
+    tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/src/main.rs b/rust/tvm-graph-rt/tests/test_tvm_basic/src/main.rs
new file mode 100644 (file)
index 0000000..653cb43
--- /dev/null
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate ndarray;
+#[macro_use]
+extern crate tvm_runtime;
+
+use ndarray::Array;
+use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
+
+mod tvm_mod {
+    import_module!("lib/test.o");
+}
+
+fn main() {
+    // try static
+    let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
+    let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
+    let mut c = Array::from_vec(vec![0f32; 4]);
+    let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
+    let mut a_dl: DLTensor = (&mut a).into();
+    let mut b_dl: DLTensor = (&mut b).into();
+    let mut c_dl: DLTensor = (&mut c).into();
+    call_packed!(tvm_mod::default_function, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
+    assert!(c.all_close(&e, 1e-8f32));
+
+    // try runtime
+    let syslib = SystemLibModule::default();
+    let add = syslib
+        .get_function("default_function")
+        .expect("main function not found");
+    call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
+    assert!(c.all_close(&e, 1e-8f32));
+}
diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/Cargo.toml b/rust/tvm-graph-rt/tests/test_tvm_dso/Cargo.toml
new file mode 100644 (file)
index 0000000..1909268
--- /dev/null
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "test-rt-tvm-dso"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray="0.12"
+tvm-graph-rt = { path = "../../" }
diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/build.rs b/rust/tvm-graph-rt/tests/test_tvm_dso/build.rs
new file mode 100644 (file)
index 0000000..f1d9822
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{env, path::Path, process::Command};
+
+fn main() {
+    let out_dir = env::var("OUT_DIR").unwrap();
+
+    let output = Command::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/src/build_test_lib.py"
+    ))
+    .arg(&out_dir)
+    .output()
+    .expect("Failed to execute command");
+    assert!(
+        Path::new(&format!("{}/test.so", out_dir)).exists(),
+        "Could not build tvm lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+}
diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py
new file mode 100755 (executable)
index 0000000..cb7353f
--- /dev/null
@@ -0,0 +1,41 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Prepares a simple TVM library for testing."""
+
+from os import path as osp
+import sys
+
+import tvm
+from tvm import te
+from tvm.contrib import cc
+
+def main():
+    n = te.var('n')
+    A = te.placeholder((n,), name='A')
+    B = te.placeholder((n,), name='B')
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    s = tvm.te.create_schedule(C.op)
+    s[C].parallel(s[C].op.axis[0])
+    print(tvm.lower(s, [A, B, C], simple_mode=True))
+    obj_file = osp.join(sys.argv[1], 'test.o')
+    tvm.build(s, [A, B, C], 'llvm').save(obj_file)
+    cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file])
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/src/main.rs b/rust/tvm-graph-rt/tests/test_tvm_dso/src/main.rs
new file mode 100644 (file)
index 0000000..953676c
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate ndarray;
+#[macro_use]
+extern crate tvm_runtime;
+
+use ndarray::Array;
+use tvm_runtime::{DLTensor, DsoModule, Module};
+
+fn main() {
+    tvm_runtime::TVMGetLastError();
+    let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap();
+    let add = module
+        .get_function("__tvm_main__")
+        .expect("main function not found");
+    let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
+    let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
+    let mut c = Array::from_vec(vec![0f32; 4]);
+    let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
+    let mut a_dl: DLTensor = (&mut a).into();
+    let mut b_dl: DLTensor = (&mut b).into();
+    let mut c_dl: DLTensor = (&mut c).into();
+    call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
+    assert!(c.all_close(&e, 1e-8f32));
+}
diff --git a/rust/tvm-graph-rt/tests/test_wasm32/.cargo/config b/rust/tvm-graph-rt/tests/test_wasm32/.cargo/config
new file mode 100644 (file)
index 0000000..6b77899
--- /dev/null
@@ -0,0 +1,2 @@
+[build]
+target = "wasm32-wasi"
diff --git a/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml b/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml
new file mode 100644 (file)
index 0000000..aed467f
--- /dev/null
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "test-rt-wasm32"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+edition = "2018"
+
+[dependencies]
+ndarray="0.12"
+tvm-graph-rt = { path = "../../" }
+
+[build-dependencies]
+anyhow = "^1.0"
diff --git a/rust/tvm-graph-rt/tests/test_wasm32/build.rs b/rust/tvm-graph-rt/tests/test_wasm32/build.rs
new file mode 100644 (file)
index 0000000..5c816c3
--- /dev/null
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{path::PathBuf, process::Command};
+
+use anyhow::{Context, Result};
+
+fn main() -> Result<()> {
+    let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
+    out_dir.push("lib");
+
+    if !out_dir.is_dir() {
+        std::fs::create_dir(&out_dir).context("failed to create directory for WASM outputs")?;
+    }
+
+    let obj_file = out_dir.join("test.o");
+    let lib_file = out_dir.join("libtest_wasm32.a");
+
+    let output = Command::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/src/build_test_lib.py"
+    ))
+    .arg(&out_dir)
+    .output()
+    .context("failed to execute Python script for generating TVM library")?;
+
+    assert!(
+        obj_file.exists(),
+        "Could not build tvm lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8");
+
+    let output = Command::new(ar)
+        .arg("rcs")
+        .arg(&lib_file)
+        .arg(&obj_file)
+        .output()
+        .context("failed to run LLVM_AR command")?;
+
+    assert!(
+        lib_file.exists(),
+        "Could not create archive: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    println!("cargo:rustc-link-lib=static=test_wasm32");
+    println!("cargo:rustc-link-search=native={}", out_dir.display());
+    Ok(())
+}
diff --git a/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py
new file mode 100755 (executable)
index 0000000..6016c60
--- /dev/null
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Prepares a simple TVM library for testing."""
+
+from os import path as osp
+import sys
+
+import tvm
+from tvm import te
+
+def main():
+    n = te.var('n')
+    A = te.placeholder((n,), name='A')
+    B = te.placeholder((n,), name='B')
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    s = tvm.te.create_schedule(C.op)
+    s[C].parallel(s[C].op.axis[0])
+    print(tvm.lower(s, [A, B, C], simple_mode=True))
+    tvm.build(s, [A, B, C], 'llvm -target=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o'))
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs b/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs
new file mode 100644 (file)
index 0000000..a46cfa9
--- /dev/null
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern "C" {
+    static __tvm_module_ctx: i32;
+}
+
+#[no_mangle]
+unsafe fn __get_tvm_module_ctx() -> i32 {
+    // Refer a symbol in the libtest_wasm32.a to make sure that the link of the
+    // library is not optimized out.
+    __tvm_module_ctx
+}
+
+extern crate ndarray;
+#[macro_use]
+extern crate tvm_runtime;
+
+use ndarray::Array;
+use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
+
+fn main() {
+    // try static
+    let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
+    let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
+    let mut c = Array::from_vec(vec![0f32; 4]);
+    let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
+    let mut a_dl: DLTensor = (&mut a).into();
+    let mut b_dl: DLTensor = (&mut b).into();
+    let mut c_dl: DLTensor = (&mut c).into();
+
+    let syslib = SystemLibModule::default();
+    let add = syslib
+        .get_function("default_function")
+        .expect("main function not found");
+    call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
+    assert!(c.all_close(&e, 1e-8f32));
+}
index 85e16be..01d2934 100644 (file)
@@ -54,6 +54,7 @@ fn main() {
         .layout_tests(false)
         .derive_partialeq(true)
         .derive_eq(true)
+        .derive_default(true)
         .generate()
         .expect("unable to generate bindings")
         .write_to_file(PathBuf::from("src/c_runtime_api.rs"))
index 1627e9e..5d09d86 100644 (file)
@@ -48,6 +48,7 @@ macro_rules! impl_dltensor_from_ndarray {
                     shape: arr.shape().as_ptr() as *const i64 as *mut i64,
                     strides: arr.strides().as_ptr() as *const i64 as *mut i64,
                     byte_offset: 0,
+                    ..Default::default()
                 }
             }
         }
index ccdee3f..c98d374 100644 (file)
@@ -39,7 +39,7 @@ pub struct DataType {
 }
 
 impl DataType {
-    pub fn new(code: u8, bits: u8, lanes: u16) -> DataType {
+    pub const fn new(code: u8, bits: u8, lanes: u16) -> DataType {
         DataType { code, bits, lanes }
     }
 
@@ -73,6 +73,18 @@ impl DataType {
     pub fn lanes(&self) -> usize {
         self.lanes as usize
     }
+
+    pub const fn int(bits: u8, lanes: u16) -> DataType {
+        DataType::new(DL_INT_CODE, bits, lanes)
+    }
+
+    pub const fn float(bits: u8, lanes: u16) -> DataType {
+        DataType::new(DL_FLOAT_CODE, bits, lanes)
+    }
+
+    pub const fn uint(bits: u8, lanes: u16) -> DataType {
+        DataType::new(DL_UINT_CODE, bits, lanes)
+    }
 }
 
 impl<'a> From<&'a DataType> for DLDataType {
index da3a456..e3e74ad 100644 (file)
@@ -107,6 +107,7 @@ ALLOW_SPECIFIC_FILE = {
     "Jenkinsfile",
     # cargo config
     "rust/runtime/tests/test_wasm32/.cargo/config",
+    "rust/tvm-graph-rt/tests/test_wasm32/.cargo/config",
     "apps/sgx/.cargo/config",
     # html for demo purposes
     "web/apps/browser/rpc_server.html",