Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / extractor.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from mo.front.common.partial_infer.split import tf_split_infer
20 from mo.front.tf.extractors.bias_add import tf_bias_add_ext
21 from mo.front.tf.extractors.concat import tf_concat_ext
22 from mo.front.tf.extractors.const import tf_const_ext
23 from mo.front.tf.extractors.eltwise import make_tf_eltwise
24 from mo.front.tf.extractors.expand_dims import tf_expand_dims_ext
25 from mo.front.tf.extractors.fused_bn import tf_fused_bn_extractor
26 from mo.front.tf.extractors.lrn import tf_lrn_ext
27 from mo.front.tf.extractors.matmul import tf_matmul_ext
28 from mo.front.tf.extractors.mean import tf_mean_ext
29 from mo.front.tf.extractors.native_tf import native_tf_node_extractor
30 from mo.front.tf.extractors.pack import tf_pack_ext
31 from mo.front.tf.extractors.placeholder import tf_placeholder_ext
32 from mo.front.tf.extractors.prod import tf_reduce_prod_ext
33 from mo.front.tf.extractors.random_uniform import tf_random_uniform_ext
34 from mo.front.tf.extractors.range import tf_range_ext
35 from mo.front.tf.extractors.reshape import tf_reshape_ext
36 from mo.front.tf.extractors.space_to_batch import tf_space_to_batch_ext, tf_batch_to_space_ext
37 from mo.front.tf.extractors.split import tf_split_ext
38 from mo.front.tf.extractors.squeeze import tf_squeeze_ext
39 from mo.front.tf.extractors.transpose import tf_transpose_ext
40 from mo.front.tf.extractors.unpack import tf_unpack_ext
41 from mo.front.tf.extractors.utils import get_tf_node_port
42 from mo.graph.graph import Node
43
44
45 def get_tf_edges(node: Node):
46     """
47     By TF/NX node find all inputs and return list of all edges.
48     Edge direction represents data flow (from source op to this node).
49     So the resulting list contains all input edges for a given node.
50     Edge attributes: 'in' is index of input port for a given node, 'out' is an index
51     of output port of some other node that produces input data for this node.
52     """
53     edge_list = []
54     for in_port, src_node_id in enumerate(node.pb.input):
55         src_node, src_port = get_tf_node_port(src_node_id)
56         cf_flag = False
57         if src_node[0] == '^':
58             src_node = src_node[1:]
59             cf_flag = True
60         edge = (src_node, node.id, {
61             'in': in_port,
62             'out': src_port,
63             'fw_tensor_debug_info': [(src_node_id, src_port)],  # debug anchor for a framework tensor name and port
64             'in_attrs': ['in', 'control_flow_edge', 'permutation'],
65             'out_attrs': ['out', 'permutation'],
66             'data_attrs': ['fw_tensor_debug_info'],
67             'control_flow_edge': cf_flag
68         })
69         edge_list.append(edge)
70     return edge_list
71
72
73 def node_pb_arg(pb_extractor: callable):
74     return lambda node: pb_extractor(node.pb)
75
76
77 tf_op_extractors = {
78     'TFCustomSubgraphCall': node_pb_arg(lambda pb: None),
79     'Transpose': node_pb_arg(tf_transpose_ext),
80     'LRN': node_pb_arg(tf_lrn_ext),
81     'Split': node_pb_arg(lambda pb: tf_split_ext(pb, tf_split_infer)),
82     'FusedBatchNorm': node_pb_arg(tf_fused_bn_extractor),
83     'Relu6': node_pb_arg(
84         make_tf_eltwise(lambda a: np.maximum(0, np.minimum(a, 6)), attrs={'type': 'Clamp', 'min': 0, 'max': 6})),
85     'ExpandDims': node_pb_arg(tf_expand_dims_ext),
86     'ConcatV2': node_pb_arg(tf_concat_ext),
87     'MatMul': node_pb_arg(tf_matmul_ext),
88     'Pack': node_pb_arg(tf_pack_ext),
89     'Unpack': node_pb_arg(tf_unpack_ext),
90     'Prod': node_pb_arg(tf_reduce_prod_ext),
91     'Const': node_pb_arg(tf_const_ext),
92     'Placeholder': node_pb_arg(tf_placeholder_ext),
93     'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
94     'Add': node_pb_arg(
95         make_tf_eltwise(lambda a, b: a + b, attrs={'type': 'Eltwise', 'operation': 'sum', 'can_be_bias': True})),
96     'Mul': node_pb_arg(make_tf_eltwise(lambda a, b: a * b, attrs={'type': 'Eltwise', 'operation': 'mul'})),
97     'Rsqrt': node_pb_arg(make_tf_eltwise(lambda v: np.reciprocal(np.sqrt(v)),
98                                          attrs={'type': 'Power', 'power': -0.5, 'scale': 1, 'shift': 0})),
99     'Neg': node_pb_arg(make_tf_eltwise(lambda v: -v, attrs={'type': 'Power', 'power': 1, 'scale': -1, 'shift': 0})),
100     'Sub': node_pb_arg(make_tf_eltwise(lambda a, b: a - b)),
101     'RealDiv': node_pb_arg(make_tf_eltwise(lambda a, b: a / b, attrs={'op': 'Div'})),
102     'Relu': node_pb_arg(make_tf_eltwise(lambda v: np.maximum(0, v), attrs={'type': 'ReLU'})),  # 0 is an integer
103     'RandomUniform': node_pb_arg(tf_random_uniform_ext),
104     'Mean': node_pb_arg(tf_mean_ext),
105     'BiasAdd': node_pb_arg(tf_bias_add_ext),
106     'Reshape': node_pb_arg(tf_reshape_ext),
107     'Squeeze': node_pb_arg(tf_squeeze_ext),
108     'SpaceToBatchND': node_pb_arg(tf_space_to_batch_ext),
109     'BatchToSpaceND': node_pb_arg(tf_batch_to_space_ext),
110     'Square': node_pb_arg(make_tf_eltwise(lambda a: a * a)),
111     'Minimum': node_pb_arg(make_tf_eltwise(lambda a, b: np.minimum(a, b))),  # can use clamp if one argument is const
112     'Maximum': node_pb_arg(make_tf_eltwise(lambda a, b: np.maximum(a, b), attrs={'type': 'Eltwise',
113                                                                                  'operation': 'max'})),
114     'ReadVariableOp': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
115     'PlaceholderWithDefault': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True}))
116 }
117
118
119 def common_tf_fields(node: Node):
120     return {
121         'kind': 'op',
122         'name': node.pb.name,
123         'op': node.pb.op,
124         'precision': 'FP32'  # TODO use real precision derived from the model
125     }
126
127
128 def tf_op_extractor(node: Node, lowered_keys_map: dict):
129     # all required attributes for the 'TFCustomSubgraphCall' are set during their initialization
130     if (node.has('op') and node.op == 'TFCustomSubgraphCall') or (not node.has_valid('pb')):
131         return True, node.graph.node[node.id]
132
133     result = common_tf_fields(node)
134     node.graph.node[node.id].update(result)
135     supported = False
136     op = result['op'].lower()
137     if op in lowered_keys_map:
138         op = lowered_keys_map[op]
139         assert op in tf_op_extractors
140         attrs = tf_op_extractors[op](node)
141         if attrs:
142             result.update(attrs)
143             supported = True
144     new_attrs = native_tf_node_extractor(node.pb)
145     new_attrs.update(result)
146     result = new_attrs
147     return supported, result