[Rust] Static syslib (#3274)
authorNick Hynes <nhynes@nhynes.com>
Sun, 9 Jun 2019 03:56:58 +0000 (20:56 -0700)
committerGitHub <noreply@github.com>
Sun, 9 Jun 2019 03:56:58 +0000 (20:56 -0700)
rust/Cargo.toml
rust/macros/Cargo.toml [new file with mode: 0644]
rust/macros/src/lib.rs [new file with mode: 0644]
rust/runtime/Cargo.toml
rust/runtime/src/graph.rs
rust/runtime/src/lib.rs
rust/runtime/tests/test_tvm_basic/build.rs
rust/runtime/tests/test_tvm_basic/src/main.rs

index 6e89bae..02e2c7c 100644 (file)
@@ -18,6 +18,7 @@
 [workspace]
 members = [
        "common",
+       "macros",
        "runtime",
        "runtime/tests/test_tvm_basic",
        "runtime/tests/test_tvm_dso",
diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml
new file mode 100644 (file)
index 0000000..15773b6
--- /dev/null
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "tvm-macros"
+version = "0.1.0"
+license = "Apache-2.0"
+description = "Proc macros used by the TVM crates."
+repository = "https://github.com/dmlc/tvm"
+readme = "README.md"
+keywords = ["tvm"]
+authors = ["TVM Contributors"]
+edition = "2018"
+
+[lib]
+proc-macro = true
+
+[dependencies]
+goblin = "0.0.22"
+proc-macro2 = "0.4"
+proc-quote = "0.2"
+syn = "0.15"
diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs
new file mode 100644 (file)
index 0000000..704f7c1
--- /dev/null
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#![feature(bind_by_move_pattern_guards, proc_macro_span)]
+
+extern crate proc_macro;
+
+use std::{fs::File, io::Read};
+
+use proc_quote::quote;
+
+#[proc_macro]
+pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+    let obj_file_path = syn::parse_macro_input!(input as syn::LitStr);
+
+    let mut path = obj_file_path.span().unwrap().source_file().path();
+    path.pop(); // remove the filename
+    path.push(obj_file_path.value());
+
+    let mut fd = File::open(&path)
+        .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
+    let mut buffer = Vec::new();
+    fd.read_to_end(&mut buffer).unwrap();
+
+    let fn_names = match goblin::Object::parse(&buffer).unwrap() {
+        goblin::Object::Elf(elf) => elf
+            .syms
+            .iter()
+            .filter_map(|s| {
+                if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
+                    return None;
+                }
+                match elf.strtab.get(s.st_name) {
+                    Some(Ok(name)) if name != "" => {
+                        Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
+                    }
+                    _ => None,
+                }
+            })
+            .collect::<Vec<_>>(),
+        goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
+            obj.symbols()
+                .filter_map(|s| match s {
+                    Ok((name, nlist))
+                        if nlist.is_global()
+                            && nlist.n_sect != 0
+                            && !name.ends_with("tvm_module_ctx") =>
+                    {
+                        Some(syn::Ident::new(
+                            if name.starts_with('_') {
+                                // Mach objects prepend a _ to globals.
+                                &name[1..]
+                            } else {
+                                &name
+                            },
+                            proc_macro2::Span::call_site(),
+                        ))
+                    }
+                    _ => None,
+                })
+                .collect::<Vec<_>>()
+        }
+        _ => panic!("Unsupported object format."),
+    };
+
+    let extern_fns = quote! {
+        mod ext {
+            extern "C" {
+                #(
+                    pub(super) fn #fn_names(
+                        args: *const tvm_runtime::ffi::TVMValue,
+                        type_codes: *const std::os::raw::c_int,
+                        num_args: std::os::raw::c_int
+                    ) -> std::os::raw::c_int;
+                )*
+            }
+        }
+    };
+
+    let fns = quote! {
+        use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
+        #extern_fns
+
+        #(
+            pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
+                let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
+                   .into_iter()
+                   .map(|arg| {
+                       let (val, code) = arg.to_tvm_value();
+                       (val, code as i32)
+                   })
+                   .unzip();
+                let exit_code = unsafe {
+                    ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
+                };
+                if exit_code == 0 {
+                    Ok(TVMRetValue::default())
+                } else {
+                    Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
+                }
+            }
+        )*
+    };
+
+    proc_macro::TokenStream::from(fns)
+}
index 5809af0..3c81a93 100644 (file)
@@ -41,7 +41,8 @@ nom = {version = "4.0.0", default-features = false }
 serde = "1.0.59"
 serde_derive = "1.0.79"
 serde_json = "1.0.17"
