[relay][frontend] Return Module from get_workload (#3483)
authorZhi <5145158+zhiics@users.noreply.github.com>
Sat, 6 Jul 2019 04:23:27 +0000 (21:23 -0700)
committerJared Roesch <roeschinc@gmail.com>
Sat, 6 Jul 2019 04:23:27 +0000 (21:23 -0700)
* [relay][frontend] Return Module from get_workload

* pass entry_func to autotvm

* disable tune

* add property to module

* mod.entry_func to main

* .main -> mod["main"]

* fix

87 files changed:
include/tvm/relay/analysis.h
include/tvm/relay/module.h
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
python/tvm/autotvm/graph_tuner/utils/utils.py
python/tvm/relay/backend/interpreter.py
python/tvm/relay/backend/vm.py
python/tvm/relay/build_module.py
python/tvm/relay/frontend/caffe2.py
python/tvm/relay/frontend/common.py
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/module.py
python/tvm/relay/quantize/quantize.py
python/tvm/relay/testing/__init__.py
python/tvm/relay/testing/dcgan.py
python/tvm/relay/testing/densenet.py
python/tvm/relay/testing/dqn.py
python/tvm/relay/testing/inception_v3.py
python/tvm/relay/testing/init.py
python/tvm/relay/testing/lstm.py
python/tvm/relay/testing/mlp.py
python/tvm/relay/testing/mobilenet.py
python/tvm/relay/testing/resnet.py
python/tvm/relay/testing/squeezenet.py
python/tvm/relay/testing/vgg.py
src/relay/backend/build_module.cc
src/relay/backend/vm/vm.cc
src/relay/ir/module.cc
src/relay/pass/fold_constant.cc
src/relay/pass/partial_eval.cc
src/relay/pass/quantize.cc
src/relay/pass/type_infer.cc
tests/cpp/relay_pass_type_infer_test.cc
tests/cpp/relay_transform_sequential.cc
tests/python/frontend/caffe2/test_graph.py
tests/python/frontend/coreml/test_forward.py
tests/python/frontend/mxnet/test_graph.py
tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py
tests/python/relay/benchmarking/benchmark_vm.py
tests/python/relay/test_autotvm_task_extraction.py
tests/python/relay/test_backend_compile_engine.py
tests/python/relay/test_backend_graph_runtime.py
tests/python/relay/test_backend_interpreter.py
tests/python/relay/test_error_reporting.py
tests/python/relay/test_feature.py
tests/python/relay/test_op_grad_level1.py
tests/python/relay/test_op_level1.py
tests/python/relay/test_op_level10.py
tests/python/relay/test_op_level2.py
tests/python/relay/test_op_level3.py
tests/python/relay/test_op_level4.py
tests/python/relay/test_op_level5.py
tests/python/relay/test_pass_alter_op_layout.py
tests/python/relay/test_pass_annotation.py
tests/python/relay/test_pass_canonicalize_cast.py
tests/python/relay/test_pass_combine_parallel_conv2d.py
tests/python/relay/test_pass_dead_code_elimination.py
tests/python/relay/test_pass_eliminate_common_subexpr.py
tests/python/relay/test_pass_eta_expand.py
tests/python/relay/test_pass_fold_constant.py
tests/python/relay/test_pass_fold_scale_axis.py
tests/python/relay/test_pass_fuse_ops.py
tests/python/relay/test_pass_gradient.py
tests/python/relay/test_pass_mac_count.py
tests/python/relay/test_pass_manager.py
tests/python/relay/test_pass_partial_eval.py
tests/python/relay/test_pass_quantize.py
tests/python/relay/test_pass_to_a_normal_form.py
tests/python/relay/test_pass_to_cps.py
tests/python/relay/test_pass_to_graph_normal_form.py
tests/python/relay/test_type_infer.py
tests/python/relay/test_typecall.py
tests/python/relay/test_vm.py
tests/python/unittest/test_graph_tuner_core.py
tests/python/unittest/test_graph_tuner_utils.py
tutorials/autotvm/tune_relay_arm.py
tutorials/autotvm/tune_relay_cuda.py
tutorials/autotvm/tune_relay_mobile_gpu.py
tutorials/autotvm/tune_relay_x86.py
tutorials/frontend/deploy_model_on_rasp.py
tutorials/frontend/from_mxnet.py
tutorials/relay_quick_start.py
vta/python/vta/top/graphpack.py
vta/scripts/tune_resnet.py
vta/tutorials/autotvm/tune_relay_vta.py
vta/tutorials/frontend/deploy_resnet_on_vta.py

index deb9c7dec0c56b362b9e9e383dc1713c040d08a0..3672a22847dbf8c6d664d9ffc8b262db44517f4a 100644 (file)
 namespace tvm {
 namespace relay {
 
-/*!
- * \brief Infer the type of a function as if it is mapped to var in the mod.
- *
- * \param f the function.
- * \param mod The module used for referencing global functions.
- * \param var The global variable corresponding to the function.
- *
- * \return A type checked Function with its checked_type field populated.
- * \note this function mutates mod and is not thread-safe.
- */
-TVM_DLL Function InferType(const Function& f,
-                           const Module& mod,
-                           const GlobalVar& var);
-
 /*!
  * \brief Check that types are well kinded by applying "kinding rules".
  *
index 389f0c1c3eb499c9c0da95efba96f8ca959d3f9b..e888c54c17aca04729a7a06c5b165d06e3dfe3da 100644 (file)
@@ -65,16 +65,12 @@ class ModuleNode : public RelayNode {
   /*! \brief A map from global type vars to ADT type data. */
   tvm::Map<GlobalTypeVar, TypeData> type_definitions;
 
-  /*! \brief The entry function (i.e. "main"). */
-  GlobalVar entry_func;
-
   ModuleNode() {}
 
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("functions", &functions);
     v->Visit("type_definitions", &type_definitions);
     v->Visit("global_var_map_", &global_var_map_);
-    v->Visit("entry_func", &entry_func);
     v->Visit("global_type_var_map_", &global_type_var_map_);
   }
 
@@ -119,6 +115,13 @@ class ModuleNode : public RelayNode {
    */
   TVM_DLL void Remove(const GlobalVar& var);
 
+  /*!
+   * \brief Check if the global_var_map_ contains a global variable.
+   * \param name The variable name.
+   * \returns true if contains, otherise false.
+   */
+  TVM_DLL bool ContainGlobalVar(const std::string& name) const;
+
   /*!
    * \brief Lookup a global function by its variable.
    * \param str The unique string specifying the global variable.
@@ -180,10 +183,10 @@ class ModuleNode : public RelayNode {
    * Allows one to optionally pass a global function map as
    * well.
    *
-   * \param expr The expression to set as the entry point to the module.
+   * \param expr The expression to set as the main function to the module.
    * \param global_funcs The global function map.
    *
-   * \returns A module with expr set as the entry point.
+   * \returns A module with expr set as the main function.
    */
   TVM_DLL static Module FromExpr(
     const Expr& expr,
index 252882d17ecebae64a7f9774fd5a704ea132a213..cffd42347b35d48da19fe6af3433c604b2bcff7c 100644 (file)
@@ -142,7 +142,7 @@ class BaseGraphTuner(object):
 
         # Generate workload and schedule dictionaries.
         if isinstance(graph, relay.Module):
-            graph = graph[graph.entry_func]
+            graph = graph["main"]
 
         if isinstance(graph, relay.expr.Function):
             node_dict = {}
index 62e409fec1a01cddd3c703563c0040088e45c20e..5d07bd3fbce5f28367bf6931581c5aa4a102e987 100644 (file)
@@ -85,7 +85,7 @@ def _infer_type(node):
     """A method to infer the type of a relay expression."""
     mod = relay.Module.from_expr(node)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(node, relay.Function) else entry.body
 
 
index 797a38ae36983ed4cf187d63ebaf59365c0440ae..b9777ef84459524036bfc7c8c8f5e1f0fd19df71 100644 (file)
@@ -110,5 +110,5 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
 
     mod = relay.Module.from_expr(updated_expr)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(updated_expr, relay.Function) else entry.body
index 5ca09b0f24d588dc055785f7aa2996f268d6a5c5..462dda9488c21b9a3d5bca6ff5b812318e5f17b7 100644 (file)
@@ -289,7 +289,7 @@ class Interpreter(Executor):
             assert self.mod is not None
         def _interp_wrapper(*args, **kwargs):
             if expr is None:
-                args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs)
+                args = self._convert_args(self.mod["main"], args, kwargs)
             else:
                 args = self._convert_args(expr, args, kwargs)
 
@@ -301,17 +301,17 @@ class Interpreter(Executor):
             if expr is None:
                 pass
             elif isinstance(expr, GlobalVar):
-                self.mod[self.mod.entry_func] = self.mod[expr]
+                self.mod["main"] = self.mod[expr]
             else:
                 assert isinstance(expr, Function)
                 func = Function([], Call(expr, relay_args))
                 relay_args = []
                 if self.mod:
-                    self.mod[self.mod.entry_func] = func
+                    self.mod["main"] = func
                 else:
                     self.mod = module.Module.from_expr(func)
 
             mod = self.optimize()
-            opt_expr = Call(mod[self.mod.entry_func.name_hint], relay_args)
+            opt_expr = Call(mod["main"], relay_args)
             return self._intrp(opt_expr)
         return _interp_wrapper
index ceb403fe77174ee69b0e2aeb2de2927ff4cde092..152ee576e7bdb7fbf02fa678166fbae976f9d707 100644 (file)
@@ -45,7 +45,7 @@ def optimize(mod):
     ret : tvm.relay.Module
         The optimized module.
     """
-    main_func = mod[mod.entry_func]
+    main_func = mod["main"]
 
     opt_passes = []
     if not main_func.params and isinstance(main_func.body, GlobalVar):
@@ -134,8 +134,8 @@ class VMExecutor(Executor):
         expr = expr if expr else self.mod
         assert expr, "either expr or self.mod should be not null."
         if isinstance(expr, Expr):
-            self.mod[self.mod.entry_func] = expr
-        main = self.mod[self.mod.entry_func]
+            self.mod["main"] = expr
+        main = self.mod["main"]
 
         def _vm_wrapper(*args, **kwargs):
             args = self._convert_args(main, args, kwargs)
index 6337e629516c3f63a3bb404281f2c7916bbaaf74..404829f74cf785b2084a6562d1d6c54a6e861421 100644 (file)
@@ -177,7 +177,7 @@ def build(mod, target=None, target_host=None, params=None):
         The parameters of the final graph.
     """
     if isinstance(mod, _Module):
-        func = mod[mod.entry_func]
+        func = mod["main"]
     elif isinstance(mod, _expr.Function):
         func = mod
         warnings.warn(
@@ -233,8 +233,8 @@ class GraphExecutor(_interpreter.Executor):
 
     def _make_executor(self, expr=None):
         if expr:
-            self.mod[self.mod.entry_func] = expr
-        ret_type = self.mod[self.mod.entry_func].checked_type.ret_type
+            self.mod["main"] = expr
+        ret_type = self.mod["main"].checked_type.ret_type
         num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
         graph_json, mod, params = build(self.mod, target=self.target)
         gmodule = _graph_rt.create(graph_json, mod, self.ctx)
@@ -242,7 +242,7 @@ class GraphExecutor(_interpreter.Executor):
             gmodule.set_input(**params)
 
         def _graph_wrapper(*args, **kwargs):
-            args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs)
+            args = self._convert_args(self.mod["main"], args, kwargs)
             # Create map of inputs.
             for i, arg in enumerate(args):
                 gmodule.set_input(i, arg)
index 91f0409b39d5ff54f18efa0929e291cda293bb74..43d9d21c09b5d4df5910d4f4566c835bc706b1a1 100644 (file)
@@ -451,7 +451,7 @@ class Caffe2NetDef(object):
             outputs = out[0]
 
         func = _expr.Function(analysis.free_vars(outputs), outputs)
-        self._mod[self._mod.entry_func] = func
+        self._mod["main"] = func
 
         return self._mod, self._params
 
index 6d8e14569e73c2cb4db3973f5eb878fbdea3b158..c5057f35fedef717e4cddddb0dc61ad33a5c1224 100644 (file)
@@ -412,7 +412,7 @@ def infer_type(node):
     """A method to infer the type of an intermediate node in the relay graph."""
     mod = _module.Module.from_expr(node)
     mod = _transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(node, _expr.Function) else entry.body
 
 def infer_shape(inputs):
index 26c357e9c9244bc97c908419e6606ceee0cb9915..e40f1dea61a9f5cdc8e2b13a4e2245f5f6b7f36e 100644 (file)
@@ -45,7 +45,7 @@ def _infer_type(node):
     """A method to infer the type of an intermediate node in the relay graph."""
     mod = _module.Module.from_expr(node)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(node, _expr.Function) else entry.body
 
 def _mx_fully_connected(inputs, attrs):
@@ -1200,5 +1200,5 @@ def from_mxnet(symbol,
     else:
         msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol))
         raise ValueError(msg)
-    mod[mod.entry_func] = func
+    mod["main"] = func
     return mod, params
index e14566f6ab334a4d6b48139e90bf700335128dd9..59e0983e95985182b91b90cdae7c4eb79b9de409 100644 (file)
@@ -240,7 +240,7 @@ def _infer_type(node):
     """A method to infer the type of an intermediate node in the relay graph."""
     mod = _module.Module.from_expr(node)
     mod = _transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(node, _expr.Function) else entry.body
 
 def _infer_shape(node, params=None):
@@ -2122,7 +2122,7 @@ class GraphProto(object):
 
         out = out[0] if len(out) == 1 else _expr.Tuple(out)
         func = _expr.Function(analysis.free_vars(out), out)
-        self._mod[self._mod.entry_func] = func
+        self._mod["main"] = func
         return self._mod, self._params
 
     def _parse_import_prerequisites(self, graph):
index aeeedb864dbc384b6c78f475e8b3a100cc76f977..8ac15f743fc4fabaf8ee762acdc504235299e82f 100644 (file)
@@ -78,8 +78,11 @@ class Module(RelayNode):
     def _add(self, var, val, update=False):
         if isinstance(val, _expr.Expr):
             if isinstance(var, _base.string_types):
-                var = _expr.GlobalVar(var)
-            _make.Module_Add(self, var, val, update)
+                if _module.Module_ContainGlobalVar(self, var):
+                    var = _module.Module_GetGlobalVar(self, var)
+                else:
+                    var = _expr.GlobalVar(var)
+            _module.Module_Add(self, var, val, update)
         else:
             assert isinstance(val, _ty.Type)
             if isinstance(var, _base.string_types):
index b7994217e9640fad5c903d8eb7f0bc41e220e37e..beebceaf8590cb99a702d4fc5a08a89bcf180be7 100644 (file)
@@ -365,4 +365,4 @@ def quantize(graph, params=None, dataset=None):
         mod = optimize(mod)
         mod = quantize_seq(mod)
 
-    return mod[mod.entry_func.name_hint]
+    return mod["main"]
index 9d12529e576f3a6dd7897954bbbd0ac01fa68ba9..de9e55b369d19215fecfcaf09ae08a598dac5608 100644 (file)
@@ -41,7 +41,7 @@ def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, transform.Pass)
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index e9a914ecd69aef55de6cc73849ab7f18fe65fd8d..c6b258badb5b69d08988ef9ad525601e4a33a4e3 100644 (file)
@@ -103,8 +103,8 @@ def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype=
 
     Returns
     -------
-    net : nnvm.symbol
-        The computational graph
+    mod : tvm.relay.Module
+        The relay module that contains a DCGAN network.
     params : dict of str to NDArray
         The parameters.
     """
index 573a4bc367946c0ee0de79d765c7e689290c1b9f..f9b479153bfadd401db053f0c5f66a04169b11fa 100644 (file)
@@ -105,8 +105,8 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4,
 
     Returns
     -------
-    net: relay.Function
-        The computation graph representing densenet.
+    mod: tvm.relay.Module
+        The relay module that contains a DenseNet network.
 
     params : dict of str to NDArray
         The benchmark paraeters.
index fdf46fbc2f7c20220dd57228346f5221aa3e5bda..cdf9d24af996a7ae78dd41260e764d8121f28abf 100644 (file)
@@ -72,8 +72,8 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo
         The data type
     Returns
     -------
-    net : nnvm.symbol
-        The computational graph
+    mod : tvm.relay.Module
+        The relay module that contains a DQN network.
     params : dict of str to NDArray
         The parameters.
     """
index c3f0181f2951f1189da43dd93ce21eb0ebfb3324..4da543257c3184bc363eaa2371c2b74500463620 100644 (file)
@@ -289,8 +289,8 @@ def get_workload(batch_size=1, num_classes=1000,
 
     Returns
     -------
-    net : nnvm.Symbol
-        The computational graph
+    mod : tvm.relay.Module
+        The relay module that contains an Inception V3 network.
 
     params : dict of str to NDArray
         The parameters.
index 20b5156990a7c51e37ef1ab8c0af5fbed402d8c6..0b8ab2b42029b1071b9982d90b66e3bae7273751 100644 (file)
@@ -144,17 +144,16 @@ def create_workload(net, initializer=None, seed=0):
 
     Returns
     -------
-    net : tvm.relay.Function
-        The updated dataflow
+    mod : tvm.relay.Module
+        The created relay module.
 
     params : dict of str to NDArray
         The parameters.
     """
     mod = relay.Module.from_expr(net)
     mod = relay.transform.InferType()(mod)
-    net = mod[mod.entry_func]
     shape_dict = {
-        v.name_hint : v.checked_type for v in net.params}
+        v.name_hint : v.checked_type for v in mod["main"].params}
     np.random.seed(seed)
     initializer = initializer if initializer else Xavier()
     params = {}
@@ -164,4 +163,4 @@ def create_workload(net, initializer=None, seed=0):
         init_value = np.zeros(v.concrete_shape).astype(v.dtype)
         initializer(k, init_value)
         params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0))
-    return net, params
+    return mod, params
index 9721c26f2a1517401e9a10761ea2c94029e498fe..d0134c1a864d686e4baff84bf59d60e380523a63 100644 (file)
@@ -173,8 +173,8 @@ def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"):
         The data type
     Returns
     -------
-    net : nnvm.symbol
-        The computational graph
+    mod : tvm.relay.Module
+        The relay module that contains a LSTM network.
     params : dict of str to NDArray
         The parameters.
     """
index e178408a6a1bc7280dc30ae674efd1a198c835ce..337bde5d5889edc2e46ac2b280a7a43745413fd0 100644 (file)
@@ -84,8 +84,8 @@ def get_workload(batch_size,
 
     Returns
     -------
-    net : relay.Function
-        The dataflow.
+    mod : tvm.relay.Module
+        The relay module that contains a mlp network.
 
     params : dict of str to NDArray
         The parameters.
index dff103150ab0ade5e41ce8bb7eedf430ef1af3a2..3b068c05a24ed1d2ca69099beff25151801d1709 100644 (file)
@@ -130,8 +130,8 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtyp
 
     Returns
     -------
-    net : relay.Function
-        The computational graph
+    mod : tvm.relay.Module
+        The relay module that contains a MobileNet network.
 
     params : dict of str to NDArray
         The parameters.
index f67785917384cef5c3142335e3d4e67862ca2cbf..a8e369b7402190ae30539cef3970f19afa0a2894 100644 (file)
@@ -261,8 +261,8 @@ def get_workload(batch_size=1,
 
     Returns
     -------
-    net : relay.Function
-        The computational graph
+    mod : tvm.relay.Module
+        The relay module that contains a ResNet network.
 
     params : dict of str to NDArray
         The parameters.
index 5c90265183ff479d2615538c85bd4812a3536d68..1e9ea73e9360ecfae862332e75c0232402a32970 100644 (file)
@@ -149,8 +149,8 @@ def get_workload(batch_size=1,
 
     Returns
     -------
-    net : nnvm.Symbol
-        The computational graph
+    mod : tvm.relay.Module
+        The relay module that contains a SqueezeNet network.
 
     params : dict of str to NDArray
         The parameters.
index 06d9aa3d2d93c8027d9f5ee21333d7fab0242670..205c5b1fa8e395c7a616500fcc2647e0d760ba97 100644 (file)
@@ -124,8 +124,8 @@ def get_workload(batch_size,
 
     Returns
     -------
-    net : nnvm.Symbol
-        The computational graph
+    mod : tvm.relay.Module
+        The relay module that contains a VGG network.
 
     params : dict of str to NDArray
         The parameters.
index 3ab57f166d900e5bb0d0e4ca7827f114ac373a0b..7de77c8bcfd4b244664b8a122c8c7708d7325c97 100644 (file)
@@ -434,7 +434,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     relay_module = Optimize(relay_module, targets_, params);
     CHECK(relay_module.defined());
     // Get the updated function.
-    func = relay_module->Lookup(relay_module->entry_func->name_hint);
+    func = relay_module->Lookup("main");
 
     // Generate code for the updated function.
     graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
index 4dbcda9abb6f9d246fdc7965483489c561645c4e..2f656c8cef992965eb8e36052296220dd68abbb7 100644 (file)
@@ -52,10 +52,10 @@ Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
   // TODO(zhiics): This measurement is for temporary usage. Remove it later. We
   // need to introduce a better profiling method.
 #if ENABLE_PROFILING
-  DLOG(INFO) << "Entry function is " << module->entry_func << std::endl;
+  DLOG(INFO) << "Entry function is main." << std::endl;
   auto start = std::chrono::high_resolution_clock::now();
 #endif  // ENABLE_PROFILING
-  Object res = vm.Invoke(module->entry_func->name_hint, vm_args);
+  Object res = vm.Invoke("main", vm_args);
 #if ENABLE_PROFILING
   auto end = std::chrono::high_resolution_clock::now();
   auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
index a616f5e4836aa79953f5e7ff94dd21f782de09f1..0ad0a91efd217e97f4d56b4452ccc37dd2e8d80d 100644 (file)
@@ -46,8 +46,6 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
     n->global_var_map_.Set(kv.first->name_hint, kv.first);
   }
 
-  n->entry_func = GlobalVarNode::make("main");
-
   for (const auto& kv : n->type_definitions) {
     // set global typevar map
     CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
@@ -59,6 +57,10 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
   return Module(n);
 }
 
+bool ModuleNode::ContainGlobalVar(const std::string& name) const {
+  return global_var_map_.find(name) != global_var_map_.end();
+}
+
 GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
   auto it = global_var_map_.find(name);
   CHECK(it != global_var_map_.end())
@@ -194,7 +196,8 @@ Module ModuleNode::FromExpr(
   } else {
     func = FunctionNode::make({}, expr, Type(), {}, {});
   }
-  mod->Add(mod->entry_func, func);
+  auto main_gv = GlobalVarNode::make("main");
+  mod->Add(main_gv, func);
   return mod;
 }
 
@@ -203,7 +206,7 @@ TVM_REGISTER_NODE_TYPE(ModuleNode);
 TVM_REGISTER_API("relay._make.Module")
 .set_body_typed(ModuleNode::make);
 
-TVM_REGISTER_API("relay._make.Module_Add")
+TVM_REGISTER_API("relay._module.Module_Add")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   Module mod = args[0];
   GlobalVar var = args[1];
@@ -231,6 +234,9 @@ TVM_REGISTER_API("relay._module.Module_AddDef")
 TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
 .set_body_method<Module>(&ModuleNode::GetGlobalVar);
 
+TVM_REGISTER_API("relay._module.Module_ContainGlobalVar")
+.set_body_method<Module>(&ModuleNode::ContainGlobalVar);
+
 TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
 .set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
 
index 71d189b0800fcb7151c2333b3bf072d3a48bc3cc..7b896a8d0f7feba0cb78b350ee9362f26661dae4 100644 (file)
@@ -161,7 +161,7 @@ class ConstantFolder : public ExprMutator {
     auto mod = ModuleNode::FromExpr(expr);
     auto seq = transform::Sequential(passes);
     mod = seq(mod);
-    auto entry_func = mod->Lookup(mod->entry_func);
+    auto entry_func = mod->Lookup("main");
     expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
     return ValueToExpr(executor_(expr));
   }
index b7f12b65751dbd6d29b700684f60631b7552706f..3b7628a10789c9b885485996034fcf6e7d5b9ed9 100644 (file)
@@ -751,7 +751,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     auto mod = ModuleNode::FromExpr(expr);
     auto seq = transform::Sequential(passes);
     mod = seq(mod);
-    auto entry_func = mod->Lookup(mod->entry_func);
+    auto entry_func = mod->Lookup("main");
     auto fused_infered =
         expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
     return Reify(executor_(fused_infered), ll);
@@ -1018,7 +1018,6 @@ Expr PostProcess(const Expr& e) {
 }  // namespace partial_eval
 
 Module PartialEval(const Module& m) {
-  CHECK(m->entry_func.defined());
   relay::partial_eval::PartialEvaluator pe(m);
   std::vector<GlobalVar> gvs;
   for (const auto& p : m->functions) {
index 7527d2a216286c5a0684fa6ecfd20ee4e4df0b99..dbfbb7ef1de31d659abdf5e556b8dac7220900f3 100644 (file)
@@ -263,7 +263,7 @@ Expr QuantizeRealize(const Call& ref_call,
 Expr FoldConstantOpt(const Expr& expr) {
   auto mod = ModuleNode::FromExpr(expr);
   mod = transform::FoldConstant()(mod);
-  auto entry_func = mod->Lookup(mod->entry_func);
+  auto entry_func = mod->Lookup("main");
   return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
 }
 
index 5ae39084a88f0f2fd2649f93ad5c2f9434c08a60..64f125a9050657581ea6c114991cb67d23743ea3 100644 (file)
@@ -774,7 +774,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) {
     // type check it anyway; afterwards we can just recover type
     // from the type-checked function to avoid doing unnecessary work.
 
-    Function func = mod->Lookup(mod->entry_func);
+    Function func = mod->Lookup("main");
 
     // FromExpr wraps a naked expression as a function, we will unbox
     // it here.
@@ -784,7 +784,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) {
       return func->body;
     }
   } else {
-    auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr);
+    auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr);
     CHECK(WellFormed(e));
     auto free_tvars = FreeTypeVars(e, mod_ref);
     CHECK(free_tvars.size() == 0)
index 8257e94db19777c6093adc616f6418cc3fc85837..38a88309ed1efd5c543d72f2b3549c83fb440848 100644 (file)
@@ -35,7 +35,7 @@ TEST(Relay, SelfReference) {
   auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
   auto mod = relay::ModuleNode::FromExpr(fx);
   mod = relay::transform::InferType()(mod);
-  auto type_fx = mod->Lookup(mod->entry_func);
+  auto type_fx = mod->Lookup("main");
 
   auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
   CHECK(AlphaEqual(type_fx->checked_type(), expected));
index a943ba29cc923e24d777ab0292588f346cca5049..0df78fc62f42eb3e416cf7d18ba3032f1e340536 100644 (file)
@@ -84,9 +84,9 @@ TEST(Relay, Sequential) {
   }
 
   CHECK(mod.defined());
-  auto entry_func = mod->entry_func;
+  auto entry_func = mod->GetGlobalVar("main");
   CHECK(entry_func.defined());
-  relay::Function f = mod->Lookup(entry_func->name_hint);
+  relay::Function f = mod->Lookup("main");
   CHECK(f.defined());
 
   // Expected function
@@ -102,7 +102,7 @@ TEST(Relay, Sequential) {
   // Infer type for the expected function.
   auto mod1 = relay::ModuleNode::FromExpr(expected_func);
   mod1 = relay::transform::InferType()(mod1);
-  auto expected = mod1->Lookup(mod1->entry_func);
+  auto expected = mod1->Lookup("main");
   CHECK(relay::AlphaEqual(f, expected));
 }
 
index 98f872ce19b20fa65b7c1da16ec576676d1c60d1..35914ec1f9bfe7361d1e727964a362cca8bd484b 100644 (file)
@@ -20,11 +20,10 @@ from tvm.relay import transform
 from model_zoo import c2_squeezenet, relay_squeezenet
 
 
-def compare_graph(lhs_mod, func):
-    rhs_mod = relay.Module.from_expr(func)
+def compare_graph(lhs_mod, rhs_mod):
+    lhs_mod = transform.InferType()(lhs_mod)
     rhs_mod = transform.InferType()(rhs_mod)
-    assert relay.analysis.alpha_equal(lhs_mod[lhs_mod.entry_func],
-                                      rhs_mod[rhs_mod.entry_func])
+    assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
 
 
 def test_squeeze_net():
@@ -32,8 +31,8 @@ def test_squeeze_net():
     dtype_dict = {'data': 'float32'}
     mod, _, = relay.frontend.from_caffe2(
         c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict)
-    relay_func, _ = relay_squeezenet()
-    compare_graph(mod, relay_func)
+    relay_mod, _ = relay_squeezenet()
+    compare_graph(mod, relay_mod)
 
 
 if __name__ == '__main__':
index 13f987c32be78d8f389c3877699eb8e7b98a3cd5..59d4dd6f29b755ecd69c00648aee7cb5a833e39f 100644 (file)
@@ -48,7 +48,7 @@ def run_model_checkonly(model_file, model_name='', input_name='image'):
     shape_dict = {input_name : x.shape}
     mod, params = relay.frontend.from_coreml(model, shape_dict)
     for target, ctx in ctx_list():
-        tvm_output = get_tvm_output(mod[mod.entry_func], x, params, target, ctx)
+        tvm_output = get_tvm_output(mod["main"], x, params, target, ctx)
         print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat))
 
 def test_mobilenet_checkonly():
index 37a46f6ce3dc022a1ccb7db033a9e4ae5d0040ef..467d5529d0c14780197964073bf3299d044a7efc 100644 (file)
@@ -19,15 +19,17 @@ from tvm import relay
 from tvm.relay import transform
 import model_zoo
 
-def compare_graph(f1, f2):
-    assert relay.analysis.alpha_equal(f1, f2)
+def compare_graph(lhs_mod, rhs_mod):
+    lhs_mod = transform.InferType()(lhs_mod)
+    rhs_mod = transform.InferType()(rhs_mod)
+    assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
 
 def test_mlp():
     shape = {"data": (1, 1, 28, 28)}
     mx_fun = model_zoo.mx_mlp()
     mod, _ = relay.frontend.from_mxnet(mx_fun, shape=shape)
     relay_fun = model_zoo.relay_mlp()
-    compare_graph(mod[mod.entry_func], relay_fun)
+    compare_graph(mod, relay_fun)
 
 
 def test_vgg():
@@ -35,8 +37,8 @@ def test_vgg():
     for n in [11, 13, 16, 19]:
         mx_sym = model_zoo.mx_vgg(n)
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
-        relay_sym = model_zoo.relay_vgg(n)
-        compare_graph(mod[mod.entry_func], relay_sym)
+        relay_mod = model_zoo.relay_vgg(n)
+        compare_graph(mod, relay_mod)
 
 
 def test_resnet():
@@ -44,8 +46,8 @@ def test_resnet():
     for n in [18, 34, 50, 101]:
         mx_sym = model_zoo.mx_resnet(n)
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
-        relay_sym = model_zoo.relay_resnet(n)
-        compare_graph(mod[mod.entry_func], relay_sym)
+        relay_mod = model_zoo.relay_resnet(n)
+        compare_graph(mod, relay_mod)
 
 
 def test_squeezenet():
@@ -53,32 +55,32 @@ def test_squeezenet():
     for version in ['1.0', '1.1']:
         mx_sym = model_zoo.mx_squeezenet(version)
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
-        relay_sym = model_zoo.relay_squeezenet(version)
-        compare_graph(mod[mod.entry_func], relay_sym)
+        relay_mod = model_zoo.relay_squeezenet(version)
+        compare_graph(mod, relay_mod)
 
 
 def test_inception_v3():
     shape = {"data": (1, 3, 299, 299)}
     mx_sym = model_zoo.mx_inception_v3()
     mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
-    relay_sym = model_zoo.relay_inception_v3()
-    compare_graph(mod[mod.entry_func], relay_sym)
+    relay_mod = model_zoo.relay_inception_v3()
+    compare_graph(mod, relay_mod)
 
 
 def test_dqn():
     shape = {"data": (1, 4, 84, 84)}
     mx_sym = model_zoo.mx_dqn()
     mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
-    relay_sym = model_zoo.relay_dqn()
-    compare_graph(mod[mod.entry_func], relay_sym)
+    relay_mod = model_zoo.relay_dqn()
+    compare_graph(mod, relay_mod)
 
 
 def test_dcgan():
     shape = {"data": (2, 100)}
     mx_sym = model_zoo.mx_dcgan()
     mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
-    relay_sym = model_zoo.relay_dcgan(batch_size=2)
-    compare_graph(mod[mod.entry_func], relay_sym)
+    relay_mod = model_zoo.relay_dcgan(batch_size=2)
+    compare_graph(mod, relay_mod)
 
 
 def test_multi_outputs():
@@ -97,15 +99,13 @@ def test_multi_outputs():
         z = F.split(x, **kwargs)
         z = F.subtract(F.add(z[0], z[2]), y)
         func = relay.Function(relay.analysis.free_vars(z), z)
-        mod = relay.Module.from_expr(func)
-        mod = transform.InferType()(mod)
-        return mod[mod.entry_func]
+        return relay.Module.from_expr(func)
 
     mx_sym = mx_compose(mx, num_outputs=3, axis=1)
     mod, _ = relay.frontend.from_mxnet(
         mx_sym, shape={"x":xshape, "y":yshape})
-    relay_sym = relay_compose(relay, indices_or_sections=3, axis=1)
-    compare_graph(mod[mod.entry_func], relay_sym)
+    relay_mod = relay_compose(relay, indices_or_sections=3, axis=1)
+    compare_graph(mod, relay_mod)
 
 
 if __name__ == "__main__":
index d59fe1830a1840ed9eafcdc46f715cd99524eba8..b7d21912e44ea65b77ce6db3e8fcdecde2af90cf 100644 (file)
@@ -77,7 +77,7 @@ def test_alter_layout_conv2d():
             with autotvm.tophub.context(target):
                 mod = relay.Module.from_expr(N)
                 mod = transform.AlterOpLayout()(mod)
-                O = mod[mod.entry_func]
+                O = mod["main"]
 
                 # graph should differ
                 assert not relay.analysis.alpha_equal(N, O)
index e359ade864e218853523aa672090d03ca0a5487b..26301e90f789d5d9bddc9ebe3296ea92be39af42 100644 (file)
@@ -23,15 +23,15 @@ from tvm import relay
 from tvm.relay import testing
 
 
-def benchmark_execution(net,
+def benchmark_execution(mod,
                         params,
                         measure=False,
                         data_shape=(1, 3, 224, 224),
                         out_shape=(1, 1000),
                         dtype='float32'):
-    def get_tvm_output(net, data, params, target, ctx, dtype='float32'):
+    def get_tvm_output(mod, data, params, target, ctx, dtype='float32'):
         with relay.build_config(opt_level=1):
-            graph, lib, params = relay.build(net, target, params=params)
+            graph, lib, params = relay.build(mod, target, params=params)
 
         m = graph_runtime.create(graph, lib, ctx)
         # set inputs
@@ -50,9 +50,9 @@ def benchmark_execution(net,
 
         return out.asnumpy()
 
-    def get_tvm_vm_output(net, data, params, target, ctx, dtype='float32'):
-        ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
-        result = ex.evaluate(net)(data, **params)
+    def get_tvm_vm_output(mod, data, params, target, ctx, dtype='float32'):
+        ex = relay.create_executor('vm', mod=mod, ctx=ctx)
+        result = ex.evaluate()(data, **params)
         return result.asnumpy().astype(dtype)
 
     # random input
@@ -60,64 +60,64 @@ def benchmark_execution(net,
     target = "llvm"
     ctx = tvm.cpu(0)
 
-    tvm_out = get_tvm_output(net, tvm.nd.array(data.astype(dtype)), params,
+    tvm_out = get_tvm_output(mod, tvm.nd.array(data.astype(dtype)), params,
                              target, ctx, dtype)
-    vm_out = get_tvm_vm_output(net, tvm.nd.array(data.astype(dtype)), params,
+    vm_out = get_tvm_vm_output(mod, tvm.nd.array(data.astype(dtype)), params,
                                target, ctx, dtype)
     tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
 def test_mlp():
-    image_shape = (1, 28, 28)
-    net, params = testing.mlp.get_workload(1)
-    benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 10))
+    image_shape = (1, 1, 28, 28)
+    mod, params = testing.mlp.get_workload(1)
+    benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 10))
 
 
 def test_vgg():
     for n in [11, 16]:
-        net, params = testing.vgg.get_workload(1, num_layers=n)
-        benchmark_execution(net, params)
+        mod, params = testing.vgg.get_workload(1, num_layers=n)
+        benchmark_execution(mod, params)
 
 
 def test_resnet():
     for n in [18, 50]:
-        net, params = testing.resnet.get_workload(batch_size=1, num_layers=n)
-        benchmark_execution(net, params, True)
+        mod, params = testing.resnet.get_workload(batch_size=1, num_layers=n)
+        benchmark_execution(mod, params, True)
 
 
 def test_squeezenet():
     for version in ['1.0', '1.1']:
-        net, params = testing.squeezenet.get_workload(version=version)
-        benchmark_execution(net, params)
+        mod, params = testing.squeezenet.get_workload(version=version)
+        benchmark_execution(mod, params)
 
 
 def test_inception_v3():
-    image_shape = (3, 299, 299)
-    net, params = testing.inception_v3.get_workload(image_shape=image_shape)
-    benchmark_execution(net, params, data_shape=image_shape)
+    image_shape = (1, 3, 299, 299)
+    mod, params = testing.inception_v3.get_workload(image_shape=image_shape)
+    benchmark_execution(mod, params, data_shape=image_shape)
 
 
 def test_dqn():
-    image_shape = (4, 84, 84)
-    net, params = testing.dqn.get_workload(
+    image_shape = (1, 4, 84, 84)
+    mod, params = testing.dqn.get_workload(
         batch_size=1, image_shape=image_shape)
-    benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 18))
+    benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 18))
 
 
 def test_dcgan():
     image_shape = (1, 100)
-    net, params = testing.dcgan.get_workload(batch_size=1)
-    benchmark_execution(net, params, data_shape=image_shape)
+    mod, params = testing.dcgan.get_workload(batch_size=1)
+    benchmark_execution(mod, params, data_shape=image_shape)
 
 
 def test_mobilenet():
-    net, params = testing.mobilenet.get_workload(batch_size=1)
-    benchmark_execution(net, params)
+    mod, params = testing.mobilenet.get_workload(batch_size=1)
+    benchmark_execution(mod, params)
 
 
 def test_densenet():
-    net, params = testing.densenet.get_workload(batch_size=1)
-    benchmark_execution(net, params)
+    mod, params = testing.densenet.get_workload(batch_size=1)
+    benchmark_execution(mod, params)
 
 
 if __name__ == '__main__':
index 7374ab94dde408f6777eceb61061769281ef8454..0bef382cb5d0bc9e1227c3e5a976dc78d4117e56 100644 (file)
@@ -24,48 +24,48 @@ def get_network(name, batch_size):
     input_shape = (batch_size, 3, 224, 224)
 
     if name == 'resnet-18':
-        net, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
+        mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
     elif name == 'mobilenet':
-        net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
+        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
     elif name == 'dcgan':
-        net, params = relay.testing.dcgan.get_workload(batch_size=batch_size)
+        mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size)
         input_shape = (batch_size, 100)
     else:
         raise ValueError("Unsupported network: " + name)
 
