Revert "[Relay][Frontend][ONNX] Fix reshape precompute, and type error (#3230)" ...
authorTianqi Chen <tqchen@users.noreply.github.com>
Mon, 17 Jun 2019 23:27:53 +0000 (16:27 -0700)
committerGitHub <noreply@github.com>
Mon, 17 Jun 2019 23:27:53 +0000 (16:27 -0700)
This reverts commit df6957a5ea49806b3073bbb81e339ae379cbbb1c.

python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index a28981c..468a748 100644 (file)
@@ -409,24 +409,21 @@ class Reshape(OnnxOpConverter):
             shape = tuple(params[inputs[1].name_hint].asnumpy())
             out = _op.reshape(inputs[0], shape)
         else:
-            data, shape = inputs
-            logging.warning("Constant evaluating Reshape's shape argument, may reduce performance")
-            shape_params = ir_pass.free_vars(shape)
-            func = _expr.Function(shape_params, shape)
-            func = ir_pass.infer_type(func)
-            func = ir_pass.fold_constant(func)
-            shape_params = ir_pass.free_vars(func.body)
-            func = _expr.Function(shape_params, func.body)
+            # Try to infer shape by precompute prune if possible.
+            # TODO: good to check inputs to be in params.
+            #       to be enhanced when relay support list_input_names API of NNVM
+            logging.warning("Infering Reshape argument by precompute")
+            func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
             with tvm.relay.build_config(opt_level=0):
-                ex = tvm.relay.create_executor("debug")
-                inputs = []
-                for sp in shape_params:
-                    if not sp.name_hint in params:
-                        sh = [int(i) for i in sp.type_annotation.shape]
-                        inputs.append(
-                            tvm.nd.array(np.random.rand(*sh).astype('float32')))
-                static_shape = ex.evaluate(func)(*inputs, **params)
-            out = _op.reshape(data, newshape=tuple(static_shape.asnumpy()))
+                graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
+            ctx = tvm.context("llvm", 0)
+            from tvm.contrib import graph_runtime
+            m = graph_runtime.create(graph, lib, ctx)
+            m.set_input(**params)
+            m.run()
+            params_new = m.get_output(0)
+            inputs.pop(1)
+            out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten()))
 
         return out
 
@@ -571,7 +568,6 @@ class Shape(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        # TODO(@jroesch): use shape_of once it has been fixed
         return _op.shape_of(inputs[0])
 
 class Cast(OnnxOpConverter):
@@ -1062,15 +1058,8 @@ class GraphProto(object):
             if op_name == "Constant":
                 t_proto = self._parse_attr(node.attribute)["value"]
                 self._num_param += 1
-                # We should convert scalar integers to int32, to normalize.
-                array = self._parse_array(t_proto)
-                if len(array.shape) == 0 and array.dtype == 'int64':
-                    array = _nd.array(array.asnumpy().astype('int32'))
-                self._params[node.output[0]] = array
-                self._nodes[node.output[0]] = new_var(
-                    node.output[0],
-                    shape=list(t_proto.dims),
-                    dtype=array.dtype)
+                self._params[node.output[0]] = self._parse_array(t_proto)
+                self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims))
             else:
                 if op_name == "ConstantFill":
                     fill_value = attr.get('value', 0.0)
index d4c8ee9..7371a88 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import attr
 import numpy as np
 import math
-import torch
-import torchvision
 import topi
 import topi.testing
 import tvm
@@ -1075,47 +1072,6 @@ def test_LogSoftmax():
                               'LogSoftmax',
                               {'axis': 1})
 
-def check_torch_conversion(model, input_size):
-    dummy_input = torch.randn(*input_size)
-    file_name = '{}.onnx'.format(model.__name__)
-    # Set verbose=True for more output
-    torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
-    onnx_model = onnx.load(file_name)
-    shapes = { '0' : input_size }
-    expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes)
-
-def test_resnet():
-    check_torch_conversion(torchvision.models.resnet18, (1,3,224,224))
-    # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))
-
-# def test_alexnet():
-    # Torch's ONNX export does not support the adaptive pooling used by AlexNet?
-    # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
-
-# Torch's ONNX export does not support the adaptive pooling used by vgg16?
-# def test_vgg16():
-#     check_torch_conversion(torchvision.models.vgg16, (1,3,224,224))
-
-# TODO(@jroesch): Update Torch + ONNX to support this import.
-# def test_squeezenet():
-#     # Torch's ONNX export does not support the max pooling used by Squezenet
-#     check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))
-
-def test_densenet():
-    check_torch_conversion(torchvision.models.densenet161, (1,3,224,224))
-
-def test_inception():
-    check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224))
-
-# TODO(@jroesch): Update Torch + ONNX to support this import.
-# def test_googlenet():
-#     check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))
-
-# TODO(@jroesch): Update Torch + ONNX to support this import.
-# def test_shufflenetv2():
-#     check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))
-
-
 if __name__ == '__main__':
     test_flatten()
     test_reshape()
@@ -1155,6 +1111,3 @@ if __name__ == '__main__':
     test_ParametricSoftplus()
     test_Scale()
     test_LogSoftmax()
-    test_resnet()
-    test_inception()
-    test_densenet()