ab17ee131dd5ccfed35302e1174e9d9ffa26222a
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / extractor.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from mo.front.common.partial_infer.split import tf_split_infer
20 from mo.front.tf.extractors.concat import tf_concat_ext
21 from mo.front.tf.extractors.const import tf_const_ext
22 from mo.front.tf.extractors.eltwise import make_tf_eltwise
23 from mo.front.tf.extractors.fused_bn import tf_fused_bn_extractor
24 from mo.front.tf.extractors.lrn import tf_lrn_ext
25 from mo.front.tf.extractors.matmul import tf_matmul_ext, tf_batchmatmul_ext
26 from mo.front.tf.extractors.native_tf import native_tf_node_extractor
27 from mo.front.tf.extractors.pack import tf_pack_ext
28 from mo.front.tf.extractors.random_uniform import tf_random_uniform_ext
29 from mo.front.tf.extractors.space_to_batch import tf_space_to_batch_ext, tf_batch_to_space_ext
30 from mo.front.tf.extractors.split import tf_split_ext
31 from mo.front.tf.extractors.unpack import tf_unpack_ext
32 from mo.front.tf.extractors.utils import get_tf_node_port
33 from mo.graph.graph import Node
34
35
36 def get_tf_edges(node: Node):
37     """
38     By TF/NX node find all inputs and return list of all edges.
39     Edge direction represents data flow (from source op to this node).
40     So the resulting list contains all input edges for a given node.
41     Edge attributes: 'in' is index of input port for a given node, 'out' is an index
42     of output port of some other node that produces input data for this node.
43     """
44     edge_list = []
45     for in_port, src_node_id in enumerate(node.pb.input):
46         src_node, src_port = get_tf_node_port(src_node_id)
47         cf_flag = False
48         if src_node[0] == '^':
49             src_node = src_node[1:]
50             cf_flag = True
51         edge = (src_node, node.id, {
52             'in': in_port,
53             'out': src_port,
54             'fw_tensor_debug_info': [(src_node_id, src_port)],  # debug anchor for a framework tensor name and port
55             'in_attrs': ['in', 'control_flow_edge', 'permutation'],
56             'out_attrs': ['out', 'permutation'],
57             'data_attrs': ['fw_tensor_debug_info'],
58             'control_flow_edge': cf_flag
59         })
60         edge_list.append(edge)
61     return edge_list
62
63
64 def node_pb_arg(pb_extractor: callable):
65     return lambda node: pb_extractor(node.pb)
66
67
68 tf_op_extractors = {
69     'TFCustomSubgraphCall': node_pb_arg(lambda pb: None),
70     'LRN': node_pb_arg(tf_lrn_ext),
71     'Split': node_pb_arg(lambda pb: tf_split_ext(pb, tf_split_infer)),
72     'FusedBatchNorm': node_pb_arg(tf_fused_bn_extractor),
73     'ConcatV2': node_pb_arg(tf_concat_ext),
74     'MatMul': node_pb_arg(tf_matmul_ext),
75     'BatchMatMul': node_pb_arg(tf_batchmatmul_ext),
76     'BatchMatMulV2': node_pb_arg(tf_batchmatmul_ext),
77     'Pack': node_pb_arg(tf_pack_ext),
78     'Unpack': node_pb_arg(tf_unpack_ext),
79     'Const': node_pb_arg(tf_const_ext),
80     'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
81     'RandomUniform': node_pb_arg(tf_random_uniform_ext),
82     'SpaceToBatchND': node_pb_arg(tf_space_to_batch_ext),
83     'BatchToSpaceND': node_pb_arg(tf_batch_to_space_ext),
84     'ReadVariableOp': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
85     'PlaceholderWithDefault': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True}))
86 }
87
88
89 def common_tf_fields(node: Node):
90     return {
91         'kind': 'op',
92         'name': node.pb.name,
93         'op': node.pb.op,
94         'precision': 'FP32'  # TODO use real precision derived from the model
95     }
96
97
98 def tf_op_extractor(node: Node, lowered_keys_map: dict):
99     # all required attributes for the 'TFCustomSubgraphCall' are set during their initialization
100     if (node.has('op') and node.op == 'TFCustomSubgraphCall') or (not node.has_valid('pb')):
101         return True, node.graph.node[node.id]
102
103     result = common_tf_fields(node)
104     node.graph.node[node.id].update(result)
105     supported = False
106     op = result['op'].lower()
107     if op in lowered_keys_map:
108         op = lowered_keys_map[op]
109         assert op in tf_op_extractors
110         attrs = tf_op_extractors[op](node)
111         if attrs:
112             result.update(attrs)
113             supported = True
114     new_attrs = native_tf_node_extractor(node.pb)
115     new_attrs.update(result)
116     result = new_attrs
117     return supported, result