-    return net, params, input_shape
+    return mod, params, input_shape
 
 def test_task_extraction():
     target = 'llvm'
 
-    net, params, input_shape = get_network('resnet-18', batch_size=1)
-    tasks = autotvm.task.extract_from_program(net, target=target,
+    mod, params, input_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                               params=params,
                                               ops=(relay.op.nn.conv2d,))
     assert len(tasks) == 12
 
-    net, params, input_shape = get_network('resnet-18', batch_size=1)
-    tasks = autotvm.task.extract_from_program(net, target=target,
-                                            params=params,
-                                            ops=(relay.op.nn.dense,))
+    mod, params, input_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.dense,))
     assert len(tasks) == 1
 
-    net, params, input_shape = get_network('resnet-18', batch_size=1)
-    tasks = autotvm.task.extract_from_program(net, target=target,
-                                            params=params,
-                                            ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+    mod, params, input_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense))
     assert len(tasks) == 13
 
-    net, params, input_shape = get_network('mobilenet', batch_size=1)
-    tasks = autotvm.task.extract_from_program(net, target=target,
-                                            params=params,
-                                            ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+    mod, params, input_shape = get_network('mobilenet', batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense))
     assert len(tasks) == 20
 
-    net, params, input_shape = get_network('dcgan', batch_size=1)
-    tasks = autotvm.task.extract_from_program(net, target=target,
-                                            params=params,
-                                            ops=(relay.op.nn.conv2d_transpose,))
+    mod, params, input_shape = get_network('dcgan', batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d_transpose,))
     assert len(tasks) == 4
 
 if __name__ == '__main__':
