[RUST] Rust DSO module (#2976)
authorNick Hynes <nhynes@berkeley.edu>
Tue, 28 May 2019 22:20:18 +0000 (15:20 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 28 May 2019 22:20:18 +0000 (15:20 -0700)
12 files changed:
rust/Cargo.toml
rust/common/build.rs
rust/runtime/Cargo.toml
rust/runtime/src/module/dso.rs [new file with mode: 0644]
rust/runtime/src/module/mod.rs [new file with mode: 0644]
rust/runtime/src/module/syslib.rs [moved from rust/runtime/src/module.rs with 62% similarity]
rust/runtime/src/threading.rs
rust/runtime/tests/test_tvm_dso/Cargo.toml [new file with mode: 0644]
rust/runtime/tests/test_tvm_dso/build.rs [new file with mode: 0644]
rust/runtime/tests/test_tvm_dso/src/build_test_lib.py [new file with mode: 0755]
rust/runtime/tests/test_tvm_dso/src/main.rs [new file with mode: 0644]
tests/scripts/task_rust.sh

index 25466e0..6e89bae 100644 (file)
@@ -20,6 +20,7 @@ members = [
        "common",
        "runtime",
        "runtime/tests/test_tvm_basic",
+       "runtime/tests/test_tvm_dso",
        "runtime/tests/test_nnvm",
        "frontend",
        "frontend/tests/basics",
index 5dac99e..919e0ad 100644 (file)
@@ -22,23 +22,30 @@ extern crate bindgen;
 use std::path::PathBuf;
 
 fn main() {
+    let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
+        let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+            .canonicalize()
+            .unwrap();
+        tvm_home
+            .parent()
+            .unwrap()
+            .parent()
+            .unwrap()
+            .to_str()
+            .unwrap()
+            .to_string()
+    });
     if cfg!(feature = "bindings") {
         println!("cargo:rerun-if-env-changed=TVM_HOME");
         println!("cargo:rustc-link-lib=dylib=tvm_runtime");
-        println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
+        println!("cargo:rustc-link-search={}/build", tvm_home);
     }
 
     // @see rust-bindgen#550 for `blacklist_type`
     bindgen::Builder::default()
-        .header(format!(
-            "{}/include/tvm/runtime/c_runtime_api.h",
-            env!("TVM_HOME")
-        ))
-        .header(format!(
-            "{}/include/tvm/runtime/c_backend_api.h",
-            env!("TVM_HOME")
-        ))
-        .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
+        .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
+        .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
+        .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
         .blacklist_type("max_align_t")
         .layout_tests(false)
         .derive_partialeq(true)
index 8e70565..5809af0 100644 (file)
@@ -45,3 +45,6 @@ tvm-common = { version = "0.1.0", path = "../common/" }
 
 [target.'cfg(not(target_env = "sgx"))'.dependencies]
 num_cpus = "1.8.0"
+
+[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
+libloading = "0.5"
diff --git a/rust/runtime/src/module/dso.rs b/rust/runtime/src/module/dso.rs
new file mode 100644 (file)
index 0000000..3442fad
--- /dev/null
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    cell::RefCell,
+    collections::HashMap,
+    ffi::CStr,
+    os::raw::{c_char, c_int, c_void},
+    pin::Pin,
+};
+
+use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
+
+use crate::{
+    threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch},
+    workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace},
+    TVMAPISetLastError,
+};
+
+use super::Module;
+
+const TVM_MAIN: &'static [u8] = b"__tvm_main__";
+const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx";
+
+/// A module backed by a Dynamic Shared Object (dylib).
+pub struct DsoModule<'a> {
+    lib: libloading::Library,
+    packed_funcs: RefCell<HashMap<String, &'a (dyn PackedFunc)>>,
+    _pin: std::marker::PhantomPinned,
+}
+
+macro_rules! init_context_func {
+    ($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => {
+        unsafe {
+            $(
+                let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes());
+                if let Ok(fn_ptr) = fn_ptr {
+                    **fn_ptr = $fn;
+                }
+            )+
+        }
+    };
+}
+
+impl<'a> DsoModule<'a> {
+    pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
+        let lib = libloading::Library::new(filename)?;
+
+        init_context_func!(
+            lib,
+            (TVMAPISetLastError, extern "C" fn(*const i8)),
+            (
+                TVMBackendAllocWorkspace,
+                extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void
+            ),
+            (
+                TVMBackendFreeWorkspace,
+                extern "C" fn(c_int, c_int, *mut c_void) -> c_int
+            ),
+            (
+                TVMBackendParallelLaunch,
+                extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int
+            ),
+            (
+                TVMBackendParallelBarrier,
+                extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv)
+            ),
+        );
+
+        // Pin the module in memory so that `ctx` pointer (below) is stable.
+        let dso_mod = Box::pin(Self {
+            lib,
+            packed_funcs: RefCell::new(HashMap::new()),
+            _pin: std::marker::PhantomPinned,
+        });
+
+        unsafe {
+            if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) {
+                **ctx = &dso_mod as *const _ as *const c_void;
+            }
+        }
+
+        Ok(dso_mod)
+    }
+}
+
+impl<'a> Module for DsoModule<'a> {
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
+        let name = name.as_ref();
+        let func = match unsafe {
+            self.lib
+                .get::<BackendPackedCFunc>(if name.as_bytes() == TVM_MAIN {
+                    // If __tvm_main__ is present, it contains the name of the
+                    // actual main function.
+                    match self
+                        .lib
+                        .get::<*const c_char>(TVM_MAIN)
+                        .map(|p| CStr::from_ptr(*p))
+                    {
+                        Ok(m) => m.to_bytes(),
+                        _ => return None,
+                    }
+                } else {
+                    name.as_bytes()
+                })
+        } {
+            Ok(func) => unsafe { func.into_raw() },
+            Err(_) => return None,
+        };
+
+        self.packed_funcs.borrow_mut().insert(
+            name.to_string(),
+            &*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)),
+        );
+
+        self.packed_funcs.borrow().get(name).map(|f| *f)
+    }
+}
+
+impl<'a> Drop for DsoModule<'a> {
+    fn drop(&mut self) {
+        self.packed_funcs
+            .replace(HashMap::new())
+            .into_iter()
+            .map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) })
+            .for_each(std::mem::drop);
+    }
+}
diff --git a/rust/runtime/src/module/mod.rs b/rust/runtime/src/module/mod.rs
new file mode 100644 (file)
index 0000000..2c7c107
--- /dev/null
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
+mod dso;
+mod syslib;
+
+use tvm_common::{
+    ffi::BackendPackedCFunc,
+    packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
+};
+
+#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))]
+pub use dso::DsoModule;
+pub use syslib::SystemLibModule;
+
+pub trait Module {
+    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
+}
+
+// @see `WrapPackedFunc` in `llvm_module.cc`.
+fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<dyn PackedFunc> {
+    box move |args: &[TVMArgValue]| {
+        let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
+            .into_iter()
+            .map(|arg| {
+                let (val, code) = arg.to_tvm_value();
+                (val, code as i32)
+            })
+            .unzip();
+        let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
+        if exit_code == 0 {
+            Ok(TVMRetValue::default())
+        } else {
+            Err(tvm_common::errors::FuncCallError::get_with_context(
+                func_name.clone(),
+            ))
+        }
+    }
+}
similarity index 62%
rename from rust/runtime/src/module.rs
rename to rust/runtime/src/module/syslib.rs
index 865338f..227b8c7 100644 (file)
@@ -21,14 +21,9 @@ use std::{
     collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
 };
 
