from __future__ import unicode_literals
import numpy as np
import six
+from numpy.testing import assert_array_equal
from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase
workspace.RunOperatorOnce(core.CreateOperator(
'LengthsToShape', ['l'], ['s']))
workspace.FeedBlob('res', np.array([3, 200], dtype=np.int32))
- assert ((workspace.FetchBlob('s') == workspace.FetchBlob('res')).all())
+ assert_array_equal(workspace.FetchBlob('s'), workspace.FetchBlob('res'))
def test_reshape_ops(self):
workspace.FeedBlob('res', np.array([[0, 0, 0, 0]], dtype=np.float32))
workspace.FeedBlob('input', np.zeros((2, 2), dtype=np.float32))
workspace.RunOperatorOnce(core.CreateOperator(
'Reshape', ['input', 'shape'], ['output', 'old_shape']))
- assert ((workspace.FetchBlob('output') ==
- workspace.FetchBlob('res')).all())
+ assert_array_equal(workspace.FetchBlob('output'),
+ workspace.FetchBlob('res'))
def test_basic_reshape(self):
_test_reshape(old_shape=(4, 2, 1), new_shape=(2, 4))
ws = Workspace()
with core.DeviceScope(device_option):
ws.FeedBlob(op_inputs[data_input_index], data)
+ if opset_version >= 5:
+ # Some newer models from ONNX Zoo come with pre-set "data_0" input
+ ws.FeedBlob("data_0", data)
ws.RunOperatorsOnce([op])
output_values = [ws.FetchBlob(name) for name in op_outputs]
Y_trt = namedtupledict('Outputs', op_outputs)(*output_values)
np.testing.assert_allclose(Y_c2, Y_trt, rtol=1e-3)
- @unittest.skip("Until fixing Reshape op")
+ @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
def test_resnet50(self):
- self._test_onnx_importer('resnet50', 0)
+ self._test_onnx_importer('resnet50', 0, 9)
@unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
def test_bvlc_alexnet(self):
- self._test_onnx_importer('bvlc_alexnet', 0)
+ self._test_onnx_importer('bvlc_alexnet', 0, 9)
@unittest.skip("Until fixing Unsqueeze op")
def test_densenet121(self):
self._test_onnx_importer('densenet121', -1, 3)
- @unittest.skip("Until fixing Reshape op")
+ @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
def test_inception_v1(self):
- self._test_onnx_importer('inception_v1', -1, 3)
+ self._test_onnx_importer('inception_v1', -3, 9)
- @unittest.skip("Until fixing Reshape op")
+ @unittest.skip("Until fixing Unsqueeze op")
def test_inception_v2(self):
- self._test_onnx_importer('inception_v2', 0, 3)
+ self._test_onnx_importer('inception_v2', 0, 9)
@unittest.skip('Need to revisit our ChannelShuffle exporter to avoid generating 5D tensor')
def test_shufflenet(self):
@unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
def test_squeezenet(self):
- self._test_onnx_importer('squeezenet', -1)
+ self._test_onnx_importer('squeezenet', -1, 9)
@unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
def test_vgg16(self):
- self._test_onnx_importer('vgg16', 0)
+ self._test_onnx_importer('vgg16', 0, 9)
- @unittest.skip("Until fixing Reshape op")
+ @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support")
def test_vgg19(self):
- self._test_onnx_importer('vgg19', -1, 3)
+ self._test_onnx_importer('vgg19', -2, 9)
class TensorRTTransformTest(DownloadingTestCase):