Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / mxnet / RNN_ext.py
index 1ae8e31..9842838 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  See the License for the specific language governing permissions and
  limitations under the License.
 """
-
-from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
+from extensions.ops.GRU import GRU
+from extensions.ops.LSTM import LSTM
+from extensions.ops.RNN import RNN
 from mo.front.extractor import FrontExtractorOp
-from extensions.ops.lstm_sequence import LSTMSequence
+from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
 from mo.utils.error import Error
 from mo.utils.utils import refer_to_faq_msg
 
@@ -32,31 +33,40 @@ class RNNFrontExtractor(FrontExtractorOp):
         state_size = attrs.int('state_size', None)
         bidirectional = attrs.bool('bidirectional', False)
         num_layers = attrs.int('num_layers', 1)
+        layout = attrs.str('layout', 'TNC')  # in MXNet RNN by default take data in
+                                             # format [seq_len, batch_size, inp_size]
 
         node_attrs = {
-            'batch_dim': 1,
-            'sequence_dim': 0,
+            'batch_dim': layout.index('N'),
+            'sequence_dim': layout.index('T'),
             'blobs_wrb': False,
             'hidden_size': state_size,
             'has_num_directions': bidirectional,
+            'direction': 'bidirectional' if bidirectional else 'forward',
+            'num_layers': num_layers,
             'format': 'mxnet',
+            'multilayers': num_layers != 1,
+            'gate_order':  None,
         }
 
-        if bidirectional:
-            raise Error(
-                "Operation RNN with bidirectional not supported. num_directions = 1 is supported only " +
-                refer_to_faq_msg(86))
-
-        if num_layers > 1:
-            raise Error(
-                "Operation RNN with num_layers more then one not supported. num_layers = 1 is supported only " +
-                refer_to_faq_msg(86))
-
-        if mode == 'lstm':
-            LSTMSequence.update_node_stat(node, node_attrs)
+        if mode == 'rnn_tanh':
+            node_attrs['gate_order'] = [0]
+            node_attrs['activations'] = ['tanh']
+            RNN.update_node_stat(node, node_attrs)
+        elif mode == 'rnn_relu':
+            node_attrs['gate_order'] = [0]
+            node_attrs['activations'] = ['relu']
+            RNN.update_node_stat(node, node_attrs)
+        elif mode == 'gru':
+            node_attrs['gate_order'] = [1, 0, 2]
+            node_attrs['linear_before_reset'] = 1
+            GRU.update_node_stat(node, node_attrs)
+        elif mode == 'lstm':
+            node_attrs['gate_order'] = [1, 0, 2, 3]
+            LSTM.update_node_stat(node, node_attrs)
         else:
             raise Error(
-                "Operation RNN with mode '{}' not supported. Please register RNN as custom op. " +
+                "Operation RNN with mode '{}' not supported." +
                 refer_to_faq_msg(86),
                 mode)
         return __class__.enabled