return mod from frontend for autotvm (#3401)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 20 Jun 2019 03:29:12 +0000 (20:29 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 20 Jun 2019 03:29:12 +0000 (20:29 -0700)
tutorials/autotvm/tune_relay_arm.py
tutorials/autotvm/tune_relay_cuda.py
tutorials/autotvm/tune_relay_mobile_gpu.py
tutorials/autotvm/tune_relay_x86.py

index 2c1dca9..290f975 100644 (file)
@@ -96,7 +96,8 @@ def get_network(name, batch_size):
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        net = mod[mod.entry_func]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
     else:
         raise ValueError("Unsupported network: " + name)
index 571334e..c158e4b 100644 (file)
@@ -96,7 +96,8 @@ def get_network(name, batch_size):
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        net = mod[mod.entry_func]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
     else:
         raise ValueError("Unsupported network: " + name)
index 1e4cf6d..c011268 100644 (file)
@@ -97,7 +97,8 @@ def get_network(name, batch_size):
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        net = mod[mod.entry_func]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
     else:
         raise ValueError("Unsupported network: " + name)
index ad35c19..c8d9def 100644 (file)
@@ -64,7 +64,8 @@ def get_network(name, batch_size):
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        net = mod[mod.entry_func]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
     else:
         raise ValueError("Unsupported network: " + name)