Caffe 2: Reshape Op upgrade (#15380)
authorSergei Nikolaev <drnikolaev@users.noreply.github.com>
Mon, 14 Jan 2019 06:46:39 +0000 (22:46 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 14 Jan 2019 06:49:40 +0000 (22:49 -0800)
Summary:
This is follow up on #13945 where we had to turn off some TRT tests because some ops were not ready to accept ONNX opset 9+ models. This PR fixes Reshape.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15380

Differential Revision: D13649825

Pulled By: houseroad

fbshipit-source-id: b72e62803de5b63cc001c3fe4b3bf64dfa996e94

caffe2/python/operator_test/reshape_ops_test.py
caffe2/python/trt/test_trt.py
caffe2/python/trt/transform.py

index 98189b8..0cef958 100644 (file)
@@ -4,6 +4,7 @@ from __future__ import print_function
 from __future__ import unicode_literals
 import numpy as np
 import six
+from numpy.testing import assert_array_equal
 
 from caffe2.python import core, workspace
 from caffe2.python.test_util import TestCase
@@ -16,7 +17,7 @@ class TestLengthsToShapeOps(TestCase):
         workspace.RunOperatorOnce(core.CreateOperator(
             'LengthsToShape', ['l'], ['s']))
         workspace.FeedBlob('res', np.array([3, 200], dtype=np.int32))
-        assert ((workspace.FetchBlob('s') == workspace.FetchBlob('res')).all())
+        assert_array_equal(workspace.FetchBlob('s'), workspace.FetchBlob('res'))
 
     def test_reshape_ops(self):
         workspace.FeedBlob('res', np.array([[0, 0, 0, 0]], dtype=np.float32))
@@ -24,8 +25,8 @@ class TestLengthsToShapeOps(TestCase):
         workspace.FeedBlob('input', np.zeros((2, 2), dtype=np.float32))
         workspace.RunOperatorOnce(core.CreateOperator(
             'Reshape', ['input', 'shape'], ['output', 'old_shape']))
-        assert ((workspace.FetchBlob('output') ==
-                 workspace.FetchBlob('res')).all())
+        assert_array_equal(workspace.FetchBlob('output'),
+                           workspace.FetchBlob('res'))
 
     def test_basic_reshape(self):
         _test_reshape(old_shape=(4, 2, 1), new_shape=(2, 4))
index eb21917..cc21822 100644 (file)
@@ -131,30 +131,33 @@ class TensorRTOpTest(TestCase):
         ws = Workspace()
         with core.DeviceScope(device_option):
             ws.FeedBlob(op_inputs[data_input_index], data)
+            if opset_version >= 5:
+                # Some newer models from ONNX Zoo come with pre-set "data_0" input
+                ws.FeedBlob("data_0", data)
             ws.RunOperatorsOnce([op])
             output_values = [ws.FetchBlob(name) for name in op_outputs]
             Y_trt = namedtupledict('Outputs', op_outputs)(*output_values)
         np.testing.assert_allclose(Y_c2, Y_trt, rtol=1e-3)
 
-    @unittest.skip("Until fixing Reshape op")
+    @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
     def test_resnet50(self):
-        self._test_onnx_importer('resnet50', 0)
+        self._test_onnx_importer('resnet50', 0, 9)
 
     @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
     def test_bvlc_alexnet(self):
-        self._test_onnx_importer('bvlc_alexnet', 0)
+        self._test_onnx_importer('bvlc_alexnet', 0, 9)
 
     @unittest.skip("Until fixing Unsqueeze op")
     def test_densenet121(self):
         self._test_onnx_importer('densenet121', -1, 3)
 
-    @unittest.skip("Until fixing Reshape op")
+    @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
     def test_inception_v1(self):
-        self._test_onnx_importer('inception_v1', -1, 3)
+        self._test_onnx_importer('inception_v1', -3, 9)
 
-    @unittest.skip("Until fixing Reshape op")
+    @unittest.skip("Until fixing Unsqueeze op")
     def test_inception_v2(self):
-        self._test_onnx_importer('inception_v2', 0, 3)
+        self._test_onnx_importer('inception_v2', 0, 9)
 
     @unittest.skip('Need to revisit our ChannelShuffle exporter to avoid generating 5D tensor')
     def test_shufflenet(self):
@@ -162,15 +165,15 @@ class TensorRTOpTest(TestCase):
 
     @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
     def test_squeezenet(self):
-        self._test_onnx_importer('squeezenet', -1)
+        self._test_onnx_importer('squeezenet', -1, 9)
 
     @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
     def test_vgg16(self):
-        self._test_onnx_importer('vgg16', 0)
+        self._test_onnx_importer('vgg16', 0, 9)
 
-    @unittest.skip("Until fixing Reshape op")
+    @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
     def test_vgg19(self):
-        self._test_onnx_importer('vgg19', -1, 3)
+        self._test_onnx_importer('vgg19', -2, 9)
 
 
 class TensorRTTransformTest(DownloadingTestCase):
index fbc6c11..489defe 100644 (file)
@@ -35,7 +35,7 @@ def check_gpu_():
        raise Exception("TensorRT related functions require CUDA support")
 
 def convert_onnx_model_to_trt_op(onnx_model,
-        max_batch_size=50,
+        max_batch_size=64,
         max_workspace_size=2*1024*1024,
         verbosity=1,
         debug_builder=False):
@@ -77,7 +77,7 @@ def transform_caffe2_net(
         pred_net,
         input_shapes,
         populate_shapes = False,
-        max_batch_size=50,
+        max_batch_size=64,
         max_workspace_size=2*1024*1024,
         verbosity=1,
         debug_builder=False,