-use tvm_common::{
-    ffi::BackendPackedCFunc,
-    packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue},
-};
+use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc};
 
-pub trait Module {
-    fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)>;
-}
+use super::Module;
 
 pub struct SystemLibModule;
 
@@ -53,30 +48,6 @@ impl Default for SystemLibModule {
     }
 }
 
-// @see `WrapPackedFunc` in `llvm_module.cc`.
-pub(super) fn wrap_backend_packed_func(
-    func_name: String,
-    func: BackendPackedCFunc,
-) -> Box<dyn PackedFunc> {
-    box move |args: &[TVMArgValue]| {
-        let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
-            .into_iter()
-            .map(|arg| {
-                let (val, code) = arg.to_tvm_value();
-                (val, code as i32)
-            })
-            .unzip();
-        let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
-        if exit_code == 0 {
-            Ok(TVMRetValue::default())
-        } else {
-            Err(tvm_common::errors::FuncCallError::get_with_context(
-                func_name.clone(),
-            ))
-        }
-    }
-}
-
 #[no_mangle]
 pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
     cname: *const c_char,
@@ -85,7 +56,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
     let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
     SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
         name.to_string(),
-        &*Box::leak(wrap_backend_packed_func(name.to_string(), func)),
+        &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
     );
     return 0;
 }
