refine error (#5929)
authorCody Yu <comaniac0422@gmail.com>
Fri, 26 Jun 2020 02:26:03 +0000 (19:26 -0700)
committerGitHub <noreply@github.com>
Fri, 26 Jun 2020 02:26:03 +0000 (19:26 -0700)
python/tvm/relay/frontend/mxnet.py

index 2454a55..321b145 100644 (file)
@@ -2309,6 +2309,20 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
     node_map = {}
     shape_idx = 0
 
+    # Check if there have any unsupported ops
+    unsupported = {}
+    for node in jnodes:
+        op_name = node["op"]
+        if op_name != "null" and op_name not in _convert_map:
+            if op_name not in unsupported:
+                unsupported[op_name] = 0
+            unsupported[op_name] += 1
+
+    if unsupported:
+        msg = '\n'.join(['{}: {}'.format(op_name, cnt) for op_name, cnt in unsupported.items()])
+        raise tvm.error.OpNotImplemented(
+            'One or more operators are not supported in frontend MXNet:\n{}'.format(msg))
+
     for nid, node in enumerate(jnodes):
         children = [node_map[e[0]][e[1]] for e in node["inputs"]]
         attrs = StrAttrsDict(node.get("attrs", {}))
@@ -2330,7 +2344,8 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
             if isinstance(shape_dict, (list, tuple)):
                 shape_idx += 1
             node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
-        elif op_name in _convert_map:
+        else:
+            assert op_name in _convert_map
             op_params = _get_op_params(children, attrs, op_name,
                                        node, params)
             res = _convert_map[op_name](*op_params)
@@ -2344,9 +2359,6 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
             else:
                 raise RuntimeError("unexpected type %s" % type(res))
             node_map[nid] = res
-        else:
-            raise tvm.error.OpNotImplemented(
-                'Operator {} is not supported in frontend MXNet.'.format(op_name))
 
     outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
     outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)