From 93843536010af46e8fa3668e423e027cab53deda Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 17 Dec 2019 22:17:51 -0800 Subject: [PATCH] Update legacy places from nnvm to relay. (#4535) * Update legacy places from nnvm to relay. This PR prepares the current mainline to remove nnvm compiler dep. * remove legacy stage --- Jenkinsfile | 11 --- apps/benchmark/util.py | 4 +- apps/bundle_deploy/Makefile | 14 ++-- apps/bundle_deploy/build_model.py | 21 +++--- apps/bundle_deploy/bundle.cc | 9 ++- apps/bundle_deploy/runtime.cc | 4 +- apps/howto_deploy/Makefile | 5 +- apps/rocm_rpc/Makefile | 1 - apps/sgx/README.md | 2 +- apps/sgx/enclave/src/build_model.py | 17 +++-- rust/Cargo.toml | 2 +- rust/frontend/Cargo.toml | 2 +- rust/frontend/README.md | 16 ++--- rust/frontend/examples/resnet/README.md | 6 +- rust/runtime/Cargo.toml | 2 +- rust/runtime/src/graph.rs | 2 +- rust/runtime/src/threading.rs | 2 +- rust/runtime/tests/build_model.py | 47 ++++--------- .../tests/{test_nnvm => test_nn}/Cargo.toml | 2 +- rust/runtime/tests/{test_nnvm => test_nn}/build.rs | 0 rust/runtime/tests/test_nn/src/build_test_graph.py | 54 ++++++++++++++ .../tests/{test_nnvm => test_nn}/src/main.rs | 0 .../tests/test_nnvm/src/build_test_graph.py | 82 ---------------------- tests/python/frontend/onnx/test_forward.py | 2 +- tests/python/relay/test_py_converter.py | 2 - tests/scripts/task_golang.sh | 2 +- tests/scripts/task_python_frontend.sh | 3 - tests/scripts/task_rust.sh | 8 +-- 28 files changed, 129 insertions(+), 193 deletions(-) rename rust/runtime/tests/{test_nnvm => test_nn}/Cargo.toml (98%) rename rust/runtime/tests/{test_nnvm => test_nn}/build.rs (100%) create mode 100755 rust/runtime/tests/test_nn/src/build_test_graph.py rename rust/runtime/tests/{test_nnvm => test_nn}/src/main.rs (100%) delete mode 100755 rust/runtime/tests/test_nnvm/src/build_test_graph.py diff --git a/Jenkinsfile b/Jenkinsfile index b5dbcee..6bb1da6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -310,17 +310,6 @@ stage('Integration Test') { } } }, - 'legacy: GPU': { - node('GPU') { - ws(per_exec_ws("tvm/legacy-python-gpu")) { - init_git() - unpack_lib('gpu', tvm_multilib) - timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_legacy.sh" - } - } - } - }, 'docs: GPU': { node('GPU') { ws(per_exec_ws("tvm/docs-python-gpu")) { diff --git a/apps/benchmark/util.py b/apps/benchmark/util.py index 0af1669..c7de3a1 100644 --- a/apps/benchmark/util.py +++ b/apps/benchmark/util.py @@ -34,8 +34,8 @@ def get_network(name, batch_size, dtype='float32'): Returns ------- - net: nnvm.symbol - The NNVM symbol of network definition + net: relay.Module + The relay function of network definition params: dict The random parameters for benchmark input_shape: tuple diff --git a/apps/bundle_deploy/Makefile b/apps/bundle_deploy/Makefile index 8550a0e..57e4843 100644 --- a/apps/bundle_deploy/Makefile +++ b/apps/bundle_deploy/Makefile @@ -16,15 +16,15 @@ # under the License. # Makefile Example to bundle TVM modules. + TVM_ROOT=$(shell cd ../..; pwd) -NNVM_PATH=nnvm DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CFLAGS = -std=c++14 -Oz -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ - -I${TVM_ROOT}/3rdparty/dlpack/include\ + -I${TVM_ROOT}/3rdparty/dlpack/include -PKG_LDFLAGS = -L${TVM_ROOT}/build +PKG_LDFLAGS = -pthread build_dir := build @@ -33,7 +33,7 @@ test: $(build_dir)/demo $(build_dir)/bundle.so $(build_dir)/demo: demo.cc @mkdir -p $(@D) - $(CXX) $(PKG_CFLAGS) -o $@ $^ + $(CXX) $(PKG_CFLAGS) -o $@ $^ -ldl # Serialize our graph.json file. $(build_dir)/graph.json.cc: $(build_dir)/graph.json @@ -44,13 +44,13 @@ $(build_dir)/params.bin.cc: $(build_dir)/params.bin xxd -i $^ > $@ $(build_dir)/model.o $(build_dir)/graph.json $(build_dir)/params.bin: build_model.py - python $< -o $(build_dir) + python3 $< -o $(build_dir) # Build our bundle against the serialized bundle.cc API, the runtime.cc API, and # the serialized graph.json and params.bin $(build_dir)/bundle.so: bundle.cc runtime.cc $(build_dir)/model.o $(build_dir)/graph.json.cc $(build_dir)/params.bin.cc @mkdir -p $(@D) - $(CXX) $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) -shared + $(CXX) -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) clean: rm -r $(build_dir) diff --git a/apps/bundle_deploy/build_model.py b/apps/bundle_deploy/build_model.py index dc4c14b..de9e735 100644 --- a/apps/bundle_deploy/build_model.py +++ b/apps/bundle_deploy/build_model.py @@ -18,8 +18,7 @@ import argparse import os -import nnvm.compiler -import nnvm.testing +from tvm import relay import tvm import logging @@ -34,22 +33,24 @@ def main(): dshape = (1, 3, 224, 224) from mxnet.gluon.model_zoo.vision import get_model block = get_model('mobilenet0.25', pretrained=True) - net, params = nnvm.frontend.from_mxnet(block) - net = nnvm.sym.softmax(net) + shape_dict = {'data': dshape} + mod, params = relay.frontend.from_mxnet(block, shape_dict) + func = mod["main"] + func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) + + with relay.build_config(opt_level=3): + graph, lib, params = relay.build( + func, 'llvm --system-lib', params=params) - with nnvm.compiler.build_config(opt_level=3): - graph, lib, params = nnvm.compiler.build( - net, 'llvm --system-lib', shape={'data': dshape}, params=params) - print(graph.symbol().debug_str()) build_dir = os.path.abspath(opts.out_dir) if not os.path.isdir(build_dir): os.makedirs(build_dir) lib.save(os.path.join(build_dir, 'model.o')) with open(os.path.join(build_dir, 'graph.json'), 'w') as f_graph_json: - f_graph_json.write(graph.json()) + f_graph_json.write(graph) with open(os.path.join(build_dir, 'params.bin'), 'wb') as f_params: - f_params.write(nnvm.compiler.save_param_dict(params)) + f_params.write(relay.save_param_dict(params)) if __name__ == '__main__': diff --git a/apps/bundle_deploy/bundle.cc b/apps/bundle_deploy/bundle.cc index 61169f1..14f0b7e 100644 --- a/apps/bundle_deploy/bundle.cc +++ b/apps/bundle_deploy/bundle.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,7 +26,9 @@ extern unsigned int build_graph_json_len; extern unsigned char build_params_bin[]; extern unsigned int build_params_bin_len; -#define TVM_BUNDLE_FUNCTION __attribute__((visibility("default"))) extern "C" +#define TVM_BUNDLE_FUNCTION __attribute__((visibility("default"))) + +extern "C" { TVM_BUNDLE_FUNCTION void *tvm_runtime_create() { const std::string json_data(&build_graph_json[0], @@ -64,3 +66,4 @@ TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void *handle, int index, reinterpret_cast(handle)->GetFunction("get_output")( index, reinterpret_cast(tensor)); } +} diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc index f1c2ba2..7a116e8 100644 --- a/apps/bundle_deploy/runtime.cc +++ b/apps/bundle_deploy/runtime.cc @@ -25,7 +25,7 @@ #include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/cpu_device_api.cc" #include "../../src/runtime/workspace_pool.cc" -#include "../../src/runtime/module_util.cc" +#include "../../src/runtime/library_module.cc" #include "../../src/runtime/module.cc" #include "../../src/runtime/registry.cc" #include "../../src/runtime/file_util.cc" @@ -33,5 +33,5 @@ #include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" #include "../../src/runtime/object.cc" -#include "../../src/runtime/system_lib_module.cc" +#include "../../src/runtime/system_library.cc" #include "../../src/runtime/graph/graph_runtime.cc" diff --git a/apps/howto_deploy/Makefile b/apps/howto_deploy/Makefile index 5c4a6d6..a260e89 100644 --- a/apps/howto_deploy/Makefile +++ b/apps/howto_deploy/Makefile @@ -17,7 +17,6 @@ # Makefile Example to deploy TVM modules. TVM_ROOT=$(shell cd ../..; pwd) -NNVM_PATH=nnvm DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core PKG_CFLAGS = -std=c++11 -O2 -fPIC\ @@ -25,7 +24,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\ -PKG_LDFLAGS = -L${TVM_ROOT}/build -ldl -lpthread +PKG_LDFLAGS = -L${TVM_ROOT}/build -ldl -pthread .PHONY: clean all @@ -39,7 +38,7 @@ lib/libtvm_runtime_pack.o: tvm_runtime_pack.cc # The code library built by TVM lib/test_addone_sys.o: prepare_test_libs.py @mkdir -p $(@D) - python prepare_test_libs.py + python3 prepare_test_libs.py # Deploy using the all in one TVM package library lib/cpp_deploy_pack: cpp_deploy.cc lib/test_addone_sys.o lib/libtvm_runtime_pack.o diff --git a/apps/rocm_rpc/Makefile b/apps/rocm_rpc/Makefile index 8d30fb6..36eb415 100644 --- a/apps/rocm_rpc/Makefile +++ b/apps/rocm_rpc/Makefile @@ -19,7 +19,6 @@ ROCM_PATH=/opt/rocm TVM_ROOT=$(shell cd ../..; pwd) -NNVM_PATH=nnvm DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core PKG_CFLAGS = -std=c++11 -O2 -fPIC\ diff --git a/apps/sgx/README.md b/apps/sgx/README.md index 13f72b0..ad87be4 100644 --- a/apps/sgx/README.md +++ b/apps/sgx/README.md @@ -49,7 +49,7 @@ mkdir build && cd build cmake .. -DUSE_LLVM=ON -DUSE_SGX=/opt/sgxsdk -DRUST_SGX_SDK=/opt/rust-sgx-sdk make -j4 cd .. -pip install -e python -e topi/python -e nnvm/python +pip install -e python -e topi/python cd apps/sgx ``` diff --git a/apps/sgx/enclave/src/build_model.py b/apps/sgx/enclave/src/build_model.py index 5a6b10c..dff5716 100644 --- a/apps/sgx/enclave/src/build_model.py +++ b/apps/sgx/enclave/src/build_model.py @@ -20,8 +20,8 @@ import argparse import os from os import path as osp -import nnvm.compiler -import nnvm.testing +from tvm import relay +from tvm.relay import testing import tvm @@ -30,14 +30,13 @@ def main(): parser.add_argument('-o', '--out-dir', default='.') opts = parser.parse_args() - # from tutorials/nnvm_quick_start.py dshape = (1, 3, 224, 224) - net, params = nnvm.testing.resnet.get_workload( + net, params = relay.testing.resnet.get_workload( layers=18, batch_size=dshape[0], image_shape=dshape[1:]) - with nnvm.compiler.build_config(opt_level=3): - graph, lib, params = nnvm.compiler.build( - net, 'llvm --system-lib', shape={'data': dshape}, params=params) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build( + net, 'llvm --system-lib', params=params) build_dir = osp.abspath(opts.out_dir) if not osp.isdir(build_dir): @@ -45,9 +44,9 @@ def main(): lib.save(osp.join(build_dir, 'model.bc')) with open(osp.join(build_dir, 'graph.json'), 'w') as f_graph_json: - f_graph_json.write(graph.json()) + f_graph_json.write(graph) with open(osp.join(build_dir, 'params.bin'), 'wb') as f_params: - f_params.write(nnvm.compiler.save_param_dict(params)) + f_params.write(relay.save_param_dict(params)) if __name__ == '__main__': diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 02e2c7c..8467f6a 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -22,7 +22,7 @@ members = [ "runtime", "runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_dso", - "runtime/tests/test_nnvm", + "runtime/tests/test_nn", "frontend", "frontend/tests/basics", "frontend/tests/callback", diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index c6b5680..3f99188 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -23,7 +23,7 @@ description = "Rust frontend support for TVM" repository = "https://github.com/apache/incubator-tvm" homepage = "https://github.com/apache/incubator-tvm" readme = "README.md" -keywords = ["rust", "tvm", "nnvm"] +keywords = ["rust", "tvm"] categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" diff --git a/rust/frontend/README.md b/rust/frontend/README.md index b77a4bd..c61ba84 100644 --- a/rust/frontend/README.md +++ b/rust/frontend/README.md @@ -35,14 +35,12 @@ Here's a Python snippet for downloading and building a pretrained Resnet18 via A ```python block = get_model('resnet18_v1', pretrained=True) - -sym, params = nnvm.frontend.from_mxnet(block) -# add the softmax layer for prediction -net = nnvm.sym.softmax(sym) + +sym, params = relay.frontend.from_mxnet(block, shape_dict) # compile the model -with nnvm.compiler.build_config(opt_level=opt_level): - graph, lib, params = nnvm.compiler.build( - net, target, shape={"data": data_shape}, params=params) +with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build( + net, target, params=params) # same the model artifacts lib.save(os.path.join(target_dir, "deploy_lib.o")) cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), @@ -51,7 +49,7 @@ cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: fo.write(graph.json()) with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: - fo.write(nnvm.compiler.save_param_dict(params)) + fo.write(relay.save_param_dict(params)) ``` Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image @@ -113,7 +111,7 @@ and the model correctly predicts the input image as **tiger cat**. Please follow TVM [installations](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. -*Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. +*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. ## Supported TVM Functionalities diff --git a/rust/frontend/examples/resnet/README.md b/rust/frontend/examples/resnet/README.md index 3ce4a77..2927474 100644 --- a/rust/frontend/examples/resnet/README.md +++ b/rust/frontend/examples/resnet/README.md @@ -18,11 +18,11 @@ ## Resnet example This end-to-end example shows how to: -* build `Resnet 18` with `tvm` and `nnvm` from Python +* build `Resnet 18` with `tvm` from Python * use the provided Rust frontend API to test for an input image -To run the example with pretrained resnet weights, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` -and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html). +To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +and to install `tvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html). * **Build the example**: `cargo build diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index 34acc77..f0d2459 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -22,7 +22,7 @@ license = "Apache-2.0" description = "A static TVM runtime" repository = "https://github.com/apache/incubator-tvm" readme = "README.md" -keywords = ["tvm", "nnvm"] +keywords = ["tvm"] categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index cacd7a3..42b9458 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -440,7 +440,7 @@ named!( ) ); -/// Loads a param dict saved using `nnvm.compiler.save_param_dict`. +/// Loads a param dict saved using `relay.save_param_dict`. pub fn load_param_dict(bytes: &[u8]) -> Result, GraphFormatError> { if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { if remaining_bytes.len() == 0 { diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index 3f25309..f05faf7 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -296,7 +296,7 @@ pub(crate) fn sgx_join_threads() { ocall_packed!("__sgx_thread_group_join__", 0); } -// @see https://github.com/apache/incubator-tvm/issues/988 for information on why this function is used. +// @see issue 988 for information on why this function is used. #[no_mangle] pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { let barrier: &Arc = unsafe { &*((*penv).sync_handle as *const Arc) }; diff --git a/rust/runtime/tests/build_model.py b/rust/runtime/tests/build_model.py index bed3c0a..e3da95f 100755 --- a/rust/runtime/tests/build_model.py +++ b/rust/runtime/tests/build_model.py @@ -16,56 +16,37 @@ # specific language governing permissions and limitations # under the License. -"""Builds a simple NNVM graph for testing.""" +"""Builds a simple graph for testing.""" from os import path as osp -import nnvm -from nnvm import sym -from nnvm.compiler import graph_util -from nnvm.testing import init import numpy as np import tvm +from tvm import relay +from tvm.relay import testing CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) - def _get_model(dshape): - data = sym.Variable('data', shape=dshape) - fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True) - left, right = sym.split(fc1, indices_or_sections=2, axis=1) - return sym.Group(((left + 1), (right - 1))) - + data = relay.var('data', shape=dshape) + fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2) + fc = relay.nn.bias_add(data, relay.var("dense_bias")) + left, right = relay.split(fc, indices_or_sections=2, axis=1) + one = relay.const(1, dtype="float32") + return relay.Tuple([(left + one), (right - one), fc]) -def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): - if isinstance(graph, sym.Symbol): - graph = nnvm.graph.create(graph) - ishapes, _ = graph_util.infer_shape(graph, **input_shapes) - param_shapes = dict(zip(graph.index.input_names, ishapes)) - np.random.seed(seed) - params = {} - for param, shape in param_shapes.items(): - if param in {'data', 'label'} or not shape: - continue - init_value = np.empty(shape).astype('float32') - initializer(param, init_value) - params[param] = tvm.nd.array(init_value) - return params def main(): dshape = (32, 16) net = _get_model(dshape) - ishape_dict = {'data': dshape} - params = _init_params(net, ishape_dict) - graph, lib, params = nnvm.compiler.build(net, 'llvm', - shape=ishape_dict, - params=params, - dtype='float32') + mod, params = testing.create_workload(net) + graph, lib, params = relay.build( + mod, 'llvm', params=params) with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet: - f_resnet.write(graph.json()) + f_resnet.write(graph) with open(osp.join(CWD, 'graph.params'), 'wb') as f_params: - f_params.write(nnvm.compiler.save_param_dict(params)) + f_params.write(relay.save_param_dict(params)) if __name__ == '__main__': main() diff --git a/rust/runtime/tests/test_nnvm/Cargo.toml b/rust/runtime/tests/test_nn/Cargo.toml similarity index 98% rename from rust/runtime/tests/test_nnvm/Cargo.toml rename to rust/runtime/tests/test_nn/Cargo.toml index 93fdef4..afd2188 100644 --- a/rust/runtime/tests/test_nnvm/Cargo.toml +++ b/rust/runtime/tests/test_nn/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [package] -name = "test-nnvm" +name = "test-nn" version = "0.0.0" license = "Apache-2.0" authors = ["TVM Contributors"] diff --git a/rust/runtime/tests/test_nnvm/build.rs b/rust/runtime/tests/test_nn/build.rs similarity index 100% rename from rust/runtime/tests/test_nnvm/build.rs rename to rust/runtime/tests/test_nn/build.rs diff --git a/rust/runtime/tests/test_nn/src/build_test_graph.py b/rust/runtime/tests/test_nn/src/build_test_graph.py new file mode 100755 index 0000000..dd7621b --- /dev/null +++ b/rust/runtime/tests/test_nn/src/build_test_graph.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Builds a simple graph for testing.""" + +from os import path as osp +import sys + +import numpy as np +import tvm +from tvm import relay +from tvm.relay import testing + + +def _get_model(dshape): + data = relay.var('data', shape=dshape) + fc = relay.nn.dense(data, relay.var("dense_weight"), units=dshape[-1]*2) + fc = relay.nn.bias_add(data, relay.var("dense_bias")) + left, right = relay.split(fc, indices_or_sections=2, axis=1) + one = relay.const(1, dtype="float32") + return relay.Tuple([(left + one), (right - one), fc]) + +def main(): + dshape = (4, 8) + net = _get_model(dshape) + mod, params = testing.create_workload(net) + graph, lib, params = relay.build( + mod, 'llvm --system-lib', params=params) + + out_dir = sys.argv[1] + lib.save(osp.join(sys.argv[1], 'graph.o')) + with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph) + + with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: + f_params.write(relay.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_nnvm/src/main.rs b/rust/runtime/tests/test_nn/src/main.rs similarity index 100% rename from rust/runtime/tests/test_nnvm/src/main.rs rename to rust/runtime/tests/test_nn/src/main.rs diff --git a/rust/runtime/tests/test_nnvm/src/build_test_graph.py b/rust/runtime/tests/test_nnvm/src/build_test_graph.py deleted file mode 100755 index 69ec6d2..0000000 --- a/rust/runtime/tests/test_nnvm/src/build_test_graph.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Builds a simple NNVM graph for testing.""" - -from os import path as osp -import sys - -import nnvm -from nnvm import sym -from nnvm.compiler import graph_util -from nnvm.testing import init -import numpy as np -import tvm - - -def _get_model(dshape): - data = sym.Variable('data', shape=dshape) - fc = sym.dense(data, units=dshape[-1]*2, use_bias=True) - left, right = sym.split(fc, indices_or_sections=2, axis=1) - return sym.Group(((left + 1), (right - 1), fc)) - - -def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): - if isinstance(graph, sym.Symbol): - graph = nnvm.graph.create(graph) - - ishapes, _ = graph_util.infer_shape(graph, **input_shapes) - param_shapes = dict(zip(graph.index.input_names, ishapes)) - np.random.seed(seed) - params = {} - for param, shape in param_shapes.items(): - if param in {'data', 'label'} or not shape: - continue - - init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32') - if param.endswith('_bias'): - params[param] = tvm.nd.array(init_value) - continue - - init_value = np.empty(shape).astype('float32') - initializer(param, init_value) - # init_value /= init_value.sum() + 1e-10 - params[param] = tvm.nd.array(init_value) - - return params - -def main(): - dshape = (4, 8) - net = _get_model(dshape) - ishape_dict = {'data': dshape} - params = _init_params(net, ishape_dict) - graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib', - shape=ishape_dict, - params=params, - dtype='float32') - - out_dir = sys.argv[1] - lib.save(osp.join(sys.argv[1], 'graph.o')) - with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: - f_resnet.write(graph.json()) - - with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: - f_params.write(nnvm.compiler.save_param_dict(params)) - -if __name__ == '__main__': - main() diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e074bac..399d947 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -23,7 +23,7 @@ import topi.testing import tvm from tvm import relay from tvm.contrib import graph_runtime -from nnvm.testing.config import ctx_list +from tvm.relay.testing.config import ctx_list import onnx from onnx import helper, TensorProto, mapping import scipy diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index 49a2219..2a07e95 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -510,7 +510,6 @@ def test_op_stack(): # test an op with a tuple output # adapted from test_split_infer_type in test_op_level3 -# and test_split in nnvm's test_top_level1 def test_split(): def verify_split(shape, indices_or_sections, axis=0): x = np.random.normal(size=shape).astype('float32') @@ -529,7 +528,6 @@ def test_split(): # ensure we can generate code for batch_norm, since it requires simplify_inference -# adapted from test_batchnorm in nnvm's test_top_level1 def test_batch_norm(): def verify_batch_norm(shapes): data = [np.absolute(np.random.normal(size=shape).astype('float32')) diff --git a/tests/scripts/task_golang.sh b/tests/scripts/task_golang.sh index ee9ec19..4996579 100755 --- a/tests/scripts/task_golang.sh +++ b/tests/scripts/task_golang.sh @@ -22,7 +22,7 @@ set -u export LD_LIBRARY_PATH="lib:${LD_LIBRARY_PATH:-}" tvm_root="$(git rev-parse --show-toplevel)" -export PYTHONPATH="$tvm_root/python":"$tvm_root/nnvm/python":"$tvm_root/topi/python" +export PYTHONPATH="$tvm_root/python":"$tvm_root/topi/python" # Golang tests make -C golang tests diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 7a7bcac..fdb7ef6 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -42,9 +42,6 @@ python3 -m pytest -v tests/python/frontend/onnx echo "Running relay CoreML frontend test..." python3 -m pytest -v tests/python/frontend/coreml -echo "Running nnvm to relay frontend test..." -python3 -m pytest -v tests/python/frontend/nnvm_to_relay - echo "Running relay Tensorflow frontend test..." python3 -m pytest -v tests/python/frontend/tensorflow diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index cdf777c..140563a 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -21,8 +21,8 @@ set -u export TVM_HOME="$(git rev-parse --show-toplevel)" -export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:$TVM_HOME/nnvm:${LD_LIBRARY_PATH:-}" -export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/nnvm/python":"$TVM_HOME/topi/python" +export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}" +export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/topi/python" export RUST_DIR="$TVM_HOME/rust" cd $RUST_DIR @@ -52,8 +52,8 @@ cd tests/test_tvm_dso cargo run cd - -# run NNVM graph test -cd tests/test_nnvm +# run nn graph test +cd tests/test_nn cargo run cd - -- 2.7.4