index 479c4169a95906f336733481b6af344aa4429a99..ea16a8d6122ebab0f6c8046e4a9c3949f925bcd7 100644 (file)
@@ -29,7 +29,7 @@ def test_compile_engine():
         f = relay.Function([x], z)
         mod = relay.Module.from_expr(f)
         mod = relay.transform.InferType()(mod)
-        return mod[mod.entry_func]
+        return mod["main"]
     z1 = engine.lower(get_func((10,)), "llvm")
     z2 = engine.lower(get_func((10,)), "llvm")
     z3 = engine.lower(get_func(()), "llvm")
index 742e3b4daa9f147bee0d1168f5d8ff0c1b33b971..fbccb94bc67088d2e1ed8dc90bf2ca611c018a00 100644 (file)
@@ -125,7 +125,7 @@ def test_plan_memory():
     func = relay.Function([x, y], z)
     mod = relay.Module.from_expr(func)
     mod = relay.transform.FuseOps(0)(mod)
-    func = mod[mod.entry_func]
+    func = mod["main"]
     smap = relay.backend._backend.GraphPlanMemory(func)
     storage_ids = set()
     device_types = set()
index 3c79fb7605210e8c40c26735ad0dade7c7c7d841..0e5e981a53216a9eb1f25a5a051a5634b999b323 100644 (file)
@@ -224,9 +224,8 @@ def test_tuple_passing():
 
     fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
     mod = relay.Module({})
