"common",
"runtime",
"runtime/tests/test_tvm_basic",
+ "runtime/tests/test_tvm_dso",
"runtime/tests/test_nnvm",
"frontend",
"frontend/tests/basics",
use std::path::PathBuf;
fn main() {
+ let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
+ let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+ .canonicalize()
+ .unwrap();
+ tvm_home
+ .parent()
+ .unwrap()
+ .parent()
+ .unwrap()
+ .to_str()
+ .unwrap()
+ .to_string()
+ });
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
- println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
+ println!("cargo:rustc-link-search={}/build", tvm_home);
}
// @see rust-bindgen#550 for `blacklist_type`
bindgen::Builder::default()
- .header(format!(
- "{}/include/tvm/runtime/c_runtime_api.h",
- env!("TVM_HOME")
- ))
- .header(format!(
- "{}/include/tvm/runtime/c_backend_api.h",
- env!("TVM_HOME")
- ))
- .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
+ .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
+ .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
+ .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
.blacklist_type("max_align_t")
.layout_tests(false)
.derive_partialeq(true)
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
+
+[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
+libloading = "0.5"
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+ cell::RefCell,
+ collections::HashMap,
+ ffi::CStr,
+ os::raw::{c_char, c_int, c_void},
+ pin::Pin,
+};
+
+use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
+
+use crate::{
+ threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch},
+ workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace},
+ TVMAPISetLastError,
+};
+
+use super::Module;
+
+const TVM_MAIN: &'static [u8] = b"__tvm_main__";
+const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx";
+
+/// A module backed by a Dynamic Shared Object (dylib).
+pub struct DsoModule<'a> {
+ lib: libloading::Library,
+ packed_funcs: RefCell<HashMap<String, &'a (dyn PackedFunc)>>,
+ _pin: std::marker::PhantomPinned,
+}
+
+macro_rules! init_context_func {
+ ($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => {
+ unsafe {
+ $(
+ let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes());
+ if let Ok(fn_ptr) = fn_ptr {
+ **fn_ptr = $fn;
+ }
+ )+
+ }
+ };
+}
+
+impl<'a> DsoModule<'a> {
+ pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
+ let lib = libloading::Library::new(filename)?;
+
+ init_context_func!(
+ lib,
+ (TVMAPISetLastError, extern "C" fn(*const i8)),
+ (
+ TVMBackendAllocWorkspace,
+ extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
+ ),
+ (
+ TVMBackendFreeWorkspace,
+ extern "C" fn(c_int, c_int, *mut c_void) -> c_int
+ ),
+ (
+ TVMBackendParallelLaunch,
+ extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int
+ ),
+ (
+ TVMBackendParallelBarrier,
+ extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv)
+ ),
+ );
+
+ // Pin the module in memory so that `ctx` pointer (below) is stable.
+ let dso_mod = Box::pin(Self {
+ lib,
+ packed_funcs: RefCell::new(HashMap::new()),
+ _pin: std::marker::PhantomPinned,
+ });
+
+ unsafe {
+ if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) {
+ **ctx = &dso_mod as *const _ as *const c_void;
+ }
+ }
+
+ Ok(dso_mod)
+ }
+}
+
+impl<'a> Module for DsoModule<'a> {
+ fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
+ let name = name.as_ref();
+ let func = match unsafe {
+ self.lib
+ .get::<BackendPackedCFunc>(if name.as_bytes() == TVM_MAIN {
+ // If __tvm_main__ is present, it contains the name of the
+ // actual main function.
+ match self
+ .lib
+ .get::<*const c_char>(TVM_MAIN)
+ .map(|p| CStr::from_ptr(*p))
+ {
+ Ok(m) => m.to_bytes(),
+ _ => return None,
+ }
+ } else {
+ name.as_bytes()
+ })
+ } {
+ Ok(func) => unsafe { func.into_raw() },
+ Err(_) => return None,
+ };
+
+ self.packed_funcs.borrow_mut().insert(
+ name.to_string(),
+ &*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)),
+ );
+
+ self.packed_funcs.borrow().get(name).map(|f| *f)
+ }
+}
+
+impl<'a> Drop for DsoModule<'a> {
+ fn drop(&mut self) {
+ self.packed_funcs
+ .replace(HashMap::new())
+ .into_iter()
+ .map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) })
+ .for_each(std::mem::drop);
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
+mod dso;
+mod syslib;
+
+use tvm_common::{
+ ffi::BackendPackedCFunc,
+ packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
+};
+
+#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
+pub use dso::DsoModule;
+pub use syslib::SystemLibModule;
+
+pub trait Module {
+ fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
+}
+
+// @see `WrapPackedFunc` in `llvm_module.cc`.
+fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> {
+ box move |args: &[TVMArgValue]| {
+ let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
+ .into_iter()
+ .map(|arg| {
+ let (val, code) = arg.to_tvm_value();
+ (val, code as i32)
+ })
+ .unzip();
+ let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
+ if exit_code == 0 {
+ Ok(TVMRetValue::default())
+ } else {
+ Err(tvm_common::errors::FuncCallError::get_with_context(
+ func_name.clone(),
+ ))
+ }
+ }
+}
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
};
-use tvm_common::{
- ffi::BackendPackedCFunc,
- packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
-};
+use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
-pub trait Module {
- fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
-}
+use super::Module;
pub struct SystemLibModule;
}
}
-// @see `WrapPackedFunc` in `llvm_module.cc`.
-pub(super) fn wrap_backend_packed_func(
- func_name: String,
- func: BackendPackedCFunc,
-) -> Box<dyn PackedFunc> {
- box move |args: &[TVMArgValue]| {
- let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
- .into_iter()
- .map(|arg| {
- let (val, code) = arg.to_tvm_value();
- (val, code as i32)
- })
- .unzip();
- let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
- if exit_code == 0 {
- Ok(TVMRetValue::default())
- } else {
- Err(tvm_common::errors::FuncCallError::get_with_context(
- func_name.clone(),
- ))
- }
- }
-}
-
#[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char,
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
name.to_string(),
- &*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
+ &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
);
return 0;
}
#[cfg(target_env = "sgx")]
use super::{TVMArgValue, TVMRetValue};
-type FTVMParallelLambda =
+pub(crate) type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
/// Holds a parallel job request made by a TVM library function.
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "test-tvm-dso"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray="0.12"
+tvm-runtime = { path = "../../" }
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{env, path::Path, process::Command};
+
+fn main() {
+ let out_dir = env::var("OUT_DIR").unwrap();
+
+ let output = Command::new(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/src/build_test_lib.py"
+ ))
+ .arg(&out_dir)
+ .output()
+ .expect("Failed to execute command");
+ assert!(
+ Path::new(&format!("{}/test.so", out_dir)).exists(),
+ "Could not build tvm lib: {}",
+ String::from_utf8(output.stderr)
+ .unwrap()
+ .trim()
+ .split("\n")
+ .last()
+ .unwrap_or("")
+ );
+}
--- /dev/null
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Prepares a simple TVM library for testing."""
+
+from os import path as osp
+import sys
+
+import tvm
+from tvm.contrib import cc
+
+def main():
+ n = tvm.var('n')
+ A = tvm.placeholder((n,), name='A')
+ B = tvm.placeholder((n,), name='B')
+ C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+ s = tvm.create_schedule(C.op)
+ s[C].parallel(s[C].op.axis[0])
+ print(tvm.lower(s, [A, B, C], simple_mode=True))
+ obj_file = osp.join(sys.argv[1], 'test.o')
+ tvm.build(s, [A, B, C], 'llvm').save(obj_file)
+ cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file])
+
+if __name__ == '__main__':
+ main()
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate ndarray;
+#[macro_use]
+extern crate tvm_runtime;
+
+use ndarray::Array;
+use tvm_runtime::{DLTensor, DsoModule, Module};
+
+fn main() {
+ tvm_runtime::TVMGetLastError();
+ let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap();
+ let add = module
+ .get_function("__tvm_main__")
+ .expect("main function not found");
+ let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
+ let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
+ let mut c = Array::from_vec(vec![0f32; 4]);
+ let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
+ let mut a_dl: DLTensor = (&mut a).into();
+ let mut b_dl: DLTensor = (&mut b).into();
+ let mut c_dl: DLTensor = (&mut c).into();
+ call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
+ assert!(c.all_close(&e, 1e-8f32));
+}
cargo run
cd -
+cd tests/test_tvm_dso
+cargo run
+cd -
+
# run NNVM graph test
cd tests/test_nnvm
cargo run