return mod from frontend for autotvm (#3401)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 20 Jun 2019 03:29:12 +0000 (20:29 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 20 Jun 2019 03:29:12 +0000 (20:29 -0700)
tutorials/autotvm/tune_relay_arm.py
tutorials/autotvm/tune_relay_cuda.py
tutorials/autotvm/tune_relay_mobile_gpu.py
tutorials/autotvm/tune_relay_x86.py

index 2c1dca9921ebc474d0e8eed76dc265f142e94671..290f9756f1955a85c8b0d64ff0ed659caef6618c 100644 (file)
@@ -96,7 +96,8 @@ def get_network(name, batch_size):
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        net = mod[mod.entry_func]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
     else:
         raise ValueError("Unsupported network: " + name)
index 571334e8c106379a5a24dd73ceea847ef65f4514..c158e4b9fe3634f50d9dc9b9b67ad1f6a2459e22 100644 (file)
@@ -96,7 +96,8 @@ def get_network(name, batch_size):
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        net = mod[mod.entry_func]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
     else:
         raise ValueError("Unsupported network: " + name)
index 1e4cf6d52adeb061cb01bfa34d1a18fd1a8bfdbc..c011268fda512d858d5ecdeb5f8686df39a9c705 100644 (file)
@@ -97,7 +97,8 @@ def get_network(name, batch_size):
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        net = mod[mod.entry_func]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
     else:
         raise ValueError("Unsupported network: " + name)
index ad35c198bc777423ec1949813b8773d78c9c51ab..c8d9def206fe363e850e7d60ec5f34576db035e5 100644 (file)
@@ -64,7 +64,8 @@ def get_network(name, batch_size):
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
-        net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
+        net = mod[mod.entry_func]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
     else:
         raise ValueError("Unsupported network: " + name)