-    gv = relay.GlobalVar('fn')
+    gv = relay.GlobalVar('main')
     mod[gv] = fn
-    mod.entry_func = gv
     mod = relay.transform.InferType()(mod)
 
     ctx = tvm.cpu()
index aad4856fa9431bc2fd023fd5ef4a51fd3442423c..c446f361101d84fd57575a1d78c9796c0d7a0b84 100644 (file)
@@ -21,7 +21,7 @@ def check_type_err(expr, msg):
     try:
         mod = relay.Module.from_expr(expr)
         mod = relay.transform.InferType()(mod)
-        entry = mod[mod.entry_func]
+        entry = mod["main"]
         expr = entry if isinstance(expr, relay.Function) else entry.body
         assert False
     except tvm.TVMError as err:
index 9b5010286d4f665e12db090f7172f7fb726987e6..2e0cd374b0251dcec0e827021a2f4785496e81a6 100644 (file)
@@ -49,7 +49,7 @@ def test_ad():
     func = relay.Function([x], x + x)
     mod = relay.Module.from_expr(gradient(func))
     mod = relay.transform.InferType()(mod)
-    back_func = mod[mod.entry_func]
+    back_func = mod["main"]
     feats = detect_feature(back_func)
     assert feats == set([
         Feature.fVar,
index 7da623a45ce6e06e7fd66bc1443f1aaea6db6b45..3dcba4778f5fa2dd1e809a8cd1b4f44bb9fdf14f 100644 (file)
@@ -24,7 +24,7 @@ from tvm.relay.testing import ctx_list
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = relay.transform.InferType()(mod)
-    return mod[mod.entry_func]
+    return mod["main"]
 
 
 def sigmoid(x):
index 8baec8c79e9abd7826bb31a69d6d078c56523aaa..b5abafadf49ebe59bcab0ea7bac868b6cadd2be2 100644 (file)
@@ -24,7 +24,7 @@ import topi.testing
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 def sigmoid(x):
index bcf6b7f80abd5ab597f4f31a478e7910ffd7ad6e..046da8de5fe89ceb3de2431642da90a57d75f0d1 100644 (file)
@@ -28,7 +28,7 @@ import topi.testing
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_collapse_sum_like():
index 722e8d178fab7e92624ab52c322906e6de23cb10..9f49f61c0d5f258147c268696adca3a82db24cd6 100644 (file)
@@ -26,7 +26,7 @@ import topi.testing
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_conv2d_infer_type():
index 575996fbe61eea8ae667cd47ad3bb9bd16c6174c..e1a760421349ac270fb7b40ef5ea869728513f18 100644 (file)
@@ -26,7 +26,7 @@ from tvm.relay.testing import ctx_list
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_zeros_ones():
index 9bab5d87389a3eddea9d7f5f8198048bb0d1cc63..69fd88b562b7b426b23fb70dd8250950ebd0b4d8 100644 (file)
@@ -24,7 +24,7 @@ import topi.testing
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_binary_op():
index cd008e3d19a3aa85abfedcde39f139e46f908dd1..328e4d5b0da34cc1c17c7a723f20cf91cb167511 100644 (file)
@@ -27,7 +27,7 @@ import topi.testing
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 def test_resize_infer_type():
index 65fd0b0819ccb7b795046478e8ff0a17de1d3624..6b31eed8f166d4ed4c4c0750f0101676b66a5def 100644 (file)
@@ -28,7 +28,7 @@ def run_opt_pass(expr, passes):
     seq = transform.Sequential(passes)
     with transform.PassContext(opt_level=3):
         mod = seq(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index 86ebf73d3dd6905137a3d92f41309ef762ccf891..14d53a0e2c2c5db8340dd2e79aeafec19fb7628a 100644 (file)
@@ -31,7 +31,7 @@ def run_opt_pass(expr, passes):
     seq = transform.Sequential(passes)
     with transform.PassContext(opt_level=3):
         mod = seq(mod)
-    return mod[mod.entry_func]
+    return mod["main"]
 
 
 def test_redundant_annotation():
index c7b88a8dc9e3f9674031741c493228e230ab667f..b72ded21ef5219316e901444d5f85905d0896f3a 100644 (file)
@@ -58,7 +58,7 @@ def test_canonicalize_cast():
                                      _transform.InferType()])
         with _transform.PassContext(opt_level=3):
             mod = seq(mod)
-        y = mod[mod.entry_func.name_hint]
+        y = mod["main"]
         y_expected = expected(data, conv_weight, bias1, bias2)
         gv = relay.GlobalVar("expected")
         mod[gv] = y_expected
index 4ea11f42f40dc60047c82457b11b70aa4199a082..599b308b213676a1a9208c03e777fb8d3b40c7eb 100644 (file)
@@ -21,13 +21,13 @@ from tvm.relay import transform
 def run_combine_parallel(expr, min_num_branches=3):
     mod = relay.Module.from_expr(expr)
     mod = transform.CombineParallelConv2D(min_num_branches)(mod)
-    return mod[mod.entry_func]
+    return mod["main"]
 
 def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, transform.Pass)
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    return mod[mod.entry_func]
+    return mod["main"]
 
 
 def test_combine_parallel_conv2d():
