[Rust] Fixes for wasm32 target (#5489)
authorMORITA Kazutaka <morita.kazutaka@gmail.com>
Sat, 2 May 2020 00:06:43 +0000 (09:06 +0900)
committerGitHub <noreply@github.com>
Sat, 2 May 2020 00:06:43 +0000 (17:06 -0700)
* [Rust] Fixes for wasm32 target

* [Rust] Add test for wasm32

* allow cargo config to be into repo

* Disable wasm tests in CI

15 files changed:
rust/Cargo.toml
rust/common/build.rs
rust/common/src/array.rs
rust/common/src/lib.rs
rust/runtime/src/array.rs
rust/runtime/src/graph.rs
rust/runtime/src/module/mod.rs
rust/runtime/src/threading.rs
rust/runtime/tests/test_wasm32/.cargo/config [new file with mode: 0644]
rust/runtime/tests/test_wasm32/Cargo.toml [new file with mode: 0644]
rust/runtime/tests/test_wasm32/build.rs [new file with mode: 0644]
rust/runtime/tests/test_wasm32/src/build_test_lib.py [new file with mode: 0755]
rust/runtime/tests/test_wasm32/src/main.rs [new file with mode: 0644]
tests/lint/check_file_type.py
tests/scripts/task_rust.sh

index 8467f6a..f08f861 100644 (file)
@@ -22,6 +22,7 @@ members = [
        "runtime",
        "runtime/tests/test_tvm_basic",
        "runtime/tests/test_tvm_dso",
+       "runtime/tests/test_wasm32",
        "runtime/tests/test_nn",
        "frontend",
        "frontend/tests/basics",
index b3ae7b6..07326f4 100644 (file)
@@ -51,6 +51,7 @@ fn main() {
         .layout_tests(false)
         .derive_partialeq(true)
         .derive_eq(true)
+        .derive_default(true)
         .generate()
         .expect("unable to generate bindings")
         .write_to_file(PathBuf::from("src/c_runtime_api.rs"))
index d0a66a6..a8f4f98 100644 (file)
@@ -133,6 +133,7 @@ macro_rules! impl_dltensor_from_ndarray {
                     shape: arr.shape().as_ptr() as *const i64 as *mut i64,
                     strides: arr.strides().as_ptr() as *const isize as *mut i64,
                     byte_offset: 0,
+                    ..Default::default()
                 }
             }
         }
index 2ae64e7..33b2993 100644 (file)
@@ -31,8 +31,13 @@ pub mod ffi {
 
     include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
 
-    pub type BackendPackedCFunc =
-        extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
+    pub type BackendPackedCFunc = extern "C" fn(
+        args: *const TVMValue,
+        type_codes: *const c_int,
+        num_args: c_int,
+        out_ret_value: *mut TVMValue,
+        out_ret_tcode: *mut u32,
+    ) -> c_int;
 }
 
 pub mod array;
index 2b6c7c2..c38b3ff 100644 (file)
@@ -297,6 +297,7 @@ impl<'a> Tensor<'a> {
                 self.strides.as_ref().unwrap().as_ptr()
             } as *mut i64,
             byte_offset: 0,
+            ..Default::default()
         }
     }
 }
index 518bf72..71541ba 100644 (file)
@@ -382,7 +382,18 @@ named! {
 // Converts a bytes to String.
 named! {
     name<String>,
-    map_res!(length_data!(le_u64), |b: &[u8]| String::from_utf8(b.to_vec()))
+    do_parse!(
+        len_l: le_u32 >>
+        len_h: le_u32 >>
+        data: take!(len_l) >>
+        (
+            if len_h == 0 {
+                String::from_utf8(data.to_vec()).unwrap()
+            } else {
+                panic!("Too long string")
+            }
+        )
+    )
 }
 
 // Parses a TVMContext
index 856dd78..cb4d777 100644 (file)
@@ -44,9 +44,17 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<
                 (val, code as i32)
             })
             .unzip();
