"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See the License for the specific language governing permissions and
limitations under the License.
"""
-
-
import numpy as np
-from extensions.ops.lstm_sequence import LSTMSequence
+from extensions.ops.LSTM import LSTM
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
-from mo.ops.op import Op
class LSTMFrontExtractor(FrontExtractorOp):
@staticmethod
def extract(node):
-
- def split_helper(node, index: int, direction: str):
- return Op._create_data_node(
- node.graph,
- name=node.name + '/SplittedBiLSTM/{}/'.format(direction),
- attrs={'value': node.value[index], 'shape': np.array(node.value[index].shape, dtype=np.int64)}
- )
+ activation_alpha = onnx_attr(node, 'activation_alpha', 'floats',
+ default=None, dst_type=lambda x: np.array(x, dtype=np.float32))
+ activation_beta = onnx_attr(node, 'activation_beta', 'floats',
+ default=None, dst_type=lambda x: np.array(x, dtype=np.float32))
+ activations = onnx_attr(node, 'activations', 'strings', default=None,
+ dst_type=lambda x: list(map(lambda s: s.decode(encoding="utf-8").lower(), list(x))))
+ clip = onnx_attr(node, 'clip', 'f', default=None)
+ input_forget = onnx_attr(node, 'input_forget', 'i', default=0)
attrs = {
- 'hidden_size': np.array(onnx_attr(node, 'hidden_size', 'i'), dtype=np.int64),
'batch_dim': 1,
'sequence_dim': 0,
'blobs_wrb': True,
'has_num_directions': True,
- 'direction': onnx_attr(node, 'direction', 's', b'forward').decode().lower(),
+ 'num_layers': 1,
'format': 'onnx',
- 'blob_bidirectional_split': lambda node: (
- split_helper(node, 0, 'forward'),
- split_helper(node, 1, 'reverse')
- )
+ 'multilayers': False,
+ 'gate_order': [2, 0, 3, 1], # iofc --> fico
+
+ # ONNX attrs
+ 'activation_alpha': activation_alpha,
+ 'activation_beta': activation_beta,
+ 'activations': activations,
+ 'clip': clip,
+ 'direction': onnx_attr(node, 'direction', 's', b'forward').decode().lower(),
+ 'hidden_size': np.array(onnx_attr(node, 'hidden_size', 'i'), dtype=np.int64),
+ 'input_forget': input_forget,
}
- LSTMSequence.update_node_stat(node, attrs)
+ LSTM.update_node_stat(node, attrs)
return __class__.enabled