index 17a836beecd57de90b806f6c5f081b49719183e2..f3515800e449f5c0f2a8472d14bcd5889c3f31f6 100644 (file)
@@ -49,7 +49,7 @@ def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, transform.Pass)
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index f08d0dfd1f2620eb865754a3f1619bf978cd3bb4..09ea7044daf5f4ca14ba8c11586256d28ef7f6bb 100644 (file)
@@ -24,7 +24,7 @@ def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, transform.Pass)
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index 5308e472129a62bb87e8aa467d7acdfc10889275..73c3a4eb4073c76dec4de8f45c539be0a06b605d 100644 (file)
@@ -26,7 +26,7 @@ def test_eta_expand_basic():
     with _transform.PassContext(opt_level=3):
         mod = seq(mod)
 
-    got = mod[mod.entry_func.name_hint]
+    got = mod["main"]
 
     y = relay.var('y', 'int32')
     expected = relay.Function([y], orig(y))
index 881ec8f912c9c2de5a460446309ba30fc319cc1c..97b20c6b9219a91ec21f2076a681a84b1f5fd1a5 100644 (file)
@@ -25,7 +25,7 @@ def run_opt_pass(expr, opt_pass):
 
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index 70354fbdaa3b81dec888233ca7877ea1070d1358..d6f471bef04a2377106823fa40abc69ee71b944c 100644 (file)
@@ -27,7 +27,7 @@ def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, transform.Pass)
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index b2f7b9340ad42252441fa874e6aa4ce00e9067ef..8bcde88ba1cd824e6c1c8d1f62216f6bd6efef14 100644 (file)
@@ -357,7 +357,7 @@ def test_tuple_intermediate():
     m = fuse2(relay.Module.from_expr(orig))
     relay.build(m, 'llvm')
     after = run_opt_pass(expected(x), transform.InferType())
