From cdc7ae492e1ce30f6080d9567eecf89c97ed939f Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 9 May 2020 16:59:18 -0700 Subject: [PATCH] [WEB] WebGPU support (#5545) This PR introduces WebGPU support to tvm. The WebGPU runtime is directly built in javascript(as WebGPU uses JS as the first class citizen API) and exposes back to the tvm's runtime via PackedFuncs. One important note is that `ctx.sync` is not async. This is due to the fact that WebGPU is a purely async API and we cannot block in the web environment. So the current best way to use the js api is to wrap things in an async function. When copy a GPU array to CPU, `await ctx.sync()` need to be called to wait for copy completion. We use a AsyncIO rpc server to serve the async functions to the clients. --- include/tvm/runtime/c_runtime_api.h | 1 + include/tvm/runtime/device_api.h | 1 + python/tvm/_ffi/libinfo.py | 1 + python/tvm/_ffi/runtime_ctypes.py | 2 + python/tvm/autotvm/tophub.py | 1 + python/tvm/contrib/emcc.py | 1 + python/tvm/exec/rpc_proxy.py | 16 +- python/tvm/rpc/client.py | 4 + python/tvm/rpc/proxy.py | 7 +- python/tvm/runtime/module.py | 2 +- python/tvm/runtime/ndarray.py | 16 ++ src/runtime/cuda/cuda_common.h | 4 +- src/runtime/vulkan/vulkan.cc | 9 +- src/target/spirv/build_vulkan.cc | 21 ++- src/target/spirv/codegen_spirv.cc | 11 +- src/target/spirv/codegen_spirv.h | 5 +- src/target/spirv/intrin_rule_spirv.cc | 32 ++++ src/target/spirv/ir_builder.cc | 44 ++++- src/target/spirv/ir_builder.h | 4 + src/target/target.cc | 11 +- web/Makefile | 4 +- web/README.md | 13 ++ web/apps/browser/rpc_server.html | 2 +- web/emcc/tvmjs_support.cc | 276 ++++++++++++++++++++------- web/emcc/webgpu_runtime.cc | 253 +++++++++++++++++++++++++ web/package.json | 5 +- web/rollup.config.js | 6 +- web/src/compact.ts | 47 +++++ web/src/index.ts | 4 +- web/src/rpc_server.ts | 69 +++++-- web/src/runtime.ts | 342 +++++++++++++++++++++++++++++----- web/src/webgpu.ts | 337 +++++++++++++++++++++++++++++++++ web/tests/python/webgpu_rpc_test.py | 79 ++++++++ web/tests/python/websock_rpc_test.py | 6 +- web/tsconfig.json | 2 +- 35 files changed, 1459 insertions(+), 179 deletions(-) create mode 100644 web/emcc/webgpu_runtime.cc create mode 100644 web/src/compact.ts create mode 100644 web/src/webgpu.ts create mode 100644 web/tests/python/webgpu_rpc_test.py diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 79bcdc6..5d371ee 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -83,6 +83,7 @@ typedef enum { kOpenGL = 11, kDLMicroDev = 13, kDLHexagon = 14, + kDLWebGPU = 15 // AddExtraTVMType which is not in DLPack here } TVMDeviceExtType; diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 1206918..4ccaa3c 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -226,6 +226,7 @@ inline const char* DeviceName(int type) { case kDLROCM: return "rocm"; case kOpenGL: return "opengl"; case kDLExtDev: return "ext_dev"; + case kDLWebGPU: return "webgpu"; case kDLMicroDev: return "micro_dev"; case kDLHexagon: return "hexagon"; default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index de8f7b5..a1483a1 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -90,6 +90,7 @@ def find_lib_path(name=None, search_path=None, optional=False): if os.path.isdir(source_dir): dll_path.append(os.path.join(source_dir, "web", "dist", "wasm")) + dll_path.append(os.path.join(source_dir, "web", "dist")) dll_path = [os.path.realpath(x) for x in dll_path] if search_path is not None: diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 6b06ad0..0d6e5ac 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -147,6 +147,7 @@ class TVMContext(ctypes.Structure): 12: 'ext_dev', 13: 'micro_dev', 14: 'hexagon', + 15: 'webgpu' } STR2MASK = { 'llvm': 1, @@ -169,6 +170,7 @@ class TVMContext(ctypes.Structure): 'ext_dev': 12, 'micro_dev': 13, 'hexagon': 14, + 'webgpu': 15, } def __init__(self, device_type, device_id): super(TVMContext, self).__init__() diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index b34ab15..3fbccfe 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -66,6 +66,7 @@ def _alias(name): 'vtacpu': 'vta', 'metal': 'opencl', + 'webgpu': 'opencl', 'vulkan': 'opencl', 'nvptx': 'cuda', } diff --git a/python/tvm/contrib/emcc.py b/python/tvm/contrib/emcc.py index 6df205a..6e7e997 100644 --- a/python/tvm/contrib/emcc.py +++ b/python/tvm/contrib/emcc.py @@ -61,6 +61,7 @@ def create_tvmjs_wasm(output, objects += [find_lib_path("wasm_runtime.bc")[0]] objects += [find_lib_path("tvmjs_support.bc")[0]] + objects += [find_lib_path("webgpu_runtime.bc")[0]] cmd += ["-o", output] cmd += objects diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 59da8fa..eb80286 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -31,14 +31,20 @@ def find_example_resource(): curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) base_path = os.path.abspath(os.path.join(curr_path, "..", "..", "..")) index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html") - js_files = [ - os.path.join(base_path, "web/dist/tvmjs.bundle.js"), - os.path.join(base_path, "web/dist/wasm/tvmjs_runtime.wasi.js") + resource_files = [ + os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"), + os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js") ] - for fname in [index_page] + js_files: + resource_base = os.path.join(base_path, "web", "dist", "www") + if os.path.isdir(resource_base): + for fname in os.listdir(resource_base): + full_name = os.path.join(resource_base, fname) + if os.path.isfile(full_name): + resource_files.append(full_name) + for fname in [index_page] + resource_files: if not os.path.exists(fname): raise RuntimeError("Cannot find %s" % fname) - return index_page, js_files + return index_page, resource_files def main(args): diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 9997673..3f38c4f 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -190,6 +190,10 @@ class RPCSession(object): """Construct extension device.""" return self.context(12, dev_id) + def webgpu(self, dev_id=0): + """Construct WebGPU device.""" + return self.context(15, dev_id) + class LocalSession(RPCSession): """RPCSession interface backed by local environment. diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index 03746da..994e230 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -130,7 +130,7 @@ class ForwardHandler(object): def on_close_event(self): """on close event""" assert not self._done - logging.info("RPCProxy:on_close %s ...", self.name()) + logging.info("RPCProxy:on_close_event %s ...", self.name()) if self.match_key: key = self.match_key if self._proxy._client_pool.get(key, None) == self: @@ -158,10 +158,12 @@ class TCPHandler(tornado_util.TCPHandler, ForwardHandler): self.on_data(message) def on_close(self): + logging.info("RPCProxy: on_close %s ...", self.name()) + self._close_process = True + if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None - logging.info("%s Close socket..", self.name()) self.on_close_event() @@ -187,6 +189,7 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler): self.on_error(err) def on_close(self): + logging.info("RPCProxy: on_close %s ...", self.name()) if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index b580e3f..3cdb28f 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -319,7 +319,7 @@ class Module(object): if self.imported_modules: if enabled("llvm") and llvm_target_triple: - path_obj = temp.relpath("devc.o") + path_obj = temp.relpath("devc." + object_format) m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_triple) m.save(path_obj) files.append(path_obj) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 9b7e7c5..9f5f0f6 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -478,6 +478,22 @@ def hexagon(dev_id=0): return TVMContext(14, dev_id) +def webgpu(dev_id=0): + """Construct a webgpu device. + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + ctx : TVMContext + The created context + """ + return TVMContext(15, dev_id) + + cl = opencl mtl = metal diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h index 87cf3be..b7d9ecb 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/runtime/cuda/cuda_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 4e2f8cb..48fbdc7 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -750,8 +750,10 @@ class VulkanModuleNode final : public runtime::ModuleNode { } } - std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, - size_t num_pack_args) { + std::shared_ptr GetPipeline( + size_t device_id, + const std::string& func_name, + size_t num_pack_args) { const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); std::lock_guard lock(mutex_); const auto& cp = ecache_[device_id][func_name]; @@ -776,6 +778,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::vector arg_binding; std::vector arg_template; uint32_t num_pod = 0, num_buffer = 0; + { auto fit = fmap_.find(func_name); CHECK(fit != fmap_.end()); @@ -931,8 +934,6 @@ class VulkanModuleNode final : public runtime::ModuleNode { } private: - // the binary data - std::vector data_; // function information table. std::unordered_map smap_; // function information table. diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 4873557..825bdcb 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -70,7 +70,7 @@ class SPIRVTools { spv_context ctx_; }; -runtime::Module BuildSPIRV(IRModule mod, std::string target) { +runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restriction) { using tvm::runtime::Registry; using tvm::runtime::VulkanShader; @@ -98,7 +98,15 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target) { std::string f_name = global_symbol.value(); VulkanShader shader; - shader.data = cg.BuildFunction(f); + std::string entry = webgpu_restriction ? "main" : f_name; + shader.data = cg.BuildFunction(f, entry); + + if (webgpu_restriction) { + for (auto param : f->params) { + CHECK(param.dtype().is_handle()) + << "WebGPU does not yet support non-buffer arguments"; + } + } if (postproc != nullptr) { TVMByteArray arr; @@ -119,7 +127,14 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target) { } TVM_REGISTER_GLOBAL("target.build.vulkan") -.set_body_typed(BuildSPIRV); +.set_body_typed([](IRModule mod, std::string target) { + return BuildSPIRV(mod, target, false); +}); + +TVM_REGISTER_GLOBAL("target.build.webgpu") +.set_body_typed([](IRModule mod, std::string target) { + return BuildSPIRV(mod, target, true); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index be058b7..032a72a 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -30,7 +30,9 @@ namespace tvm { namespace codegen { -std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { +std::vector CodeGenSPIRV::BuildFunction( + const PrimFunc& f, + const std::string& name) { this->InitFuncState(); CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; @@ -77,12 +79,7 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { builder_->MakeInst(spv::OpReturn); builder_->MakeInst(spv::OpFunctionEnd); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - - builder_->CommitKernelFunction( - func_ptr, static_cast(global_symbol.value())); + builder_->CommitKernelFunction(func_ptr, name); return builder_->Finalize(); } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index b51e8ed..adbb59b 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -32,6 +32,7 @@ #include #include #include +#include #include "ir_builder.h" #include "../../runtime/thread_storage_scope.h" @@ -51,9 +52,11 @@ class CodeGenSPIRV: /*! * \brief Compile and add function f to the current module. * \param f The function to be added. + * \param name The name of the target function. * \return The final spirv module. */ - virtual std::vector BuildFunction(const PrimFunc& f); + virtual std::vector BuildFunction(const PrimFunc& f, + const std::string& name); /*! * \brief Create Value for expression e * \param e The expression to be created value for. diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index ead6952..d8b9e71 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -65,6 +65,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") .set_body(DispatchGLSLPureIntrin); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log") .set_body(DispatchGLSLPureIntrin); @@ -77,6 +78,37 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh") .set_body(DispatchGLSLPureIntrin); +// WebGPU rules. +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.round") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.trunc") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh") +.set_body(DispatchGLSLPureIntrin); + } // namespace spirv } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index bf43f11..7573b47 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -32,10 +32,14 @@ namespace spirv { void IRBuilder::InitHeader() { CHECK_EQ(header_.size(), 0U); header_.push_back(spv::MagicNumber); - // Use SPIR-V v1.0. This needs to be kept in sync (or at least behind) - // `VkApplicationInfo.apiVersion` in `vulkan.cc` to ensure Vulkan API - // validation passes. + + // Use the spirv version as indicated in the SDK. +#if SPV_VERSION >= 0x10300 + header_.push_back(0x10300); +#else header_.push_back(0x10000); +#endif + // generator: set to 0, unknown header_.push_back(0U); // Bound: set during Finalize @@ -146,11 +150,20 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, ib_.Begin(spv::OpMemberDecorate) .AddSeq(struct_type, 0, spv::DecorationOffset, 0) .Commit(&decorate_); + + +#if SPV_VERSION < 0x10300 + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. // runtime array are always decorated as BufferBlock(shader storage buffer) if (num_elems == 0) { this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); } +#else + this->Decorate(spv::OpDecorate, + struct_type, spv::DecorationBlock); +#endif struct_array_type_tbl_[key] = struct_type; return struct_type; } @@ -190,11 +203,21 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding) { + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. +#if SPV_VERSION >= 0x10300 + spv::StorageClass storage_class = spv::StorageClassStorageBuffer; +#else + spv::StorageClass storage_class = spv::StorageClassUniform; +#endif + SType sarr_type = GetStructArrayType(value_type, 0); - SType ptr_type = GetPointerType(sarr_type, spv::StorageClassUniform); + SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); + ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, spv::StorageClassUniform).Commit(&global_); + .AddSeq(ptr_type, val, storage_class).Commit(&global_); + this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set); this->Decorate(spv::OpDecorate, @@ -262,10 +285,13 @@ void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) void IRBuilder::StartFunction(const Value& func) { CHECK_EQ(func.flag, kFunction); - this->MakeInst( - spv::OpFunction, t_void_, func, 0, t_void_func_); + // add function declaration to the header. + ib_.Begin(spv::OpFunction).AddSeq( + t_void_, func, 0, t_void_func_).Commit(&func_header_); + spirv::Label start_label = this->NewLabel(); - this->StartLabel(start_label); + ib_.Begin(spv::OpLabel).AddSeq(start_label).Commit(&func_header_); + curr_label_ = start_label; } void IRBuilder::SetLocalSize(const Value& func, @@ -286,7 +312,7 @@ Value IRBuilder::Allocate(const SType& value_type, Value val = NewValue(ptr_type, kStructArrayPtr); if (storage_class == spv::StorageClassFunction) { ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, storage_class).Commit(&function_); + .AddSeq(ptr_type, val, storage_class).Commit(&func_header_); } else { ib_.Begin(spv::OpVariable) .AddSeq(ptr_type, val, storage_class).Commit(&global_); diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index bdfea4f..e9e04e8 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -40,6 +40,7 @@ namespace tvm { namespace codegen { namespace spirv { + /*! \brief Represent the SPIRV Type */ struct SType { /*! \brief The Id to represent type */ @@ -301,6 +302,7 @@ class IRBuilder { data.insert(data.end(), debug_.begin(), debug_.end()); data.insert(data.end(), decorate_.begin(), decorate_.end()); data.insert(data.end(), global_.begin(), global_.end()); + data.insert(data.end(), func_header_.begin(), func_header_.end()); data.insert(data.end(), function_.begin(), function_.end()); return data; } @@ -612,6 +614,8 @@ class IRBuilder { std::vector decorate_; /*! \brief Global segment: types, variables, types */ std::vector global_; + /*! \brief Function header segment */ + std::vector func_header_; /*! \brief Function segment */ std::vector function_; }; diff --git a/src/target/target.cc b/src/target/target.cc index 2cb72a2..c733eae 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -110,11 +110,15 @@ Target CreateTarget(const std::string& target_name, if (t->device_name == "intel_graphics") { t->thread_warp_size = 16; } - } else if (target_name == "metal" || target_name == "vulkan") { + } else if (target_name == "metal" || + target_name == "vulkan" || + target_name == "webgpu") { if (target_name == "metal") { t->device_type = kDLMetal; - } else { + } else if (target_name == "vulkan") { t->device_type = kDLVulkan; + } else { + t->device_type = kDLWebGPU; } t->keys_array.push_back(target_name); t->keys_array.push_back("gpu"); @@ -139,6 +143,9 @@ Target CreateTarget(const std::string& target_name, } else if (target_name == "hexagon") { t->keys_array.push_back("hexagon"); t->device_type = kDLHexagon; + } else if (target_name == "webgpu") { + t->keys_array.push_back("webgpu"); + t->device_type = kDLWebGPU; } else { LOG(ERROR) << "Unknown target name " << target_name << "; falling back to stackvm"; return target::stackvm(); diff --git a/web/Makefile b/web/Makefile index be7fa19..c0b8f07 100644 --- a/web/Makefile +++ b/web/Makefile @@ -20,7 +20,7 @@ TVM_ROOT=$(shell cd ..; pwd) INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include -.PHONY: clean all +.PHONY: clean all removetypedep all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js @@ -37,7 +37,7 @@ dist/wasm/%.bc: emcc/%.cc $(EMCC) $(EMCC_CFLAGS) -c -o dist/wasm/$*.bc $< -dist/wasm/tvmjs_runtime.wasm: dist/wasm/wasm_runtime.bc dist/wasm/tvmjs_support.bc +dist/wasm/tvmjs_runtime.wasm: dist/wasm/wasm_runtime.bc dist/wasm/tvmjs_support.bc dist/wasm/webgpu_runtime.bc @mkdir -p $(@D) $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/tvmjs_runtime.js $+ $(EMCC_LDFLAGS) diff --git a/web/README.md b/web/README.md index 66a64a3..358884c 100644 --- a/web/README.md +++ b/web/README.md @@ -82,3 +82,16 @@ The following is an example to reproduce this. - Browswer version: open https://localhost:8888, click connect to proxy - NodeJS version: `npm run rpc` - run `python tests/node/websock_rpc_test.py` to run the rpc client. + + +## WebGPU Experiments + +Web gpu is still experimental, so apis can change. +Right now we use the SPIRV to generate shaders that can be accepted by Chrome and Firefox. + +- Obtain a browser that support webgpu. + - So far only Chrome Canary on MacOS works + - Firefox should be close pending the support of Fence. +- Download vulkan SDK (1.1 or higher) that supports SPIRV 1.3 +- Start the WebSocket RPC +- run `python tests/node/webgpu_rpc_test.py` diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html index 22907f1..6d353e2 100644 --- a/web/apps/browser/rpc_server.html +++ b/web/apps/browser/rpc_server.html @@ -74,6 +74,6 @@
- + diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 97099e7..9ea65d0 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -37,6 +37,7 @@ #include #include #include +#include "../../src/runtime/rpc/rpc_local_session.h" extern "C" { // --- Additional C API for the Wasm runtime --- @@ -108,85 +109,224 @@ int TVMWasmFuncCreateFromCFunc(void* resource_handle, namespace tvm { namespace runtime { -// chrono in the WASI does not provide very accurate time support -// and also have problems in the i64 support in browser. -// We redirect the timer to a JS side time using performance.now -PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms) { - auto ftimer = [pf, ctx, number, repeat, min_repeat_ms]( - TVMArgs args, TVMRetValue *rv) { - - TVMRetValue temp; - auto finvoke = [&](int n) { - // start timing - for (int i = 0; i < n; ++i) { - pf.CallPacked(args, &temp); +// A special local session that can interact with async +// functions in the JS runtime. +class AsyncLocalSession : public LocalSession { + public: + AsyncLocalSession() { + } + + PackedFuncHandle GetFunction(const std::string& name) final { + if (name == "runtime.RPCTimeEvaluator") { + return get_time_eval_placeholder_.get(); + } else if (auto* fp = tvm::runtime::Registry::Get(name)) { + // return raw handle because the remote need to explicitly manage it. + return new PackedFunc(*fp); + } else if(auto* fp = tvm::runtime::Registry::Get("__async." + name)) { + auto* rptr = new PackedFunc(*fp); + async_func_set_.insert(rptr); + return rptr; + } else { + return nullptr; + } + } + + void FreeHandle(void* handle, int type_code) final { + if (type_code == kTVMPackedFuncHandle) { + auto it = async_func_set_.find(handle); + if (it != async_func_set_.end()) { + async_func_set_.erase(it); } - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - }; + } + if (handle != get_time_eval_placeholder_.get()) { + LocalSession::FreeHandle(handle, type_code); + } + } - auto* get_timer = runtime::Registry::Get("wasm.GetTimer"); - CHECK(get_timer != nullptr) << "Cannot find wasm.GetTimer in the global function"; - TypedPackedFunc timer_ms = (*get_timer)( - TypedPackedFunc(finvoke)); + void AsyncCallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + FAsyncCallback callback) final { + auto it = async_func_set_.find(func); + if (it != async_func_set_.end()) { + PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) { + int code = args[0]; + TVMRetValue rv; + rv = args[1]; + this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { + callback(RPCCode::kReturn, encoded_args); + }); + }); - std::ostringstream os; - finvoke(1); + TVMRetValue temp; + std::vector values(arg_values, arg_values + num_args); + std::vector type_codes(arg_type_codes, arg_type_codes + num_args); + values.emplace_back(TVMValue()); + type_codes.emplace_back(0); - int setup_number = number; + TVMArgsSetter setter(&values[0], &type_codes[0]); + // pass the callback as the last argument. + setter(num_args, packed_callback); - for (int i = 0; i < repeat; ++i) { - double duration_ms = 0.0; + auto* pf = static_cast(func); + pf->CallPacked(TVMArgs(values.data(), type_codes.data(), num_args + 1), &temp); + } else if (func == get_time_eval_placeholder_.get()) { + // special handle time evaluator. + try { + TVMArgs args(arg_values, arg_type_codes, num_args); + PackedFunc retfunc = this->GetTimeEvaluator( + args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + TVMRetValue rv; + rv = retfunc; + this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { + // mark as async. + async_func_set_.insert(encoded_args.values[1].v_handle); + callback(RPCCode::kReturn, encoded_args); + }); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); + } + } else { + LocalSession::AsyncCallFunc(func, arg_values, arg_type_codes, num_args, callback); + } + } - do { - if (duration_ms > 0.0) { - setup_number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random - } - duration_ms = timer_ms(setup_number); - } while (duration_ms < min_repeat_ms); + void AsyncCopyToRemote(void* local_from, + size_t local_from_offset, + void* remote_to, + size_t remote_to_offset, + size_t nbytes, + TVMContext remote_ctx_to, + DLDataType type_hint, + FAsyncCallback on_complete) final { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + try { + this->GetDeviceAPI(remote_ctx_to)->CopyDataFromTo( + local_from, local_from_offset, + remote_to, remote_to_offset, + nbytes, cpu_ctx, remote_ctx_to, type_hint, nullptr); + this->AsyncStreamWait(remote_ctx_to, nullptr, on_complete); + } catch (const std::runtime_error& e) { + this->SendException(on_complete, e.what()); + } + } - double speed = duration_ms / setup_number / 1000; - os.write(reinterpret_cast(&speed), sizeof(speed)); + void AsyncCopyFromRemote(void* remote_from, + size_t remote_from_offset, + void* local_to, + size_t local_to_offset, + size_t nbytes, + TVMContext remote_ctx_from, + DLDataType type_hint, + FAsyncCallback on_complete) final { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + try { + this->GetDeviceAPI(remote_ctx_from)->CopyDataFromTo( + remote_from, remote_from_offset, + local_to, local_to_offset, + nbytes, remote_ctx_from, cpu_ctx, type_hint, nullptr); + this->AsyncStreamWait(remote_ctx_from, nullptr, on_complete); + } catch (const std::runtime_error& e) { + this->SendException(on_complete, e.what()); } + } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} + void AsyncStreamWait(TVMContext ctx, + TVMStreamHandle stream, + FAsyncCallback on_complete) final { + if (ctx.device_type == kDLCPU) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; + on_complete(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } else { + CHECK(ctx.device_type == static_cast(kDLWebGPU)); + if (async_wait_ == nullptr) { + async_wait_ = tvm::runtime::Registry::Get("__async.wasm.WebGPUWaitForTasks"); + } + CHECK(async_wait_ != nullptr); + PackedFunc packed_callback([on_complete](TVMArgs args, TVMRetValue*) { + int code = args[0]; + on_complete(static_cast(code), + TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1)); + }); + (*async_wait_)(packed_callback); + } + } -TVM_REGISTER_GLOBAL("wasm.RPCTimeEvaluator") -.set_body_typed([](Optional opt_mod, - std::string name, - int device_type, - int device_id, - int number, - int repeat, - int min_repeat_ms) { - TVMContext ctx; - ctx.device_type = static_cast(device_type); - ctx.device_id = device_id; - - if (opt_mod.defined()) { - Module m = opt_mod.value(); - std::string tkey = m->type_key(); - return WrapWasmTimeEvaluator( - m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); - } else { - auto* pf = runtime::Registry::Get(name); - CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; - return WrapWasmTimeEvaluator( - *pf, ctx, number, repeat, min_repeat_ms); + bool IsAsync() const final { + return true; } + + private: + std::unordered_set async_func_set_; + std::unique_ptr get_time_eval_placeholder_ = std::make_unique(); + const PackedFunc* async_wait_{nullptr}; + + // time evaluator + PackedFunc GetTimeEvaluator(Optional opt_mod, + std::string name, + int device_type, + int device_id, + int number, + int repeat, + int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + return WrapWasmTimeEvaluator( + m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapWasmTimeEvaluator( + *pf, ctx, number, repeat, min_repeat_ms); + } + } + + // time evaluator + PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, + TVMContext ctx, + int number, + int repeat, + int min_repeat_ms) { + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms]( + TVMArgs args, TVMRetValue *rv) { + // the function is a async function. + PackedFunc on_complete = args[args.size() - 1]; + // keep argument alive in finvoke so that they + // can be used throughout the async benchmark + std::vector values(args.values, args.values + args.size() - 1); + std::vector type_codes(args.type_codes, args.type_codes + args.size() - 1); + + auto finvoke = [pf, values, type_codes](int n) { + TVMRetValue temp; + TVMArgs invoke_args(values.data(), type_codes.data(), values.size()); + for (int i = 0; i < n; ++i) { + pf.CallPacked(invoke_args, &temp); + } + }; + auto* time_exec = runtime::Registry::Get("__async.wasm.TimeExecution"); + CHECK(time_exec != nullptr) << "Cannot find wasm.GetTimer in the global function"; + (*time_exec)(TypedPackedFunc(finvoke), + ctx, number, repeat, min_repeat_ms, on_complete); + }; + return PackedFunc(ftimer); + } +}; + +TVM_REGISTER_GLOBAL("wasm.LocalSession") +.set_body_typed([]() { + return CreateRPCSessionModule(std::make_shared()); }); } // namespace runtime diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc new file mode 100644 index 0000000..537ab18 --- /dev/null +++ b/web/emcc/webgpu_runtime.cc @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file webgpu_runtime.cc + * \brief WebGPU runtime based on the TVM JS. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include +#include +#include +#include +#include "../../src/runtime/meta_data.h" +#include "../../src/runtime/workspace_pool.h" +#include "../../src/runtime/vulkan/vulkan_shader.h" + +namespace tvm { +namespace runtime { + +/*! \brief Thread local workspace */ +class WebGPUThreadEntry { + public: + /*! \brief thread local pool*/ + WorkspacePool pool; + /*! \brief constructor */ + WebGPUThreadEntry(); + // get the threadlocal workspace + static WebGPUThreadEntry* ThreadLocal(); +}; + + +// All the implementations are redirectly to the JS side. +class WebGPUDeviceAPI : public DeviceAPI { + public: + WebGPUDeviceAPI() { + auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUDeviceAPI"); + CHECK(fp != nullptr) << "Cannot find wasm.WebGPUContext in the env"; + auto getter = TypedPackedFunc(*fp); + alloc_space_ = getter("deviceAllocDataSpace"); + free_space_ = getter("deviceFreeDataSpace"); + copy_to_gpu_ = getter("deviceCopyToGPU"); + copy_from_gpu_ = getter("deviceCopyFromGPU"); + copy_within_gpu_ = getter("deviceCopyWithinGPU"); + } + + void SetDevice(TVMContext ctx) final { + } + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { + if (kind == kExist) { + *rv = 1; + } + } + + void* AllocDataSpace(TVMContext ctx, + size_t nbytes, + size_t alignment, + DLDataType type_hint) final { + + double ptr_number = alloc_space_(nbytes); + return reinterpret_cast(static_cast(ptr_number)); + } + + void FreeDataSpace(TVMContext ctx, void* ptr) final { + return free_space_(ptr); + } + + void CopyDataFromTo(const void* from, + size_t from_offset, + void* to, size_t to_offset, size_t size, + TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, + TVMStreamHandle stream) final { + if (static_cast(ctx_from.device_type) == kDLWebGPU && + static_cast(ctx_to.device_type) == kDLWebGPU) { + CHECK_EQ(ctx_from.device_id, ctx_to.device_id); + copy_within_gpu_(const_cast(from), from_offset, to, to_offset, size); + } else if (static_cast(ctx_from.device_type) == kDLWebGPU && + ctx_to.device_type == kDLCPU) { + void* to_ptr = static_cast(to) + to_offset; + copy_from_gpu_(const_cast(from), from_offset, to_ptr, size); + } else if (ctx_from.device_type == kDLCPU && + static_cast(ctx_to.device_type) == kDLWebGPU) { + void* from_ptr = static_cast(const_cast(from)) + from_offset; + copy_to_gpu_(from_ptr, to, to_offset, size); + } else { + LOG(FATAL) << "expect copy from/to WebGPU or between WebGPU"; + } + } + + TVMStreamHandle CreateStream(TVMContext ctx) final { + LOG(FATAL) << "Not implemented"; + return nullptr; + } + + void FreeStream(TVMContext ctx, TVMStreamHandle stream) final { + LOG(FATAL) << "Not implemented"; + return; + } + + void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) { + LOG(FATAL) << "Not implemented"; + return; + } + + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + LOG(FATAL) << "Not implemented"; + } + + void SetStream(TVMContext ctx, TVMStreamHandle stream) final { + LOG(FATAL) << "Not implemented"; + return; + } + + void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { + return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); + } + + void FreeWorkspace(TVMContext ctx, void* data) final { + WebGPUThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); + } + + static const std::shared_ptr& Global() { + static std::shared_ptr inst = + std::make_shared(); + return inst; + } + + private: + // NOTE: js return number as double. + TypedPackedFunc alloc_space_; + TypedPackedFunc free_space_; + TypedPackedFunc copy_to_gpu_; + TypedPackedFunc copy_from_gpu_; + TypedPackedFunc copy_within_gpu_; +}; + + +typedef dmlc::ThreadLocalStore WebGPUThreadStore; + +WebGPUThreadEntry::WebGPUThreadEntry() + : pool(static_cast(kDLWebGPU), WebGPUDeviceAPI::Global()) { +} + +WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { + return WebGPUThreadStore::Get(); +} + + +class WebGPUModuleNode final : public runtime::ModuleNode { + public: + explicit WebGPUModuleNode(std::unordered_map smap, + std::unordered_map fmap, + std::string source) + : smap_(smap), fmap_(fmap), source_(source) { + auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUCreateShader"); + CHECK(fp != nullptr); + create_shader_ = *fp; + } + + const char* type_key() const final { return "webgpu"; } + + PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) final { + auto it = smap_.find(name); + if (it != smap_.end()) { + FunctionInfo info = fmap_.at(name); + info.name = name; + std::ostringstream os; + dmlc::JSONWriter writer(&os); + info.Save(&writer); + TVMByteArray arr; + arr.data = reinterpret_cast(it->second.data.data()); + arr.size = it->second.data.size() * sizeof(it->second.data[0]); + return create_shader_(os.str(), arr); + } else { + return PackedFunc(nullptr); + } + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + LOG(FATAL) << "Not implemented"; + } + + void SaveToBinary(dmlc::Stream* stream) final { + LOG(FATAL) << "Not implemented"; + } + + std::string GetSource(const std::string& format) final { + // can only return source code. + return source_; + } + + private: + // function information table. + std::unordered_map smap_; + // function information table. + std::unordered_map fmap_; + // The source + std::string source_; + // Callback to get the GPU function. + TypedPackedFunc create_shader_; +}; + + +Module WebGPUModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::unordered_map smap; + std::unordered_map fmap; + + std::string fmt; + stream->Read(&fmt); + stream->Read(&fmap); + stream->Read(&smap); + return Module(make_object(smap, fmap, "")); +} + +// for now webgpu is hosted via a vulkan module. +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan") +.set_body_typed(WebGPUModuleLoadBinary); + +TVM_REGISTER_GLOBAL("device_api.webgpu") +.set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = WebGPUDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); + +} // namespace runtime +} // namespace tvm diff --git a/web/package.json b/web/package.json index 76aa111..f6b700d 100644 --- a/web/package.json +++ b/web/package.json @@ -5,7 +5,6 @@ "version": "0.7.0", "scripts": { "build": "tsc -b", - "watch": "tsc -b -w", "lint": "eslint -c .eslintrc.json .", "bundle": "npm run build && rollup -c rollup.config.js", "example": "npm run bundle && node apps/node/example.js", @@ -15,6 +14,7 @@ "devDependencies": { "typescript": "^3.8.3", "@types/node": "^12.12.37", + "@webgpu/types": "^0.0.24", "eslint": "^6.8.0", "@typescript-eslint/eslint-plugin": "^2.29.0", "@typescript-eslint/parser": "^2.29.0", @@ -24,6 +24,5 @@ "@rollup/plugin-commonjs": "^11.1.0", "@rollup/plugin-node-resolve": "^7.1.3", "rollup-plugin-typescript2": "^0.27.0" - }, - "dependencies": {} + } } diff --git a/web/rollup.config.js b/web/rollup.config.js index 0046e44..9090c77 100644 --- a/web/rollup.config.js +++ b/web/rollup.config.js @@ -27,8 +27,10 @@ export default { format: 'umd', name: 'tvmjs', exports: 'named', - globals: {'ws': 'ws'} + globals: {'ws': 'ws', + 'perf_hooks': 'perf_hooks', + '@webgpu/types': 'webgputypes'} }, plugins: [commonjs(), resolve()], - external: ['ws'] + external: ['ws', 'perf_hooks', '@webgpu/types'] }; diff --git a/web/src/compact.ts b/web/src/compact.ts new file mode 100644 index 0000000..29569b5 --- /dev/null +++ b/web/src/compact.ts @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** NodeJS and Web compact layer */ + +/** + * Get performance masurement. + */ +export function getPeformance(): Performance { + if (typeof performance == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const performanceNode = require("perf_hooks"); + return performanceNode.performance as Performance; + } else { + return performance as Performance; + } +} + +/** + * Create a new websocket for a given URL + * @param url The url. + */ +export function createWebSocket(url: string): WebSocket { + if (typeof WebSocket == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const WebSocket = require("ws"); + return new WebSocket(url); + } else { + return new (WebSocket as any)(url); + } + +} \ No newline at end of file diff --git a/web/src/index.ts b/web/src/index.ts index 5d7d7cc..2d99fc9 100644 --- a/web/src/index.ts +++ b/web/src/index.ts @@ -24,4 +24,6 @@ export { } from "./runtime"; export { Disposable, LibraryProvider } from "./types"; export { RPCServer } from "./rpc_server"; -export { wasmPath } from "./support"; \ No newline at end of file +export { wasmPath } from "./support"; +export { detectGPUDevice } from "./webgpu"; +export { assert } from "./support"; \ No newline at end of file diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index 054a1b6..50227dc 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -19,8 +19,9 @@ import { SizeOf, TypeCode } from "./ctypes"; import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; +import { detectGPUDevice } from "./webgpu"; +import * as compact from "./compact"; import * as runtime from "./runtime"; -import { Class } from "estree"; enum RPCServerState { InitHeader, @@ -79,6 +80,7 @@ export class RPCServer { state: RPCServerState = RPCServerState.InitHeader; logger: (msg: string) => void; getImports: () => Record; + private pendingSend: Promise = Promise.resolve(); private name: string; private inst?: runtime.Instance = undefined; private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void; @@ -102,16 +104,7 @@ export class RPCServer { this.logger = logger; this.checkLittleEndian(); - - if (typeof WebSocket == "undefined") { - // eslint-disable-next-line @typescript-eslint/no-var-requires - const WebSocket = require("ws"); - this.socket = new WebSocket(url); - } else { - this.socket = new (WebSocket as any)(url); - } - - //this.socket = this.getSocket(url); + this.socket = compact.createWebSocket(url); this.socket.binaryType = "arraybuffer"; this.socket.addEventListener("open", (event: Event) => { @@ -132,6 +125,8 @@ export class RPCServer { } if (this.state == RPCServerState.ReceivePacketHeader) { this.log("Closing the server in clean state"); + this.log("Automatic reconnecting.."); + new RPCServer(this.url, this.key, this.getImports, this.logger); } else { this.log("Closing the server, final state=" + this.state); } @@ -247,11 +242,26 @@ export class RPCServer { ): void { // start the server assert(args[0] == "rpc.WasmSession"); - assert(args[1] instanceof Uint8Array); assert(this.pendingBytes == 0); - runtime.instantiate(args[1].buffer, this.getImports()) - .then((inst: runtime.Instance) => { + const asyncInitServer = async (): Promise => { + assert(args[1] instanceof Uint8Array); + const inst = await runtime.instantiate( + args[1].buffer, + this.getImports(), + this.logger + ); + try { + const gpuDevice: GPUDevice | undefined = await detectGPUDevice(); + if (gpuDevice !== undefined) { + const label = gpuDevice.label?.toString() || "WebGPU"; + this.log("Initialize GPU device: " + label); + inst.initWebGPU(gpuDevice); + } + } catch (err) { + this.log("Cannnot initialize WebGPU, " + err.toString()); + } + this.inst = inst; const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); @@ -259,7 +269,30 @@ export class RPCServer { (cbytes: Uint8Array): runtime.Scalar => { assert(this.inst !== undefined); if (this.socket.readyState == 1) { - this.socket.send(cbytes); + // WebSocket will automatically close the socket + // if we burst send data that exceeds its internal buffer + // wait a bit before we send next one. + const sendDataWithCongestionControl = async (): Promise => { + const packetSize = 4 << 10; + const maxBufferAmount = 4 * packetSize; + const waitTimeMs = 20; + for ( + let offset = 0; + offset < cbytes.length; + offset += packetSize + ) { + const end = Math.min(offset + packetSize, cbytes.length); + while (this.socket.bufferedAmount >= maxBufferAmount) { + await new Promise((r) => setTimeout(r, waitTimeMs)); + } + this.socket.send(cbytes.slice(offset, end)); + } + }; + // Chain up the pending send so that the async send is always in-order. + this.pendingSend = this.pendingSend.then( + sendDataWithCongestionControl + ); + // Directly return since the data are "sent" from the caller's pov. return this.inst.scalar(cbytes.length, "int32"); } else { return this.inst.scalar(0, "int32"); @@ -285,7 +318,7 @@ export class RPCServer { // The RPC will look for "rpc.wasmSession" // and we will redirect it to the correct local session. // register the callback to redirect the session to local. - const flocal = this.inst.getGlobalFunc("rpc.LocalSession"); + const flocal = this.inst.getGlobalFunc("wasm.LocalSession"); const localSession = flocal(); flocal.dispose(); assert(localSession instanceof runtime.Module); @@ -307,8 +340,10 @@ export class RPCServer { this.state = RPCServerState.ReceivePacketHeader; // call process events in case there are bufferred data. this.processEvents(); - }); + }; + this.state = RPCServerState.WaitForCallback; + asyncInitServer(); } private log(msg: string): void { diff --git a/web/src/runtime.ts b/web/src/runtime.ts index cd9b967..bcf7be7 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -25,7 +25,9 @@ import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array } from "./support"; import { Environment } from "./environment"; +import { WebGPUContext } from "./webgpu"; +import * as compact from "./compact"; import * as ctypes from "./ctypes"; /** @@ -42,8 +44,8 @@ class FFILibrary implements Disposable { wasm32: boolean; memory: Memory; exports: Record; + webGPUContext?: WebGPUContext; private wasmInstance: WebAssembly.Instance; - private recycledCallStacks: Array = []; constructor( @@ -174,8 +176,8 @@ const DeviceEnumToStr: Record = { 1: "cpu", 2: "gpu", 4: "opencl", - 7: "vulkan", 8: "metal", + 15: "webgpu" }; const DeviceStrToEnum: Record = { @@ -186,6 +188,7 @@ const DeviceStrToEnum: Record = { opencl: 4, vulkan: 7, metal: 8, + webgpu: 15 }; /** @@ -203,6 +206,9 @@ export class DLContext { const tp = typeof deviceType; if (tp == "string") { this.deviceType = DeviceStrToEnum[deviceType]; + if (this.deviceType == undefined) { + throw new Error("Cannot recogonize deviceType " + deviceType); + } } else if (tp == "number") { this.deviceType = deviceType as number; } else { @@ -215,14 +221,11 @@ export class DLContext { /** * Synchronize the context */ - sync(): void { - this.lib.checkCall( - (this.lib.exports.TVMSynchronize as ctypes.FTVMSynchronize)( - this.deviceType, - this.deviceId, - 0 - ) - ); + async sync(): Promise { + if (this.deviceType == DeviceStrToEnum.webgpu) { + assert(this.lib.webGPUContext !== undefined); + await this.lib.webGPUContext.sync(); + } } toString(): string { @@ -284,17 +287,24 @@ export class NDArray implements Disposable { shape: Array; /** Context of the array. */ context: DLContext; - + /** Whether it is a temporary view that can become invalid after the call. */ + private isView: boolean; private byteOffset: number; private dltensor: Pointer; + private dataPtr: Pointer; private lib: FFILibrary; private dlDataType: DLDataType; - constructor(handle: Pointer, lib: FFILibrary) { + constructor(handle: Pointer, isView: boolean, lib: FFILibrary) { this.handle = handle; + this.isView = isView; this.lib = lib; - this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + if (this.isView) { + this.dltensor = handle; + } else { + this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + } // constant offsets. const arrayOffsetData = 0; const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); @@ -308,6 +318,8 @@ export class NDArray implements Disposable { const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); + // dataPtr + this.dataPtr = lib.memory.loadPointer(this.dltensor); // ndim this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); // shape @@ -333,7 +345,7 @@ export class NDArray implements Disposable { } dispose(): void { - if (this.handle != 0) { + if (this.handle != 0 && !this.isView) { this.lib.checkCall( (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) ); @@ -347,7 +359,7 @@ export class NDArray implements Disposable { * @param data The source data array. * @returns this */ - copyFrom(data: NDArray | Array): this { + copyFrom(data: NDArray | Array | Float32Array): this { if (data instanceof NDArray) { this.lib.checkCall( (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( @@ -421,9 +433,13 @@ export class NDArray implements Disposable { * @returns The result array. */ toRawBytes(): Uint8Array { + if (this.context.deviceType != DeviceStrToEnum.cpu) { + throw new Error("Can only synchronize copy for GPU array, use copyfrom instead."); + } const size = this.shape.reduce((a, b) => { return a * b; }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; const stack = this.lib.getOrAllocCallStack(); @@ -545,6 +561,114 @@ export class Module implements Disposable { } /** + * Graph runtime. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +class GraphRuntime implements Disposable { + module: Module; + private packedSetInput: PackedFunc; + private packedRun: PackedFunc; + private packedGetOutput: PackedFunc; + private packedLoadParams: PackedFunc; + + /** + * COnstructor + * @param module The underlying module. + */ + constructor(module: Module) { + this.module = module; + this.packedSetInput = module.getFunction("set_input"); + this.packedRun = module.getFunction("run"); + this.packedGetOutput = module.getFunction("get_output"); + this.packedLoadParams = module.getFunction("load_params"); + } + + dispose(): void { + this.packedSetInput.dispose(); + this.packedRun.dispose(); + this.packedGetOutput.dispose(); + } + + /** + * Set input to the executor. + * + * @param key The input key. + * @param value The value to get set. + */ + setInput(key: number | string, value: NDArray): void { + if (typeof key == "number") { + this.packedSetInput(new Scalar(key, "int32"), value); + } else { + this.packedSetInput(key, value); + + } + } + + /** + * Execute the underlying graph. + */ + run(): void { + this.packedRun(); + } + + /** + * Get index-th output. + * @param index The index number. + * @param out The optional output storage parameters. + * @returns The output array. + */ + getOutput(index: number, out: NDArray | undefined = undefined): NDArray { + if (out !== undefined) { + this.packedGetOutput(new Scalar(index, "int32"), out) + return out; + } else { + return this.packedGetOutput(new Scalar(index, "int32")); + } + } + + /** + * Load parameters from parameter binary. + * @param paramBinary The parameter binary. + */ + loadParams(paramBinary: Uint8Array): void { + this.packedLoadParams(paramBinary); + } + + /** + * Benchmark stable execution of the graph(without data copy). + * @params ctx The context to sync during each run. + * @number The number of times to compute the average. + * @repeat The number of times to repeat the run. + */ + async benchmarkRuns(ctx: DLContext, number=10, repeat=4): Promise { + // Skip first run as it can involve GPU warmup and module loading time. + const perf = compact.getPeformance(); + const results = []; + this.run(); + await ctx.sync(); + for (let k = 0; k < repeat; ++k) { + const tstart = perf.now(); + for (let i = 0; i < number; ++i) { + this.run(); + } + await ctx.sync(); + const tend = perf.now(); + results.push((tend - tstart) / number); + } + return results; + } +} + +/** Code used as the first argument of the async callback. */ +const enum AyncCallbackCode { + kReturn = 4, + kException = 5, +} + +/** * TVM runtime instance. */ export class Instance implements Disposable { @@ -789,11 +913,27 @@ export class Instance implements Disposable { * @param deviceId The device index. * @returns The created context. */ - context(deviceType: number | string, deviceId: number): DLContext { + context(deviceType: number | string, deviceId = 0): DLContext { return new DLContext(deviceType, deviceId, this.lib); } /** + * Create a new cpu {@link DLContext} + * @param deviceId The device index. + */ + cpu(deviceId = 0): DLContext { + return this.context("cpu", deviceId); + } + + /** + * Create a new webgpu {@link DLContext} + * @param deviceId The device index. + */ + webgpu(deviceId = 0): DLContext { + return this.context("webgpu", deviceId); + } + + /** * Create an empty {@link NDArray} with given shape and dtype. * * @param shape The shape of the array. @@ -831,36 +971,125 @@ export class Instance implements Disposable { outPtr ) ); - const ret = new NDArray(this.memory.loadPointer(outPtr), this.lib); + const ret = new NDArray(this.memory.loadPointer(outPtr), false, this.lib); this.lib.recycleCallStack(stack); return ret; } + /** + * Create a new graph runtime. + * + * @param graphJson The graph runtime json file. + * @param lib The underlying library. + * @param ctx The execution context of the graph. + */ + createGraphRuntime( + graphJson: string, + lib: Module, + ctx: DLContext + ): GraphRuntime { + const fcreate = this.getGlobalFunc("tvm.graph_runtime.create"); + const module = fcreate( + graphJson, + lib, + this.scalar(ctx.deviceType, "int32"), + this.scalar(ctx.deviceId, "int32")) as Module; + return new GraphRuntime(module); + } + + + /** + * Register an asyncfunction to be global function in the server. + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note The async function will only be used for serving remote calls in the rpc. + */ + registerAsyncServerFunc( + name: string, + func: Function, + override = false + ): void { + const asyncVariant = (...args: Array): void => { + const fargs = args.slice(0, args.length - 1); + const callback = args[args.length - 1] as PackedFunc; + const promise: Promise = func(...fargs); + promise.then((rv: any) => { + callback(this.scalar(AyncCallbackCode.kReturn, "int32"), rv); + }); + }; + this.registerFunc("__async." + name, asyncVariant, override); + } + + /** + * Initialize webgpu in the runtime. + * @param device The given GPU device. + */ + initWebGPU(device: GPUDevice): void { + const webGPUContext = new WebGPUContext( + this.memory, device + ); + this.registerFunc("wasm.WebGPUDeviceAPI", (name: string) => { + return webGPUContext.getDeviceAPI(name); + }); + this.registerFunc("wasm.WebGPUCreateShader", (info: string, data: Uint8Array) => { + return webGPUContext.createShader(info, data); + }); + this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { + await webGPUContext.sync(); + }); + this.lib.webGPUContext = webGPUContext; + } + /** Register global packed functions needed by the backend to the env. */ private registerEnvGlobalPackedFuncs(): void { // Register the timer function to enable the time_evaluator. - let perf: Performance; - if (typeof performance == "undefined") { - // eslint-disable-next-line @typescript-eslint/no-var-requires - const performanceNode = require('perf_hooks'); - perf = performanceNode.performance as Performance; - } else { - perf = performance as Performance; - } - - const getTimer = (func: PackedFunc) => { - return (n: number): number => { - const nscalar = this.scalar(n, "int32"); - const tstart: number = perf.now(); - func(nscalar); - const tend: number = perf.now(); - return tend - tstart; + const perf = compact.getPeformance(); + + // Helper function to time the finvoke + const timeExecution = async ( + finvoke: PackedFunc, + ctx: DLContext, + nstep: number, + repeat: number, + minRepeatMs: number + ): Promise => { + finvoke(this.scalar(1, "int32")); + await ctx.sync(); + const result = []; + let setupNumber: number = nstep; + + for (let i = 0; i < repeat; ++i) { + let durationMs = 0.0; + do { + if (durationMs > 0.0) { + setupNumber = Math.floor( + Math.max(minRepeatMs / (durationMs / nstep) + 1, nstep * 1.618) + ); + } + const tstart: number = perf.now(); + finvoke(this.scalar(setupNumber, "int32")); + await ctx.sync(); + const tend: number = perf.now(); + + durationMs = tend - tstart; + } while (durationMs < minRepeatMs); + const speed = durationMs / setupNumber / 1000; + result.push(speed); } + const ret = new Float64Array(result.length); + ret.set(result); + return new Uint8Array(ret.buffer); + }; + + const addOne = async (x: number): Promise => { + await new Promise(resolve => setTimeout(resolve, 100)); + return x + 1; }; - this.registerFunc("wasm.GetTimer", getTimer); - const rpcWrapTimeEvaluator = this.getGlobalFunc("wasm.RPCTimeEvaluator"); - this.registerFunc("runtime.RPCTimeEvaluator", rpcWrapTimeEvaluator, true); - rpcWrapTimeEvaluator.dispose(); + + this.registerAsyncServerFunc("wasm.TimeExecution", timeExecution); + this.registerAsyncServerFunc("testing.asyncAddOne", addOne); } private createPackedFuncFromCFunc( @@ -924,6 +1153,10 @@ export class Instance implements Disposable { stack.storePtr(valueOffset, val.value); stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle); } + } else if (val instanceof DLContext) { + stack.storeI32(valueOffset, val.deviceType); + stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); + stack.storeI32(codeOffset, TypeCode.TVMContext); } else if (tp == "number") { stack.storeF64(valueOffset, val); stack.storeI32(codeOffset, TypeCode.Float); @@ -984,7 +1217,7 @@ export class Instance implements Disposable { ); } tcode = lib.memory.loadI32(codePtr); - jsArgs.push(this.retValueToJS(valuePtr, tcode)); + jsArgs.push(this.retValueToJS(valuePtr, tcode, true)); } const rv = func(...jsArgs); @@ -1041,7 +1274,7 @@ export class Instance implements Disposable { ) ); - const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr)); + const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr), false); this.lib.recycleCallStack(stack); return ret; }; @@ -1055,15 +1288,22 @@ export class Instance implements Disposable { return ret as PackedFunc; } - private retValueToJS(rvaluePtr: Pointer, tcode: number): any { + private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { switch (tcode) { case TypeCode.Int: case TypeCode.UInt: return this.memory.loadI64(rvaluePtr); case TypeCode.Float: return this.memory.loadF64(rvaluePtr); + case TypeCode.TVMOpaqueHandle: { + return this.memory.loadPointer(rvaluePtr); + } case TypeCode.TVMNDArrayHandle: { - return new NDArray(this.memory.loadPointer(rvaluePtr), this.lib); + return new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib); + } + case TypeCode.TVMDLTensorHandle: { + assert(callbackArg); + return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib); } case TypeCode.TVMPackedFuncHandle: { return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); @@ -1077,10 +1317,15 @@ export class Instance implements Disposable { } ); } - case TypeCode.Null: - return undefined; + case TypeCode.Null: return undefined; + case TypeCode.TVMContext: { + const deviceType = this.memory.loadI32(rvaluePtr); + const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); + return this.context(deviceType, deviceId); + } case TypeCode.TVMStr: { - return this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + return ret; } case TypeCode.TVMBytes: { return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); @@ -1098,12 +1343,17 @@ export class Instance implements Disposable { * a WASI object, or an object containing wasmLibraryProvider field. * We can take benefit of syslib implementations from the Emscripten * by passing its generated js Module as the imports. + * + * @param bufferSource The source to be compiled. + * @param importObject The import objects. + * @param logger The system logger. */ export function instantiate( bufferSource: ArrayBuffer, - importObject: Record = {} + importObject: Record = {}, + logger: (msg: string) => void = console.log ): Promise { - const env = new Environment(importObject); + const env = new Environment(importObject, logger); return WebAssembly.instantiate(bufferSource, env.imports).then( (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts new file mode 100644 index 0000000..640f7b4 --- /dev/null +++ b/web/src/webgpu.ts @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import "@webgpu/types"; +import { assert } from "./support"; +import { Pointer } from "./ctypes"; +import { Memory } from "./memory"; + +/** A pointer to points to the raw address space. */ +export type GPUPointer = number; + +/** + * DetectGPU device in the environment. + */ +export async function detectGPUDevice(): Promise { + if (typeof navigator !== "undefined" && navigator.gpu !== undefined) { + const adapter = await navigator.gpu.requestAdapter(); + return await adapter.requestDevice(); + } else { + return undefined; + } +} + +interface FunctionInfo { + name: string; + arg_types: Array; + thread_axis_tags: Array; +} + +/** + * WebGPU context + * Manages all the webgpu resources here. + */ +export class WebGPUContext { + device: GPUDevice; + memory: Memory; + + //private readBuffer:; + private bufferTable: Array = [undefined]; + private bufferTableFreeId: Array = []; + private pendingRead: Promise = Promise.resolve(); + private numPendingReads = 0; + + constructor(memory: Memory, device: GPUDevice) { + this.memory = memory; + this.device = device; + } + + /** + * Wait for all pending GPU tasks to complete + */ + async sync(): Promise { + const fence = this.device.defaultQueue.createFence(); + this.device.defaultQueue.signal(fence, 1); + if (this.numPendingReads != 0) { + // eslint-disable-next-line @typescript-eslint/no-empty-function + await Promise.all([fence.onCompletion(1), this.pendingRead]); + } else { + await fence.onCompletion(1); + } + } + + /** + * Create a PackedFunc that runs the given shader + * + * @param info The function information in json. + * @param data The shader data(in SPIRV) + */ + createShader(info: string, data: Uint8Array): Function { + const finfo = JSON.parse(info); + const layoutEntries: Array = []; + for (let i = 0; i < finfo.arg_types.length; ++i) { + const dtype = finfo.arg_types[i]; + if (dtype == "handle") { + layoutEntries.push({ + binding: i, + visibility: GPUShaderStage.COMPUTE, + type: "storage-buffer" + }); + } else { + throw new Error("Cannot handle argument type " + dtype + " in WebGPU shader"); + } + } + const bindGroupLayout = this.device.createBindGroupLayout({ + entries: layoutEntries + }); + + const pipeline = this.device.createComputePipeline({ + layout: this.device.createPipelineLayout({ + bindGroupLayouts: [ bindGroupLayout ] + }), + computeStage: { + module: this.device.createShaderModule({ + code: new Uint32Array(data.buffer) + }), + entryPoint: "main" + } + }); + + const dispatchToDim: Array = []; + + for (let i = 0; i < finfo.thread_axis_tags.length; ++i) { + const tag: string = finfo.thread_axis_tags[i]; + if (tag.startsWith("blockIdx.")) { + const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target); + } else if (tag.startsWith("threadIdx.")) { + const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target + 3); + } else { + throw new Error("Cannot handle thread_axis " + tag); + } + } + + const submitShader = (...args: Array): void => { + const commandEncoder = this.device.createCommandEncoder(); + const compute = commandEncoder.beginComputePass(); + compute.setPipeline(pipeline); + const bindGroupEntries: Array = []; + assert(args.length == layoutEntries.length + dispatchToDim.length); + + for (let i = 0; i < layoutEntries.length; ++i) { + bindGroupEntries.push({ + binding: i, + resource: { + buffer: this.gpuBufferFromPtr(args[i]) + } + }); + } + + compute.setBindGroup(0, this.device.createBindGroup({ + layout: bindGroupLayout, + entries: bindGroupEntries + })); + const wl: Array = [1, 1, 1, 1, 1, 1]; + for (let i = 0; i < dispatchToDim.length; ++i) { + wl[dispatchToDim[i]] = args[layoutEntries.length + i]; + } + compute.dispatch(wl[0], wl[1], wl[2]); + compute.endPass(); + const command = commandEncoder.finish(); + this.device.defaultQueue.submit([command]); + }; + + return submitShader; + } + + /** + * Get the device API according to its name + * @param The name of the API. + * @returns The corresponding device api. + */ + getDeviceAPI(name: string): Function { + if (name == "deviceAllocDataSpace") { + return (nbytes: number): GPUPointer => { + return this.deviceAllocDataSpace(nbytes); + }; + } else if (name == "deviceFreeDataSpace") { + return (ptr: GPUPointer): void => { + return this.deviceFreeDataSpace(ptr); + }; + } else if (name == "deviceCopyToGPU") { + return ( + from: Pointer, + to: GPUPointer, + toOffset: number, + nbytes: number + ): void => { + this.deviceCopyToGPU(from, to, toOffset, nbytes); + }; + } else if (name == "deviceCopyFromGPU") { + return ( + from: GPUPointer, + fromOffset: number, + to: Pointer, + nbytes: number + ): void => { + this.deviceCopyFromGPU(from, fromOffset, to, nbytes); + }; + } else if (name == "deviceCopyWithinGPU") { + return ( + from: GPUPointer, + fromOffset: number, + to: Pointer, + toOffset: number, + nbytes: number + ): void => { + this.deviceCopyWithinGPU(from, fromOffset, to, toOffset, nbytes); + }; + } else { + throw new Error("Unknown DeviceAPI function " + name); + } + + } + + // DeviceAPI + private deviceAllocDataSpace(nbytes: number): GPUPointer { + const buffer = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + return this.attachToBufferTable(buffer); + } + + private deviceFreeDataSpace(ptr: GPUPointer): void { + const idx = ptr; + const buffer = this.bufferTable[idx]; + this.bufferTable[idx] = undefined; + assert(buffer !== undefined); + this.bufferTableFreeId.push(idx); + buffer.destroy(); + } + + private deviceCopyToGPU( + from: Pointer, + to: GPUPointer, + toOffset: number, + nbytes: number + ): void { + // Perhaps it would be more useful to use a staging buffer? + const [gpuTemp, cpuTemp] = this.device.createBufferMapped({ + size: nbytes, + usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC, + }); + + const viewU8 = new Uint8Array(cpuTemp); + viewU8.set(this.memory.loadRawBytes(from, nbytes)); + gpuTemp.unmap(); + + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + gpuTemp, + 0, + this.gpuBufferFromPtr(to), + toOffset, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + gpuTemp.destroy(); + } + + private deviceCopyFromGPU( + from: GPUPointer, + fromOffset: number, + to: Pointer, + nbytes: number + ): void { + // Perhaps it would be more useful to resuse a staging buffer? + const gpuTemp = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + this.gpuBufferFromPtr(from), + fromOffset, + gpuTemp, + 0, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + + this.numPendingReads += 1; + const readEvent = gpuTemp.mapReadAsync().then((data: ArrayBuffer) => { + this.memory.storeRawBytes(to, new Uint8Array(data)); + this.numPendingReads -= 1; + gpuTemp.destroy(); + }); + + if (this.numPendingReads == 1) { + this.pendingRead = readEvent; + } else { + this.pendingRead = Promise.all([ + this.pendingRead, + readEvent, + // eslint-disable-next-line @typescript-eslint/no-empty-function + ]).then(() => {}); + } + } + + private deviceCopyWithinGPU( + from: GPUPointer, + fromOffset: number, + to: Pointer, + toOffset: number, + nbytes: number + ): void { + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + this.gpuBufferFromPtr(from), + fromOffset, + this.gpuBufferFromPtr(to), + toOffset, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + } + + private gpuBufferFromPtr(ptr: GPUPointer): GPUBuffer { + const buffer = this.bufferTable[ptr]; + assert(buffer !== undefined); + return buffer; + } + + private attachToBufferTable(buffer: GPUBuffer): GPUPointer { + if (this.bufferTableFreeId.length != 0) { + const idx = this.bufferTableFreeId.pop() as number; + this.bufferTable[idx] = buffer; + return idx; + } else { + const idx = this.bufferTable.length; + this.bufferTable.push(buffer); + return idx; + } + } +} diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py new file mode 100644 index 0000000..d16ba3f --- /dev/null +++ b/web/tests/python/webgpu_rpc_test.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Simple testcode to test Javascript RPC + +To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy". +Connect javascript end to the websocket port and connect to the RPC. +""" + +import tvm +from tvm import te +from tvm import rpc +from tvm.contrib import util, emcc +import numpy as np + +proxy_host = "localhost" +proxy_port = 9090 + + +def test_rpc(): + if not tvm.runtime.enabled("rpc"): + return + # generate the wasm library + target_device = "webgpu" + target_host = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" + if not tvm.runtime.enabled(target_host): + raise RuntimeError("Target %s is not enbaled" % target_host) + + n = 2048 + A = te.placeholder((n,), name='A') + B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = te.create_schedule(B.op) + + num_thread = 2 + xo, xi = s[B].split(B.op.axis[0], factor=num_thread) + s[B].bind(xi, te.thread_axis("threadIdx.x")) + s[B].bind(xo, te.thread_axis("blockIdx.x")) + + + fadd = tvm.build(s, [A, B], target_device, target_host=target_host, name="addone") + temp = util.tempdir() + + wasm_path = temp.relpath("addone_gpu.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + + wasm_binary = open(wasm_path, "rb").read() + remote = rpc.connect(proxy_host, proxy_port, key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary]) + + def check(remote): + # basic function checks. + ctx = remote.webgpu(0) + adata = np.random.uniform(size=n).astype(A.dtype) + a = tvm.nd.array(adata, ctx) + b = tvm.nd.array(np.zeros(n, dtype=A.dtype), ctx) + + np.testing.assert_equal(a.asnumpy(), adata) + f1 = remote.system_lib() + addone = f1.get_function("addone") + addone(a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + print("Test pass..") + + check(remote) + +test_rpc() diff --git a/web/tests/python/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py index 7fa0c6b..f7c0792 100644 --- a/web/tests/python/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -54,7 +54,10 @@ def test_rpc(): def check(remote): # basic function checks. + faddone = remote.get_function("testing.asyncAddOne") fecho = remote.get_function("testing.echo") + assert(faddone(100) == 101) + assert(fecho(1, 2, 3) == 1) assert(fecho(1, 2, 3) == 1) assert(fecho(100, 2, 3) == 100) assert(fecho("xyz") == "xyz") @@ -70,7 +73,7 @@ def test_rpc(): addone(a, b) # time evaluator - time_f = f1.time_evaluator("addone", ctx, number=10) + time_f = f1.time_evaluator("addone", ctx, number=100, repeat=10) time_f(a, b) cost = time_f(a, b).mean print('%g secs/op' % cost) @@ -78,5 +81,4 @@ def test_rpc(): check(remote) - test_rpc() diff --git a/web/tsconfig.json b/web/tsconfig.json index 3c20b3d..6aec448 100644 --- a/web/tsconfig.json +++ b/web/tsconfig.json @@ -6,7 +6,7 @@ "rootDir": "src", "declaration": true, "sourceMap": true, - "strict": true, + "strict": true }, "include": ["src"], "exclude": ["node_modules"] -- 2.7.4