index 9614384..eb2f418 100644 (file)
@@ -42,7 +42,7 @@ use tvm_common::ffi::TVMParallelGroupEnv;
 #[cfg(target_env = "sgx")]
 use super::{TVMArgValue, TVMRetValue};
 
-type FTVMParallelLambda =
+pub(crate) type FTVMParallelLambda =
     extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
 
 /// Holds a parallel job request made by a TVM library function.
diff --git a/rust/runtime/tests/test_tvm_dso/Cargo.toml b/rust/runtime/tests/test_tvm_dso/Cargo.toml
new file mode 100644 (file)
index 0000000..afe7f26
--- /dev/null
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "test-tvm-dso"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray="0.12"
+tvm-runtime = { path = "../../" }
diff --git a/rust/runtime/tests/test_tvm_dso/build.rs b/rust/runtime/tests/test_tvm_dso/build.rs
new file mode 100644 (file)
index 0000000..f1d9822
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{env, path::Path, process::Command};
+
+fn main() {
+    let out_dir = env::var("OUT_DIR").unwrap();
+
+    let output = Command::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/src/build_test_lib.py"
+    ))
+    .arg(&out_dir)
+    .output()
+    .expect("Failed to execute command");
+    assert!(
+        Path::new(&format!("{}/test.so", out_dir)).exists(),
+        "Could not build tvm lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+}
diff --git a/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py b/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py
new file mode 100755 (executable)
index 0000000..63b43a5
--- /dev/null
@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Prepares a simple TVM library for testing."""
+
+from os import path as osp
+import sys
+
+import tvm
+from tvm.contrib import cc
+
+def main():
+    n = tvm.var('n')
+    A = tvm.placeholder((n,), name='A')
+    B = tvm.placeholder((n,), name='B')
+    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    s = tvm.create_schedule(C.op)
+    s[C].parallel(s[C].op.axis[0])
+    print(tvm.lower(s, [A, B, C], simple_mode=True))
+    obj_file = osp.join(sys.argv[1], 'test.o')
+    tvm.build(s, [A, B, C], 'llvm').save(obj_file)
+    cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file])
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/runtime/tests/test_tvm_dso/src/main.rs b/rust/runtime/tests/test_tvm_dso/src/main.rs
new file mode 100644 (file)
index 0000000..953676c
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate ndarray;
+#[macro_use]
+extern crate tvm_runtime;
+
+use ndarray::Array;
+use tvm_runtime::{DLTensor, DsoModule, Module};
+
+fn main() {
+    tvm_runtime::TVMGetLastError();
+    let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap();
+    let add = module
+        .get_function("__tvm_main__")
+        .expect("main function not found");
+    let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
+    let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
+    let mut c = Array::from_vec(vec![0f32; 4]);
+    let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
+    let mut a_dl: DLTensor = (&mut a).into();
+    let mut b_dl: DLTensor = (&mut b).into();
+    let mut c_dl: DLTensor = (&mut c).into();
+    call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
+    assert!(c.all_close(&e, 1e-8f32));
+}
index 1728fec..cdf777c 100755 (executable)
@@ -48,6 +48,10 @@ cd tests/test_tvm_basic
 cargo run
 cd -
 
+cd tests/test_tvm_dso
+cargo run
+cd -
+
 # run NNVM graph test
 cd tests/test_nnvm
 cargo run