-    assert relay.analysis.alpha_equal(m[m.entry_func], after)
+    assert relay.analysis.alpha_equal(m["main"], after)
 
 
 def test_tuple_consecutive():
@@ -412,7 +412,7 @@ def test_tuple_consecutive():
     m = fuse2(relay.Module.from_expr(orig))
     relay.build(m, 'llvm')
     after = run_opt_pass(expected(dshape), transform.InferType())
-    assert relay.analysis.alpha_equal(m[m.entry_func], after)
+    assert relay.analysis.alpha_equal(m["main"], after)
 
 
 def test_inception_like():
@@ -479,7 +479,7 @@ def test_inception_like():
     m = fuse2(relay.Module.from_expr(orig))
     relay.build(m, 'llvm')
     after = run_opt_pass(expected(dshape), transform.InferType())
-    assert relay.analysis.alpha_equal(m[m.entry_func], after)
+    assert relay.analysis.alpha_equal(m["main"], after)
 
 
 def test_fuse_parallel_injective():
index 555e418644bcff2abb36535823589eb07ebc2e57..3fc1d74de8768e8024652718d0c628552c8293fd 100644 (file)
@@ -185,9 +185,9 @@ def test_pow():
     i = relay.var("i", t)
     func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
     func = gradient(func, mod=mod)
-    mod[mod.entry_func] = func
+    mod["main"] = func
     m = transform.InferType()(mod)
-    back_func = m[m.entry_func]
+    back_func = m["main"]
     assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
     i_nd = rand(dtype, *shape)
     ex = create_executor(mod=mod)
index e68c748d1bb183ab5c18d7b8b211d7c1f88f7bde..0ad1e3abe7595b781fab15fd8a8eef00d4b26e0c 100644 (file)
@@ -25,7 +25,7 @@ def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, transform.Pass)
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index 930dbe0451983e4898e67d39516fd9cfb303ccd6..22e9c76b4acafe74100cd830e47b5bbc8bb41003 100644 (file)
@@ -29,7 +29,7 @@ from tvm.relay.testing import ctx_list
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = _transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index 8855b089ba8c04819e3ee95d980c0d674aa31628..452399a2a5eaa2cb68b6a918371aae65924c96aa 100644 (file)
@@ -41,7 +41,7 @@ def run_opt_pass(expr, passes):
     seq = transform.Sequential(passes)
     with transform.PassContext(opt_level=3):
        mod = seq(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
@@ -57,10 +57,10 @@ def dcpe(expr, mod=None, grad=False):
         expr = gradient(expr)
     if mod:
         assert isinstance(expr, Function)
-        mod[mod.entry_func] = expr
+        mod["main"] = expr
         seq = transform.Sequential(passes)
         mod = seq(mod)
-        return mod[mod.entry_func]
+        return mod["main"]
     return run_opt_pass(expr, passes)
 
 
@@ -192,8 +192,8 @@ def test_map():
     orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
     expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil())))
     expected = Function([], expected)
-    mod[mod.entry_func] = expected
-    expected = mod[mod.entry_func]
+    mod["main"] = expected
+    expected = mod["main"]
     orig = Function([], orig)
     res = dcpe(orig, mod=mod)
     assert alpha_equal(res.body, expected.body)
@@ -206,8 +206,8 @@ def test_loop():
     loop = GlobalVar("loop")
     mod[loop] = Function([x], loop(x), t, [t])
     expected = Call(loop, [const(1)])
-    mod[mod.entry_func] = Function([], expected)
-    expected = mod[mod.entry_func].body
+    mod["main"] = Function([], expected)
+    expected = mod["main"].body
     call = Function([], loop(const(1)))
     res = dcpe(call, mod=mod)
     assert alpha_equal(res.body, expected)
