"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See the License for the specific language governing permissions and
limitations under the License.
"""
-
-from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
+from extensions.ops.GRU import GRU
+from extensions.ops.LSTM import LSTM
+from extensions.ops.RNN import RNN
from mo.front.extractor import FrontExtractorOp
-from extensions.ops.lstm_sequence import LSTMSequence
+from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from mo.utils.error import Error
from mo.utils.utils import refer_to_faq_msg
state_size = attrs.int('state_size', None)
bidirectional = attrs.bool('bidirectional', False)
num_layers = attrs.int('num_layers', 1)
+ layout = attrs.str('layout', 'TNC') # in MXNet RNN by default take data in
+ # format [seq_len, batch_size, inp_size]
node_attrs = {
- 'batch_dim': 1,
- 'sequence_dim': 0,
+ 'batch_dim': layout.index('N'),
+ 'sequence_dim': layout.index('T'),
'blobs_wrb': False,
'hidden_size': state_size,
'has_num_directions': bidirectional,
+ 'direction': 'bidirectional' if bidirectional else 'forward',
+ 'num_layers': num_layers,
'format': 'mxnet',
+ 'multilayers': num_layers != 1,
+ 'gate_order': None,
}
- if bidirectional:
- raise Error(
- "Operation RNN with bidirectional not supported. num_directions = 1 is supported only " +
- refer_to_faq_msg(86))
-
- if num_layers > 1:
- raise Error(
- "Operation RNN with num_layers more then one not supported. num_layers = 1 is supported only " +
- refer_to_faq_msg(86))
-
- if mode == 'lstm':
- LSTMSequence.update_node_stat(node, node_attrs)
+ if mode == 'rnn_tanh':
+ node_attrs['gate_order'] = [0]
+ node_attrs['activations'] = ['tanh']
+ RNN.update_node_stat(node, node_attrs)
+ elif mode == 'rnn_relu':
+ node_attrs['gate_order'] = [0]
+ node_attrs['activations'] = ['relu']
+ RNN.update_node_stat(node, node_attrs)
+ elif mode == 'gru':
+ node_attrs['gate_order'] = [1, 0, 2]
+ node_attrs['linear_before_reset'] = 1
+ GRU.update_node_stat(node, node_attrs)
+ elif mode == 'lstm':
+ node_attrs['gate_order'] = [1, 0, 2, 3]
+ LSTM.update_node_stat(node, node_attrs)
else:
raise Error(
- "Operation RNN with mode '{}' not supported. Please register RNN as custom op. " +
+ "Operation RNN with mode '{}' not supported." +
refer_to_faq_msg(86),
mode)
return __class__.enabled