Fix bug in ONNX importer (#3084)
authorJared Roesch <roeschinc@gmail.com>
Mon, 29 Apr 2019 19:54:16 +0000 (12:54 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 29 Apr 2019 19:54:16 +0000 (12:54 -0700)
python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index 53f104c..d91ee4b 100644 (file)
@@ -944,7 +944,10 @@ class GraphProto(object):
                                               dtype=self._params[i_name].dtype)
             else:
                 self._num_input += 1
-                tshape = self._shape[i_name] if i_name in self._shape else ()
+                if i_name in self._shape:
+                    tshape = self._shape[i_name]
+                else:
+                    raise ValueError("Must provide an input shape for `{0}`.".format(i_name))
                 if isinstance(self._dtype, dict):
                     dtype = self._dtype[i_name] if i_name in self._dtype else d_type
                 else:
index 2564d83..7be6bb6 100644 (file)
@@ -724,10 +724,15 @@ def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs):
     else:
         fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs)
 
+    if is_shape == True:
+        inputs = []
+    else:
+        inputs = [helper.make_tensor_value_info("input_a",
+                  TensorProto.FLOAT, list(input_dim))]
+
     graph = helper.make_graph([fill_node],
                               "fill_test",
-                              inputs = [helper.make_tensor_value_info("input_a",
-                                            TensorProto.FLOAT, list(input_dim))],
+                              inputs,
                               outputs = [helper.make_tensor_value_info("out",
                                             TensorProto.FLOAT, list(out.shape))])