-        let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
+        let ret: TVMRetValue = TVMRetValue::default();
+        let (mut ret_val, mut ret_type_code) = ret.to_tvm_value();
+        let exit_code = func(
+            values.as_ptr(),
+            type_codes.as_ptr(),
+            values.len() as i32,
+            &mut ret_val,
+            &mut ret_type_code,
+        );
         if exit_code == 0 {
-            Ok(TVMRetValue::default())
+            Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code))
         } else {
             Err(tvm_common::errors::FuncCallError::get_with_context(
                 func_name.clone(),
index f473bbf..b8be012 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 use std::{
-    env,
     os::raw::{c_int, c_void},
     sync::{
         atomic::{AtomicUsize, Ordering},
@@ -27,6 +26,9 @@ use std::{
     thread::{self, JoinHandle},
 };
 
+#[cfg(not(target_arch = "wasm32"))]
+use std::env;
+
 use crossbeam::channel::{bounded, Receiver, Sender};
 use tvm_common::ffi::TVMParallelGroupEnv;
 
@@ -147,7 +149,10 @@ impl ThreadPool {
 
     fn run_worker(queue: Receiver<Task>) {
         loop {
-            let task = queue.recv().expect("should recv");
+            let task = match queue.recv() {
+                Ok(v) => v,
+                Err(_) => break,
+            };
             let result = task.run();
             if result == <i32>::min_value() {
                 break;
diff --git a/rust/runtime/tests/test_wasm32/.cargo/config b/rust/runtime/tests/test_wasm32/.cargo/config
new file mode 100644 (file)
index 0000000..6b77899
--- /dev/null
@@ -0,0 +1,2 @@
+[build]
+target = "wasm32-wasi"
diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/runtime/tests/test_wasm32/Cargo.toml
new file mode 100644 (file)
index 0000000..1d3373a
--- /dev/null
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "test-wasm32"
+version = "0.0.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+
+[dependencies]
+ndarray="0.12"
+tvm-runtime = { path = "../../" }
diff --git a/rust/runtime/tests/test_wasm32/build.rs b/rust/runtime/tests/test_wasm32/build.rs
new file mode 100644 (file)
index 0000000..8b72be2
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{path::PathBuf, process::Command};
+
+fn main() {
+    let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
+    out_dir.push("lib");
+
+    if !out_dir.is_dir() {
+        std::fs::create_dir(&out_dir).unwrap();
+    }
+
+    let obj_file = out_dir.join("test.o");
+    let lib_file = out_dir.join("libtest_wasm32.a");
+
+    let output = Command::new(concat!(
+        env!("CARGO_MANIFEST_DIR"),
+        "/src/build_test_lib.py"
+    ))
+    .arg(&out_dir)
+    .output()
+    .expect("Failed to execute command");
+    assert!(
+        obj_file.exists(),
+        "Could not build tvm lib: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8");
+    let output = Command::new(ar)
+        .arg("rcs")
+        .arg(&lib_file)
+        .arg(&obj_file)
+        .output()
+        .expect("Failed to execute command");
+    assert!(
+        lib_file.exists(),
+        "Could not create archive: {}",
+        String::from_utf8(output.stderr)
+            .unwrap()
+            .trim()
+            .split("\n")
+            .last()
+            .unwrap_or("")
+    );
+
+    println!("cargo:rustc-link-lib=static=test_wasm32");
+    println!("cargo:rustc-link-search=native={}", out_dir.display());
+}
diff --git a/rust/runtime/tests/test_wasm32/src/build_test_lib.py b/rust/runtime/tests/test_wasm32/src/build_test_lib.py
new file mode 100755 (executable)
index 0000000..6016c60
--- /dev/null
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Prepares a simple TVM library for testing."""
+
+from os import path as osp
+import sys
+
+import tvm
+from tvm import te
+
+def main():
+    n = te.var('n')
+    A = te.placeholder((n,), name='A')
+    B = te.placeholder((n,), name='B')
+    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    s = tvm.te.create_schedule(C.op)
+    s[C].parallel(s[C].op.axis[0])
+    print(tvm.lower(s, [A, B, C], simple_mode=True))
+    tvm.build(s, [A, B, C], 'llvm -target=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o'))
+
+if __name__ == '__main__':
+    main()
diff --git a/rust/runtime/tests/test_wasm32/src/main.rs b/rust/runtime/tests/test_wasm32/src/main.rs
new file mode 100644 (file)
index 0000000..a46cfa9
--- /dev/null
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern "C" {
+    static __tvm_module_ctx: i32;
+}
+
+#[no_mangle]
+unsafe fn __get_tvm_module_ctx() -> i32 {
+    // Refer a symbol in the libtest_wasm32.a to make sure that the link of the
+    // library is not optimized out.
+    __tvm_module_ctx
+}
+
+extern crate ndarray;
+#[macro_use]
+extern crate tvm_runtime;
+
+use ndarray::Array;
+use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
+
+fn main() {
+    // try static
+    let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
+    let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
+    let mut c = Array::from_vec(vec![0f32; 4]);
+    let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
+    let mut a_dl: DLTensor = (&mut a).into();
+    let mut b_dl: DLTensor = (&mut b).into();
+    let mut c_dl: DLTensor = (&mut c).into();
+
+    let syslib = SystemLibModule::default();
+    let add = syslib
+        .get_function("default_function")
+        .expect("main function not found");
+    call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
+    assert!(c.all_close(&e, 1e-8f32));
+}
index 0ec0a2e..04d6c94 100644 (file)
@@ -103,7 +103,8 @@ ALLOW_SPECIFIC_FILE = {
     "KEYS",
     "DISCLAIMER",
     "Jenkinsfile",
-    # sgx config
+    # cargo config
+    "rust/runtime/tests/test_wasm32/.cargo/config",
     "apps/sgx/.cargo/config",
     # html for demo purposes
     "tests/webgl/test_static_webgl_library.html",
index fae07d3..5529632 100755 (executable)
@@ -54,6 +54,12 @@ cd tests/test_tvm_dso
 cargo run
 cd -
 
+# # run wasm32 test
+# cd tests/test_wasm32
+# cargo build
+# wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm
+# cd -
+
 # run nn graph test
 cd tests/test_nn
 cargo run