[Relay][Frontend][ONNX] operator support: Tile (#3941)
authorNeo Chien <cchung100m@cs.ccu.edu.tw>
Fri, 20 Sep 2019 19:17:11 +0000 (03:17 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 20 Sep 2019 19:17:11 +0000 (12:17 -0700)
* [Relay][Frontend][ONNX] operator support: Tile

* Trigger notification

python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index b7fe2cf..822a431 100644 (file)
@@ -885,6 +885,18 @@ class And(Elemwise):
         return _op.logical_and(inputs[0], inputs[1])
 
 
+class Tile(Elemwise):
+    """Operator converter for Tile
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        if 'repeats' not in attr:
+            raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set '
+                                               'for operator Tile.')
+        reps = attr.pop('repeats')  # The number of times repeating the tensor data.
+        return _op.tile(inputs[0], reps)
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1002,7 +1014,8 @@ def _get_convert_map(opset):
         'Sign': Sign.get_converter(opset),
         'Equal': Equal.get_converter(opset),
         'Not': Not.get_converter(opset),
-        'And': And.get_converter(opset)
+        'And': And.get_converter(opset),
+        'Tile': Tile.get_converter(opset)
     }
 
 
index 7e0e11f..cdcc596 100644 (file)
@@ -1205,6 +1205,27 @@ def test_and():
     verify_and(indata=[x, y], dtype=bool)
 
 
+def verify_tile(indata, outdata, **kwargs):
+    node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs)
+    graph = helper.make_graph([node],
+                              'tile_test',
+                              inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
+                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])
+
+    model = helper.make_model(graph, producer_name='tile_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
+        tvm.testing.assert_allclose(outdata, tvm_out)
+
+
+def test_tile():
+    x = np.random.rand(2, 3, 4, 5).astype(np.float32)
+    repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
+    z = np.tile(x, repeats)
+    verify_tile(x, z, repeats=repeats)
+
+
 if __name__ == '__main__':
     test_flatten()
     test_reshape()
@@ -1250,3 +1271,4 @@ if __name__ == '__main__':
     test_sign()
     test_not()
     test_and()
+    test_tile()