From 05a5c170930ef649d6f196950e680ca16d30d07a Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Wed, 19 Jun 2019 20:29:12 -0700 Subject: [PATCH] return mod from frontend for autotvm (#3401) --- tutorials/autotvm/tune_relay_arm.py | 3 ++- tutorials/autotvm/tune_relay_cuda.py | 3 ++- tutorials/autotvm/tune_relay_mobile_gpu.py | 3 ++- tutorials/autotvm/tune_relay_x86.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tutorials/autotvm/tune_relay_arm.py b/tutorials/autotvm/tune_relay_arm.py index 2c1dca9..290f975 100644 --- a/tutorials/autotvm/tune_relay_arm.py +++ b/tutorials/autotvm/tune_relay_arm.py @@ -96,7 +96,8 @@ def get_network(name, batch_size): # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + net = mod[mod.entry_func] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_cuda.py b/tutorials/autotvm/tune_relay_cuda.py index 571334e..c158e4b 100644 --- a/tutorials/autotvm/tune_relay_cuda.py +++ b/tutorials/autotvm/tune_relay_cuda.py @@ -96,7 +96,8 @@ def get_network(name, batch_size): # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + net = mod[mod.entry_func] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_mobile_gpu.py b/tutorials/autotvm/tune_relay_mobile_gpu.py index 1e4cf6d..c011268 100644 --- a/tutorials/autotvm/tune_relay_mobile_gpu.py +++ b/tutorials/autotvm/tune_relay_mobile_gpu.py @@ -97,7 +97,8 @@ def get_network(name, batch_size): # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + net = mod[mod.entry_func] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index ad35c19..c8d9def 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -64,7 +64,8 @@ def get_network(name, batch_size): # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + net = mod[mod.entry_func] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) else: raise ValueError("Unsupported network: " + name) -- 2.7.4