index 21aa02df7f3a4fe2b37f662964d80dc8b8debaa8..f6f67d6b6ac9328651ea9b38e93868285f2f06ed 100644 (file)
@@ -25,7 +25,7 @@ from tvm.relay import transform
 def run_infer_type(expr):
     mod = relay.Module.from_expr(expr)
     mod = transform.InferType()(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index 51b8793f7667dc4da0c81fc34f60a262e95fe449..d5d44fdaec7290cd140242f62cc15505809236d3 100644 (file)
@@ -30,7 +30,7 @@ def run_opt_pass(expr, passes):
     seq = transform.Sequential(passes)
     with transform.PassContext(opt_level=3):
        mod = seq(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
@@ -195,7 +195,7 @@ def test_gradient_if():
     net = relay.Function([cond,x,y], net)
     mod = relay.Module.from_expr(net)
     mod = relay.transform.ToANormalForm()(mod)
-    mod[mod.entry_func] = relay.transform.gradient(mod[mod.entry_func], mode='higher_order')
+    mod["main"] = relay.transform.gradient(mod["main"], mode='higher_order')
     mod = relay.transform.ToANormalForm()(mod)
 
 
index 128fc49b58ca9507a9d1c030e31943ad9a69d7d1..2b6f2ef6b8587122dcd8f6ae23aba76e3c799bdc 100644 (file)
@@ -42,12 +42,12 @@ def test_recursion():
     double = relay.Function([x], x + x)
     i = relay.var("i", t)
     func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
-    mod[mod.entry_func] = func
-    mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod)
-    mod[mod.entry_func] = un_cps(mod[mod.entry_func])
+    mod["main"] = func
+    mod["main"] = to_cps(mod["main"], mod=mod)
+    mod["main"] = un_cps(mod["main"])
     ex = create_executor(mod=mod)
     i_nd = rand(dtype, *shape)
-    forward = ex.evaluate(mod.entry_func)(i_nd)
+    forward = ex.evaluate()(i_nd)
     tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
 
 
index 9e8c5887ac582170b566d9b78ea6712f5dabc863..a29172471d4843aee8a7e6c7e471a21f4e2deb8f 100644 (file)
@@ -24,7 +24,7 @@ from tvm.relay.analysis import detect_feature
 def run_opt_pass(expr, opt_pass):
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
index 29b79283a1fc7114e23dce74c425493280b8ecc0..eae05ec7d7e0e342c27c6263423987074a8c28e7 100644 (file)
@@ -25,7 +25,7 @@ def run_infer_type(expr, mod=None):
     if not mod:
         mod = relay.Module.from_expr(expr)
         mod = transform.InferType()(mod)
-        entry = mod[mod.entry_func]
+        entry = mod["main"]
         return entry if isinstance(expr, relay.Function) else entry.body
     else:
         if isinstance(expr, relay.GlobalVar):
@@ -34,7 +34,7 @@ def run_infer_type(expr, mod=None):
             func = expr
             if not isinstance(expr, relay.Function):
                 func = relay.Function(analysis.free_vars(expr), expr)
-            mod[mod.entry_func] = func
+            mod["main"] = func
             gv = "main"
         mod = transform.InferType()(mod)
 
@@ -266,7 +266,7 @@ def test_type_args():
 
 def test_global_var_recursion():
     mod = relay.Module({})
-    gv = relay.GlobalVar("foo")
+    gv = relay.GlobalVar("main")
     x = relay.var('x', shape=[])
     tt = relay.scalar_type('float32')
 
index 963f2ac46846534cbd3331ab471477be5c2692bb..b500a937a6eea827114550f5687673713f4261b3 100644 (file)
@@ -25,7 +25,7 @@ def test_dup_type():
     b = relay.Var("b", t)
     mod = relay.Module.from_expr(make_id(b))
     mod = transform.InferType()(mod)
-    inferred = mod[mod.entry_func].body
+    inferred = mod["main"].body
     assert inferred.checked_type == relay.TupleType([t, t])
 
 
@@ -39,9 +39,9 @@ def test_id_type():
     make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
     t = relay.scalar_type("float32")
     b = relay.Var("b", t)
-    mod[mod.entry_func] = relay.Function([], make_id(b))
+    mod["main"] = relay.Function([], make_id(b))
     mod = transform.InferType()(mod)
-    assert mod[mod.entry_func].body.checked_type == id_type(t)
+    assert mod["main"].body.checked_type == id_type(t)
 
 
 if __name__ == "__main__":
index 302dc553cdb0d1ab3dd273be209229a08cb5831d..f85d21255736f3387c93b698d566a5293cf982e1 100644 (file)
@@ -121,7 +121,7 @@ def test_simple_call():
     mod[sum_up] = func
     i_data = np.array(0, dtype='int32')
     iarg = relay.var('i', shape=[], dtype='int32')
-    mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
+    mod["main"] = relay.Function([iarg], sum_up(iarg))
     result = veval(mod, i_data)
     tvm.testing.assert_allclose(result.asnumpy(), i_data)
 
@@ -140,7 +140,7 @@ def test_count_loop():
     mod[sum_up] = func
     i_data = np.array(0, dtype='int32')
     iarg = relay.var('i', shape=[], dtype='int32')
-    mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
+    mod["main"] = relay.Function([iarg], sum_up(iarg))
     result = veval(mod, i_data)
     tvm.testing.assert_allclose(result.asnumpy(), i_data)
 
@@ -163,7 +163,7 @@ def test_sum_loop():
     accum_data = np.array(0, dtype='int32')
     iarg = relay.var('i', shape=[], dtype='int32')
     aarg = relay.var('accum', shape=[], dtype='int32')
-    mod[mod.entry_func] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
+    mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
     result = veval(mod, i_data, accum_data)
     tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
 
@@ -212,7 +212,7 @@ def test_list_constructor():
     one4 = cons(relay.const(3), one3)
     f = relay.Function([], one4)
 
-    mod[mod.entry_func] = f
+    mod["main"] = f
 
     result = veval(mod)()
     obj = to_list(result)
@@ -284,7 +284,7 @@ def test_compose():
     mod[add_one] = add_one_func
 
     f = relay.Function([y], add_two_body)
-    mod[mod.entry_func] = f
+    mod["main"] = f
 
     x_data = np.array(np.random.rand()).astype('float32')
     result = veval(mod)(x_data)
index 1c3171944bc9dc447ea8531115115926f846afe2..ee37b12e74d030434ff1d2320dea0cfcbbbb4dc9 100644 (file)
@@ -44,8 +44,8 @@ def _create_data(target, dshape, dtype, layout):
     conv2 = relay.nn.conv2d(conv1, w2, channels=32, kernel_size=(3, 3), padding=(1, 1))
     out = relay.add(conv1, conv2)
     net = relay.Function(relay.analysis.free_vars(out), out)
-    net, params = relay.testing.create_workload(net)
-    tasks = autotvm.task.extract_from_program(net,
+    mod, params = relay.testing.create_workload(net)
+    tasks = autotvm.task.extract_from_program(mod["main"],
                                               target=target,
                                               params=params,
                                               ops=(relay.op.nn.conv2d,))
@@ -160,7 +160,7 @@ def test_DPTuner_run():
 
     g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
     mod = relay.module.Module()
-    mod[mod.entry_func] = g
+    mod["main"] = g
     costs = [0.02, 0.02, 0.045]
     config_list = []
     cfg_dict = {"i": -1,
index 5bbd1c4860c26ff7c386cb161a4e61a6836c039f..c66854ac65a9de26db8333b9f182275dc1777337 100644 (file)
@@ -64,7 +64,7 @@ def test_has_multiple_inputs():
 
 
 def test_expr2graph():
-    net, _ = resnet.get_workload(num_layers=50, batch_size=1)
+    mod, _ = resnet.get_workload(num_layers=50, batch_size=1)
     node_dict = {}
     node_list = []
     target_ops = ["conv2d"]
@@ -80,9 +80,9 @@ def test_expr2graph():
             op_name_list.append("Tuple")
         else:
             op_name_list.append("null")
-    relay.analysis.post_order_visit(net, _count_node)
+    relay.analysis.post_order_visit(mod["main"], _count_node)
 
-    expr2graph(net, target_ops, node_dict, node_list)
+    expr2graph(mod["main"], target_ops, node_dict, node_list)
     for i, item in enumerate(zip(op_name_list, node_list)):
         op_name, node = item
         assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
index 290f9756f1955a85c8b0d64ff0ed659caef6618c..b671f214022613b8ac3ae6a892fd537a3084568e 100644 (file)
@@ -81,28 +81,29 @@ def get_network(name, batch_size):
 
     if "resnet" in name:
         n_layer = int(name.split('-')[1])
-        net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
     elif "vgg" in name:
         n_layer = int(name.split('-')[1])
-        net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
     elif name == 'mobilenet':
-        net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
+        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
     elif name == 'squeezenet_v1.1':
-        net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
+        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
     elif name == 'inception_v3':
         input_shape = (1, 3, 299, 299)
-        net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
     elif name == 'mxnet':
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
         mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
-        net = mod[mod.entry_func]
+        net = mod["main"]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        mod = relay.Module.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
 
-    return net, params, input_shape, output_shape
+    return mod, params, input_shape, output_shape
 
 
 #################################################################
@@ -316,10 +317,10 @@ def tune_tasks(tasks,
 def tune_and_evaluate(tuning_opt):
     # extract workloads from relay program
     print("Extract tasks...")
-    net, params, input_shape, _ = get_network(network, batch_size=1)
-    tasks = autotvm.task.extract_from_program(net, target=target,
-                                            params=params,
-                                            ops=(relay.op.nn.conv2d,))
+    mod, params, input_shape, _ = get_network(network, batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d,))
 
     # run tuning tasks
     print("Tuning...")
@@ -330,7 +331,7 @@ def tune_and_evaluate(tuning_opt):
         print("Compile...")
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build_module.build(
-                net, target=target, params=params)
+                mod, target=target, params=params)
 
         # export library
         tmp = tempdir()
index c158e4b9fe3634f50d9dc9b9b67ad1f6a2459e22..5044bd13a7d0ad57b58d7ba8db66452b1e362939 100644 (file)
@@ -81,28 +81,29 @@ def get_network(name, batch_size):
 
     if "resnet" in name:
         n_layer = int(name.split('-')[1])
-        net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
     elif "vgg" in name:
         n_layer = int(name.split('-')[1])
-        net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
     elif name == 'mobilenet':
-        net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
     elif name == 'squeezenet_v1.1':
-        net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
+        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
     elif name == 'inception_v3':
         input_shape = (1, 3, 299, 299)
-        net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
     elif name == 'mxnet':
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
         mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
-        net = mod[mod.entry_func]
+        net = mod["main"]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        mod = relay.Module.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
 
-    return net, params, input_shape, output_shape
+    return mod, params, input_shape, output_shape
 
 ###########################################
 # Set Tuning Options
@@ -218,9 +219,9 @@ def tune_tasks(tasks,
 def tune_and_evaluate(tuning_opt):
     # extract workloads from relay program
     print("Extract tasks...")
-    net, params, input_shape, out_shape = get_network(network, batch_size=1)
-    tasks = autotvm.task.extract_from_program(net, target=target,
-                                            params=params, ops=(relay.op.nn.conv2d,))
+    mod, params, input_shape, out_shape = get_network(network, batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
+                                              params=params, ops=(relay.op.nn.conv2d,))
 
     # run tuning tasks
     print("Tuning...")
@@ -231,7 +232,7 @@ def tune_and_evaluate(tuning_opt):
         print("Compile...")
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build_module.build(
-                net, target=target, params=params)
+                mod, target=target, params=params)
 
         # export library
         tmp = tempdir()
index c011268fda512d858d5ecdeb5f8686df39a9c705..94a86248c9353ea3468952c6f9a732a65e4d04c2 100644 (file)
@@ -82,28 +82,29 @@ def get_network(name, batch_size):
 
     if "resnet" in name:
         n_layer = int(name.split('-')[1])
-        net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
     elif "vgg" in name:
         n_layer = int(name.split('-')[1])
-        net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
     elif name == 'mobilenet':
-        net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
     elif name == 'squeezenet_v1.1':
-        net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
+        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
     elif name == 'inception_v3':
         input_shape = (1, 3, 299, 299)
-        net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
     elif name == 'mxnet':
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
         mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
-        net = mod[mod.entry_func]
+        net = mod["main"]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        mod = relay.Module.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
 
-    return net, params, input_shape, output_shape
+    return mod, params, input_shape, output_shape
 
 
 #################################################################
@@ -300,8 +301,10 @@ def tune_tasks(tasks,
 def tune_and_evaluate(tuning_opt):
     # extract workloads from relay program
     print("Extract tasks...")
-    net, params, input_shape, _ = get_network(network, batch_size=1)
-    tasks = autotvm.task.extract_from_program(net, target=target, target_host=target_host,
+    mod, params, input_shape, _ = get_network(network, batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod["main"],
+                                              target=target,
+                                              target_host=target_host,
                                               params=params, ops=(relay.op.nn.conv2d,))
 
     # run tuning tasks
@@ -313,7 +316,7 @@ def tune_and_evaluate(tuning_opt):
         print("Compile...")
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build_module.build(
-                net, target=target, params=params, target_host=target_host)
+                mod, target=target, params=params, target_host=target_host)
         # export library
         tmp = tempdir()
         if use_android:
index c8d9def206fe363e850e7d60ec5f34576db035e5..b53b3c12178b5565b5e3760e27674b27064840a6 100644 (file)
@@ -49,28 +49,29 @@ def get_network(name, batch_size):
 
     if "resnet" in name:
         n_layer = int(name.split('-')[1])
-        net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
     elif "vgg" in name:
         n_layer = int(name.split('-')[1])
-        net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
     elif name == 'mobilenet':
-        net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
     elif name == 'squeezenet_v1.1':
-        net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
+        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
     elif name == 'inception_v3':
         input_shape = (1, 3, 299, 299)
-        net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
+        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
     elif name == 'mxnet':
         # an example for mxnet model
         from mxnet.gluon.model_zoo.vision import get_model
         block = get_model('resnet18_v1', pretrained=True)
         mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
-        net = mod[mod.entry_func]
+        net = mod["main"]
         net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
+        mod = relay.Module.from_expr(net)
     else:
         raise ValueError("Unsupported network: " + name)
 
-    return net, params, input_shape, output_shape
+    return mod, params, input_shape, output_shape
 
 # Replace "llvm" with the correct target of your CPU.
 # For example, for AWS EC2 c5 instance with Intel Xeon
@@ -177,21 +178,21 @@ def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
 def tune_and_evaluate(tuning_opt):
     # extract workloads from relay program
     print("Extract tasks...")
-    net, params, data_shape, out_shape = get_network(model_name, batch_size)
-    tasks = autotvm.task.extract_from_program(net, target=target,
+    mod, params, data_shape, out_shape = get_network(model_name, batch_size)
+    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                               params=params, ops=(relay.op.nn.conv2d,))
 
     # run tuning tasks
     print("Tuning...")
     tune_kernels(tasks, **tuning_opt)
-    tune_graph(net, data_shape, log_file, graph_opt_sch_file)
+    tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)
 
     # compile kernels with graph-level best records
     with autotvm.apply_graph_best(graph_opt_sch_file):
         print("Compile...")
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build_module.build(
-                net, target=target,  params=params)
+                mod, target=target, params=params)
 
         # upload parameters to device
         ctx = tvm.cpu()
index 78377849c10b2d42c054cda7183a751d5b230fc8..d19805bdc2fbfabbf24f227cf9d75faa37cb1170 100644 (file)
@@ -142,7 +142,7 @@ with open(synset_path) as f:
 shape_dict = {'data': x.shape}
 mod, params = relay.frontend.from_mxnet(block, shape_dict)
 # we want a probability so add a softmax operator
-func = mod[mod.entry_func]
+func = mod["main"]
 func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
 
 ######################################################################
index 1109fd9b7d1cfae27e92605edd3b419b98ab10b1..d0e4c4ab0d1805c210914b978084c0ab890db1ff 100644 (file)
@@ -84,7 +84,7 @@ print('x', x.shape)
 shape_dict = {'data': x.shape}
 mod, params = relay.frontend.from_mxnet(block, shape_dict)
 ## we want a probability so add a softmax operator
-func = mod[mod.entry_func]
+func = mod["main"]
 func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
 
 ######################################################################
index b21f4fc5571c0f6941d228e489e7c8ec72c8540e..26157f07d83e6b61304de8e59887ea46ff791eee 100644 (file)
@@ -65,11 +65,11 @@ image_shape = (3, 224, 224)
 data_shape = (batch_size,) + image_shape
 out_shape = (batch_size, num_class)
 
-net, params = relay.testing.resnet.get_workload(
+mod, params = relay.testing.resnet.get_workload(
     num_layers=18, batch_size=batch_size, image_shape=image_shape)
 
 # set show_meta_data=True if you want to show meta data
-print(net.astext(show_meta_data=False))
+print(mod.astext(show_meta_data=False))
 
 ######################################################################
 # Compilation
@@ -98,7 +98,7 @@ opt_level = 3
 target = tvm.target.cuda()
 with relay.build_config(opt_level=opt_level):
     graph, lib, params = relay.build_module.build(
-        net, target, params=params)
+        mod, target, params=params)
 
 #####################################################################
 # Run the generate library
index f7d7be8c8047757cea828a69327c2f8b15947d3a..98dcab2c07f9cb76074dd5a233437c30030efd01 100644 (file)
@@ -26,7 +26,7 @@ def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, transform.Pass)
     mod = relay.Module.from_expr(expr)
     mod = opt_pass(mod)
