From a479432d901f3ad0b4f2e8622ae65960999c1f33 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Tue, 28 May 2019 15:20:18 -0700 Subject: [PATCH] [RUST] Rust DSO module (#2976) --- rust/Cargo.toml | 1 + rust/common/build.rs | 27 ++-- rust/runtime/Cargo.toml | 3 + rust/runtime/src/module/dso.rs | 144 +++++++++++++++++++++ rust/runtime/src/module/mod.rs | 56 ++++++++ rust/runtime/src/{module.rs => module/syslib.rs} | 35 +---- rust/runtime/src/threading.rs | 2 +- rust/runtime/tests/test_tvm_dso/Cargo.toml | 26 ++++ rust/runtime/tests/test_tvm_dso/build.rs | 42 ++++++ .../tests/test_tvm_dso/src/build_test_lib.py | 40 ++++++ rust/runtime/tests/test_tvm_dso/src/main.rs | 42 ++++++ tests/scripts/task_rust.sh | 4 + 12 files changed, 379 insertions(+), 43 deletions(-) create mode 100644 rust/runtime/src/module/dso.rs create mode 100644 rust/runtime/src/module/mod.rs rename rust/runtime/src/{module.rs => module/syslib.rs} (62%) create mode 100644 rust/runtime/tests/test_tvm_dso/Cargo.toml create mode 100644 rust/runtime/tests/test_tvm_dso/build.rs create mode 100755 rust/runtime/tests/test_tvm_dso/src/build_test_lib.py create mode 100644 rust/runtime/tests/test_tvm_dso/src/main.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 25466e0..6e89bae 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -20,6 +20,7 @@ members = [ "common", "runtime", "runtime/tests/test_tvm_basic", + "runtime/tests/test_tvm_dso", "runtime/tests/test_nnvm", "frontend", "frontend/tests/basics", diff --git a/rust/common/build.rs b/rust/common/build.rs index 5dac99e..919e0ad 100644 --- a/rust/common/build.rs +++ b/rust/common/build.rs @@ -22,23 +22,30 @@ extern crate bindgen; use std::path::PathBuf; fn main() { + let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ + let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .canonicalize() + .unwrap(); + tvm_home + .parent() + .unwrap() + .parent() + .unwrap() + .to_str() + .unwrap() + .to_string() + }); if cfg!(feature = "bindings") { println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rustc-link-lib=dylib=tvm_runtime"); - println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); + println!("cargo:rustc-link-search={}/build", tvm_home); } // @see rust-bindgen#550 for `blacklist_type` bindgen::Builder::default() - .header(format!( - "{}/include/tvm/runtime/c_runtime_api.h", - env!("TVM_HOME") - )) - .header(format!( - "{}/include/tvm/runtime/c_backend_api.h", - env!("TVM_HOME") - )) - .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) + .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) + .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) .blacklist_type("max_align_t") .layout_tests(false) .derive_partialeq(true) diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index 8e70565..5809af0 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -45,3 +45,6 @@ tvm-common = { version = "0.1.0", path = "../common/" } [target.'cfg(not(target_env = "sgx"))'.dependencies] num_cpus = "1.8.0" + +[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies] +libloading = "0.5" diff --git a/rust/runtime/src/module/dso.rs b/rust/runtime/src/module/dso.rs new file mode 100644 index 0000000..3442fad --- /dev/null +++ b/rust/runtime/src/module/dso.rs @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + cell::RefCell, + collections::HashMap, + ffi::CStr, + os::raw::{c_char, c_int, c_void}, + pin::Pin, +}; + +use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc}; + +use crate::{ + threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch}, + workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace}, + TVMAPISetLastError, +}; + +use super::Module; + +const TVM_MAIN: &'static [u8] = b"__tvm_main__"; +const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx"; + +/// A module backed by a Dynamic Shared Object (dylib). +pub struct DsoModule<'a> { + lib: libloading::Library, + packed_funcs: RefCell>, + _pin: std::marker::PhantomPinned, +} + +macro_rules! init_context_func { + ($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => { + unsafe { + $( + let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes()); + if let Ok(fn_ptr) = fn_ptr { + **fn_ptr = $fn; + } + )+ + } + }; +} + +impl<'a> DsoModule<'a> { + pub fn new>(filename: P) -> Result>, failure::Error> { + let lib = libloading::Library::new(filename)?; + + init_context_func!( + lib, + (TVMAPISetLastError, extern "C" fn(*const i8)), + ( + TVMBackendAllocWorkspace, + extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void + ), + ( + TVMBackendFreeWorkspace, + extern "C" fn(c_int, c_int, *mut c_void) -> c_int + ), + ( + TVMBackendParallelLaunch, + extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int + ), + ( + TVMBackendParallelBarrier, + extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv) + ), + ); + + // Pin the module in memory so that `ctx` pointer (below) is stable. + let dso_mod = Box::pin(Self { + lib, + packed_funcs: RefCell::new(HashMap::new()), + _pin: std::marker::PhantomPinned, + }); + + unsafe { + if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) { + **ctx = &dso_mod as *const _ as *const c_void; + } + } + + Ok(dso_mod) + } +} + +impl<'a> Module for DsoModule<'a> { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)> { + let name = name.as_ref(); + let func = match unsafe { + self.lib + .get::(if name.as_bytes() == TVM_MAIN { + // If __tvm_main__ is present, it contains the name of the + // actual main function. + match self + .lib + .get::<*const c_char>(TVM_MAIN) + .map(|p| CStr::from_ptr(*p)) + { + Ok(m) => m.to_bytes(), + _ => return None, + } + } else { + name.as_bytes() + }) + } { + Ok(func) => unsafe { func.into_raw() }, + Err(_) => return None, + }; + + self.packed_funcs.borrow_mut().insert( + name.to_string(), + &*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)), + ); + + self.packed_funcs.borrow().get(name).map(|f| *f) + } +} + +impl<'a> Drop for DsoModule<'a> { + fn drop(&mut self) { + self.packed_funcs + .replace(HashMap::new()) + .into_iter() + .map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) }) + .for_each(std::mem::drop); + } +} diff --git a/rust/runtime/src/module/mod.rs b/rust/runtime/src/module/mod.rs new file mode 100644 index 0000000..2c7c107 --- /dev/null +++ b/rust/runtime/src/module/mod.rs @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))] +mod dso; +mod syslib; + +use tvm_common::{ + ffi::BackendPackedCFunc, + packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, +}; + +#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))] +pub use dso::DsoModule; +pub use syslib::SystemLibModule; + +pub trait Module { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; +} + +// @see `WrapPackedFunc` in `llvm_module.cc`. +fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box { + box move |args: &[TVMArgValue]| { + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); + if exit_code == 0 { + Ok(TVMRetValue::default()) + } else { + Err(tvm_common::errors::FuncCallError::get_with_context( + func_name.clone(), + )) + } + } +} diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module/syslib.rs similarity index 62% rename from rust/runtime/src/module.rs rename to rust/runtime/src/module/syslib.rs index 865338f..227b8c7 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module/syslib.rs @@ -21,14 +21,9 @@ use std::{ collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, }; -use tvm_common::{ - ffi::BackendPackedCFunc, - packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, -}; +use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc}; -pub trait Module { - fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; -} +use super::Module; pub struct SystemLibModule; @@ -53,30 +48,6 @@ impl Default for SystemLibModule { } } -// @see `WrapPackedFunc` in `llvm_module.cc`. -pub(super) fn wrap_backend_packed_func( - func_name: String, - func: BackendPackedCFunc, -) -> Box { - box move |args: &[TVMArgValue]| { - let (values, type_codes): (Vec, Vec) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); - if exit_code == 0 { - Ok(TVMRetValue::default()) - } else { - Err(tvm_common::errors::FuncCallError::get_with_context( - func_name.clone(), - )) - } - } -} - #[no_mangle] pub extern "C" fn TVMBackendRegisterSystemLibSymbol( cname: *const c_char, @@ -85,7 +56,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol( let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert( name.to_string(), - &*Box::leak(wrap_backend_packed_func(name.to_string(), func)), + &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)), ); return 0; } diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index 9614384..eb2f418 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -42,7 +42,7 @@ use tvm_common::ffi::TVMParallelGroupEnv; #[cfg(target_env = "sgx")] use super::{TVMArgValue, TVMRetValue}; -type FTVMParallelLambda = +pub(crate) type FTVMParallelLambda = extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; /// Holds a parallel job request made by a TVM library function. diff --git a/rust/runtime/tests/test_tvm_dso/Cargo.toml b/rust/runtime/tests/test_tvm_dso/Cargo.toml new file mode 100644 index 0000000..afe7f26 --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-tvm-dso" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +tvm-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_tvm_dso/build.rs b/rust/runtime/tests/test_tvm_dso/build.rs new file mode 100644 index 0000000..f1d9822 --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/build.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{env, path::Path, process::Command}; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/test.so", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); +} diff --git a/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py b/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py new file mode 100755 index 0000000..63b43a5 --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm +from tvm.contrib import cc + +def main(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + obj_file = osp.join(sys.argv[1], 'test.o') + tvm.build(s, [A, B, C], 'llvm').save(obj_file) + cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file]) + +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_tvm_dso/src/main.rs b/rust/runtime/tests/test_tvm_dso/src/main.rs new file mode 100644 index 0000000..953676c --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/src/main.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, DsoModule, Module}; + +fn main() { + tvm_runtime::TVMGetLastError(); + let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap(); + let add = module + .get_function("__tvm_main__") + .expect("main function not found"); + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index 1728fec..cdf777c 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -48,6 +48,10 @@ cd tests/test_tvm_basic cargo run cd - +cd tests/test_tvm_dso +cargo run +cd - + # run NNVM graph test cd tests/test_nnvm cargo run -- 2.7.4