-tvm-common = { version = "0.1.0", path = "../common/" }
+tvm-common = { version = "0.1", path = "../common" }
+tvm-macros = { version = "0.1", path = "../macros" }
 
 [target.'cfg(not(target_env = "sgx"))'.dependencies]
 num_cpus = "1.8.0"
index bff02f5..cacd7a3 100644 (file)
@@ -164,7 +164,7 @@ impl<'a> TryFrom<&'a str> for Graph {
 /// ```
 pub struct GraphExecutor<'m, 't> {
     graph: Graph,
-    op_execs: Vec<Box<Fn() + 'm>>,
+    op_execs: Vec<Box<dyn Fn() + 'm>>,
     tensors: Vec<Tensor<'t>>,
 }
 
@@ -240,7 +240,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
         graph: &Graph,
         lib: &'m M,
         tensors: &Vec<Tensor<'t>>,
-    ) -> Result<Vec<Box<Fn() + 'm>>, Error> {
+    ) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {
         ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
         let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
 
@@ -279,7 +279,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
                 })
                 .collect::<Result<Vec<DLTensor>, Error>>()
                 .unwrap();
-            let op: Box<Fn()> = box move || {
+            let op: Box<dyn Fn()> = box move || {
                 let args = dl_tensors
                     .iter()
                     .map(|t| t.into())
index c774d5b..010fbf7 100644 (file)
@@ -29,7 +29,6 @@
 //! For examples of use, please refer to the multi-file tests in the `tests` directory.
 
 #![feature(
-    alloc,
     allocator_api,
     box_syntax,
     fn_traits,
@@ -77,6 +76,7 @@ pub use tvm_common::{
     packed_func::{self, *},
     TVMArgValue, TVMRetValue,
 };
+pub use tvm_macros::import_module;
 
 pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*};
 
index ea3bfcb..3439f9c 100644 (file)
 
 extern crate ar;
 
-use std::{env, path::Path, process::Command};
+use std::{path::PathBuf, process::Command};
 
 use ar::Builder;
 use std::fs::File;
 
 fn main() {
-    let out_dir = env::var("OUT_DIR").unwrap();
+    let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
+    out_dir.push("lib");
+
+    if !out_dir.is_dir() {
+        std::fs::create_dir(&out_dir).unwrap();
+    }
+
+    let obj_file = out_dir.join("test.o");
+    let lib_file = out_dir.join("libtest.a");
 
     let output = Command::new(concat!(
         env!("CARGO_MANIFEST_DIR"),
@@ -35,7 +43,7 @@ fn main() {
     .output()
     .expect("Failed to execute command");
     assert!(
-        Path::new(&format!("{}/test.o", out_dir)).exists(),
+        obj_file.exists(),
         "Could not build tvm lib: {}",
         String::from_utf8(output.stderr)
             .unwrap()
@@ -45,9 +53,9 @@ fn main() {
             .unwrap_or("")
     );
 
-    let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap());
-    builder.append_path(format!("{}/test.o", out_dir)).unwrap();
+    let mut builder = Builder::new(File::create(lib_file).unwrap());
+    builder.append_path(obj_file).unwrap();
 
     println!("cargo:rustc-link-lib=static=test");
-    println!("cargo:rustc-link-search=native={}", out_dir);
+    println!("cargo:rustc-link-search=native={}", out_dir.display());
 }
index 14bb7c2..a83078e 100644 (file)
@@ -22,13 +22,14 @@ extern crate ndarray;
 extern crate tvm_runtime;
 
 use ndarray::Array;
-use tvm_runtime::{DLTensor, Module, SystemLibModule};
+use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
+
+mod tvm_mod {
+    import_module!("../lib/test.o");
+}
 
 fn main() {
-    let syslib = SystemLibModule::default();
-    let add = syslib
-        .get_function("default_function")
-        .expect("main function not found");
+    // try static
     let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
     let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
     let mut c = Array::from_vec(vec![0f32; 4]);
@@ -36,6 +37,14 @@ fn main() {
     let mut a_dl: DLTensor = (&mut a).into();
     let mut b_dl: DLTensor = (&mut b).into();
     let mut c_dl: DLTensor = (&mut c).into();
+    call_packed!(tvm_mod::default_function, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
+    assert!(c.all_close(&e, 1e-8f32));
+
+    // try runtime
+    let syslib = SystemLibModule::default();
+    let add = syslib
+        .get_function("default_function")
+        .expect("main function not found");
     call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
     assert!(c.all_close(&e, 1e-8f32));
 }