-    entry = mod[mod.entry_func]
+    entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
 def _to_shape(shape):
index 43bc6acc15d597666cc2e34d85c9199a618e875e..80a213ccd3ffa2f85b22f28e3623696015b0e2ca 100644 (file)
@@ -127,7 +127,7 @@ def compile_network(opt, env, target):
     # Perform quantization in Relay
     with relay.quantize.qconfig(global_scale=8.0,
                                 skip_conv_layers=[0]):
-        relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
+        relay_prog = relay.quantize.quantize(mod["main"], params=params)
 
     # Perform graph packing and constant folding for VTA target
     if target.device_name == "vta":
index 9f734bc65d929041b6752aeadc7825f443697f17..2bf33bcdf79f9b41a8fe45385fd9a2034e565081 100644 (file)
@@ -91,7 +91,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
     # Perform quantization in Relay
     with relay.quantize.qconfig(global_scale=8.0,
                                 skip_conv_layers=[0]):
-        relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
+        relay_prog = relay.quantize.quantize(mod["main"], params=params)
 
     # Perform graph packing and constant folding for VTA target
     if target.device_name == "vta":
index 3035decb2160d13b4e3716aed9c8a962ddfafd98..c4e7aaf246b44ade7228ef2fd0ec30f79167e911 100644 (file)
@@ -160,7 +160,7 @@ with autotvm.tophub.context(target):
     # Perform quantization in Relay
     with relay.quantize.qconfig(global_scale=8.0,
                                 skip_conv_layers=[0]):
-        relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
+        relay_prog = relay.quantize.quantize(mod["main"], params=params)
 
     # Perform graph packing and constant folding for VTA target
     if target.device_name == "vta":