add (#4311)
authorHaichen Shen <shenhaichen@gmail.com>
Tue, 12 Nov 2019 19:54:56 +0000 (11:54 -0800)
committerLeyuan Wang <laurawly@gmail.com>
Tue, 12 Nov 2019 19:54:56 +0000 (11:54 -0800)
python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index 25c0626..abef45d 100644 (file)
@@ -20,10 +20,12 @@ from __future__ import absolute_import as _abs
 
 import json
 import tvm
+from topi.util import get_const_tuple
 from .. import analysis
 from .. import expr as _expr
 from .. import op as _op
 from .. import module as _module
+from .. import scope_builder as _scope_builder
 from ... import nd as _nd
 
 from .common import StrAttrsDict
@@ -1037,6 +1039,47 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
     new_attrs['axis'] = attrs.get_int('axis')
     return _op.nn.fifo_buffer(*inputs, **new_attrs)
 
+def _mx_cond(inputs, attrs, subgraphs):
+    assert len(subgraphs) == 3
+    cond_input_locs = json.loads(attrs.get_str("cond_input_locs"))
+    then_input_locs = json.loads(attrs.get_str("then_input_locs"))
+    else_input_locs = json.loads(attrs.get_str("else_input_locs"))
+    num_outputs = attrs.get_int("num_outputs")
+
+    input_args = []
+    for i, arg in enumerate(inputs):
+        var = _expr.var("arg%s" % i, _infer_type(arg).checked_type)
+        input_args.append(var)
+    cond_args = [input_args[i] for i in cond_input_locs]
+    then_args = [input_args[i] for i in then_input_locs]
+    else_args = [input_args[i] for i in else_input_locs]
+
+    cond_arg_shapes = [arg.type_annotation.shape for arg in cond_args]
+    cond_arg_dtype_info = [arg.type_annotation.dtype for arg in cond_args]
+    cond_func = _from_mxnet_impl(subgraphs[0], cond_arg_shapes, cond_arg_dtype_info)
+    cond = _expr.Call(cond_func, cond_args).astype("bool")
+    cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape)
+    if len(cond_shape) > 0:
+        assert len(cond_shape) == 1 and cond_shape[0] == 1, "Condition is not scalar"
+        cond = _op.take(cond, _expr.const(1, "int"))
+
+    sb = _scope_builder.ScopeBuilder()
+    with sb.if_scope(cond):
+        then_arg_shapes = [arg.type_annotation.shape for arg in then_args]
+        then_arg_dtype_info = [arg.type_annotation.dtype for arg in then_args]
+        then_func = _from_mxnet_impl(subgraphs[1], then_arg_shapes, then_arg_dtype_info)
+        sb.ret(_expr.Call(then_func, then_args))
+    with sb.else_scope():
+        else_arg_shapes = [arg.type_annotation.shape for arg in else_args]
+        else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args]
+        else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info)
+        sb.ret(_expr.Call(else_func, else_args))
+    func = _expr.Function(input_args, sb.get())
+    ret = _expr.Call(func, inputs)
+    if num_outputs > 1:
+        ret = _expr.TupleWrapper(ret, num_outputs)
+    return ret
+
 
 # Note: due to attribute conversion constraint
 # ops in the identity set must be attribute free
@@ -1204,6 +1247,8 @@ _convert_map = {
     # NLP
     "RNN"               : _mx_rnn_layer,
     "_rnn_param_concat" : _mx_rnn_param_concat,
+    # control flow
+    "_cond"             : _mx_cond,
     # Depricated:
     "Crop"              : _mx_crop_like,
     # List of missing operators that are present in NNVMv1
@@ -1245,9 +1290,13 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
         Converted relay Function
     """
     assert symbol is not None
-    jgraph = json.loads(symbol.tojson())
+    if isinstance(symbol, dict):
+        jgraph = symbol
+    else:
+        jgraph = json.loads(symbol.tojson())
     jnodes = jgraph["nodes"]
     node_map = {}
+    shape_idx = 0
 
     for nid, node in enumerate(jnodes):
         children = [node_map[e[0]][e[1]] for e in node["inputs"]]
@@ -1255,14 +1304,27 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
         node_name = node["name"]
         op_name = node["op"]
         if op_name == "null":
-            shape = shape_dict[node_name] if node_name in shape_dict else None
+            if isinstance(shape_dict, dict):
+                shape = shape_dict[node_name] if node_name in shape_dict else None
+            elif isinstance(shape_dict, (list, tuple)):
+                shape = shape_dict[shape_idx]
+            else:
+                raise ValueError("Unknown type of shape_dict: %s" + type(shape_dict))
             if isinstance(dtype_info, dict):
                 dtype = dtype_info[node_name] if node_name in dtype_info else "float32"
+            elif isinstance(dtype_info, (list, tuple)):
+                dtype = dtype_info[shape_idx]
             else:
                 dtype = dtype_info
+            if isinstance(shape_dict, (list, tuple)):
+                shape_idx += 1
             node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
         elif op_name in _convert_map:
-            res = _convert_map[op_name](children, attrs)
+            if op_name in ['_cond', '_foreach', '_while_loop']:
+                subgraphs = node['subgraphs']
+                res = _convert_map[op_name](children, attrs, subgraphs)
+            else:
+                res = _convert_map[op_name](children, attrs)
             if res is None:
                 # defer conversion, used in RNN state initialization
                 res = [node]
index f45f152..be4436d 100644 (file)
@@ -909,6 +909,31 @@ def test_forward_deconvolution():
     verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
 
+def test_forward_cond():
+    def verify(a_np, b_np):
+        a_nd, b_nd = mx.nd.array(a_np), mx.nd.array(b_np)
+        pred = a_nd * b_nd < 5
+        then_func = lambda: (a_nd + 5) * (b_nd + 5)
+        else_func = lambda: (a_nd - 5) * (b_nd - 5)
+        ref_res = mx.nd.contrib.cond(pred, then_func, else_func)
+
+        a_sym, b_sym = mx.sym.var("a"), mx.sym.var("b")
+        pred = a_sym * b_sym < 5
+        then_func = lambda: (a_sym + 5) * (b_sym + 5)
+        else_func = lambda: (a_sym - 5) * (b_sym - 5)
+        mx_sym = mx.sym.contrib.cond(pred, then_func, else_func)
+
+        shape_dict = {"a": a_np.shape, "b": b_np.shape}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+        for target, ctx in ctx_list():
+            for kind in ["debug", "vm"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(a_np, b_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+
+    verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
+    verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
+
 
 if __name__ == '__main__':
     test_forward_mlp()
@@ -963,3 +988,4 @@ if __name__ == '__main__':
     test_forward_one_hot()
     test_forward_convolution()
     test_forward_deconvolution()
+    test_forward_cond()