Add MXNet converter for RNN layer ops (#3125)
authorHaichen Shen <shenhaichen@gmail.com>
Thu, 2 May 2019 15:59:22 +0000 (08:59 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 2 May 2019 15:59:22 +0000 (11:59 -0400)
python/tvm/relay/build_module.py
python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index b16b5e28bf34f0c53e9eeb04f6eeb1b977217b50..a4929d0b839d8ef46b6073fe3450e7b68083a7be 100644 (file)
@@ -26,6 +26,7 @@ from .. import nd as _nd, target as _target, autotvm
 from ..contrib import graph_runtime as _graph_rt
 from . import ir_pass
 from . import expr as _expr
+from . import ty as _ty
 from .backend import interpreter as _interpreter
 from .backend import graph_runtime_codegen as _graph_gen
 
@@ -427,6 +428,8 @@ class GraphExecutor(_interpreter.Executor):
         self.target = target
 
     def _make_executor(self, func):
+        ret_type = ir_pass.infer_type(func).ret_type
+        num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
         graph_json, mod, params = build(func, target=self.target)
         gmodule = _graph_rt.create(graph_json, mod, self.ctx)
         if params:
@@ -440,7 +443,12 @@ class GraphExecutor(_interpreter.Executor):
             # Run the module, and fetch the output.
             gmodule.run()
             # make a copy so multiple invocation won't hurt perf.
-            return gmodule.get_output(0).copyto(_nd.cpu(0))
+            if num_outputs == 1:
+                return gmodule.get_output(0).copyto(_nd.cpu(0))
+            outputs = []
+            for i in range(num_outputs):
+                outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
+            return outputs
 
         return _graph_wrapper
 
index f1bf6788ea20a6b6660dccb1492a7c10e47cc15b..b93bd5b244eb872afef31eecdb77d7b926552f59 100644 (file)
@@ -34,6 +34,12 @@ from .nnvm_common import _warn_not_used
 
 __all__ = ['from_mxnet']
 
+_activation_map = {
+    "sigmoid": _op.sigmoid,
+    "tanh"   : _op.tanh,
+    "relu"   : _op.nn.relu
+}
+
 def _mx_fully_connected(inputs, attrs):
     import mxnet as mx
     units = attrs.get_int("num_hidden")
@@ -66,12 +72,6 @@ def _get_channel_axis(layout, op_name):
 def _mx_activations(inputs, attrs):
     act_type = attrs.get_str("act_type")
     assert len(inputs) == 1
-    if act_type == "sigmoid":
-        return _op.sigmoid(inputs[0])
-    if act_type == "tanh":
-        return _op.tanh(inputs[0])
-    if act_type == "relu":
-        return _op.nn.relu(inputs[0])
     if act_type == "softrelu":
         def _stable_softrelu(x):
             # log(1 + exp(-abs(x))) + relu(x)
@@ -80,8 +80,10 @@ def _mx_activations(inputs, attrs):
             return _op.add(_op.log(_op.add(one, exp_neg_abs_x)),
                            _op.nn.relu(x))
         return _stable_softrelu(inputs[0])
-    raise tvm.error.OpNotImplemented(
-        'Operator {} is not supported for frontend MXNet.'.format(act_type))
+    if act_type not in _activation_map:
+        raise tvm.error.OpNotImplemented(
+            'Operator {} is not supported for frontend MXNet.'.format(act_type))
+    return _activation_map[act_type](inputs[0])
 
 
 def _mx_compare(new_op, wrapper):
@@ -189,7 +191,8 @@ def _mx_pooling(inputs, attrs):
 def _mx_adaptive_avg_pooling(inputs, attrs):
     output_size = attrs.get_int_tuple("output_size", [])
     if output_size != (1,):
-        raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
+        raise tvm.error.OpAttributeUnimplemented(
+            "AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
     return _op.nn.global_avg_pool2d(inputs[0])
 
 
@@ -471,7 +474,7 @@ def _mx_take(inputs, attrs):
     assert len(inputs) == 2
     mode = attrs.get_str("mode", "clip")
     if mode == "raise":
-        raise RuntimeError("take doesn't support raise mode")
+        raise tvm.error.OpAttributeUnimplemented("take with raise mode is not supported yet")
     axis = attrs.get_int("axis", 0)
     return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
 
@@ -571,13 +574,13 @@ def _mx_l2_normalize(inputs, attrs):
 def _mx_shape_array(inputs, attrs):
     assert len(inputs) == 1
     if attrs.get_int("lhs_begin", None) is not None:
-        raise RuntimeError("shape_array doesn't support lhs_begin")
+        raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_begin")
     if attrs.get_int("lhs_end", None) is not None:
-        raise RuntimeError("shape_array doesn't support lhs_end")
+        raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_end")
     if attrs.get_int("rhs_begin", None) is not None:
-        raise RuntimeError("shape_array doesn't support rhs_begin")
+        raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_begin")
     if attrs.get_int("rhs_end", None) is not None:
-        raise RuntimeError("shape_array doesn't support rhs_end")
+        raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_end")
     return _op.shape_of(inputs[0], dtype='int64')
 
 
@@ -657,6 +660,101 @@ def _mx_argsort(inputs, attrs):
     return _op.argsort(inputs[0], **new_attrs)
 
 
+def _mx_rnn_param_concat(inputs, _):
+    # We don't need to concatenate RNN params because we will unravel the RNN op
+    return [inputs]
+
+
+def _mx_rnn_layer(inputs, attrs):
+    def _rnn_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias, activation):
+        i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
+        h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
+        out = _activation_map[activation](i2h + h2h)
+        return out, [out]
+
+    def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
+        dtype = ir_pass.infer_type(data).checked_type.dtype
+        i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
+        h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
+        i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1)
+        h2h_r, h2h_z, h2h = _op.split(h2h, indices_or_sections=3, axis=1)
+        reset_gate = _activation_map["sigmoid"](i2h_r + h2h_r)
+        update_gate = _activation_map["sigmoid"](i2h_z + h2h_z)
+        next_h_tmp = _activation_map["tanh"](reset_gate * h2h + i2h)
+        next_h = (_expr.const(1, dtype) - update_gate) * next_h_tmp + update_gate * states[0]
+        return next_h, [next_h]
+
+    def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
+        i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
+        h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
+        gates = i2h + h2h
+        slice_gates = _op.split(gates, indices_or_sections=4, axis=1)
+        in_gate = _activation_map["sigmoid"](slice_gates[0])
+        forget_gate = _activation_map["sigmoid"](slice_gates[1])
+        in_transform = _activation_map["tanh"](slice_gates[2])
+        out_gate = _activation_map["sigmoid"](slice_gates[3])
+        next_c = forget_gate * states[1] + in_gate * in_transform
+        next_h = out_gate * _activation_map["tanh"](next_c)
+        return next_h, [next_h, next_c]
+
+    num_layers = attrs.get_int("num_layers", 1)
+    mode = attrs.get_str("mode")
+    if mode.startswith("rnn"):
+        mode, activation = mode.split('_')
+    assert mode in ["rnn", "gru", "lstm"]
+    bidirectional = attrs.get_bool("bidirectional", False)
+    if bidirectional:
+        raise tvm.error.OpAttributeUnimplemented(
+            "Bidirectional RNN op is not supported yet")
+    layout = attrs.get_str("layout", "TNC")
+    if layout != "TNC":
+        raise tvm.error.OpAttributeUnimplemented(
+            "RNN with layout other than TNC is not supported yet")
+    num_states = 2 if mode == 'lstm' else 1
+    assert len(inputs) == num_states + 2
+
+    seq_data = inputs[0]
+    concat_weight = inputs[1]
+    concat_states = inputs[2:]
+    seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0])
+    assert len(concat_weight) == num_layers * 4
+
+    weights = []
+    bias = []
+    states = []
+    for i in range(num_layers):
+        w = []
+        b = []
+        s = []
+        for j in range(2):
+            w.append(concat_weight[i*2 + j].args[0])
+            b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
+        for state in concat_states:
+            s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
+        weights.append(w)
+        bias.append(b)
+        states.append(s)
+
+    seq_output = []
+    for t in range(seq_len):
+        data = _op.take(seq_data, _expr.const(t, "int32"), axis=0)
+        for l in range(num_layers):
+            if mode == "rnn":
+                out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation)
+            elif mode == "gru":
+                out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l])
+            else: # mode == "lstm"
+                out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l])
+            states[l] = new_states
+            data = out
+        seq_output.append(out)
+
+    outputs = [_op.stack(seq_output, axis=0)]
+    for i in range(num_states):
+        outputs.append(_op.stack([s[i] for s in states], axis=0))
+    return outputs
+
+
 # Note: due to attribute conversion constraint
 # ops in the identity set must be attribute free
 _identity_list = [
@@ -807,6 +905,9 @@ _convert_map = {
     "_contrib_box_nms" : _mx_box_nms,
     "_contrib_DeformableConvolution" : _mx_deformable_convolution,
     "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
+    # NLP
+    "RNN"               : _mx_rnn_layer,
+    "_rnn_param_concat" : _mx_rnn_param_concat,
     # List of missing operators that are present in NNVMv1
     # TODO(tvm-tvm): support all operators.
     #
index d00efb39e16f088d9a55312034878994f65e703d..067c356830bbb81345d073f79ec85035d8a3b26a 100644 (file)
@@ -527,6 +527,54 @@ def test_forward_bilinear_resize():
     mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10)
     verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
 
+def test_forward_rnn_layer():
+    def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
+        if mode == "rnn":
+            layer = gluon.rnn.RNN(hidden_size, num_layers)
+        elif mode == "gru":
+            layer = gluon.rnn.GRU(hidden_size, num_layers)
+        else: # mode == "lstm"
+            layer = gluon.rnn.LSTM(hidden_size, num_layers)
+        num_states = 2 if mode == "lstm" else 1
+        layer.initialize()
+
+        dtype = "float32"
+        data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
+        states_np = []
+        states_mx = []
+        shape_dict = {'data0': data_np.shape}
+        inputs = {'data0': data_np}
+        for i in range(num_states):
+            s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
+            states_np.append(s)
+            states_mx.append(mx.nd.array(s))
+            shape_dict['data%s' % (i+1)] = s.shape
+            inputs['data%s' % (i+1)] = s
+
+        layer.hybridize()
+        mx_out, mx_states = layer(mx.nd.array(data_np), states_mx)
+        mx_res = [mx_out] + mx_states
+        mx_sym = layer._cached_graph[1]
+        mx_params = {}
+        for name, param in layer.collect_params().items():
+            mx_params[name] = param._reduce()
+
+        new_sym, params = relay.frontend.from_mxnet(
+            mx_sym, shape=shape_dict, arg_params=mx_params)
+        for target, ctx in ctx_list():
+            # only test graph runtime because debug runtime is too slow
+            for kind in ["graph"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(**inputs, **params)
+                assert len(op_res) == len(mx_res)
+                for i, val in enumerate(op_res):
+                    tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
+
+    for mode in ["rnn", "gru", "lstm"]:
+        verify(mode, 64, 10, 64, 1)
+        verify(mode, 64, 10, 64, 2)
+        verify(mode, 64, 10, 32, 2)
+
 
 if __name__ == '__main__':
     test_forward_mlp()
@@ -566,3 +614,4 @@ if __name__ == '__main__':
     test_forward_take()
     test_forward_gather_nd()
     test_forward_bilinear_resize()
+    test_forward_rnn_layer()