From 702fd0f0f4ef283dedac3e65e71455ca4661a8e3 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 7 May 2020 13:47:36 -0700 Subject: [PATCH] [WEB][RUNTIME] TVM WebAssembly JS Runtime (#5506) * [WEB] Remove the old web runtime * [WEB][RUNTIME] TVM WebAssembly Runtime This PR introduces a brand new TVM web runtime based on the WASM standard API. Main highlights: - The new runtime is rewritten using the Typescript. - The new runtime now directly interfaces with WebAssembly's standard API, instead of relying on emscripten's API. This change will make the js runtime more portable to runtime variants. For example, we could also try to make it interface with the tvm's rust runtime implementation. - System library can be provided through WASI - We also build a hack to enable Emscripten to generate a WASI like bundle for runtime environment on the Web. - The wasm generation now uses the mainlin LLVM. - Dynamic link(dlopen) is not used due to limitation of wasm, instead we rely on the recent new RPC refactor to directly restart a new session for each wasm binary sent to the RPC. * Address review comments * Skip tensorcore test --- docs/api/python/contrib.rst | 6 +- python/tvm/_ffi/libinfo.py | 4 + python/tvm/contrib/{emscripten.py => emcc.py} | 48 +- python/tvm/exec/rpc_proxy.py | 11 +- tests/lint/check_file_type.py | 5 +- tests/lint/rat-excludes | 3 + tests/scripts/task_python_docs.sh | 4 - tests/web/test_packed_func.js | 72 -- tests/webgl/README.md | 24 - tests/webgl/test_local_gemm.py | 58 - tests/webgl/test_local_save_load.py | 53 - tests/webgl/test_local_topi_conv2d_nchw.py | 99 -- tests/webgl/test_local_topi_dense.py | 76 -- tests/webgl/test_local_topi_pooling.py | 132 -- tests/webgl/test_local_topi_softmax.py | 96 -- tests/webgl/test_remote_save_load.py | 96 -- tests/webgl/test_static_webgl_library.html | 72 -- tests/webgl/test_static_webgl_library.py | 66 - .../tests/python/test_topi_conv2d_nhwc_winograd.py | 8 +- web/.eslintignore | 1 + web/.gitignore | 6 + web/.jsdoc_conf.json | 7 - web/Makefile | 51 + web/README.md | 169 +-- web/apps/browser/rpc_server.html | 79 ++ web/apps/node/example.js | 37 + web/apps/node/wasi_example.js | 36 + .../node/wasi_rpc_server.js} | 31 +- .../emcc/decorate_as_wasi.py | 45 +- web/emcc/preload.js | 41 + web/emcc/tvmjs_support.cc | 193 +++ web/emcc/wasm_runtime.cc | 92 ++ web/example_rpc.html | 61 - web/package.json | 29 + web/{.eslintrc.js => rollup.config.js} | 41 +- web/src/ctypes.ts | 229 ++++ web/src/environment.ts | 146 +++ web/src/index.ts | 27 + web/src/memory.ts | 408 +++++++ web/src/rpc_server.ts | 379 ++++++ web/src/runtime.ts | 1113 +++++++++++++++++ web/src/support.ts | 64 + web/src/types.ts | 53 + {tests/web => web/tests/node}/test_module_load.js | 35 +- .../tests/node/test_ndarray.js | 52 +- web/tests/node/test_packed_func.js | 130 ++ .../web => web/tests/python}/prepare_test_libs.py | 21 +- .../web => web/tests/python}/websock_rpc_test.py | 64 +- web/tsconfig.json | 13 + web/tvm_runtime.js | 1274 -------------------- web/web_runtime.cc | 88 -- 51 files changed, 3358 insertions(+), 2590 deletions(-) rename python/tvm/contrib/{emscripten.py => emcc.py} (65%) delete mode 100644 tests/web/test_packed_func.js delete mode 100644 tests/webgl/README.md delete mode 100644 tests/webgl/test_local_gemm.py delete mode 100644 tests/webgl/test_local_save_load.py delete mode 100644 tests/webgl/test_local_topi_conv2d_nchw.py delete mode 100644 tests/webgl/test_local_topi_dense.py delete mode 100644 tests/webgl/test_local_topi_pooling.py delete mode 100644 tests/webgl/test_local_topi_softmax.py delete mode 100644 tests/webgl/test_remote_save_load.py delete mode 100644 tests/webgl/test_static_webgl_library.html delete mode 100644 tests/webgl/test_static_webgl_library.py create mode 100644 web/.eslintignore create mode 100644 web/.gitignore delete mode 100644 web/.jsdoc_conf.json create mode 100644 web/Makefile create mode 100644 web/apps/browser/rpc_server.html create mode 100644 web/apps/node/example.js create mode 100644 web/apps/node/wasi_example.js rename web/{example_rpc_node.js => apps/node/wasi_rpc_server.js} (60%) rename tests/webgl/test_local_multi_stage.py => web/emcc/decorate_as_wasi.py (50%) create mode 100644 web/emcc/preload.js create mode 100644 web/emcc/tvmjs_support.cc create mode 100644 web/emcc/wasm_runtime.cc delete mode 100644 web/example_rpc.html create mode 100644 web/package.json rename web/{.eslintrc.js => rollup.config.js} (69%) create mode 100644 web/src/ctypes.ts create mode 100644 web/src/environment.ts create mode 100644 web/src/index.ts create mode 100644 web/src/memory.ts create mode 100644 web/src/rpc_server.ts create mode 100644 web/src/runtime.ts create mode 100644 web/src/support.ts create mode 100644 web/src/types.ts rename {tests/web => web/tests/node}/test_module_load.js (64%) rename tests/web/test_basic.js => web/tests/node/test_ndarray.js (55%) create mode 100644 web/tests/node/test_packed_func.js rename {tests/web => web/tests/python}/prepare_test_libs.py (69%) rename {tests/web => web/tests/python}/websock_rpc_test.py (55%) create mode 100644 web/tsconfig.json delete mode 100644 web/tvm_runtime.js delete mode 100644 web/web_runtime.cc diff --git a/docs/api/python/contrib.rst b/docs/api/python/contrib.rst index b482d30..8ac4e1f 100644 --- a/docs/api/python/contrib.rst +++ b/docs/api/python/contrib.rst @@ -48,9 +48,9 @@ tvm.contrib.dlpack .. automodule:: tvm.contrib.dlpack :members: -tvm.contrib.emscripten -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: tvm.contrib.emscripten +tvm.contrib.emcc +~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.emcc :members: tvm.contrib.miopen diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 0d1a4e2..de8f7b5 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -88,6 +88,9 @@ def find_lib_path(name=None, search_path=None, optional=False): dll_path.append(install_lib_dir) + if os.path.isdir(source_dir): + dll_path.append(os.path.join(source_dir, "web", "dist", "wasm")) + dll_path = [os.path.realpath(x) for x in dll_path] if search_path is not None: if isinstance(search_path, list): @@ -154,6 +157,7 @@ def find_include_path(name=None, search_path=None, optional=False): ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) source_dir = os.path.join(ffi_dir, "..", "..", "..") install_include_dir = os.path.join(ffi_dir, "..", "..", "..", "..") + third_party_dir = os.path.join(source_dir, "3rdparty") header_path = [] diff --git a/python/tvm/contrib/emscripten.py b/python/tvm/contrib/emcc.py similarity index 65% rename from python/tvm/contrib/emscripten.py rename to python/tvm/contrib/emcc.py index 7f31273..6df205a 100644 --- a/python/tvm/contrib/emscripten.py +++ b/python/tvm/contrib/emcc.py @@ -16,18 +16,16 @@ # under the License. """Util to invoke emscripten compilers in the system.""" # pylint: disable=invalid-name -from __future__ import absolute_import as _abs - import subprocess -from .._ffi.base import py_str -from .._ffi.libinfo import find_lib_path +from tvm._ffi.base import py_str +from tvm._ffi.libinfo import find_lib_path + -def create_js(output, - objects, - options=None, - side_module=False, - cc="emcc"): - """Create emscripten javascript library. +def create_tvmjs_wasm(output, + objects, + options=None, + cc="emcc"): + """Create wasm that is supposed to run with the tvmjs. Parameters ---------- @@ -44,25 +42,27 @@ def create_js(output, The compile string. """ cmd = [cc] - cmd += ["-Oz"] - if not side_module: - cmd += ["-s", "RESERVED_FUNCTION_POINTERS=2"] - cmd += ["-s", "NO_EXIT_RUNTIME=1"] - extra_methods = ['cwrap', 'getValue', 'setValue', 'addFunction'] - cfg = "[" + (','.join("\'%s\'" % x for x in extra_methods)) + "]" - cmd += ["-s", "EXTRA_EXPORTED_RUNTIME_METHODS=" + cfg] - else: - cmd += ["-s", "SIDE_MODULE=1"] - cmd += ["-o", output] + cmd += ["-O3"] + + cmd += ["-std=c++14"] + cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"] + cmd += ["-s", "STANDALONE_WASM=1"] + cmd += ["-s", "ALLOW_MEMORY_GROWTH=1"] + + objects = [objects] if isinstance(objects, str) else objects + with_runtime = False for obj in objects: - if obj.find("libtvm_web_runtime.bc") != -1: + if obj.find("wasm_runtime.bc") != -1: with_runtime = True - if not with_runtime and not side_module: - objects += [find_lib_path("libtvm_web_runtime.bc")[0]] + if not with_runtime: + objects += [find_lib_path("wasm_runtime.bc")[0]] + objects += [find_lib_path("tvmjs_support.bc")[0]] + + cmd += ["-o", output] cmd += objects if options: @@ -79,4 +79,4 @@ def create_js(output, msg += py_str(out) raise RuntimeError(msg) -create_js.object_format = "bc" +create_tvmjs_wasm.object_format = "bc" diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 4cf3413..59da8fa 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -29,12 +29,11 @@ from ..rpc.proxy import Proxy def find_example_resource(): """Find resource examples.""" curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - base_path = os.path.join(curr_path, "../../../") - index_page = os.path.join(base_path, "web/example_rpc.html") + base_path = os.path.abspath(os.path.join(curr_path, "..", "..", "..")) + index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html") js_files = [ - os.path.join(base_path, "web/tvm_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js.mem") + os.path.join(base_path, "web/dist/tvmjs.bundle.js"), + os.path.join(base_path, "web/dist/wasm/tvmjs_runtime.wasi.js") ] for fname in [index_page] + js_files: if not os.path.exists(fname): @@ -69,7 +68,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--host', type=str, default="0.0.0.0", + parser.add_argument('--host', type=str, default="localhost", help='the hostname of the server') parser.add_argument('--port', type=int, default=9090, help='The port of the RPC') diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 04d6c94..da3a456 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -36,6 +36,7 @@ ALLOW_EXTENSION = { "scala", "java", "go", + "ts", "sh", "py", "pyi", @@ -81,6 +82,7 @@ ALLOW_EXTENSION = { # List of file names allowed ALLOW_FILE_NAME = { ".gitignore", + ".eslintignore", ".gitattributes", "README", "Makefile", @@ -107,8 +109,7 @@ ALLOW_SPECIFIC_FILE = { "rust/runtime/tests/test_wasm32/.cargo/config", "apps/sgx/.cargo/config", # html for demo purposes - "tests/webgl/test_static_webgl_library.html", - "web/example_rpc.html", + "web/apps/browser/rpc_server.html", # images are normally not allowed # discuss with committers before add more images "apps/android_rpc/app/src/main/res/mipmap-hdpi/ic_launcher.png", diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 5421d22..0714850 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -28,6 +28,8 @@ core.cpp build _static _build +node_modules +dist .*~ \#..*\# \.#.* @@ -40,6 +42,7 @@ RelayVisitor.py # Specific files package-list MANIFEST +.eslintignore .gitignore .gitattributes .gitmodules diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 819961d..41006f4 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -43,9 +43,6 @@ cd .. make doc rm -f docs/doxygen/html/*.map docs/doxygen/html/*.md5 -# JS doc -jsdoc -c web/.jsdoc_conf.json web/tvm_runtime.js web/README.md - # Java doc make javadoc @@ -54,7 +51,6 @@ rm -rf _docs mv docs/_build/html _docs rm -f _docs/.buildinfo mv docs/doxygen/html _docs/doxygen -mv out _docs/jsdoc mv jvm/core/target/site/apidocs _docs/javadoc echo "Start creating the docs tarball.." diff --git a/tests/web/test_packed_func.js b/tests/web/test_packed_func.js deleted file mode 100644 index d239f73..0000000 --- a/tests/web/test_packed_func.js +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -function testGetGlobal() { - var targs = [10, 10.0, "hello"] - tvm.registerFunc("my_packed_func", function () { - tvm.assert(Array.from(arguments).toString() == targs, "assert fail"); - return 10 - }); - var f = tvm.getGlobalFunc("my_packed_func") - tvm.assert(tvm.isPackedFunc(f)); - y = f.apply(null, targs); - tvm.assert(y == 10); - f.release(); -} - - -function testReturnFunc() { - function addy(y) { - function add(x) { - return x + y; - } - return add; - } - var myf = tvm.convertFunc(addy); - var f = myf(10); - tvm.assert(tvm.isPackedFunc(f)); - tvm.assert(f(11) == 21); - myf.release(); - f.release(); -} - -function testByteArray() { - var a = new Uint8Array(3); - a[0] = 1; - a[1] = 2; - function myfunc(ss){ - tvm.assert(ss instanceof Uint8Array); - tvm.assert(ss.toString() == a); - } - f = tvm.convertFunc(myfunc); - f(a); - f.release(); -} - -testGetGlobal(); -testReturnFunc(); -testByteArray(); diff --git a/tests/webgl/README.md b/tests/webgl/README.md deleted file mode 100644 index 5303cc0..0000000 --- a/tests/webgl/README.md +++ /dev/null @@ -1,24 +0,0 @@ - - - - - - - - - - - - - - - - - -## Test cases for the WebGL backend - -Any test case with name `test_local_...` tests the C++ OpenGL backend on the -local OS, which can be executed automatically. - -Any test case with name `test_remote_...` tests the WebGL backend within the -browser, which must be run manually. See instruction within the test. diff --git a/tests/webgl/test_local_gemm.py b/tests/webgl/test_local_gemm.py deleted file mode 100644 index 6bd22bf..0000000 --- a/tests/webgl/test_local_gemm.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -import numpy as np - -def test_local_gemm(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - nn = 1024 - n = te.var('n') - n = tvm.runtime.convert(nn) - m = n - l = n - A = te.placeholder((n, l), name='A', dtype='int32') - B = te.placeholder((m, l), name='B', dtype='int32') - k = te.reduce_axis((0, l), name='k') - C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k), - name='CC') - - s = te.create_schedule(C.op) - s[C].opengl() - print(tvm.lower(s, [A, B, C], simple_mode=True)) - - f = tvm.build(s, [A, B, C], "opengl", name="gemm") - print("------opengl code------") - print(f.imported_modules[0].get_source(fmt="gl")) - - ctx = tvm.opengl() - n, m, l = nn, nn, nn - a_np = np.random.uniform(low=0, high=10, size=(n, l)).astype(A.dtype) - b_np = np.random.uniform(low=0, high=10, size=(m, l)).astype(B.dtype) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) - f(a, b, c) - - tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np, b_np.T)) - -if __name__ == "__main__": - test_local_gemm() diff --git a/tests/webgl/test_local_save_load.py b/tests/webgl/test_local_save_load.py deleted file mode 100644 index cca6802..0000000 --- a/tests/webgl/test_local_save_load.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -def test_local_save_load(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - n = te.var("n") - A = te.placeholder((n,), name='A', dtype='int32') - B = te.placeholder((n,), name='B', dtype='int32') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - - f = tvm.build(s, [A, B, C], "opengl", target_host="llvm", name="myadd") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((n), dtype=C.dtype), ctx) - f(a, b, c) - - temp = util.tempdir() - path_so = temp.relpath("myadd.so") - f.export_library(path_so) - f1 = tvm.runtime.load_module(path_so) - f1(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - test_local_save_load() diff --git a/tests/webgl/test_local_topi_conv2d_nchw.py b/tests/webgl/test_local_topi_conv2d_nchw.py deleted file mode 100644 index 0d9b777..0000000 --- a/tests/webgl/test_local_topi_conv2d_nchw.py +++ /dev/null @@ -1,99 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Example code to do convolution. -Copied from topi/tests/python/test_topi_conv2d_nchw.py. -Should be removed once we fix OpenGL testing on Jenkins.""" -import os -import numpy as np -import tvm -from tvm import te -import topi -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - -def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): - in_height = in_width = in_size - - A = te.placeholder((batch, in_channel, in_height, in_width), name='A') - W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W') - B = topi.nn.conv2d_nchw(A, W, stride, padding) - C = topi.nn.relu(B) - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s1 = topi.generic.schedule_conv2d_nchw([B]) - s2 = topi.generic.schedule_conv2d_nchw([C]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - with tvm.target.build_config(auto_unroll_max_step=1400, - unroll_explicit=(device != "cuda")): - func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func1(a, w, b) - func2(a, w, c) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - - -def test_conv2d_nchw(): - # ResNet18 worklaods - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - # Vgg16 workloads - verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1) - # Super resolution workloads - verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1) - verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1) - -if __name__ == "__main__": - test_conv2d_nchw() diff --git a/tests/webgl/test_local_topi_dense.py b/tests/webgl/test_local_topi_dense.py deleted file mode 100644 index 60dfe1f..0000000 --- a/tests/webgl/test_local_topi_dense.py +++ /dev/null @@ -1,76 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for dense operator -Copied from topi/tests/python/test_topi_dense.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -from topi.util import get_const_tuple -from tvm.contrib.pickle_memoize import memoize - - -def verify_dense(batch, in_dim, out_dim, use_bias=True): - A = te.placeholder((batch, in_dim), name='A') - B = te.placeholder((out_dim, in_dim), name='B') - C = te.placeholder((out_dim,), name='C') - D = topi.nn.dense(A, B, C if use_bias else None) - D = topi.nn.relu(D) - dtype = A.dtype - - # use memoize to pickle the test data for next time use - @memoize("topi.tests.test_topi_dense") - def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) - else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) - return (a_np, b_np, c_np, d_np) - # get the test data - a_np, b_np, c_np, d_np = get_ref_data() - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_dense(D) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(c_np, ctx) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) - f = tvm.build(s, [A, B, C, D], device, name="dense") - f(a, b, c, d) - tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_dense(): - verify_dense(1, 1024, 1000, use_bias=True) - verify_dense(1, 1024, 1000, use_bias=False) - - -if __name__ == "__main__": - test_dense() diff --git a/tests/webgl/test_local_topi_pooling.py b/tests/webgl/test_local_topi_pooling.py deleted file mode 100644 index 3adae7b..0000000 --- a/tests/webgl/test_local_topi_pooling.py +++ /dev/null @@ -1,132 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for pooling -Copied from topi/tests/python/test_topi_pooling.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -import math -from topi.util import get_const_tuple - -def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): - iw = ih - kw = kh - sw = sh - ph, pw = padding - A = te.placeholder((n, ic, ih, iw), name='A') - B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, - pool_type=pool_type, ceil_mode=ceil_mode) - B = topi.nn.relu(B) - dtype = A.dtype - - bshape = get_const_tuple(B.shape) - ashape = get_const_tuple(A.shape) - if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1) - else: - assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) - - - a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) - no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) - pad_np[np.ix_(*no_zero)] = a_np - _, oc, oh, ow = get_const_tuple(B.shape) - b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) - - if pool_type == 'avg': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - elif pool_type =='max': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) - print(tvm.lower(s, [A, B], simple_mode=True)) - - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_pool(): - verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False) - verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False) - verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) - - - -def verify_global_pool(n, c, h, w, pool_type): - A = te.placeholder((n, c, h, w), name='A') - B = topi.nn.global_pool(A, pool_type=pool_type) - B = topi.nn.relu(B) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - if pool_type == 'avg': - b_np = np.mean(a_np, axis=(2,3), keepdims=True) - elif pool_type =='max': - b_np = np.max(a_np, axis=(2,3), keepdims=True) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_global_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_global_pool(): - verify_global_pool(1, 1024, 7, 7, 'avg') - verify_global_pool(4, 1024, 7, 7, 'avg') - verify_global_pool(1, 1024, 7, 7, 'max') - verify_global_pool(4, 1024, 7, 7, 'max') - - -if __name__ == "__main__": - test_pool() - test_global_pool() diff --git a/tests/webgl/test_local_topi_softmax.py b/tests/webgl/test_local_topi_softmax.py deleted file mode 100644 index c0ddbf2..0000000 --- a/tests/webgl/test_local_topi_softmax.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for softmax -Copied from topi/tests/python/test_topi_softmax.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" - -import os -import numpy as np -import tvm -from tvm import te -import topi -import logging -from topi.util import get_const_tuple - -def verify_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - -def test_softmax(): - verify_softmax(32, 10) - verify_softmax(3, 4) - - -def verify_log_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.log_softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.log_softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="log_softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - - -def test_log_softmax(): - verify_log_softmax(32, 10) - verify_log_softmax(3, 4) - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - test_softmax() - test_log_softmax() diff --git a/tests/webgl/test_remote_save_load.py b/tests/webgl/test_remote_save_load.py deleted file mode 100644 index 34bbb3f..0000000 --- a/tests/webgl/test_remote_save_load.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -The following instruction is based on web/README.md. - -Setup an RPC server: -$ python -m tvm.exec.rpc_proxy --example-rpc=1 - -Go to http://localhost:9190 in browser. - -Click "Connect To Proxy". - -Run this test script: -$ python tests/webgl/test_remote_save_load.py -""" - -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -proxy_host = "localhost" -proxy_port = 9090 - -def try_remote_save_load(): - if not tvm.runtime.enabled("rpc"): - return - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - # Build the module. - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.placeholder((n,), name='B') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B, C], "opengl", target_host=target_host, name="myadd") - - remote = rpc.connect(proxy_host, proxy_port, key="js") - - temp = util.tempdir() - ctx = remote.opengl(0) - path_obj = temp.relpath("myadd.bc") - path_dso = temp.relpath("myadd.js") - path_gl = temp.relpath("myadd.gl") - path_json = temp.relpath("myadd.tvm_meta.json") - - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - f.imported_modules[0].save(path_gl) - - remote.upload(path_dso, "myadd.dso") - remote.upload(path_gl) - remote.upload(path_json) - - remote.download("myadd.dso") - remote.download("myadd.gl") - remote.download("myadd.tvm_meta.json") - - print('Loading myadd.dso') - fhost = remote.load_module("myadd.dso") - - print('Loading myadd.gl') - fdev = remote.load_module("myadd.gl") - - print('import_module') - fhost.import_module(fdev) - - print('running...') - a = tvm.nd.array(np.random.uniform(size=16).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(16, dtype=A.dtype), ctx) - c = tvm.nd.array(np.zeros(16, dtype=C.dtype), ctx) - fhost(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - try_remote_save_load() diff --git a/tests/webgl/test_static_webgl_library.html b/tests/webgl/test_static_webgl_library.html deleted file mode 100644 index f9268c6..0000000 --- a/tests/webgl/test_static_webgl_library.html +++ /dev/null @@ -1,72 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - -

TVM Test Page

-
- - - - - - - - \ No newline at end of file diff --git a/tests/webgl/test_static_webgl_library.py b/tests/webgl/test_static_webgl_library.py deleted file mode 100644 index 929da4c..0000000 --- a/tests/webgl/test_static_webgl_library.py +++ /dev/null @@ -1,66 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Create a static WebGL library and run it in the browser.""" - -from __future__ import absolute_import, print_function - -import os, shutil, SimpleHTTPServer, SocketServer -import tvm -from tvm import te -from tvm.contrib import emscripten, util -import numpy as np - -def try_static_webgl_library(): - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - - # Change to lib/ which contains "libtvm_runtime.bc". - os.chdir(os.path.join(curr_path, "../../lib")) - - # Create OpenGL module. - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="float") - B = te.compute((n,), lambda *i: A[i], name="B") - - s = te.create_schedule(B.op) - s[B].opengl() - - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B], name="identity", target="opengl", - target_host=target_host) - - # Create a JS library that contains both the module and the tvm runtime. - path_dso = "identity_static.js" - f.export_library(path_dso, emscripten.create_js, options=[ - "-s", "USE_GLFW=3", - "-s", "USE_WEBGL2=1", - "-lglfw", - ]) - - # Create "tvm_runtime.js" and "identity_static.html" in lib/ - shutil.copyfile(os.path.join(curr_path, "../../web/tvm_runtime.js"), - "tvm_runtime.js") - shutil.copyfile(os.path.join(curr_path, "test_static_webgl_library.html"), - "identity_static.html") - - port = 8080 - handler = SimpleHTTPServer.SimpleHTTPRequestHandler - httpd = SocketServer.TCPServer(("", port), handler) - print("Please open http://localhost:" + str(port) + "/identity_static.html") - httpd.serve_forever() - -if __name__ == "__main__": - try_static_webgl_library() diff --git a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py index 45f0599..a7e5532 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py @@ -137,7 +137,8 @@ def test_conv2d_nhwc_winograd_direct(): def test_conv2d_nhwc_winograd_tensorcore(): """Test the conv2d with winograd for nhwc layout""" - print("test_winograd_tensorcore...") + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + return verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1, bgemm="tensorcore") verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1, bgemm="tensorcore") verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1, bgemm="tensorcore") @@ -145,8 +146,7 @@ def test_conv2d_nhwc_winograd_tensorcore(): verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True, bgemm="tensorcore") verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, bgemm="tensorcore") + if __name__ == "__main__": test_conv2d_nhwc_winograd_direct() - - if nvcc.have_tensorcore(tvm.gpu(0).compute_version): - test_conv2d_nhwc_winograd_tensorcore() + test_conv2d_nhwc_winograd_tensorcore() diff --git a/web/.eslintignore b/web/.eslintignore new file mode 100644 index 0000000..1521c8b --- /dev/null +++ b/web/.eslintignore @@ -0,0 +1 @@ +dist diff --git a/web/.gitignore b/web/.gitignore new file mode 100644 index 0000000..a3135cf --- /dev/null +++ b/web/.gitignore @@ -0,0 +1,6 @@ +.vscode +*~ +out +node_modules +package-lock.json +build diff --git a/web/.jsdoc_conf.json b/web/.jsdoc_conf.json deleted file mode 100644 index 33783b3..0000000 --- a/web/.jsdoc_conf.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "templates": { - "default": { - "includeDate": false - } - } -} diff --git a/web/Makefile b/web/Makefile new file mode 100644 index 0000000..be7fa19 --- /dev/null +++ b/web/Makefile @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_ROOT=$(shell cd ..; pwd) + +INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include + +.PHONY: clean all + +all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js + +EMCC = emcc + +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++14 -Wno-ignored-attributes \ + -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1 -s ERROR_ON_UNDEFINED_SYMBOLS=0 + +EMCC_LDFLAGS = --pre-js emcc/preload.js + +dist/wasm/%.bc: emcc/%.cc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/$*.bc $< >dist/wasm/$*.d + $(EMCC) $(EMCC_CFLAGS) -c -o dist/wasm/$*.bc $< + + +dist/wasm/tvmjs_runtime.wasm: dist/wasm/wasm_runtime.bc dist/wasm/tvmjs_support.bc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/tvmjs_runtime.js $+ $(EMCC_LDFLAGS) + + +dist/wasm/tvmjs_runtime.wasi.js: dist/wasm/tvmjs_runtime.wasm emcc/decorate_as_wasi.py + python3 emcc/decorate_as_wasi.py dist/wasm/tvmjs_runtime.js $@ + +clean: + @rm -rf dist/wasm + +-include dist/wasm/*.d diff --git a/web/README.md b/web/README.md index 5dfd691..66a64a3 100644 --- a/web/README.md +++ b/web/README.md @@ -15,163 +15,70 @@ -# TVM WebAssembly and Javascript Backend +# TVM WebAssembly Runtime -This folder contains TVM WebAssembly and Javascript backend through Emscripten. +This folder contains TVM WebAssembly Runtime. ## Installation -While the LLVM main branch support webassembly as a target. We still need a good runtime with libc and other -system library support. Emscripten toolchain offers that nicely. The general idea is to build TVM against -the fastcomp LLVM backend in the Emscripten project and allow us to generate ```asmjs-unknown-emscripten``` -as a backend target. + +The LLVM main branch support webassembly as a target, we can directly +build TVM with LLVM mainline to generate wasm modules. +Note that, however, we still need emscripten to compile the runtime and provide system library support. + +Note that so far we requires everything to be in the source and setup PYTHONPATH(instead of use setup.py install). ### Setup Emscripten -Checkout [Emscripten Portable SDK Downloads](https://kripken.github.io/emscripten-site/docs/getting_started/downloads.html) -to download emsdk-portable and unzip it on a local folder. Follow the installation guide from emscripten document. -```bash -./emsdk update -./emsdk install latest -./emsdk activate latest -``` +We use emscripten to compile our runtime wasm library as well as a WASI variant that we can deploy +to the browser environment. -Because we need to compile against the LLVM backend of emscripten, we will need the source and llvm library. -Which can be installed via following command. +Follow [Emscripten](https://emscripten.org/) to download emsdk and install emcc on your local environment. -```bash -./emsdk install clang-incoming-64bit -./emsdk activate clang-incoming-64bit -``` +### Build TVM Wasm Runtime -### Setup Environment Variable +After the emcc is setup correctly. We can build tvm's wasm runtime by typing `make` in the web folder. -In normal setting, we can setup the necessary environment variable with the following command. ```bash -source /path-to-emsdk-portable/emsdk_env.sh +make ``` -However, this will put emscripten's clang and llvm path ahead of the current system path. -What you can do is to set the path manually, by putting emscripten's path after the PATH like the following ones. -You can get the detailed path by type ```./emsdk activate``` -```bash -export PATH=${PATH}:/emsdk-related-path-here +This command will create the follow files: +- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link into. +- `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing purposes. +- `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by emscripten that can be fed into runtime. -``` -### Build TVM with Fastcomp LLVM +### Build TVM Wasm JS Frontend -To build TVM with Emscripten's Fastcomp LLVM, we can modify the LLVM_CONFIG in ```config.mk``` -to point to fastcomp's llvm-config and build TVM normally. +Type the following command in the web folder. ```bash -LLVM_CONFIG = /path/to/emsdk-portable/clang/fastcomp/build_incoming_64/bin/llvm-config +npm run bundle ``` -### Build TVM Web Runtime +This command will create the tvmjs library that we can use to interface with the wasm runtime. -The above command gives us the TVM compiling environment. Now we need to build runtime, -to do so, make sure we set the environment correctly as in previous section and type -```bash -make web -``` +## Use TVM to Generate Wasm Library and Run it -This will create ```build/libtvm_web_runtime.bc``` and ```build/libtvm_web_runtime.js```. - -## Use TVM to Generate Javascript Library - -The general idea is to use TVM as normally and set target to be ```llvm -target=asmjs-unknown-emscripten -system-lib```. - -The following code snippet from [tests/web/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/tests/web/prepare_test_libs.py) demonstrate -the compilation process. - -```python -import tvm -from tvm import te -from tvm.contrib import emscripten -import os -def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - if not tvm.runtime.enabled(target): - raise RuntimeError("Target %s is not enbaled" % target) - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') - s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path) - -if __name__ == "__main__": - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) -``` +Check code snippet in -In this workflow, we use TVM to generate a ```.bc``` file and statically link -that with the ```build/libtvm_web_runtime.bc```(emscripten.create_js will help you do that). -The result js library is a library that contains both TVM runtime and the compiled function. - - -## Run the Generated Library - -The following code snippet from [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/tests/web/test_module_load.js) demonstrate -how to run the compiled library. - -```js -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -// Load system library, the compiled function is registered in sysLib. -var sysLib = tvm.systemLib(); - -function randomArray(length, max) { - return Array.apply(null, Array(length)).map(function() { - return Math.random() * max; - }); -} - -function testAddOne() { - // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); - // call the function. - faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array - // verify - for (var i = 0; i < BB.length; ++i) { - assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); - } - faddOne.release(); -} - -testAddOne(); -sysLib.release(); -console.log("Finish verifying test_module_load"); -``` +- [tests/python/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/web/tests/pythob/prepare_test_libs.py) + shows how to create a wasm library that links with tvm runtime. + - Note that all wasm libraries have to created using the `--system-lib` option + - emcc.create_wasm will automatically link the runtime library `dist/wasm/libtvm_runtime.bc` +- [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/web/tests/node/test_module_load.js) demonstrate + how to run the generated library through tvmjs API. -Current example supports static linking, which is the preferred way to get more efficiency -in javascript backend. -## Proxy based RPC +## Run Wasm Remotely through WebSocket RPC. -We can now use javascript end to start an RPC server and connect to it from python side, +We can now use js side to start an RPC server and connect to it from python side, making the testing flow easier. -The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install) -- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -- Open broswer, goto the server webpage click Connect to proxy. - - Alternatively run "node web/example_rpc_node.js" -- run "python tests/web/websock_rpc_test.py" to run the rpc client. - -The general idea is to use Emscripten's dynamic linking to dynamically load modules. +The following is an example to reproduce this. +- run `python -m tvm.exec.rpc_proxy --example-rpc=1` to start proxy. +- Start the WebSocket RPC + - Browswer version: open https://localhost:8888, click connect to proxy + - NodeJS version: `npm run rpc` +- run `python tests/node/websock_rpc_test.py` to run the rpc client. diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html new file mode 100644 index 0000000..22907f1 --- /dev/null +++ b/web/apps/browser/rpc_server.html @@ -0,0 +1,79 @@ + + + + + + + + + + + + + + + + + + + + TVM RPC Test Page + + + + + +

TVM WebSocket RPC Server

+ To use this page + + +

Options

+ Proxy URL
+ RPC Server Key
+ + +
+ + + diff --git a/web/apps/node/example.js b/web/apps/node/example.js new file mode 100644 index 0000000..f81a9c9 --- /dev/null +++ b/web/apps/node/example.js @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +// the async version of the API. +tvmjs.instantiate(wasmSource, new EmccWASI()) +.then((tvm) => { + // List all the global functions from the runtime. + console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames()); +}); + diff --git a/web/apps/node/wasi_example.js b/web/apps/node/wasi_example.js new file mode 100644 index 0000000..95ec2e0 --- /dev/null +++ b/web/apps/node/wasi_example.js @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const { WASI } = require('wasi'); +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +const wasi = new WASI({ args: process.argv, env: process.env }); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), wasi); + +// List all the global functions from the runtime. +console.log("Runtime using WASI\n", tvm.listGlobalFuncNames()); diff --git a/web/example_rpc_node.js b/web/apps/node/wasi_rpc_server.js similarity index 60% rename from web/example_rpc_node.js rename to web/apps/node/wasi_rpc_server.js index 45f917a..eb4c6ed 100644 --- a/web/example_rpc_node.js +++ b/web/apps/node/wasi_rpc_server.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,17 +17,20 @@ * under the License. */ -// Javascript RPC server example -// Start and connect to websocket proxy. +/** + * Example code to start the RPC server on nodejs using WASI + */ +const { WASI } = require("wasi"); +const tvmjs = require("../../dist"); + +// Get import returns a fresh library in each call. +const getImports = () => { + return new WASI({ + args: process.argv, + env: process.env + }); +}; -// Load Emscripten Module, need to change path to root/lib -const path = require("path"); -process.chdir(path.join(__dirname, "../lib")); -var Module = require("../lib/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const proxyUrl = "ws://localhost:8888/ws"; -var websock_proxy = "ws://localhost:9190/ws"; -var num_sess = 100; -tvm.startRPCServer(websock_proxy, "js", num_sess) +new tvmjs.RPCServer(proxyUrl, "wasm", getImports, console.log); diff --git a/tests/webgl/test_local_multi_stage.py b/web/emcc/decorate_as_wasi.py similarity index 50% rename from tests/webgl/test_local_multi_stage.py rename to web/emcc/decorate_as_wasi.py index 54a554b..741e33b 100644 --- a/tests/webgl/test_local_multi_stage.py +++ b/web/emcc/decorate_as_wasi.py @@ -14,34 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te -import numpy as np +"""Decorate emcc generated js to a WASI compatible API.""" -def test_local_multi_stage(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return +import sys - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute((n,), lambda i: A[i] + 1, name="B") - C = te.compute((n,), lambda i: B[i] * 2, name="C") +template_head = """ +function EmccWASI() { +""" - s = te.create_schedule(C.op) - s[B].opengl() - s[C].opengl() +template_tail = """ + this.Module = Module; + this.start = Module.wasmLibraryProvider.start; + this.imports = Module.wasmLibraryProvider.imports; + this.wasiImport = this.imports["wasi_snapshot_preview1"]; +} - f = tvm.build(s, [A, C], "opengl", name="multi_stage") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) - c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx) - f(a, c) - - tvm.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2) +if (typeof module !== "undefined" && module.exports) { + module.exports = EmccWASI; +} +""" if __name__ == "__main__": - test_local_multi_stage() + if len(sys.argv) != 3: + print("Usage ") + result = template_head + open(sys.argv[1]).read() + template_tail + with open(sys.argv[2], "w") as fo: + fo.write(result) diff --git a/web/emcc/preload.js b/web/emcc/preload.js new file mode 100644 index 0000000..882280f --- /dev/null +++ b/web/emcc/preload.js @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-unused-vars */ +/** + * JS config used by --pre-js in emcc. + * Wrap module as a LibraryProvider. + */ + +var __wasmLib = {}; + +function __wasmLibInstantiateWasm(imports, successCallback) { + __wasmLib.imports = imports; + __wasmLib.successCallback = successCallback; +} + +function __wasmLibStart(wasmInstance) { + __wasmLib.successCallback(wasmInstance); +} + +__wasmLib.start = __wasmLibStart; + +var Module = { + "instantiateWasm": __wasmLibInstantiateWasm, + "wasmLibraryProvider": __wasmLib +}; diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc new file mode 100644 index 0000000..97099e7 --- /dev/null +++ b/web/emcc/tvmjs_support.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvmjs_support.cc + * \brief Support functions to be linked with wasm_runtime to provide + * PackedFunc callbacks in tvmjs. + * We do not need to link this file in standalone wasm. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + + +#include +#include +#include +#include +#include + +extern "C" { +// --- Additional C API for the Wasm runtime --- +/*! + * \brief Allocate space aligned to 64 bit. + * \param size The size of the space. + * \return The allocated space. + */ +TVM_DLL void* TVMWasmAllocSpace(int size); + +/*! + * \brief Free the space allocated by TVMWasmAllocSpace. + * \param data The data pointer. + */ +TVM_DLL void TVMWasmFreeSpace(void* data); + +/*! + * \brief Create PackedFunc from a resource handle. + * \param resource_handle The handle to the resource. + * \param out The output PackedFunc. + * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer +3A * \return 0 if success. + */ +TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, + TVMFunctionHandle *out); + +// --- APIs to be implemented by the frontend. --- +/*! + * \brief Wasm frontend packed function caller. + * + * \param args The arguments + * \param type_codes The type codes of the arguments + * \param num_args Number of arguments. + * \param ret The return value handle. + * \param resource_handle The handle additional resouce handle from fron-end. + * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. + */ +extern int TVMWasmPackedCFunc(TVMValue* args, + int* type_codes, + int num_args, + TVMRetValueHandle ret, + void* resource_handle); + +/*! + * \brief Wasm frontend resource finalizer. + * \param resource_handle The pointer to the external resource. + */ +extern void TVMWasmPackedCFuncFinalizer(void* resource_handle); +} // extern "C" + + +void* TVMWasmAllocSpace(int size) { + int num_count = (size + 7) / 8; + return new int64_t[num_count]; +} + +void TVMWasmFreeSpace(void* arr) { + delete[] static_cast(arr); +} + +int TVMWasmFuncCreateFromCFunc(void* resource_handle, + TVMFunctionHandle *out) { + return TVMFuncCreateFromCFunc( + TVMWasmPackedCFunc, resource_handle, + TVMWasmPackedCFuncFinalizer, out); +} + + +namespace tvm { +namespace runtime { + +// chrono in the WASI does not provide very accurate time support +// and also have problems in the i64 support in browser. +// We redirect the timer to a JS side time using performance.now +PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, + TVMContext ctx, + int number, + int repeat, + int min_repeat_ms) { + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms]( + TVMArgs args, TVMRetValue *rv) { + + TVMRetValue temp; + auto finvoke = [&](int n) { + // start timing + for (int i = 0; i < n; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + }; + + auto* get_timer = runtime::Registry::Get("wasm.GetTimer"); + CHECK(get_timer != nullptr) << "Cannot find wasm.GetTimer in the global function"; + TypedPackedFunc timer_ms = (*get_timer)( + TypedPackedFunc(finvoke)); + + std::ostringstream os; + finvoke(1); + + int setup_number = number; + + for (int i = 0; i < repeat; ++i) { + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + setup_number = static_cast( + std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + duration_ms = timer_ms(setup_number); + } while (duration_ms < min_repeat_ms); + + double speed = duration_ms / setup_number / 1000; + os.write(reinterpret_cast(&speed), sizeof(speed)); + } + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + +TVM_REGISTER_GLOBAL("wasm.RPCTimeEvaluator") +.set_body_typed([](Optional opt_mod, + std::string name, + int device_type, + int device_id, + int number, + int repeat, + int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + return WrapWasmTimeEvaluator( + m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapWasmTimeEvaluator( + *pf, ctx, number, repeat, min_repeat_ms); + } +}); + +} // namespace runtime +} // namespace tvm diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc new file mode 100644 index 0000000..6ff652c --- /dev/null +++ b/web/emcc/wasm_runtime.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file wasm_runtime.cc + * \brief TVM wasm runtime library pack. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include + +#include "src/runtime/c_runtime_api.cc" +#include "src/runtime/cpu_device_api.cc" +#include "src/runtime/workspace_pool.cc" +#include "src/runtime/library_module.cc" +#include "src/runtime/system_library.cc" + +#include "src/runtime/module.cc" +#include "src/runtime/ndarray.cc" +#include "src/runtime/object.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/file_util.cc" +#include "src/runtime/graph/graph_runtime.cc" +#include "src/runtime/rpc/rpc_session.cc" +#include "src/runtime/rpc/rpc_endpoint.cc" +#include "src/runtime/rpc/rpc_event_impl.cc" +#include "src/runtime/rpc/rpc_channel.cc" +#include "src/runtime/rpc/rpc_local_session.cc" +#include "src/runtime/rpc/rpc_module.cc" + + +// --- Implementations of backend and wasm runtime API. --- + +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, + void* cdata, + int num_task) { + TVMParallelGroupEnv env; + env.num_task = 1; + flambda(0, &env, cdata); + return 0; +} + +int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { + return 0; +} + +// --- Environment PackedFuncs for testing --- +namespace tvm { +namespace runtime { + +TVM_REGISTER_GLOBAL("testing.echo") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = args[0]; +}); + +TVM_REGISTER_GLOBAL("testing.add_one") +.set_body_typed([](int x) { + return x + 1; +}); + +TVM_REGISTER_GLOBAL("testing.wrap_callback") +.set_body([](TVMArgs args, TVMRetValue *ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf](){ + pf(); + }); + }); +} // namespace runtime +} // namespace tvm diff --git a/web/example_rpc.html b/web/example_rpc.html deleted file mode 100644 index ae2b1dd..0000000 --- a/web/example_rpc.html +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - - -

TVM Test Page

- To use this page, the easiest way is to do -
    -
  • run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -
  • Click Connect to proxy. -
  • run "python tests/web/websock_rpc_test.py" to run the rpc client. -
-

Options

- Proxy URL
- RPC Server Key
- - -
- - - - diff --git a/web/package.json b/web/package.json new file mode 100644 index 0000000..76aa111 --- /dev/null +++ b/web/package.json @@ -0,0 +1,29 @@ +{ + "name": "tvmjs", + "displayName": "TVM Wasm JS runtime", + "license": "Apache-2.0", + "version": "0.7.0", + "scripts": { + "build": "tsc -b", + "watch": "tsc -b -w", + "lint": "eslint -c .eslintrc.json .", + "bundle": "npm run build && rollup -c rollup.config.js", + "example": "npm run bundle && node apps/node/example.js", + "example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js", + "rpc": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_rpc_server.js" + }, + "devDependencies": { + "typescript": "^3.8.3", + "@types/node": "^12.12.37", + "eslint": "^6.8.0", + "@typescript-eslint/eslint-plugin": "^2.29.0", + "@typescript-eslint/parser": "^2.29.0", + "typedoc": "^0.17.6", + "rollup": "^2.7.6", + "ws": "^7.2.5", + "@rollup/plugin-commonjs": "^11.1.0", + "@rollup/plugin-node-resolve": "^7.1.3", + "rollup-plugin-typescript2": "^0.27.0" + }, + "dependencies": {} +} diff --git a/web/.eslintrc.js b/web/rollup.config.js similarity index 69% rename from web/.eslintrc.js rename to web/rollup.config.js index 2e82ba5..0046e44 100644 --- a/web/.eslintrc.js +++ b/web/rollup.config.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,29 +17,18 @@ * under the License. */ -module.exports = { - "env": { - "browser": true, - "node": true, - "es6": true +import commonjs from '@rollup/plugin-commonjs'; +import resolve from '@rollup/plugin-node-resolve'; + +export default { + input: 'dist/index.js', + output: { + file: 'dist/tvmjs.bundle.js', + format: 'umd', + name: 'tvmjs', + exports: 'named', + globals: {'ws': 'ws'} }, - "extends": "eslint:recommended", - "rules": { - "indent": [ - "error", - 2 - ], - "linebreak-style": [ - "error", - "unix" - ], - "quotes": [ - "error", - "double" - ], - "semi": [ - "error", - "always" - ] - } + plugins: [commonjs(), resolve()], + external: ['ws'] }; diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts new file mode 100644 index 0000000..f533b4e --- /dev/null +++ b/web/src/ctypes.ts @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Types for C API. + */ + +/** A pointer to points to the raw address space. */ +export type Pointer = number; + +/** A pointer offset, need to add a base address to get a valid ptr. */ +export type PtrOffset = number; + +// -- TVM runtime C API -- +/** + * const char *TVMGetLastError(); + */ +export type FTVMGetLastError = () => Pointer; + +/** + * int TVMModGetFunction(TVMModuleHandle mod, + * const char* func_name, + * int query_imports, + * TVMFunctionHandle *out); + */ +export type FTVMModGetFunction = ( + mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; +/** + * int TVMModImport(TVMModuleHandle mod, + * TVMModuleHandle dep); + */ +export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; +/** + * int TVMModFree(TVMModuleHandle mod); + */ +export type FTVMModFree = (mod: Pointer) => number; + +/** + * int TVMFuncFree(TVMFunctionHandle func); + */ +export type FTVMFuncFree = (func: Pointer) => number; + +/** + * int TVMFuncCall(TVMFunctionHandle func, + * TVMValue* arg_values, + * int* type_codes, + * int num_args, + * TVMValue* ret_val, + * int* ret_type_code); + */ +export type FTVMFuncCall = ( + func: Pointer, argValues: Pointer, typeCode: Pointer, + nargs: number, retValue: Pointer, retCode: Pointer) => number; + +/** + * int TVMCFuncSetReturn(TVMRetValueHandle ret, + * TVMValue* value, + * int* type_code, + * int num_ret); + */ +export type FTVMCFuncSetReturn = ( + ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; + +/** + * int TVMCbArgToReturn(TVMValue* value, int* code); + */ +export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; + +/** + * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + */ +export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; + +/** + * int TVMFuncRegisterGlobal( + * const char* name, TVMFunctionHandle f, int override); + */ +export type FTVMFuncRegisterGlobal = ( + name: Pointer, f: Pointer, override: number) => number; + +/** + *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + */ +export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; + +/** + * int TVMArrayAlloc(const tvm_index_t* shape, + * int ndim, + * int dtype_code, + * int dtype_bits, + * int dtype_lanes, + * int device_type, + * int device_id, + * TVMArrayHandle* out); + */ +export type FTVMArrayAlloc = ( + shape: Pointer, ndim: number, + dtypeCode: number, dtypeBits: number, + dtypeLanes: number, deviceType: number, deviceId: number, + out: Pointer) => number; + +/** + * int TVMArrayFree(TVMArrayHandle handle); + */ +export type FTVMArrayFree = (handle: Pointer) => number; + +/** + * int TVMArrayCopyFromBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyFromBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyToBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyToBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyFromTo(TVMArrayHandle from, + * TVMArrayHandle to, + * TVMStreamHandle stream); + */ +export type FTVMArrayCopyFromTo = ( + from: Pointer, to: Pointer, stream: Pointer) => number; + +/** + * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + */ +export type FTVMSynchronize = ( + deviceType: number, deviceId: number, stream: Pointer) => number; + +/** + * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMValue* out_ret_value, + * int* out_ret_tcode); + */ +export type FTVMBackendPackedCFunc = ( + argValues: Pointer, argCodes: Pointer, nargs: number, + outValue: Pointer, outCode: Pointer) => number; + +// -- TVM Wasm Auxiliary C API -- + +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; + +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; + +/** + * int TVMWasmPackedCFunc(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMRetValueHandle ret, + * void* resource_handle); + */ +export type FTVMWasmPackedCFunc = ( + args: Pointer, typeCodes: Pointer, nargs: number, + ret: Pointer, resourceHandle: Pointer) => number; + +/** + * int TVMWasmFuncCreateFromCFunc(void* resource_handle, + * TVMFunctionHandle *out); + */ +export type FTVMWasmFuncCreateFromCFunc = ( + resource: Pointer, out: Pointer) => number; + +/** + * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + */ +export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; + +/** + * Size of common data types. + */ +export const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + DLDataType = I32, + DLContext = I32 + I32, +} + +/** + * Type code in TVM FFI. + */ +export const enum TypeCode { + Int = 0, + UInt = 1, + Float = 2, + TVMOpaqueHandle = 3, + Null = 4, + TVMDataType = 5, + TVMContext = 6, + TVMDLTensorHandle = 7, + TVMObjectHandle = 8, + TVMModuleHandle = 9, + TVMPackedFuncHandle = 10, + TVMStr = 11, + TVMBytes = 12, + TVMNDArrayHandle = 13, + TVMObjectRValueRefArg = 14 +} \ No newline at end of file diff --git a/web/src/environment.ts b/web/src/environment.ts new file mode 100644 index 0000000..df0fe68 --- /dev/null +++ b/web/src/environment.ts @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Runtime environment that provide js libaries calls. + */ +import { Pointer } from "./ctypes"; +import { LibraryProvider } from "./types"; +import { assert } from "./support"; +import * as ctypes from "./ctypes"; + +/** + * Detect library provider from the importObject. + * + * @param importObject The import object. + */ +function detectLibraryProvider( + importObject: Record +): LibraryProvider | undefined { + if ( + importObject["wasmLibraryProvider"] && + importObject["wasmLibraryProvider"]["start"] && + importObject["wasmLibraryProvider"]["imports"] !== undefined + ) { + const item = importObject as { wasmLibraryProvider: LibraryProvider }; + // create provider so that we capture imports in the provider. + return { + imports: item.wasmLibraryProvider.imports, + start: (inst: WebAssembly.Instance): void => { + item.wasmLibraryProvider.start(inst); + }, + }; + } else if (importObject["imports"] && importObject["start"] !== undefined) { + return importObject as LibraryProvider; + } else if (importObject["wasiImport"] && importObject["start"] !== undefined) { + // WASI + return { + imports: { + "wasi_snapshot_preview1": importObject["wasiImport"], + }, + start: (inst: WebAssembly.Instance): void => { + importObject["start"](inst); + } + }; + } else { + return undefined; + } +} + +/** + * Environment to impelement most of the JS library functions. + */ +export class Environment implements LibraryProvider { + logger: (msg: string) => void; + imports: Record; + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + packedCFuncTable: Array = [ + undefined, + ]; + /** + * Free table index that can be recycled. + */ + packedCFuncTableFreeId: Array = []; + + private libProvider?: LibraryProvider; + + constructor( + importObject: Record = {}, + logger: (msg: string) => void = console.log + ) { + this.logger = logger; + this.libProvider = detectLibraryProvider(importObject); + // get imports from the provider + if (this.libProvider !== undefined) { + this.imports = this.libProvider.imports; + } else { + this.imports = importObject; + } + // update with more functions + this.imports.env = this.environment(this.imports.env); + } + + /** Mark the start of the instance. */ + start(inst: WebAssembly.Instance): void { + if (this.libProvider !== undefined) { + this.libProvider.start(inst); + } + } + + private environment(initEnv: Record): Record { + // default env can be be overriden by libraries. + const defaultEnv = { + "__cxa_thread_atexit": (): void => {}, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + "emscripten_notify_memory_growth": (index: number): void => {} + }; + const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = ( + args: Pointer, + typeCodes: Pointer, + nargs: number, + ret: Pointer, + resourceHandle: Pointer + ): number => { + const cfunc = this.packedCFuncTable[resourceHandle]; + assert(cfunc !== undefined); + return cfunc(args, typeCodes, nargs, ret, resourceHandle); + }; + + const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = ( + resourceHandle: Pointer + ): void => { + this.packedCFuncTable[resourceHandle] = undefined; + this.packedCFuncTableFreeId.push(resourceHandle); + }; + + const newEnv = { + TVMWasmPackedCFunc: wasmPackedCFunc, + TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "__console_log": (msg: string): void => { + this.logger(msg); + } + }; + return Object.assign(defaultEnv, initEnv, newEnv); + } +} \ No newline at end of file diff --git a/web/src/index.ts b/web/src/index.ts new file mode 100644 index 0000000..5d7d7cc --- /dev/null +++ b/web/src/index.ts @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { + Scalar, DLContext, DLDataType, + PackedFunc, Module, NDArray, Instance, + instantiate +} from "./runtime"; +export { Disposable, LibraryProvider } from "./types"; +export { RPCServer } from "./rpc_server"; +export { wasmPath } from "./support"; \ No newline at end of file diff --git a/web/src/memory.ts b/web/src/memory.ts new file mode 100644 index 0000000..ac737b7 --- /dev/null +++ b/web/src/memory.ts @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Classes to manipulate Wasm memories. + */ +import { Pointer, PtrOffset, SizeOf } from "./ctypes"; +import { Disposable } from "./types"; +import { assert, StringToUint8Array } from "./support"; + +import * as ctypes from "./ctypes"; + +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +export class Memory { + memory: WebAssembly.Memory; + wasm32 = true; + private buffer: ArrayBuffer | SharedArrayBuffer; + private viewU8: Uint8Array; + private viewU16: Uint16Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF32: Float32Array; + private viewF64: Float64Array; + + constructor(memory: WebAssembly.Memory) { + this.memory = memory; + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } + + loadU8(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU8[ptr >> 0]; + } + + loadU16(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU16[ptr >> 1]; + } + + loadU32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU32[ptr >> 2]; + } + + loadI32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewI32[ptr >> 2]; + } + + loadI64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const base = ptr >> 2; + // assumes little endian, for now truncate high. + return this.viewI32[base]; + } + + loadF32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF32[ptr >> 2]; + } + + loadF64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF64[ptr >> 3]; + } + + loadPointer(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + loadUSize(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + sizeofPtr(): number { + return this.wasm32 ? SizeOf.I32 : SizeOf.I64; + } + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const result = new Uint8Array(numBytes); + result.set(this.viewU8.slice(ptr, ptr + numBytes)); + return result; + } + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + loadTVMBytes(ptr: Pointer): Uint8Array { + const data = this.loadPointer(ptr); + const length = this.loadUSize(ptr + this.sizeofPtr()); + return this.loadRawBytes(data, length); + } + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + loadCString(ptr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + // NOTE: the views are still valid for read. + const ret = []; + let ch = 1; + while (ch != 0) { + ch = this.viewU8[ptr]; + if (ch != 0) { + ret.push(String.fromCharCode(ch)); + } + ++ptr; + } + return ret.join(""); + } + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + storeRawBytes(ptr: Pointer, bytes: Uint8Array): void { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + this.viewU8.set(bytes, ptr); + } + + /** + * Update memory view after the memory growth. + */ + private updateViews(): void { + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} + +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +export class CachedCallStack implements Disposable { + /** List of temporay arguments that can be disposed during reset. */ + tempArgs: Array = []; + + private memory: Memory; + private cAllocSpace: ctypes.FTVMWasmAllocSpace; + private cFreeSpace: ctypes.FTVMWasmFreeSpace; + + private buffer: ArrayBuffer; + private viewU8: Uint8Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF64: Float64Array; + + private stackTop: PtrOffset = 0; + private basePtr: Pointer = 0; + + private addressToSetTargetValue: Array<[PtrOffset, PtrOffset]> = []; + + constructor( + memory: Memory, + allocSpace: ctypes.FTVMWasmAllocSpace, + freeSpace: ctypes.FTVMWasmFreeSpace + ) { + const initCallStackSize = 128; + this.memory = memory; + this.cAllocSpace = allocSpace; + this.cFreeSpace = freeSpace; + this.buffer = new ArrayBuffer(initCallStackSize); + this.basePtr = this.cAllocSpace(initCallStackSize); + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + this.updateViews(); + } + + dispose(): void { + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + this.basePtr = 0; + } + } + /** + * Rest the call stack so that it can be reused again. + */ + reset(): void { + this.stackTop = 0; + assert(this.addressToSetTargetValue.length == 0); + while (this.tempArgs.length != 0) { + (this.tempArgs.pop() as Disposable).dispose(); + } + } + + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + commitToWasmMemory(nbytes: number = this.stackTop): void { + // commit all pointer values. + while (this.addressToSetTargetValue.length != 0) { + const [targetOffset, valueOffset] = this.addressToSetTargetValue.pop() as [ + number, + number + ]; + this.storePtr(targetOffset, this.ptrFromOffset(valueOffset)); + } + this.memory.storeRawBytes(this.basePtr, this.viewU8.slice(0, nbytes)); + } + + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + allocRawBytes(nbytes: number): PtrOffset { + // always aligns to 64bit + nbytes = ((nbytes + 7) >> 3) << 3; + + if (this.stackTop + nbytes > this.buffer.byteLength) { + const newSize = Math.max( + this.buffer.byteLength * 2, + this.stackTop + nbytes + ); + const oldU8 = this.viewU8; + this.buffer = new ArrayBuffer(newSize); + this.updateViews(); + this.viewU8.set(oldU8); + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + } + this.basePtr = this.cAllocSpace(newSize); + } + const retOffset = this.stackTop; + this.stackTop += nbytes; + return retOffset; + } + + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + allocPtrArray(count: number): PtrOffset { + return this.allocRawBytes(this.memory.sizeofPtr() * count); + } + + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + ptrFromOffset(offset: PtrOffset): Pointer { + return this.basePtr + offset; + } + + // Store APIs + storePtr(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeUSize(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeI32(offset: PtrOffset, value: number): void { + this.viewI32[offset >> 2] = value; + } + + storeU32(offset: PtrOffset, value: number): void { + this.viewU32[offset >> 2] = value; + } + + storeI64(offset: PtrOffset, value: number): void { + // For now, just store as 32bit + // NOTE: wasm always uses little endian. + const low = value & 0xffffffff; + const base = offset >> 2; + this.viewI32[base] = low; + this.viewI32[base + 1] = 0; + } + + storeF64(offset: PtrOffset, value: number): void { + this.viewF64[offset >> 3] = value; + } + + storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void { + this.viewU8.set(bytes, offset); + } + + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgString(offset: PtrOffset, data: string): void { + const strOffset = this.allocRawBytes(data.length + 1); + this.storeRawBytes(strOffset, StringToUint8Array(data)); + this.addressToSetTargetValue.push([offset, strOffset]); + } + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void { + // Note: size of size_t equals sizeof ptr. + const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + const dataOffset = this.allocRawBytes(data.length); + this.storeRawBytes(dataOffset, data); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + + this.addressToSetTargetValue.push([offset, headerOffset]); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + } + + /** + * Update internal cache views. + */ + private updateViews(): void { + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts new file mode 100644 index 0000000..054a1b6 --- /dev/null +++ b/web/src/rpc_server.ts @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { SizeOf, TypeCode } from "./ctypes"; +import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; +import * as runtime from "./runtime"; +import { Class } from "estree"; + +enum RPCServerState { + InitHeader, + InitHeaderKey, + InitServer, + WaitForCallback, + ReceivePacketHeader, + ReceivePacketBody, +} + +/** RPC magic header */ +const RPC_MAGIC = 0xff271; + +/** + * An utility class to read from binary bytes. + */ +class ByteStreamReader { + offset = 0; + bytes: Uint8Array; + + constructor(bytes: Uint8Array) { + this.bytes = bytes; + } + + readU32(): number { + const i = this.offset; + const b = this.bytes; + const val = b[i] | (b[i + 1] << 8) | (b[i + 2] << 16) | (b[i + 3] << 24); + this.offset += 4; + return val; + } + + readU64(): number { + const val = this.readU32(); + this.offset += 4; + return val; + } + + readByteArray(): Uint8Array { + const len = this.readU64(); + assert(this.offset + len <= this.bytes.byteLength); + const ret = new Uint8Array(len); + ret.set(this.bytes.slice(this.offset, this.offset + len)); + this.offset += len; + return ret; + } +} + +/** + * A websocket based RPC + */ +export class RPCServer { + url: string; + key: string; + socket: WebSocket; + state: RPCServerState = RPCServerState.InitHeader; + logger: (msg: string) => void; + getImports: () => Record; + private name: string; + private inst?: runtime.Instance = undefined; + private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void; + private currPacketHeader?: Uint8Array; + private currPacketLength = 0; + private remoteKeyLength = 0; + private pendingBytes = 0; + private buffredBytes = 0; + private messageQueue: Array = []; + + constructor( + url: string, + key: string, + getImports: () => Record, + logger: (msg: string) => void = console.log + ) { + this.url = url; + this.key = key; + this.name = "WebSocketRPCServer[" + this.key + "]: "; + this.getImports = getImports; + this.logger = logger; + + this.checkLittleEndian(); + + if (typeof WebSocket == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const WebSocket = require("ws"); + this.socket = new WebSocket(url); + } else { + this.socket = new (WebSocket as any)(url); + } + + //this.socket = this.getSocket(url); + this.socket.binaryType = "arraybuffer"; + + this.socket.addEventListener("open", (event: Event) => { + return this.onOpen(event); + }); + this.socket.addEventListener("message", (event: MessageEvent) => { + return this.onMessage(event); + }); + this.socket.addEventListener("close", (event: CloseEvent) => { + return this.onClose(event); + }); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onClose(_event: CloseEvent): void { + if (this.inst !== undefined) { + this.inst.dispose(); + } + if (this.state == RPCServerState.ReceivePacketHeader) { + this.log("Closing the server in clean state"); + } else { + this.log("Closing the server, final state=" + this.state); + } + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onOpen(_event: Event): void { + // Send the headers + let bkey = StringToUint8Array("server:" + this.key); + bkey = bkey.slice(0, bkey.length - 1); + const intbuf = new Int32Array(1); + intbuf[0] = RPC_MAGIC; + this.socket.send(intbuf); + intbuf[0] = bkey.length; + this.socket.send(intbuf); + this.socket.send(bkey); + this.log("connected..."); + // request bytes: magic + keylen + this.requestBytes(SizeOf.I32 + SizeOf.I32); + this.state = RPCServerState.InitHeader; + } + + /** Handler for raw message. */ + private onMessage(event: MessageEvent): void { + const buffer = event.data; + this.buffredBytes += buffer.byteLength; + this.messageQueue.push(new Uint8Array(buffer)); + this.processEvents(); + } + /** Process ready events. */ + private processEvents(): void { + while (this.buffredBytes >= this.pendingBytes && this.pendingBytes != 0) { + this.onDataReady(); + } + } + /** State machine to handle each request */ + private onDataReady(): void { + switch (this.state) { + case RPCServerState.InitHeader: { + this.handleInitHeader(); + break; + } + case RPCServerState.InitHeaderKey: { + this.handleInitHeaderKey(); + break; + } + case RPCServerState.ReceivePacketHeader: { + this.currPacketHeader = this.readFromBuffer(SizeOf.I64); + const reader = new ByteStreamReader(this.currPacketHeader); + this.currPacketLength = reader.readU64(); + assert(this.pendingBytes == 0); + this.requestBytes(this.currPacketLength); + this.state = RPCServerState.ReceivePacketBody; + break; + } + case RPCServerState.ReceivePacketBody: { + const body = this.readFromBuffer(this.currPacketLength); + assert(this.pendingBytes == 0); + assert(this.currPacketHeader !== undefined); + this.onPacketReady(this.currPacketHeader, body); + break; + } + case RPCServerState.WaitForCallback: { + assert(this.pendingBytes == 0); + break; + } + default: { + throw new Error("Cannot handle state " + this.state); + } + } + } + + private onPacketReady(header: Uint8Array, body: Uint8Array): void { + if (this.inst === undefined) { + // initialize server. + const reader = new ByteStreamReader(body); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const code = reader.readU32(); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const ver = Uint8ArrayToString(reader.readByteArray()); + const nargs = reader.readU32(); + const tcodes = []; + const args = []; + for (let i = 0; i < nargs; ++i) { + tcodes.push(reader.readU32()); + } + + for (let i = 0; i < nargs; ++i) { + const tcode = tcodes[i]; + if (tcode == TypeCode.TVMStr) { + const str = Uint8ArrayToString(reader.readByteArray()); + args.push(str); + } else if (tcode == TypeCode.TVMBytes) { + args.push(reader.readByteArray()); + } else { + throw new Error("cannot support type code " + tcode); + } + } + this.onInitServer(args, header, body); + } else { + assert(this.serverRecvData !== undefined); + this.serverRecvData(header, body); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + } + + /** Event handler during server initialization. */ + private onInitServer( + args: Array, + header: Uint8Array, + body: Uint8Array + ): void { + // start the server + assert(args[0] == "rpc.WasmSession"); + assert(args[1] instanceof Uint8Array); + assert(this.pendingBytes == 0); + + runtime.instantiate(args[1].buffer, this.getImports()) + .then((inst: runtime.Instance) => { + this.inst = inst; + const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); + + const messageHandler = fcreate( + (cbytes: Uint8Array): runtime.Scalar => { + assert(this.inst !== undefined); + if (this.socket.readyState == 1) { + this.socket.send(cbytes); + return this.inst.scalar(cbytes.length, "int32"); + } else { + return this.inst.scalar(0, "int32"); + } + }, + this.name, + this.key + ); + + fcreate.dispose(); + const writeFlag = this.inst.scalar(3, "int32"); + + this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => { + if (messageHandler(header, writeFlag) == 0) { + this.socket.close(); + } + if (messageHandler(body, writeFlag) == 0) { + this.socket.close(); + } + }; + + // Forward the same init sequence to the wasm RPC. + // The RPC will look for "rpc.wasmSession" + // and we will redirect it to the correct local session. + // register the callback to redirect the session to local. + const flocal = this.inst.getGlobalFunc("rpc.LocalSession"); + const localSession = flocal(); + flocal.dispose(); + assert(localSession instanceof runtime.Module); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + this.inst.registerFunc( + "rpc.WasmSession", + // eslint-disable-next-line @typescript-eslint/no-unused-vars + (_args: unknown): runtime.Module => { + return localSession; + } + ); + messageHandler(header, writeFlag); + messageHandler(body, writeFlag); + localSession.dispose(); + + this.log("Finish initializing the Wasm Server.."); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + // call process events in case there are bufferred data. + this.processEvents(); + }); + this.state = RPCServerState.WaitForCallback; + } + + private log(msg: string): void { + this.logger(this.name + msg); + } + + private handleInitHeader(): void { + const reader = new ByteStreamReader(this.readFromBuffer(SizeOf.I32 * 2)); + const magic = reader.readU32(); + if (magic == RPC_MAGIC + 1) { + throw new Error("key: " + this.key + " has already been used in proxy"); + } else if (magic == RPC_MAGIC + 2) { + throw new Error("RPCProxy do not have matching client key " + this.key); + } + assert(magic == RPC_MAGIC, this.url + " is not an RPC Proxy"); + this.remoteKeyLength = reader.readU32(); + assert(this.pendingBytes == 0); + this.requestBytes(this.remoteKeyLength); + this.state = RPCServerState.InitHeaderKey; + } + + private handleInitHeaderKey(): void { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const remoteKey = Uint8ArrayToString( + this.readFromBuffer(this.remoteKeyLength) + ); + assert(this.pendingBytes == 0); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + + private checkLittleEndian(): void { + const a = new ArrayBuffer(4); + const b = new Uint8Array(a); + const c = new Uint32Array(a); + b[0] = 0x11; + b[1] = 0x22; + b[2] = 0x33; + b[3] = 0x44; + assert(c[0] === 0x44332211, "RPCServer little endian to work"); + } + + private requestBytes(nbytes: number): void { + this.pendingBytes += nbytes; + } + + private readFromBuffer(nbytes: number): Uint8Array { + const ret = new Uint8Array(nbytes); + let ptr = 0; + while (ptr < nbytes) { + assert(this.messageQueue.length != 0); + const nleft = nbytes - ptr; + if (this.messageQueue[0].byteLength <= nleft) { + const buffer = this.messageQueue.shift() as Uint8Array; + ret.set(buffer, ptr); + ptr += buffer.byteLength; + } else { + const buffer = this.messageQueue[0]; + ret.set(buffer.slice(0, nleft), ptr); + this.messageQueue[0] = buffer.slice(nleft, buffer.byteLength); + ptr += nleft; + } + } + this.buffredBytes -= nbytes; + this.pendingBytes -= nbytes; + return ret; + } +} diff --git a/web/src/runtime.ts b/web/src/runtime.ts new file mode 100644 index 0000000..cd9b967 --- /dev/null +++ b/web/src/runtime.ts @@ -0,0 +1,1113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * TVM JS Wasm Runtime library. + */ +import { Pointer, PtrOffset, SizeOf, TypeCode } from "./ctypes"; +import { Disposable } from "./types"; +import { Memory, CachedCallStack } from "./memory"; +import { assert, StringToUint8Array } from "./support"; +import { Environment } from "./environment"; + +import * as ctypes from "./ctypes"; + +/** + * Type for PackedFunc inthe TVMRuntime. + */ +export type PackedFunc = ((...args: any) => any) & + Disposable & { _tvmPackedCell: PackedFuncCell }; + +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +class FFILibrary implements Disposable { + wasm32: boolean; + memory: Memory; + exports: Record; + private wasmInstance: WebAssembly.Instance; + + private recycledCallStacks: Array = []; + + constructor( + wasmInstance: WebAssembly.Instance, + imports: Record + ) { + this.wasmInstance = wasmInstance; + this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); + assert( + this.wasmInstance.exports !== undefined, + "Expect the library module contains exports" + ); + this.exports = this.wasmInstance.exports as Record; + this.wasm32 = this.memory.wasm32; + this.validateInstance(); + } + + dispose(): void { + while (this.recycledCallStacks.length != 0) { + (this.recycledCallStacks.pop() as Disposable).dispose(); + } + } + + sizeofPtr(): number { + return this.memory.sizeofPtr(); + } + + checkCall(code: number): void { + if (code != 0) { + const msgPtr = (this.exports + .TVMGetLastError as ctypes.FTVMGetLastError)(); + throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + } + } + + getOrAllocCallStack(): CachedCallStack { + if (this.recycledCallStacks.length != 0) { + return this.recycledCallStacks.pop() as CachedCallStack; + } + return new CachedCallStack( + this.memory, + this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace, + this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace + ); + } + + recycleCallStack(callstack: CachedCallStack): void { + callstack.reset(); + this.recycledCallStacks.push(callstack); + } + + private validateInstance(): void { + this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]); + } + + private checkExports(funcNames: Array): void { + const missList = []; + for (const name of funcNames) { + const f = this.exports[name]; + if (!(f instanceof Function)) { + missList.push(name); + } + } + if (missList.length != 0) { + throw new Error("Cannot find " + missList + " in exports"); + } + } + + private detectWasmMemory( + instance: WebAssembly.Instance, + imports: Record + ): WebAssembly.Memory { + if (instance.exports.memory instanceof WebAssembly.Memory) { + return instance.exports.memory; + } + if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { + return imports.env.memory; + } + + throw new Error( + "Cannt detect wasm memory from imports " + + imports + + " or exports" + + instance.exports + ); + } +} + +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +export class Scalar { + /** The value. */ + value: number; + /** The data type of the scalar. */ + dtype: string; + + constructor(value: number, dtype: string) { + this.value = value; + this.dtype = dtype; + } +} + +/** + * Cell holds the PackedFunc object. + */ +class PackedFuncCell implements Disposable { + handle: Pointer; + private lib: FFILibrary; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) + ); + this.handle = 0; + } + } +} + +const DeviceEnumToStr: Record = { + 1: "cpu", + 2: "gpu", + 4: "opencl", + 7: "vulkan", + 8: "metal", +}; + +const DeviceStrToEnum: Record = { + cpu: 1, + gpu: 2, + cuda: 2, + cl: 4, + opencl: 4, + vulkan: 7, + metal: 8, +}; + +/** + * Represent a runtime context where a NDArray can reside. + */ +export class DLContext { + /** The device type code of the context. */ + deviceType: number; + /** The device index. */ + deviceId: number; + + private lib: FFILibrary; + + constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) { + const tp = typeof deviceType; + if (tp == "string") { + this.deviceType = DeviceStrToEnum[deviceType]; + } else if (tp == "number") { + this.deviceType = deviceType as number; + } else { + throw new Error("Cannot take type " + tp + " as deviceType"); + } + this.deviceId = deviceId; + this.lib = lib; + } + + /** + * Synchronize the context + */ + sync(): void { + this.lib.checkCall( + (this.lib.exports.TVMSynchronize as ctypes.FTVMSynchronize)( + this.deviceType, + this.deviceId, + 0 + ) + ); + } + + toString(): string { + return ( + DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" + ); + } +} + +const DLDataTypeCodeToStr: Record = { + 0: "int", + 1: "uint", + 2: "float", + 4: "handle", +}; + +/** + * Runtime data type of NDArray. + */ +export class DLDataType { + /** The type code */ + code: number; + /** Number of bits in the data type. */ + bits: number; + /** Number of vector lanes. */ + lanes: number; + + constructor(code: number, bits: number, lanes: number) { + this.code = code; + this.bits = bits; + this.lanes = lanes; + } + + toString(): string { + const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); + if (this.lanes != 1) { + return ret + "x" + this.lanes.toString(); + } else { + return ret; + } + } + + numStorageBytes(): number { + return (this.bits * this.lanes + 7) >> 3; + } +} + +/** + * n-dimnesional array. + */ +export class NDArray implements Disposable { + /** Internal array handle. */ + handle: Pointer; + /** Number of dimensions. */ + ndim: number; + /** Data type of the array. */ + dtype: string; + /** Shape of the array. */ + shape: Array; + /** Context of the array. */ + context: DLContext; + + private byteOffset: number; + private dltensor: Pointer; + private lib: FFILibrary; + private dlDataType: DLDataType; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + + this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + // constant offsets. + const arrayOffsetData = 0; + const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); + const arrayOffsetDevType = arrayOffsetContext; + const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; + const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLContext; + const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; + const arrayOffsetDtypeCode = arrayOffsetDtype; + const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; + const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; + const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; + const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); + const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); + // ndim + this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); + // shape + const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); + this.shape = []; + for (let i = 0; i < this.ndim; ++i) { + this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); + } + // dtype + const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); + const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); + const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); + this.dlDataType = new DLDataType(code, bits, lanes); + this.dtype = this.dlDataType.toString(); + + // ctx + const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); + const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); + this.context = new DLContext(deviceType, deviceId, lib); + + // byte_offset + this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) + ); + this.handle = 0; + } + } + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + copyFrom(data: NDArray | Array): this { + if (data instanceof NDArray) { + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( + data.handle, + this.handle, + 0 + ) + ); + return this; + } else { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + if (data.length != size) { + throw new Error( + "data size and shape mismatch data.length" + + data.length + + " vs " + + size + ); + } + let buffer: ArrayBuffer; + if (this.dtype == "float32") { + buffer = Float32Array.from(data).buffer; + } else if (this.dtype == "float64") { + buffer = Float64Array.from(data).buffer; + } else if (this.dtype == "int32") { + buffer = Int32Array.from(data).buffer; + } else if (this.dtype == "int8") { + buffer = Int8Array.from(data).buffer; + } else if (this.dtype == "uint8") { + buffer = Uint8Array.from(data).buffer; + } else { + throw new Error("Unsupported data type " + this.dtype); + } + return this.copyFromRawBytes(new Uint8Array(buffer)); + } + } + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + copyFromRawBytes(data: Uint8Array): this { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + if (nbytes != data.length) { + throw new Error("Expect the data's length equals nbytes=" + nbytes); + } + + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.memory.storeRawBytes(tempPtr, data); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + + this.lib.recycleCallStack(stack); + return this; + } + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + toRawBytes(): Uint8Array { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); + + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { + const stype = this.dtype; + if (stype == "float32") { + return new Float32Array(this.toRawBytes().buffer); + } else if (stype == "float64") { + return new Float64Array(this.toRawBytes().buffer); + } else if (stype == "int32") { + return new Int32Array(this.toRawBytes().buffer); + } else if (stype == "int8") { + return new Int8Array(this.toRawBytes().buffer); + } else if (stype == "uint8") { + return new Uint8Array(this.toRawBytes().buffer); + } else { + throw new Error("Unsupported data type " + this.dtype); + } + } + + private getDLTensorFromArrayHandle(handle: Pointer): Pointer { + // Note: this depends on the NDArray C ABI. + // keep this function in case of ABI change. + return handle; + } +} + +/** + * Runtime Module. + */ +export class Module implements Disposable { + handle: Pointer; + private lib: FFILibrary; + private makePackedFunc: (ptr: Pointer) => PackedFunc; + + constructor( + handle: Pointer, + lib: FFILibrary, + makePackedFunc: (ptr: Pointer) => PackedFunc + ) { + this.handle = handle; + this.lib = lib; + this.makePackedFunc = makePackedFunc; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get a function in the module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( + this.handle, + stack.ptrFromOffset(nameOffset), + 1, + outPtr + ) + ); + const handle = this.lib.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + importModule(mod: Module): void { + this.lib.checkCall( + (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( + this.handle, + mod.handle + ) + ); + } +} + +/** + * TVM runtime instance. + */ +export class Instance implements Disposable { + memory: Memory; + exports: Record; + private lib: FFILibrary; + private env: Environment; + + /** + * Internal function(registered by the runtime) + */ + private wasmCreateLibraryModule?: PackedFunc & + ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc); + + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + constructor( + wasmModule: WebAssembly.Module, + importObject: Record = {}, + wasmInstance?: WebAssembly.Instance, + env?: Environment + ) { + if (wasmInstance instanceof WebAssembly.Instance) { + assert( + env instanceof Environment, + "env must be provided when passing in instance" + ); + } else { + assert(env === undefined); + env = new Environment(importObject); + wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); + } + + env.start(wasmInstance); + this.env = env; + this.lib = new FFILibrary(wasmInstance, env.imports); + this.memory = this.lib.memory; + this.exports = this.lib.exports; + this.registerEnvGlobalPackedFuncs(); + } + + dispose(): void { + this.lib.dispose(); + } + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + systemLib(): Module { + const getSysLib = this.getGlobalFunc("runtime.SystemLib"); + const mod = getSysLib() as Module; + getSysLib.dispose(); + return mod; + } + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + listGlobalFuncNames(): Array { + const stack = this.lib.getOrAllocCallStack(); + + const outSizeOffset = stack.allocPtrArray(2); + + const outSizePtr = stack.ptrFromOffset(outSizeOffset); + const outArrayPtr = stack.ptrFromOffset( + outSizeOffset + this.lib.sizeofPtr() + ); + + this.lib.checkCall( + (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( + outSizePtr, + outArrayPtr + ) + ); + + const size = this.memory.loadI32(outSizePtr); + const array = this.memory.loadPointer(outArrayPtr); + const names: Array = []; + + for (let i = 0; i < size; ++i) { + names.push( + this.memory.loadCString( + this.memory.loadPointer(array + this.lib.sizeofPtr() * i) + ) + ); + } + + this.lib.recycleCallStack(stack); + return names; + } + + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerFunc( + name: string, + func: PackedFunc | Function, + override = false + ): void { + const packedFunc = this.toPackedFunc(func); + const ioverride = override ? 1 : 0; + + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + stack.commitToWasmMemory(); + + this.lib.checkCall( + (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + stack.ptrFromOffset(nameOffset), + packedFunc._tvmPackedCell.handle, + ioverride + ) + ); + } + + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @returns The result function. + */ + getGlobalFunc(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( + stack.ptrFromOffset(nameOffset), + outPtr + ) + ); + const handle = this.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find global function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + isPackedFunc(func: unknown): boolean { + // eslint-disable-next-line no-prototype-builtins + return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell"); + } + + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + toPackedFunc(func: Function): PackedFunc { + if (this.isPackedFunc(func)) return func as PackedFunc; + return this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + } + + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + toDLDataType(dtype: string | DLDataType): DLDataType { + if (dtype instanceof DLDataType) return dtype; + if (typeof dtype == "string") { + let pattern = dtype; + let code, + bits = 32, + lanes = 1; + if (pattern.substring(0, 5) == "float") { + pattern = pattern.substring(5, pattern.length); + code = TypeCode.Float; + } else if (pattern.substring(0, 3) == "int") { + pattern = pattern.substring(3, pattern.length); + code = TypeCode.Int; + } else if (pattern.substring(0, 4) == "uint") { + pattern = pattern.substring(4, pattern.length); + code = TypeCode.UInt; + } else if (pattern.substring(0, 6) == "handle") { + pattern = pattern.substring(5, pattern.length); + code = TypeCode.TVMOpaqueHandle; + bits = 64; + } else { + throw new Error("Unknown dtype " + dtype); + } + + const arr = pattern.split("x"); + if (arr.length >= 1) { + const parsed = parseInt(arr[0]); + if (parsed + "" == arr[0]) { + bits = parsed; + } + } + if (arr.length >= 2) { + lanes = parseInt(arr[1]); + } + return new DLDataType(code, bits, lanes); + } else { + throw new Error("Unknown dtype " + dtype); + } + } + + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + scalar(value: number, dtype: string): Scalar { + return new Scalar(value, dtype); + } + + /** + * Create a new {@link DLContext} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created context. + */ + context(deviceType: number | string, deviceId: number): DLContext { + return new DLContext(deviceType, deviceId, this.lib); + } + + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param ctx The context of the ndarray. + * @returns The created ndarray. + */ + empty( + shape: Array | number, + dtype: string | DLDataType = "float32", + ctx: DLContext = this.context("cpu", 0) + ): NDArray { + dtype = this.toDLDataType(dtype); + shape = typeof shape == "number" ? [shape] : shape; + + const stack = this.lib.getOrAllocCallStack(); + const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); + for (let i = 0; i < shape.length; ++i) { + stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); + } + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( + stack.ptrFromOffset(shapeOffset), + shape.length, + dtype.code, + dtype.bits, + dtype.lanes, + ctx.deviceType, + ctx.deviceId, + outPtr + ) + ); + const ret = new NDArray(this.memory.loadPointer(outPtr), this.lib); + this.lib.recycleCallStack(stack); + return ret; + } + + /** Register global packed functions needed by the backend to the env. */ + private registerEnvGlobalPackedFuncs(): void { + // Register the timer function to enable the time_evaluator. + let perf: Performance; + if (typeof performance == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const performanceNode = require('perf_hooks'); + perf = performanceNode.performance as Performance; + } else { + perf = performance as Performance; + } + + const getTimer = (func: PackedFunc) => { + return (n: number): number => { + const nscalar = this.scalar(n, "int32"); + const tstart: number = perf.now(); + func(nscalar); + const tend: number = perf.now(); + return tend - tstart; + } + }; + this.registerFunc("wasm.GetTimer", getTimer); + const rpcWrapTimeEvaluator = this.getGlobalFunc("wasm.RPCTimeEvaluator"); + this.registerFunc("runtime.RPCTimeEvaluator", rpcWrapTimeEvaluator, true); + rpcWrapTimeEvaluator.dispose(); + } + + private createPackedFuncFromCFunc( + func: ctypes.FTVMWasmPackedCFunc + ): PackedFunc { + let findex = this.env.packedCFuncTable.length; + if (this.env.packedCFuncTableFreeId.length != 0) { + findex = this.env.packedCFuncTableFreeId.pop() as number; + } else { + this.env.packedCFuncTable.push(undefined); + } + this.env.packedCFuncTable[findex] = func; + + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall( + (this.exports + .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( + findex, + outPtr + ) + ); + const ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + setPackedArguments( + stack: CachedCallStack, + args: Array, + argsValue: PtrOffset, + argsCode: PtrOffset + ): void { + for (let i = 0; i < args.length; ++i) { + let val = args[i]; + const tp = typeof val; + const valueOffset = argsValue + i * SizeOf.TVMValue; + const codeOffset = argsCode + i * SizeOf.I32; + if (val instanceof NDArray) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, TypeCode.TVMNDArrayHandle); + } else if (val instanceof Scalar) { + if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { + stack.storeI64(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.Int); + } else if (val.dtype.startsWith("float")) { + stack.storeF64(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.Float); + } else { + assert(val.dtype == "handle", "Expect handle"); + stack.storePtr(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle); + } + } else if (tp == "number") { + stack.storeF64(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.Float); + // eslint-disable-next-line no-prototype-builtins + } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + } else if (val === null || val == undefined) { + stack.storePtr(valueOffset, 0); + stack.storeI32(codeOffset, TypeCode.Null); + } else if (tp == "string") { + stack.allocThenSetArgString(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.TVMStr); + } else if (val instanceof Uint8Array) { + stack.allocThenSetArgBytes(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.TVMBytes); + } else if (val instanceof Function) { + val = this.toPackedFunc(val); + stack.tempArgs.push(val); + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + } else if (val instanceof Module) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, TypeCode.TVMModuleHandle); + } else { + throw new Error("Unsupported argument type " + tp); + } + } + } + + private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { + const lib = this.lib; + return ( + argValues: Pointer, + argCodes: Pointer, + nargs: number, + ret: Pointer, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _handle: Pointer + ): number => { + const jsArgs = []; + for (let i = 0; i < nargs; ++i) { + const valuePtr = argValues + i * SizeOf.TVMValue; + const codePtr = argCodes + i * SizeOf.I32; + let tcode = lib.memory.loadI32(codePtr); + + if ( + tcode == TypeCode.TVMObjectHandle || + tcode == TypeCode.TVMObjectRValueRefArg || + tcode == TypeCode.TVMPackedFuncHandle || + tcode == TypeCode.TVMModuleHandle + ) { + lib.checkCall( + (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( + valuePtr, + codePtr + ) + ); + } + tcode = lib.memory.loadI32(codePtr); + jsArgs.push(this.retValueToJS(valuePtr, tcode)); + } + + const rv = func(...jsArgs); + + if (rv !== undefined && rv !== null) { + const stack = lib.getOrAllocCallStack(); + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const codeOffset = stack.allocRawBytes(SizeOf.I32); + this.setPackedArguments(stack, [rv], valueOffset, codeOffset); + const valuePtr = stack.ptrFromOffset(valueOffset); + const codePtr = stack.ptrFromOffset(codeOffset); + stack.commitToWasmMemory(); + lib.checkCall( + (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( + ret, + valuePtr, + codePtr, + 1 + ) + ); + lib.recycleCallStack(stack); + } + return 0; + }; + } + + private makePackedFunc(handle: Pointer): PackedFunc { + const cell = new PackedFuncCell(handle, this.lib); + + const packedFunc = (...args: any): any => { + const stack = this.lib.getOrAllocCallStack(); + + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); + const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); + + this.setPackedArguments(stack, args, valueOffset, tcodeOffset); + + const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const rcodeOffset = stack.allocRawBytes(SizeOf.I32); + const rvaluePtr = stack.ptrFromOffset(rvalueOffset); + const rcodePtr = stack.ptrFromOffset(rcodeOffset); + + // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) + stack.commitToWasmMemory(rvalueOffset); + + this.lib.checkCall( + (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( + handle, + stack.ptrFromOffset(valueOffset), + stack.ptrFromOffset(tcodeOffset), + args.length, + rvaluePtr, + rcodePtr + ) + ); + + const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr)); + this.lib.recycleCallStack(stack); + return ret; + }; + // Attach attributes to the function type. + // This is because javascript do not allow us to overload call. + const ret: any = packedFunc; + ret.dispose = (): void => { + cell.dispose(); + }; + ret._tvmPackedCell = cell; + return ret as PackedFunc; + } + + private retValueToJS(rvaluePtr: Pointer, tcode: number): any { + switch (tcode) { + case TypeCode.Int: + case TypeCode.UInt: + return this.memory.loadI64(rvaluePtr); + case TypeCode.Float: + return this.memory.loadF64(rvaluePtr); + case TypeCode.TVMNDArrayHandle: { + return new NDArray(this.memory.loadPointer(rvaluePtr), this.lib); + } + case TypeCode.TVMPackedFuncHandle: { + return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); + } + case TypeCode.TVMModuleHandle: { + return new Module( + this.memory.loadPointer(rvaluePtr), + this.lib, + (ptr: Pointer) => { + return this.makePackedFunc(ptr); + } + ); + } + case TypeCode.Null: + return undefined; + case TypeCode.TVMStr: { + return this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + } + case TypeCode.TVMBytes: { + return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); + } + default: + throw new Error("Unsupported return type code=" + tcode); + } + } +} + +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + */ +export function instantiate( + bufferSource: ArrayBuffer, + importObject: Record = {} +): Promise { + const env = new Environment(importObject); + + return WebAssembly.instantiate(bufferSource, env.imports).then( + (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { + return new Instance(result.module, {}, result.instance, env); + } + ); +} diff --git a/web/src/support.ts b/web/src/support.ts new file mode 100644 index 0000000..7a2667a --- /dev/null +++ b/web/src/support.ts @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +export function StringToUint8Array(str: string): Uint8Array { + const arr = new Uint8Array(str.length + 1); + for (let i = 0; i < str.length; ++i) { + arr[i] = str.charCodeAt(i); + } + arr[str.length] = 0; + return arr; +} + +/** + * Convert Uint8array to string. + * @param array The array. + * @returns The corresponding string. + */ +export function Uint8ArrayToString(arr: Uint8Array): string { + const ret = []; + for (const ch of arr) { + ret.push(String.fromCharCode(ch)); + } + return ret.join(""); +} + +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +export function assert(condition: boolean, msg?: string): asserts condition { + if (!condition) { + throw new Error("AssertError:" + (msg || "")); + } +} + +/** + * Get the path to the wasm library in nodejs. + * @return The wasm path. + */ +export function wasmPath(): string { + return __dirname + "/wasm"; +} \ No newline at end of file diff --git a/web/src/types.ts b/web/src/types.ts new file mode 100644 index 0000000..621375a --- /dev/null +++ b/web/src/types.ts @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** Common type definitions. */ + +/** + * Library interface provider that can provide + * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime. + * + * It can be viewed as a generalization of imports used in WebAssembly instance creation. + * + * The {@link LibraryProvider.start} callback will be called + * to allow the library provider to initialize related resources during startup time. + * + * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }. + */ +export interface LibraryProvider { + /** The imports that can be passed to WebAssembly instance creation. */ + imports: Record; + /** + * Callback function to notify the provider the created instance. + * @param inst The created instance. + */ + start: (inst: WebAssembly.Instance) => void; +} + +/** + * Disposable classes that contains resources (WasmMemory, GPU buffer) + * which needs to be explicitly disposed. + */ +export interface Disposable { + /** + * Dispose the internal resource + * This function can be called multiple times, + * only the first call will take effect. + */ + dispose: () => void; +} diff --git a/tests/web/test_module_load.js b/web/tests/node/test_module_load.js similarity index 64% rename from tests/web/test_module_load.js rename to web/tests/node/test_module_load.js index f4c8095..45e84fd 100644 --- a/tests/web/test_module_load.js +++ b/web/tests/node/test_module_load.js @@ -19,14 +19,18 @@ // Load Emscripten Module, need to change path to root/lib const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "test_addone.wasm")); + +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); // Load system library -var sysLib = tvm.systemLib(); +const sysLib = tvm.systemLib(); function randomArray(length, max) { return Array.apply(null, Array(length)).map(function() { @@ -36,23 +40,22 @@ function randomArray(length, max) { function testAddOne() { // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); + const faddOne = sysLib.getFunction("add_one"); + assert(tvm.isPackedFunc(faddOne)); + const n = 124; + const A = tvm.empty(n).copyFrom(randomArray(n, 1)); + const B = tvm.empty(n); // call the function. faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array + const AA = A.toArray(); // retrieve values in js array + const BB = B.toArray(); // retrieve values in js array // verify for (var i = 0; i < BB.length; ++i) { assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); } - faddOne.release(); + faddOne.dispose(); } testAddOne(); -sysLib.release(); +sysLib.dispose(); console.log("Finish verifying test_module_load"); diff --git a/tests/web/test_basic.js b/web/tests/node/test_ndarray.js similarity index 55% rename from tests/web/test_basic.js rename to web/tests/node/test_ndarray.js index 6852319..ba43621 100644 --- a/tests/web/test_basic.js +++ b/web/tests/node/test_ndarray.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -16,31 +16,34 @@ * specific language governing permissions and limitations * under the License. */ - -// Load Emscripten Module, need to change path to root/build const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist/tvmjs.bundle") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); // Basic fields. -tvm.assert(tvm.float32 == "float32"); -tvm.assert(tvm.listGlobalFuncNames() !== "undefined"); -var sysLib = tvm.systemLib(); -tvm.assert(typeof sysLib.getFunction !== "undefined"); -sysLib.release(); +assert(tvm.listGlobalFuncNames() !== undefined); // Test ndarray -function testArrayCopy(dtype, arr) { - var data = [1, 2, 3, 4, 5, 6]; - var a = tvm.empty([2, 3], dtype); - a.copyFrom(data); - var ret = a.asArray(); - tvm.assert(ret instanceof arr); - tvm.assert(ret.toString() == arr.from(data)); - a.release(); +function testArrayCopy(dtype, arrayType) { + let data = [1, 2, 3, 4, 5, 6]; + let a = tvm.empty([2, 3], dtype).copyFrom(data); + + assert(a.context.toString() == "cpu(0)"); + assert(a.shape[0] == 2 && a.shape[1] == 3); + + let ret = a.toArray(); + assert(ret instanceof arrayType); + assert(ret.toString() == arrayType.from(data).toString()); + // test multiple dispose. + a.dispose(); + a.dispose(); } testArrayCopy("float32", Float32Array); @@ -48,8 +51,3 @@ testArrayCopy("int", Int32Array); testArrayCopy("int8", Int8Array); testArrayCopy("uint8", Uint8Array); testArrayCopy("float64", Float64Array); - -// Function registration -tvm.registerFunc("xyz", function(x, y) { - return x + y; -}); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js new file mode 100644 index 0000000..c961f95 --- /dev/null +++ b/web/tests/node/test_packed_func.js @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +const path = require("path"); +const fs = require("fs"); +const assert = require('assert'); +const tvmjs = require("../../dist") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); + +function testGetGlobal() { + let flist = tvm.listGlobalFuncNames(); + let faddOne = tvm.getGlobalFunc("testing.add_one"); + let fecho = tvm.getGlobalFunc("testing.echo"); + + assert(faddOne(tvm.scalar(1, "int")) == 2); + // check function argument with different types. + assert(fecho(1123) == 1123); + assert(fecho("xyz") == "xyz"); + + let bytes = new Uint8Array([1, 2, 3]); + let rbytes = fecho(bytes); + assert(rbytes.length == bytes.length); + + for (let i = 0; i < bytes.length; ++i) { + assert(rbytes[i] == bytes[i]); + } + + assert(fecho(undefined) == undefined); + + let arr = tvm.empty([2, 2]).copyFrom([1, 2, 3, 4]); + let arr2 = fecho(arr); + assert(arr.handle == arr2.handle); + assert(arr2.toArray().toString() == arr.toArray().toString()); + + let mod = tvm.systemLib(); + let ret = fecho(mod); + assert(ret.handle == mod.handle); + assert(flist.length != 0); + + mod.dispose(); + ret.dispose(); + arr.dispose(); + arr2.dispose(); + fecho.dispose(); + faddOne.dispose(); +} + +function testReturnFunc() { + function addy(y) { + function add(x, z) { + return x + y + z; + } + return add; + } + + let fecho = tvm.getGlobalFunc("testing.echo"); + let myf = tvm.toPackedFunc(addy); + assert(tvm.isPackedFunc(myf)); + let myf2 = tvm.toPackedFunc(myf); + assert(myf2._tvmPackedCell.handle === myf._tvmPackedCell.handle); + let f = myf(10); + + assert(tvm.isPackedFunc(f)); + assert(f(11, 0) == 21); + assert(f("x", 1) == "x101"); + assert(f("x", "yz") == "x10yz"); + + fecho.dispose(); + myf.dispose(); + myf2.dispose(); + // test multiple dispose. + f.dispose(); + f.dispose(); +} + +function testRegisterGlobal() { + tvm.registerFunc("xyz", function (x, y) { + return x + y; + }); + + let f = tvm.getGlobalFunc("xyz"); + assert(f(1, 2) == 3); + f.dispose(); + + let syslib = tvm.systemLib(); + syslib.dispose(); +} + +function testTimer() { + const fecho = tvm.getGlobalFunc("testing.echo"); + const fgetTimer = tvm.getGlobalFunc("wasm.GetTimer"); + + let finvoke = (n) => { + let x = "xyz"; + for (let i = 0; i < n; ++i) { + x = fecho(x); + } + }; + const number = 10000; + const invokeTimer = fgetTimer(finvoke); + console.log("Time cost:", number / invokeTimer(number) * 1000, " ops/sec"); + fecho.dispose(); + invokeTimer.dispose(); + fgetTimer.dispose(); +} + +testGetGlobal(); +testRegisterGlobal(); +testReturnFunc(); +testTimer(); diff --git a/tests/web/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py similarity index 69% rename from tests/web/prepare_test_libs.py rename to web/tests/python/prepare_test_libs.py index a0e2c13..ec4eb5b 100644 --- a/tests/web/prepare_test_libs.py +++ b/web/tests/python/prepare_test_libs.py @@ -14,27 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# Prepare test library for js. +# Prepare test library for standalone wasm runtime test. + import tvm from tvm import te -from tvm.contrib import emscripten +from tvm.contrib import emcc import os + def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" if not tvm.runtime.enabled(target): raise RuntimeError("Target %s is not enbaled" % target) n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path, - options=["-s", "WASM=0", "-s", "USE_GLFW=3", "-s", - "USE_WEBGL2=1", "-lglfw"]) + fadd = tvm.build(s, [A, B], target, name="add_one") + + wasm_path = os.path.join(base_path, "test_addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + if __name__ == "__main__": curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) + prepare_test_libs(os.path.join(curr_path, "../../dist/wasm")) diff --git a/tests/web/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py similarity index 55% rename from tests/web/websock_rpc_test.py rename to web/tests/python/websock_rpc_test.py index 8be8ce0..7fa0c6b 100644 --- a/tests/web/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -22,45 +22,61 @@ Connect javascript end to the websocket port and connect to the RPC. import tvm from tvm import te -import os from tvm import rpc -from tvm.contrib import util, emscripten +from tvm.contrib import util, emcc import numpy as np proxy_host = "localhost" proxy_port = 9090 -def test_rpc_array(): +def test_rpc(): if not tvm.runtime.enabled("rpc"): return - # graph - n = tvm.runtime.convert(1024) + # generate the wasm library + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" + if not tvm.runtime.enabled(target): + raise RuntimeError("Target %s is not enbaled" % target) + n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - remote = rpc.connect(proxy_host, proxy_port, key="js") - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - def check_remote(): - if not tvm.runtime.enabled(target): - print("Skip because %s is not enabled" % target) - return - temp = util.tempdir() + + fadd = tvm.build(s, [A, B], target, name="addone") + temp = util.tempdir() + + wasm_path = temp.relpath("addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + + wasm_binary = open(wasm_path, "rb").read() + + remote = rpc.connect(proxy_host, proxy_port, key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary]) + + def check(remote): + # basic function checks. + fecho = remote.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + # run the generated library. + f1 = remote.system_lib() ctx = remote.cpu(0) - f = tvm.build(s, [A, B], target, name="myadd") - path_obj = temp.relpath("dev_lib.bc") - path_dso = temp.relpath("dev_lib.js") - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - # Upload to suffix as dso so it can be loaded remotely - remote.upload(path_dso, "dev_lib.dso") - data = remote.download("dev_lib.dso") - f1 = remote.load_module("dev_lib.dso") a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) - time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) + # invoke the function + addone = f1.get_function("addone") + addone(a, b) + + # time evaluator + time_f = f1.time_evaluator("addone", ctx, number=10) + time_f(a, b) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote() -test_rpc_array() + check(remote) + + +test_rpc() diff --git a/web/tsconfig.json b/web/tsconfig.json new file mode 100644 index 0000000..3c20b3d --- /dev/null +++ b/web/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "module": "commonjs", + "target": "es6", + "outDir": "dist", + "rootDir": "src", + "declaration": true, + "sourceMap": true, + "strict": true, + }, + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js deleted file mode 100644 index 86ef59c..0000000 --- a/web/tvm_runtime.js +++ /dev/null @@ -1,1274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/** - * TVM Javascript web runtime library. - * - * @projectname tvm - * @version 0.7.dev1 - */ -/* eslint no-unused-vars: "off" */ -/* eslint no-unexpected-multiline: "off" */ -/* eslint indent: "off" */ -/* eslint no-console: "off" */ -/** - * TVM Runtime namespace. - * Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}. - * - * @namespace tvm_runtime - */ -var tvm_runtime = tvm_runtime || {}; - -/** - * TVM root namespace. - * The classes inside this namespace need to be constructed by factory functions. - * Use {@link tvm_runtime}.create to get started. - * - * @namespace tvm - */ -(function() { - /** - * TVMRuntime object for interacting with TVM runtime. - * This object can be constructed using {@link tvm_runtime}.create - * - * @class - * @memberof tvm - */ - function TVMRuntime() { - "use strict"; - var runtime_ref = this; - // Utility function to throw error - function throwError(message) { - if (typeof runtime_ref.logger !== "undefined") { - runtime_ref.logger(message); - } - if (typeof Error !== "undefined") { - throw new Error(message); - } - throw message; - } - var Module = this.Module; - var Runtime = this.Runtime; - if (typeof Module === "undefined") { - throwError("Emscripten Module is not available"); - } - // constants - var SIZEOF_POINTER = 4; - var SIZEOF_SIZE_T = 4; - var SIZEOF_FLOAT = 4; - var SIZEOF_INT = 4; - var SIZEOF_INT8 = 1; - var SIZEOF_INT64 = 8; - var SIZEOF_DOUBLE = 8; - var SIZEOF_TYPE = 4; - var SIZEOF_CTX = SIZEOF_INT + SIZEOF_INT; - var SIZEOF_TVMVALUE = SIZEOF_DOUBLE; - var ARRAY_OFFSET_DATA = 0; - var ARRAY_OFFSET_CTX = ARRAY_OFFSET_DATA + SIZEOF_POINTER; - var ARRAY_OFFSET_DEV_TYPE = ARRAY_OFFSET_CTX; - var ARRAY_OFFSET_DEV_ID = ARRAY_OFFSET_CTX + SIZEOF_INT; - var ARRAY_OFFSET_NDIM = ARRAY_OFFSET_CTX + SIZEOF_CTX; - var ARRAY_OFFSET_DTYPE = ARRAY_OFFSET_NDIM + SIZEOF_INT; - var ARRAY_OFFSET_DTYPE_CODE = ARRAY_OFFSET_DTYPE; - var ARRAY_OFFSET_DTYPE_BITS = ARRAY_OFFSET_DTYPE_CODE + SIZEOF_INT8; - var ARRAY_OFFSET_DTYPE_LANES = ARRAY_OFFSET_DTYPE_BITS + SIZEOF_INT8; - var ARRAY_OFFSET_SHAPE = ARRAY_OFFSET_DTYPE + SIZEOF_TYPE; - var ARRAY_OFFSET_STRIDES = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - var ARRAY_OFFSET_BYTE_OFFSET = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - // Type codes - var kInt = 0; - var kUInt = 1; - var kFloat = 2; - var kTVMOpaqueHandle = 3; - var kNull = 4; - var kTVMDataType = 5; - var kTVMContext = 6; - var kTVMDLTensorHandle = 7; - var kTVMObjectHandle = 8; - var kTVMModuleHandle = 9; - var kTVMPackedFuncHandle = 10; - var kTVMStr = 11; - var kTVMBytes = 12; - var kTVMObjectRValueRefArg = 14; - //----------------------------------------- - // TVM CWrap library - // ---------------------------------------- - var TVMGetLastError = Module.cwrap( - "TVMGetLastError", - "string", // const char* - []); - - var TVMAPISetLastError = Module.cwrap - ("TVMAPISetLastError", - null, - ["string" // const char* - ]); - - var TVMModImport = Module.cwrap - ("TVMModImport", - "number", - ["number", // TVMModuleHandle mod - "number" // TVMModuleHandle dep - ]); - - var TVMModGetFunction = Module.cwrap - ("TVMModGetFunction", - "number", - ["number", // TVMModuleHandle mod - "string", // const char* func_name - "number", // int query_imports - "number" // TVMFunctionHandle *out - ]); - - var TVMModFree = Module.cwrap - ("TVMModFree", - "number", - ["number" // TVMModeHandle mod - ]); - - var TVMFuncFree = Module.cwrap - ("TVMFuncFree", - "number", - ["number" // TVMFunctionHandle func - ]); - - var TVMFuncCall = Module.cwrap - ("TVMFuncCall", - "number", - ["number", // TVMFunctionHandle func - "number", // TVMValue* arg_values - "number", // int* arg_tcodes - "number", // int num_args - "number", // int ret_val - "number" // int ret_type_code - ]); - - var TVMCFuncSetReturn = Module.cwrap - ("TVMCFuncSetReturn", - "number", - ["number", // TVMRetValueHandle ret - "number", // TVMValue* value - "number", // int* type_code - "number" // int num_ret - ]); - - var TVMCbArgToReturn = Module.cwrap - ("TVMCbArgToReturn", - "number", - ["number", // TVMValue* value - "number" // int* code - ]); - - var TVMFuncCreateFromCFunc = Module.cwrap - ("TVMFuncCreateFromCFunc", - "number", - ["number", // TVMPackedCFunc func, - "number", // void* resource_handle - "number", // TVMPackedCFuncFinalizer fin - "number" // TVMFunctionHandle *out - ]); - - var TVMFuncRegisterGlobal = Module.cwrap - ("TVMFuncRegisterGlobal", - "number", - ["string", // name - "number", // TVMFunctionHandle f - "number" // int override - ]); - - var TVMFuncGetGlobal = Module.cwrap - ("TVMFuncGetGlobal", - "number", - ["string", // const char* name - "number" // TVMFunctionHandle* out - ]); - - var TVMFuncListGlobalNames = Module.cwrap - ("TVMFuncListGlobalNames", - "number", - ["number", // int* out_size - "number" // const char*** out_array - ]); - - - var TVMArrayAlloc = Module.cwrap - ("TVMArrayAlloc", - "number", - ["number", // const tvm_index_t* shape - "number", // int ndim - "number", // int dtype_code - "number", // int dtype_bits - "number", // int dtype_lanes - "number", // int device_type - "number", // int device_id - "number" // int TVMArrayHandle* out - ]); - - var TVMArrayFree = Module.cwrap - ("TVMArrayFree", - "number", - ["number" // TVMArrayHandle handle - ]); - - var TVMArrayCopyFromTo = Module.cwrap - ("TVMArrayCopyFromTo", - "number", - ["number", // TVMArrayHandle from - "number" // TVMArrayHandle to - ]); - - var TVMArrayCopyFromBytes = Module.cwrap - ("TVMArrayCopyFromBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMArrayCopyToBytes = Module.cwrap - ("TVMArrayCopyToBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMModLoadFromFile = Module.cwrap - ("TVMModLoadFromFile", - "number", - ["string", // const char* file_name - "string", // const char* format - "number" // TVMModuleHandle* out - ]) - - //----------------------------------------- - // Static utility functions - // ---------------------------------------- - this.assert = function(condition, message) { - if (!condition) { - message = message || "assert failed"; - throwError(message); - } - }; - /** - * Logging function. - * Override this to change logger behavior. - * - * @param {string} message - */ - this.logger = function(message) { - console.log(message); - }; - - function logging(message) { - runtime_ref.logger(message); - } - // Override print error to logging - Module.printErr = logging; - var CHECK = this.assert; - - function TVM_CALL(ret) { - if (ret != 0) { - throwError(TVMGetLastError()); - } - } - - function CInt64ArrayToJS(ptr, size) { - var ret = []; - for (var i = 0; i < size; ++i) { - ret.push(Module.getValue(ptr + i * SIZEOF_INT64, "i64")); - } - return ret; - } - - function CStringToJS(ptr) { - var ret = []; - var ch = 1; - while (ch != 0) { - ch = Module.getValue(ptr, "i8"); - if (ch != 0) { - ret.push(String.fromCharCode(ch)); - } - ++ptr; - } - return ret.join(""); - } - - function CBytesToJS(ptr) { - var data = Module.getValue(ptr, "*"); - var size = Module.getValue(ptr + SIZEOF_POINTER, "i32"); - var ret = new Uint8Array(new ArrayBuffer(size)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, data, size)); - return ret; - } - - function StringToUint8Array(str) { - var arr = new Uint8Array(str.length + 1); - for(var i = 0; i < str.length; ++i) { - arr[i] = str.charCodeAt(i); - } - arr[str.length] = 0; - return arr; - } - //----------------------------------------- - // Class declarations - // ---------------------------------------- - function CBuffer(nbytes) { - this.data = Module._malloc(nbytes); - } - - function RefTVMValue() { - this.data = Module._malloc(SIZEOF_TVMVALUE); - } - - function TVMArgs(nargs) { - this.nargs = nargs; - this.value = Module._malloc(SIZEOF_TVMVALUE * nargs); - this.tcode = Module._malloc(SIZEOF_INT * nargs); - this.temp = []; - } - - function TVMType(code, bits, lanes) { - this.code = code; - this.bits = bits; - this.lanes = lanes; - } - /** - * TVM device context. - * @class - * @memberof tvm - */ - function TVMContext(device_type, device_id) { - this.device_type = device_type; - this.device_id = device_id; - } - /** - * TVM n-dimensional array. - * - * Use {@link tvm.TVMRuntime}.empty to create an instance. - * @class - * @memberof tvm - */ - function NDArray(handle) { - this.handle = handle; - this.ndim = Module.getValue(this.handle + ARRAY_OFFSET_NDIM, "i32"); - // shape - var cshape = Module.getValue(this.handle + ARRAY_OFFSET_SHAPE, "*"); - this.shape = CInt64ArrayToJS(cshape, this.ndim); - // dtype - var code = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_CODE, "i8"); - var bits = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_BITS, "i8"); - var lanes = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_LANES, "i16"); - var dtype = new TVMType(code, bits, lanes); - this.dtype = dtype; - this.BYTES_PER_ELEMENT = (dtype.bits * dtype.lanes / 8); - // ctx - var device_type = Module.getValue(this.handle + ARRAY_OFFSET_DEV_TYPE, "i32"); - var device_id = Module.getValue(this.handle + ARRAY_OFFSET_DEV_ID, "i32"); - this.context = new TVMContext(device_type, device_id); - // byte_offset - this.byteOffset = Module.getValue(this.handle + ARRAY_OFFSET_BYTE_OFFSET, "i64"); - } - - function TVMFunction(handle) { - this.handle = handle; - } - /** - * Module container of TVM generated functions. - * - * @class - * @memberof tvm - */ - function TVMModule(handle) { - this.handle = handle; - } - /** - * A typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * Use {@link tvm.TVMRuntime}.constant to create an instance. - * @class - * @memberof tvm - */ - function TVMConstant(value, dtype) { - this.value = value; - this.dtype = dtype; - } - //----------------------------------------- - // Private Functions - // ---------------------------------------- - function getTVMType(dtype) { - if (dtype instanceof TVMType) return dtype; - if (typeof dtype == "string") { - var pattern = dtype; - var code, bits = 32, lanes = 1; - if (pattern.substring(0, 5) == "float") { - pattern = pattern.substring(5, pattern.length); - code = kFloat; - } else if (pattern.substring(0, 3) == "int") { - pattern = pattern.substring(3, pattern.length); - code = kInt; - } else if (pattern.substring(0, 4) == "uint") { - pattern = pattern.substring(4, pattern.length); - code = kUInt; - } else if (pattern.substring(0, 6) == "handle") { - pattern = pattern.substring(5, pattern.length); - code = kTVMOpaqueHandle; - bits = 64; - } else { - throw throwError("Unknown dtype " + dtype); - } - var arr = pattern.split("x"); - if (arr.length >= 1) { - var parsed = parseInt(arr[0]); - if (parsed == arr[0]) { - bits = parsed; - } - } - if (arr.length >= 2) { - lanes = parseInt(arr[1]); - } - return new TVMType(code, bits, lanes); - } else { - throw throwError("Unknown dtype " + dtype); - } - } - - function TVMRetValueToJS(vptr, tcode) { - switch (tcode) { - case kInt: - case kUInt: return Module.getValue(vptr, "i64"); - case kFloat: return Module.getValue(vptr, "double"); - case kTVMPackedFuncHandle: return makeTVMFunction(Module.getValue(vptr, "*")); - case kTVMModuleHandle: return new TVMModule(Module.getValue(vptr, "*")); - case kNull: return null; - case kTVMStr: return CStringToJS(Module.getValue(vptr, "*")); - case kTVMBytes: return CBytesToJS(Module.getValue(vptr, "*")); - default: throwError("Unsupported return type code=" + tcode); - } - } - - function makeTVMFunction(handle) { - var func = new TVMFunction(handle); - var ret = function () { - // alloc - var args = new TVMArgs(arguments.length); - var rvalue = new RefTVMValue(); - var rtcode = new RefTVMValue(); - args.setArguments(arguments); - TVM_CALL(TVMFuncCall(handle, args.value, args.tcode, - args.nargs, rvalue.data, rtcode.data)); - var rv = TVMRetValueToJS(rvalue.data, rtcode.asInt()); - // release - args.release(); - rvalue.release(); - rtcode.release(); - return rv; - }; - var release = function() { - func.release(); - }; - ret._tvm_function = func; - ret.release = release; - return ret; - } - //----------------------------------------- - // Javascript PackedCallback System - // ---------------------------------------- - var funcTable = [0]; - var freeFuncId = []; - - function invokeCallback(arg_value, arg_tcode, nargs, ret, handle) { - var args = []; - for (var i = 0; i < nargs; ++i) { - var vptr = arg_value + i * SIZEOF_TVMVALUE; - var tcodeptr = arg_tcode + i * SIZEOF_INT; - var tcode = Module.getValue(tcodeptr, "i32"); - if (tcode == kTVMObjectHandle || - tcode == kTVMObjectRValueRefArg || - tcode == kTVMPackedFuncHandle || - tcode == kTVMModuleHandle) { - TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr)); - } - tcode = Module.getValue(tcodeptr, "i32"); - args.push(TVMRetValueToJS(vptr, tcode)); - } - var rv = funcTable[handle].apply(null, args); - if (typeof rv !== "undefined") { - // alloc - var rarg = new TVMArgs(1); - rarg.setArguments([rv]); - TVM_CALL(TVMCFuncSetReturn(ret, rarg.value, rarg.tcode, 1)); - // release - rarg.release(); - } - return 0; - } - function freeCallback(handle) { - funcTable[handle] = 0; - freeFuncId.push(handle); - } - var fptrInvokeCallback = null; - var fptrFreeCallback = null; - if (typeof Runtime !== "undefined" && - typeof Runtime.addFunction !== "undefined") { - fptrInvokeCallback = Runtime.addFunction(invokeCallback); - fptrFreeCallback = Runtime.addFunction(freeCallback); - } - /** - * Check if a function is TVM PackedFunc - * @param {Function} f function to be checked. - * @return {boolean} Whether f is PackedFunc - */ - this.isPackedFunc = function(f) { - return (typeof f == "function") && f.hasOwnProperty("_tvm_function"); - }; - var isPackedFunc = this.isPackedFunc; - /** - * Convert a javascript function to TVM function. - * @param {Function} f javascript function. - * @return {Function} The created TVMFunction. - */ - this.convertFunc = function(f) { - if (isPackedFunc(f)) return f; - CHECK(fptrInvokeCallback !== null, - "Emscripten Runtime addFunction is not available"); - var fid; - if (freeFuncId.length != 0) { - fid = freeFuncId.pop(); - } else { - fid = funcTable.length; - funcTable.push(0); - } - funcTable[fid] = f; - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncCreateFromCFunc( - fptrInvokeCallback, fid, fptrFreeCallback, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - return makeTVMFunction(out_handle); - }; - var convertFunc = this.convertFunc; - //----------------------------------------- - // Private Class declarations - // ---------------------------------------- - CBuffer.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - }; - // RefTVMValue - RefTVMValue.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - asInt : function() { - return Module.getValue(this.data, "i32"); - }, - asInt64 : function() { - return Module.getValue(this.data, "i64"); - }, - asDouble : function() { - return Module.getValue(this.data, "double"); - }, - asHandle : function() { - return Module.getValue(this.data, "*"); - } - }; - // TVMArgs - TVMArgs.prototype = { - release : function() { - if (this.value != 0) { - Module._free(this.value); - Module._free(this.tcode); - this.value = 0; - for (var i = 0; i< this.temp.length; ++i) { - if (this.temp[i].release instanceof Function) { - this.temp[i].release(); - } - } - } - }, - setInt : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kInt, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "i64"); - }, - setDouble : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kFloat, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "double"); - }, - setHandle : function(index, value, tcode) { - Module.setValue(this.tcode + index * SIZEOF_INT, tcode, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "*"); - }, - setString : function(index, value) { - var sdata = new CBuffer(value.length + 1); - Module.HEAPU8.set(StringToUint8Array(value), sdata.data); - this.temp.push(sdata); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMStr, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sdata.data, "*"); - }, - setBytes : function(index, value) { - CHECK(value instanceof Uint8Array); - var sdata = new CBuffer(value.length); - var sheader = new CBuffer(SIZEOF_POINTER + SIZEOF_SIZE_T); - Module.HEAPU8.set(new Uint8Array(value), sdata.data); - Module.setValue(sheader.data, sdata.data, "*"); - Module.setValue(sheader.data + SIZEOF_POINTER, value.length, "i32"); - this.temp.push(sdata); - this.temp.push(sheader); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMBytes, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sheader.data, "*"); - }, - setArguments : function(args) { - for (var i = 0; i < args.length; ++i) { - var v = args[i]; - var tp = typeof v; - if (v instanceof NDArray) { - this.setHandle(i, v.handle, kTVMDLTensorHandle); - } else if (v instanceof TVMConstant) { - var code = getTVMType(v.dtype).code; - if (code == kInt || code == kUInt) { - this.setInt(i, v.value); - } else if (code == kFloat) { - this.setDouble(i, v.value); - } else { - CHECK(code == kTVMOpaqueHandle); - this.setHandle(i, v.value, kTVMOpaqueHandle); - } - } else if (tp == "number") { - this.setDouble(i, v); - } else if (tp == "function" && v.hasOwnProperty("_tvm_function")) { - this.setString(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v === null) { - this.setHandle(i, 0, kNull); - } else if (tp == "string") { - this.setString(i, v); - } else if (v instanceof Uint8Array) { - this.setBytes(i, v); - } else if (v instanceof Function) { - v = convertFunc(v); - this.temp.push(v); - this.setHandle(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v instanceof TVMModule) { - this.setHandle(i, v.handle, kTVMModuleHandle); - } else { - throwError("Unsupported argument type " + tp); - } - } - } - }; - // TVMType - var TYPE_CODE2STR = { - 0 : "int", - 1 : "uint", - 2 : "float", - 4 : "handle" - }; - - TVMType.prototype = { - toString : function() { - var ret = TYPE_CODE2STR[this.code] + this.bits.toString(); - if (this.lanes != 1) { - return ret + "x" + this.lanes.toString(); - } else { - return ret; - } - } - }; - // TVMFunction - TVMFunction.prototype = { - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMFuncFree(this.handle)); - this.handle = 0; - } - } - }; - // TVMContext - var CTX_MASK2STR = { - 1 : "cpu", - 2 : "gpu", - 4 : "opencl", - 7 : "vulkan", - 8 : "metal", - 9 : "vpi", - 11 : "opengl", - }; - var CTX_STR2MASK = { - "cpu": 1, - "gpu": 2, - "cuda": 2, - "cl": 4, - "opencl": 4, - "vulkan": 7, - "metal": 8, - "vpi": 9, - "opengl": 11, - }; - TVMContext.prototype = { - toString : function() { - return CTX_MASK2STR[this.device_type] + "(" + this.device_id.toString() + ")"; - } - }; - //----------------------------------------- - // Public Functions - // ---------------------------------------- - /** - * Construct a TVMContext given device type and id. - * - * @param {number} device_type, string or int, The device type. - * @param {number} device_id, the device id. - * @return {tvm.TVMContext} The created TVMContext - */ - this.context = function(device_type, device_id) { - if (typeof device_type == "string") { - device_type = CTX_STR2MASK[device_type]; - } - return new TVMContext(device_type, device_id); - }; - var context = this.context; - /** - * Create empty ndarray with given shape. - * - * @param {Array.} shape The shape of the array. - * @param {string} dtype The data type of the array, optional, default="float32" - * @param {tvm.TVMContext} ctx The context of the array, optional, default=cpu(0). - * @return {tvm.NDArray} The created ndarray. - */ - this.empty = function(shape, dtype, ctx) { - dtype = (typeof dtype !== "undefined") ? dtype: "float32"; - ctx = (typeof ctx !== "undefined") ? ctx : context("cpu", 0); - shape = (typeof shape == "number") ? [shape] : shape; - // alloc - var cshape = Module._malloc(SIZEOF_INT64 * shape.length); - var out = new RefTVMValue(); - for (var i = 0; i < shape.length; ++i) { - Module.setValue(cshape + i * SIZEOF_INT64, shape[i], "i64"); - } - dtype = getTVMType(dtype); - TVM_CALL(TVMArrayAlloc(cshape, shape.length, - dtype.code, dtype.bits, dtype.lanes, - ctx.device_type, ctx.device_id, - out.data)); - var out_handle = out.asHandle(); - // release - Module._free(cshape); - out.release(); - return new NDArray(out_handle); - }; - /** - * List all global function names in the TVM runtime. - * @return {Array.} List of global function names. - */ - this.listGlobalFuncNames = function() { - // alloc - var out_size = new RefTVMValue(); - var out_array = new RefTVMValue(); - TVM_CALL(TVMFuncListGlobalNames(out_size.data, out_array.data)); - var length = out_size.asInt(); - var base = out_array.asHandle(); - var names = []; - for (var i = 0 ; i < length; ++i) { - names.push( - CStringToJS(Module.getValue(base + i * SIZEOF_POINTER, "*"))); - } - // release - out_size.release(); - out_array.release(); - return names; - }; - var listGlobalFuncNames = this.listGlobalFuncNames; - /** - * Get a global function from TVM runtime. - * - * @param {string} The name of the function. - * @return {Function} The corresponding function, null if function do not exist - */ - this.getGlobalFunc = function (name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncGetGlobal(name, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return makeTVMFunction(out_handle); - } else { - return null; - } - }; - var getGlobalFunc = this.getGlobalFunc; - /** - * Register function to be global function in tvm runtime. - * @param {string} name The name of the function. - * @param {Function} f function to be registered. - * @param {boolean} override Whether overwrite function in existing registry. - */ - this.registerFunc = function(name, f, override) { - f = convertFunc(f); - override = (typeof override !== "undefined") ? override: false; - var ioverride = override ? 1 : 0; - TVM_CALL(TVMFuncRegisterGlobal(name, f._tvm_function.handle, ioverride)); - }; - /** - * Create a typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * - * @param {number} value The value of the data. - * @param {string} dtype The data type. - * @param {tvm.TVMConstant} The created typed scalar. - */ - this.constant = function(value, dtype) { - return new TVMConstant(value, dtype); - }; - //----------------------------------------- - // Wrap of TVM Functions. - // ---------------------------------------- - var systemFunc = {}; - /** - * Get system-wide library module singleton.5A - * System lib is a global module that contains self register functions in startup. - * @return {tvm.TVMModule} The system module singleton. - */ - this.systemLib = function() { - if (typeof systemFunc.fGetSystemLib === "undefined") { - systemFunc.fGetSystemLib = getGlobalFunc("runtime.SystemLib"); - } - return systemFunc.fGetSystemLib(); - }; - - this.startRPCServer = function(url, key, counter) { - if (typeof key === "undefined") { - key = ""; - } - if (typeof counter === "undefined") { - counter = 1; - } - // Node js, import websocket - var bkey = StringToUint8Array("server:" + key); - bkey = bkey.slice(0, bkey.length - 1); - var server_name = "WebSocketRPCServer[" + key + "]"; - var RPC_MAGIC = 0xff271; - function checkEndian() { - var a = new ArrayBuffer(4); - var b = new Uint8Array(a); - var c = new Uint32Array(a); - b[0] = 0x11; - b[1] = 0x22; - b[2] = 0x33; - b[3] = 0x44; - CHECK(c[0] === 0x44332211, "Need little endian to work"); - } - checkEndian(); - // start rpc - function RPCServer(counter) { - var socket; - if (typeof module !== "undefined" && module.exports) { - // WebSocket for nodejs - const WebSocket = require("ws"); - socket = new WebSocket(url); - } else { - socket = new WebSocket(url); - } - var self = this; - socket.binaryType = "arraybuffer"; - this.init = true; - this.counter = counter; - - if (typeof systemFunc.fcreateServer === "undefined") { - systemFunc.fcreateServer = - getGlobalFunc("rpc.CreateEventDrivenServer"); - } - if (systemFunc.fcreateServer == null) { - throwError("RPCServer is not included in runtime"); - } - - var message_handler = systemFunc.fcreateServer( - function(cbytes) { - if (socket.readyState == 1) { - socket.send(cbytes); - return new TVMConstant(cbytes.length, "int32"); - } else { - return new TVMConstant(0, "int32"); - } - } , server_name, "%toinit"); - - function on_open(event) { - var intbuf = new Int32Array(1); - intbuf[0] = RPC_MAGIC; - socket.send(intbuf); - intbuf[0] = bkey.length; - socket.send(intbuf); - socket.send(bkey); - logging(server_name + " connected..."); - } - - function on_message(event) { - if (self.init) { - var msg = new Uint8Array(event.data); - CHECK(msg.length >= 4, "Need message header to be bigger than 4"); - var magic = new Int32Array(event.data)[0]; - - if (magic == RPC_MAGIC + 1) { - throwError("key: " + key + " has already been used in proxy"); - } else if (magic == RPC_MAGIC + 2) { - logging(server_name + ": RPCProxy do not have matching client key " + key); - } else { - CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy"); - self.init = false; - } - logging(server_name + "init end..."); - if (msg.length > 4) { - if (message_handler( - new Uint8Array(event.data, 4, msg.length -4), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } else { - if (message_handler(new Uint8Array(event.data), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } - function on_close(event) { - message_handler.release(); - logging(server_name + ": closed finish..."); - if (!self.init && self.counter != 0) { - logging(server_name + ":reconnect to serve another request, session left=" + counter); - // start a new server. - new RPCServer(counter - 1); - } - } - socket.addEventListener("open", on_open); - socket.addEventListener("message", on_message); - socket.addEventListener("close", on_close); - } - return new RPCServer(counter); - }; - - /** - * Load a TVM module from a library file. - * The file must be present in the Emscripten virtual file system. - * For example, you can pass "--preload-file file" or "--preload-file dir/" - * to "emcc" when compiling the TVM library, in order to populate files into - * the file system. - * For more detail, see: - * https://kripken.github.io/emscripten-site/docs/porting/files/packaging_files - * @param {string} file_name Path of the file to be loaded. The path refers - * to the Emscripten virtual file system. - * @param {string} format The format of the file. - * @return {tvm.TVMModule} The loaded module. - */ - this.loadModuleFromFile = function (file_name, format) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModLoadFromFile(file_name, format, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return new TVMModule(out_handle); - } else { - return null; - } - }; - var loadModuleFromFile = this.loadModuleFromFile; - - /** - * Wrapper runtime module. - * Wraps around set_input, load_params, run, and get_output. - * - * @class - * @memberof tvm - */ - function GraphModule(tvm_graph_module, ctx) { - CHECK(tvm_graph_module instanceof TVMModule, - "tvm_graph_module must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - this.tvm_graph_module = tvm_graph_module; - this.ctx = ctx; - this._set_input = tvm_graph_module.getFunction("set_input"); - this._load_params = tvm_graph_module.getFunction("load_params"); - this._run = tvm_graph_module.getFunction("run"); - this._get_output = tvm_graph_module.getFunction("get_output"); - }; - - GraphModule.prototype = { - /** - * Set input to graph module. - * - * @param {string} key The name of the input. - * @param {NDArray} value The input value. - */ - "set_input" : function(key, value) { - CHECK(typeof key == "string", "key must be string"); - CHECK(value instanceof NDArray, "value must be NDArray"); - this._set_input(key, value); - }, - - /** - * Load parameters from serialized byte array of parameter dict. - * - * @param {Uint8Array} params The serialized parameter dict. - */ - "load_params" : function(params) { - CHECK(params instanceof Uint8Array, "params must be Uint8Array"); - this._load_params(params); - }, - - /** - * Load parameters from serialized base64 string of parameter dict. - * - * @param {string} base64_params The serialized parameter dict. - */ - "load_base64_params" : function(base64_params) { - CHECK(typeof base64_params == "string", "base64_params must be string"); - var decoded_string = atob(base64_params); - var decoded_u8 = new Uint8Array(decoded_string.length); - for (var i = 0; i < decoded_string.length; i++) { - decoded_u8[i] = decoded_string[i].charCodeAt(0); - } - this.load_params(decoded_u8); - }, - - /** - * Run forward execution of the graph. - */ - "run" : function() { - this._run(); - }, - - /** - * Get index-th output to out. - * - * @param {NDArray} out The output array container. - * @return {NDArray} The output array container. - */ - "get_output" : function(index, out) { - CHECK(typeof index == "number", "index must be number"); - CHECK(out instanceof NDArray, "out must be NDArray"); - this._get_output(new TVMConstant(index, "int32"), out); - return out; - } - }; - - /** - * Create a runtime executor module given a graph and a module. - * @param {string} graph_json_str The Json string of the graph. - * @param {TVMModule} libmod The TVM module. - * @param {TVMContext} ctx The context to deploy the module. - * @return {GraphModule} Runtime graph module for executing the graph. - */ - this.createGraphRuntime = function(graph_json_str, libmod, ctx) { - CHECK(typeof graph_json_str == "string", "graph_json_str must be string"); - CHECK(libmod instanceof TVMModule, "libmod must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - var fcreate = getGlobalFunc("tvm.graph_runtime.create"); - CHECK(fcreate != null, "Cannot find tvm.graph_runtime.create"); - - var tvm_graph_module = fcreate(graph_json_str, libmod, - new TVMConstant(ctx.device_type, "int32"), - new TVMConstant(ctx.device_id, "int32")); - - return new GraphModule(tvm_graph_module, ctx); - }; - - //----------------------------------------- - // Class defintions - // ---------------------------------------- - // NDArray. - NDArray.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMArrayFree(this.handle)); - this.handle = 0; - } - }, - /** - * Copy data from another NDArray or javascript array. - * The number of elements must match. - * - * @param {Array} data The source data array. - */ - copyFrom : function(data) { - if (data instanceof NDArray) { - TVM_CALL(TVMArrayCopyFromTo(data.handle, this.handle)); - } else { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - if (data.length != size) { - throwError("data size and shape mismatch data.length" + data.length + " vs " + size); - } - if (this.dtype == "float32") { - data = Float32Array.from(data); - } else if (this.dtype == "float64") { - data = Float64Array.from(data); - } else if (this.dtype == "int32") { - data = Int32Array.from(data); - } else if (this.dtype == "int8") { - data = Int8Array.from(data); - } else if (this.dtype == "uint8") { - data = Uint8Array.from(data); - } else { - throwError("Unsupported data type " + this.dtype); - } - return this.copyFromRawBytes(new Uint8Array(data.buffer)); - } - }, - /** - * Copy data from raw bytes. - * @param {Uint8Array} data Uint8Array of bytes. - */ - copyFromRawBytes : function(data) { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var dtype = getTVMType(this.dtype); - var nbytes = this.BYTES_PER_ELEMENT * size; - CHECK(data instanceof Uint8Array); - CHECK(data.length == nbytes, - "Data length and bytes do not match " + data.length + - " vs " + nbytes); - var temp = Module._malloc(nbytes); - Module.HEAPU8.set(data, temp); - TVM_CALL(TVMArrayCopyFromBytes(this.handle, temp, nbytes)); - Module._free(temp); - return this; - }, - /** - * Return a copied Uint8Array of the raw bytes in the NDArray. - * @return {Uint8Array} The created array. - */ - asRawBytes : function() { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var nbytes = this.BYTES_PER_ELEMENT * size; - var temp = Module._malloc(nbytes); - TVM_CALL(TVMArrayCopyToBytes(this.handle, temp, nbytes)); - var ret = new Uint8Array(new ArrayBuffer(nbytes)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, temp, nbytes)); - Module._free(temp); - return ret; - }, - /** - * Return Array data content as javascript typed array. - * @return {TypedArray} The created array. - */ - asArray : function() { - if (this.dtype == "float32") { - return new Float32Array(this.asRawBytes().buffer); - } else if (this.dtype == "float64") { - return new Float64Array(this.asRawBytes().buffer); - } else if (this.dtype == "int32") { - return new Int32Array(this.asRawBytes().buffer); - } else if (this.dtype == "int8") { - return new Int8Array(this.asRawBytes().buffer); - } else if (this.dtype == "uint8") { - return new Uint8Array(this.asRawBytes().buffer); - } else { - throwError("Unsupported data type " + this.dtype); - } - } - }; - - TVMModule.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMModFree(this.handle)); - this.handle = 0; - } - }, - /** - * Get function from the module. - * @param {string} name The name of the function. - * @return {Function} The correspondin function. - */ - getFunction : function(name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModGetFunction(this.handle, name, 0, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle == 0) { - throwError("Module has no function " + name); - } - return makeTVMFunction(out_handle); - }, - /** - * Add module to the import list of current one. - * @param {tvm.TVMModule} mod The other module to be imported. - */ - import_module : function(mod) { - CHECK(mod instanceof TVMModule, "mod must be instance of TVMModule"); - TVM_CALL(TVMModImport(this.handle, mod.handle)); - } - }; - //----------------------------------------- - // Static variables. - // ---------------------------------------- - /** Float32 type */ - this.float32 = "float32"; - /** Int32 type */ - this.int32 = "int32"; - } - /** - * Create a TVM runtime given emscripten module. - * @property {string} create - * @memberof tvm_runtime - * @param Module The emscripten module. - * @return {tvm.TVMRuntime} The created TVM runtime. - */ - this.create = function(Module) { - var tvm = {}; - tvm.Module = Module; - if (typeof Module.addFunction !== "undefined") { - tvm.Runtime = Module; - } else { - tvm.Runtime = Module.Runtime; - } - TVMRuntime.apply(tvm); - return tvm; - }; -}).apply(tvm_runtime); - -// export things in node -if (typeof module !== "undefined" && module.exports) { - module.exports = tvm_runtime; -} diff --git a/web/web_runtime.cc b/web/web_runtime.cc deleted file mode 100644 index 701ded7..0000000 --- a/web/web_runtime.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file web_runtime.cc - */ -#include -#include - -#include "../src/runtime/c_runtime_api.cc" -#include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" -#include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" -#include "../src/runtime/module.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" -#include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" -#include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/opengl/opengl_device_api.cc" -#include "../src/runtime/opengl/opengl_module.cc" - -namespace tvm { -namespace contrib { - -struct RPCEnv { - public: - RPCEnv() { - base_ = "/rpc"; - mkdir(&base_[0], 0777); - } - // Get Path. - std::string GetPath(const std::string& file_name) { - return base_ + "/" + file_name; - } - - private: - std::string base_; -}; - -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body_typed([](std::string path) { - static RPCEnv env; - return env.GetPath(path); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body_typed([](std::string path) { - std::string file_name = "/rpc/" + path; - LOG(INFO) << "Load module from " << file_name << " ..."; - return Module::LoadFromFile(file_name, ""); - }); -} // namespace contrib -} // namespace tvm - -// dummy parallel runtime -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { - TVMAPISetLastError("Parallel is not supported in Web runtime"); - return -1; -} - -int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { - return 0; -} -- 2.7.4