[RELAY] Support concatenate. (#2298)
authorziheng <ziheng@apache.org>
Mon, 17 Dec 2018 18:18:21 +0000 (10:18 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 17 Dec 2018 18:18:21 +0000 (10:18 -0800)
nnvm/python/nnvm/to_relay.py

index 318ff1ee92dd8c5dab9a3b18935e16ea910b2401..a168f4fd88d273129ce28ff41ff8a07438cdfca8 100644 (file)
@@ -364,6 +364,11 @@ def _squeeze(children, attrs, odtype='float32'):
 
     return op.squeeze(children[0], axis)
 
+def _concatenate(children, attrs, odtype='float32'):
+    axis = attrs.get_int('axis', None)
+    return op.concatenate(children, axis)
+
+
 NNVM_OP_2_RELAY_OP = {
     'flatten': _nn_batch_flatten,
     'dense': _dense,
@@ -422,6 +427,7 @@ NNVM_OP_2_RELAY_OP = {
     'strided_slice': _strided_slice,
     'split': _split,
     'squeeze': _squeeze,
+    'concatenate': _concatenate,
 }
 
 
@@ -436,7 +442,7 @@ def to_relay(graph, shape_dict, dtype_dict, params):
     shape_dict : dict of str to shape
        The input shape.
 
-    dtype_dict : dict of str to shape
+    dtype_dict : dict of str to str/dtype
        The input shape.
 
     params : dict of str to array