From 5629901033d5949dc130d2371ae419d432b066e2 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sat, 22 Jun 2019 18:36:07 -0700 Subject: [PATCH] [Frontend][MxNet] Support bidirectional RNN layer (#3397) * Support bidirectional RNN layer * tweak * tweak --- python/tvm/relay/frontend/mxnet.py | 89 +++++++++++++++++++---------- tests/python/frontend/mxnet/test_forward.py | 25 ++++---- 2 files changed, 75 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 00cbc70..2f36355 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -748,13 +748,12 @@ def _mx_rnn_layer(inputs, attrs): num_layers = attrs.get_int("num_layers", 1) mode = attrs.get_str("mode") + output_states = attrs.get_bool("state_outputs", False) if mode.startswith("rnn"): mode, activation = mode.split('_') assert mode in ["rnn", "gru", "lstm"] bidirectional = attrs.get_bool("bidirectional", False) - if bidirectional: - raise tvm.error.OpAttributeUnimplemented( - "Bidirectional RNN op is not supported yet") + direct = 2 if bidirectional else 1 layout = attrs.get_str("layout", "TNC") if layout != "TNC": raise tvm.error.OpAttributeUnimplemented( @@ -765,11 +764,10 @@ def _mx_rnn_layer(inputs, attrs): seq_data = inputs[0] concat_weight = inputs[1] init_states = inputs[2:] - data_shape = ir_pass.infer_type(seq_data).checked_type.shape seq_len = int(data_shape[0]) - assert len(concat_weight) == num_layers * 4 - output_states = True + assert len(concat_weight) == num_layers * 4 * direct + for idx, state in enumerate(init_states[:]): if isinstance(state, dict): node = state @@ -787,43 +785,76 @@ def _mx_rnn_layer(inputs, attrs): assert axis >= 0 new_shape[i] = int(data_shape[axis]) init_states[idx] = _op.zeros(new_shape, dtype) - output_states = False weights = [] bias = [] states = [] + back_weights = [] + back_bias = [] + back_states = [] for i in range(num_layers): - w = [] - b = [] + weights.append([concat_weight[i*2*direct].args[0], + concat_weight[i*2*direct + 1].args[0]]) + bias.append([concat_weight[(num_layers+i)*2*direct].args[0], + concat_weight[(num_layers+i)*2*direct + 1].args[0]]) s = [] - for j in range(2): - w.append(concat_weight[i*2 + j].args[0]) - b.append(concat_weight[num_layers*2 + i*2 + j].args[0]) for state in init_states: - s.append(_op.take(state, _expr.const(i, "int32"), axis=0)) - weights.append(w) - bias.append(b) + s.append(_op.take(state, _expr.const(i*direct, "int32"), axis=0)) states.append(s) - - seq_output = [] - for t in range(seq_len): - data = _op.take(seq_data, _expr.const(t, "int32"), axis=0) - for l in range(num_layers): + if bidirectional: + back_weights.append([concat_weight[i*2*direct + 2].args[0], + concat_weight[i*2*direct + 3].args[0]]) + back_bias.append([concat_weight[(num_layers+i)*2*direct + 2].args[0], + concat_weight[(num_layers+i)*2*direct + 3].args[0]]) + s = [] + for state in init_states: + s.append(_op.take(state, _expr.const(i*direct+1, "int32"), axis=0)) + back_states.append(s) + + xs = [_op.take(seq_data, _expr.const(t, "int32"), axis=0) for t in range(seq_len)] + for l in range(num_layers): + outputs = [] + back_outputs = [] + for x in xs: if mode == "rnn": - out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation) + out, new_states = _rnn_cell(x, states[l], *weights[l], *bias[l], activation) elif mode == "gru": - out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l]) + out, new_states = _gru_cell(x, states[l], *weights[l], *bias[l]) else: # mode == "lstm" - out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l]) + out, new_states = _lstm_cell(x, states[l], *weights[l], *bias[l]) states[l] = new_states - data = out - seq_output.append(out) - - outputs = [_op.stack(seq_output, axis=0)] + outputs.append(out) + if bidirectional: + for x in reversed(xs): + if mode == "rnn": + out, new_states = _rnn_cell( + x, back_states[l], *back_weights[l], *back_bias[l], activation) + elif mode == "gru": + out, new_states = _gru_cell( + x, back_states[l], *back_weights[l], *back_bias[l]) + else: # mode == "lstm" + out, new_states = _lstm_cell( + x, back_states[l], *back_weights[l], *back_bias[l]) + back_states[l] = new_states + back_outputs.append(out) + back_outputs.reverse() + concat_outputs = [] + for t, out in enumerate(outputs): + new_out = _op.concatenate([out, back_outputs[t]], axis=-1) + concat_outputs.append(new_out) + outputs = concat_outputs + xs = outputs + + ret = [_op.stack(outputs, axis=0)] if output_states: for i in range(num_states): - outputs.append(_op.stack([s[i] for s in states], axis=0)) - return outputs + inputs = [] + for l, s in enumerate(states): + inputs.append(s[i]) + if bidirectional: + inputs.append(back_states[l][i]) + ret.append(_op.stack(inputs, axis=0)) + return ret # Note: due to attribute conversion constraint diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index d70b222..c82dc20 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -536,29 +536,31 @@ def test_forward_bilinear_resize(): verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) def test_forward_rnn_layer(): - def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True): + def verify(mode, seq_len, input_size, hidden_size, num_layers, + batch=1, init_states=True, bidirectional=False): if mode == "rnn": - layer = gluon.rnn.RNN(hidden_size, num_layers) + layer = gluon.rnn.RNN(hidden_size, num_layers, bidirectional=bidirectional) elif mode == "gru": - layer = gluon.rnn.GRU(hidden_size, num_layers) + layer = gluon.rnn.GRU(hidden_size, num_layers, bidirectional=bidirectional) else: # mode == "lstm" - layer = gluon.rnn.LSTM(hidden_size, num_layers) + layer = gluon.rnn.LSTM(hidden_size, num_layers, bidirectional=bidirectional) num_states = 2 if mode == "lstm" else 1 layer.initialize() layer.hybridize() dtype = "float32" - batch = 1 + directions = 2 if bidirectional else 1 data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype) data_mx = mx.nd.array(data_np) if init_states: shape_dict = {'data0': data_np.shape} inputs = {'data0': data_np} + state_shape = (num_layers*directions, batch, hidden_size) states_np = [] states_mx = [] for i in range(num_states): - s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype) + s = np.random.uniform(size=state_shape).astype(dtype) states_np.append(s) states_mx.append(mx.nd.array(s)) shape_dict['data%s' % (i+1)] = s.shape @@ -592,10 +594,13 @@ def test_forward_rnn_layer(): op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3) for mode in ["rnn", "gru", "lstm"]: - verify(mode, 64, 10, 64, 1) - verify(mode, 64, 10, 64, 2) - verify(mode, 64, 10, 32, 2) - verify(mode, 64, 10, 64, 2, init_states=False) + verify(mode, 1, 64, 64, 1) + verify(mode, 10, 64, 64, 2) + verify(mode, 10, 64, 32, 2) + verify(mode, 10, 64, 32, 2, batch=2) + verify(mode, 10, 64, 64, 3, init_states=False) + verify(mode, 10, 32, 64, 1, bidirectional=True) + verify(mode, 10, 64, 64, 3, batch=2, bidirectional=True, init_states=False) def test_forward_Crop(): def verify(xshape, yshape, offset=None): -- 2.7.4