hotfix for onnx (#3387)
authorZhi <5145158+zhiics@users.noreply.github.com>
Tue, 18 Jun 2019 04:51:24 +0000 (21:51 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 18 Jun 2019 04:51:24 +0000 (21:51 -0700)
python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index 468a748..bb968ec 100644 (file)
@@ -23,6 +23,7 @@ import numpy as np
 import tvm
 from ... import nd as _nd
 from .. import ir_pass
+from .. import transform as _transform
 from .. import expr as _expr
 from .. import module as _module
 from .. import op as _op
@@ -409,21 +410,27 @@ class Reshape(OnnxOpConverter):
             shape = tuple(params[inputs[1].name_hint].asnumpy())
             out = _op.reshape(inputs[0], shape)
         else:
-            # Try to infer shape by precompute prune if possible.
-            # TODO: good to check inputs to be in params.
-            #       to be enhanced when relay support list_input_names API of NNVM
-            logging.warning("Infering Reshape argument by precompute")
-            func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
+            data, shape = inputs
+            logging.warning("Constant evaluating Reshape's shape argument, may reduce performance")
+            shape_params = ir_pass.free_vars(shape)
+            func = _expr.Function(shape_params, shape)
+            mod = _module.Module.from_expr(func)
+            seq = _transform.Sequential([_transform.InferType(),
+                                         _transform.FoldConstant(),
+                                         _transform.FuseOps(0),
+                                         _transform.InferType()])
+            with tvm.relay.PassContext(opt_level=2):
+                mod = seq(mod)
             with tvm.relay.build_config(opt_level=0):
-                graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
-            ctx = tvm.context("llvm", 0)
-            from tvm.contrib import graph_runtime
-            m = graph_runtime.create(graph, lib, ctx)
-            m.set_input(**params)
-            m.run()
-            params_new = m.get_output(0)
-            inputs.pop(1)
-            out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten()))
+                ex = tvm.relay.create_executor("debug", mod=mod)
+                inputs = []
+                for sp in shape_params:
+                    if not sp.name_hint in params:
+                        sh = [int(i) for i in sp.type_annotation.shape]
+                        inputs.append(
+                            tvm.nd.array(np.random.rand(*sh).astype('float32')))
+                static_shape = ex.evaluate()(*inputs, **params)
+            out = _op.reshape(data, newshape=tuple(static_shape.asnumpy()))
 
         return out
 
@@ -568,6 +575,7 @@ class Shape(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        # TODO(@jroesch): use shape_of once it has been fixed)
         return _op.shape_of(inputs[0])
 
 class Cast(OnnxOpConverter):
@@ -1058,8 +1066,15 @@ class GraphProto(object):
             if op_name == "Constant":
                 t_proto = self._parse_attr(node.attribute)["value"]
                 self._num_param += 1
-                self._params[node.output[0]] = self._parse_array(t_proto)
-                self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims))
+                # We should convert scalar integers to int32, to normalize.
+                array = self._parse_array(t_proto)
+                if len(array.shape) == 0 and array.dtype == 'int64':
+                    array = _nd.array(array.asnumpy().astype('int32'))
+                self._params[node.output[0]] = array
+                self._nodes[node.output[0]] = new_var(
+                    node.output[0],
+                    shape=list(t_proto.dims),
+                    dtype=array.dtype)
             else:
                 if op_name == "ConstantFill":
                     fill_value = attr.get('value', 0.0)
index 7371a88..a52e3f0 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import attr
 import numpy as np
 import math
+import torch
+import torchvision
 import topi
 import topi.testing
 import tvm
@@ -1072,6 +1075,48 @@ def test_LogSoftmax():
                               'LogSoftmax',
                               {'axis': 1})
 
+
+def check_torch_conversion(model, input_size):
+    dummy_input = torch.randn(*input_size)
+    file_name = '{}.onnx'.format(model.__name__)
+    # Set verbose=True for more output
+    torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
+    onnx_model = onnx.load(file_name)
+    shapes = { '0' : input_size }
+    expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes)
+
+def test_resnet():
+    check_torch_conversion(torchvision.models.resnet18, (1,3,224,224))
+    # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))
+
+# def test_alexnet():
+    # Torch's ONNX export does not support the adaptive pooling used by AlexNet?
+    # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
+
+# Torch's ONNX export does not support the adaptive pooling used by vgg16?
+# def test_vgg16():
+#     check_torch_conversion(torchvision.models.vgg16, (1,3,224,224))
+
+# TODO(@jroesch): Update Torch + ONNX to support this import.
+# def test_squeezenet():
+#     # Torch's ONNX export does not support the max pooling used by Squezenet
+#     check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))
+
+def test_densenet():
+    check_torch_conversion(torchvision.models.densenet161, (1,3,224,224))
+
+def test_inception():
+    check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224))
+
+# TODO(@jroesch): Update Torch + ONNX to support this import.
+# def test_googlenet():
+#     check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))
+
+# TODO(@jroesch): Update Torch + ONNX to support this import.
+# def test_shufflenetv2():
+#     check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))
+
+
 if __name__ == '__main__':
     test_flatten()
     test_reshape()
@@ -1111,3 +1156,6 @@ if __name__ == '__main__':
     test_ParametricSoftplus()
     test_Scale()
     test_LogSoftmax()
+    test_resnet()
+    test_inception()
+    test_densenet()