2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.front.common.partial_infer.split import tf_split_infer
20 from mo.front.tf.extractors.bias_add import tf_bias_add_ext
21 from mo.front.tf.extractors.concat import tf_concat_ext
22 from mo.front.tf.extractors.const import tf_const_ext
23 from mo.front.tf.extractors.eltwise import make_tf_eltwise
24 from mo.front.tf.extractors.expand_dims import tf_expand_dims_ext
25 from mo.front.tf.extractors.fused_bn import tf_fused_bn_extractor
26 from mo.front.tf.extractors.lrn import tf_lrn_ext
27 from mo.front.tf.extractors.matmul import tf_matmul_ext
28 from mo.front.tf.extractors.mean import tf_mean_ext
29 from mo.front.tf.extractors.native_tf import native_tf_node_extractor
30 from mo.front.tf.extractors.pack import tf_pack_ext
31 from mo.front.tf.extractors.placeholder import tf_placeholder_ext
32 from mo.front.tf.extractors.prod import tf_reduce_prod_ext
33 from mo.front.tf.extractors.random_uniform import tf_random_uniform_ext
34 from mo.front.tf.extractors.range import tf_range_ext
35 from mo.front.tf.extractors.reshape import tf_reshape_ext
36 from mo.front.tf.extractors.space_to_batch import tf_space_to_batch_ext, tf_batch_to_space_ext
37 from mo.front.tf.extractors.split import tf_split_ext
38 from mo.front.tf.extractors.squeeze import tf_squeeze_ext
39 from mo.front.tf.extractors.transpose import tf_transpose_ext
40 from mo.front.tf.extractors.unpack import tf_unpack_ext
41 from mo.front.tf.extractors.utils import get_tf_node_port
42 from mo.graph.graph import Node
45 def get_tf_edges(node: Node):
47 By TF/NX node find all inputs and return list of all edges.
48 Edge direction represents data flow (from source op to this node).
49 So the resulting list contains all input edges for a given node.
50 Edge attributes: 'in' is index of input port for a given node, 'out' is an index
51 of output port of some other node that produces input data for this node.
54 for in_port, src_node_id in enumerate(node.pb.input):
55 src_node, src_port = get_tf_node_port(src_node_id)
57 if src_node[0] == '^':
58 src_node = src_node[1:]
60 edge = (src_node, node.id, {
63 'fw_tensor_debug_info': [(src_node_id, src_port)], # debug anchor for a framework tensor name and port
64 'in_attrs': ['in', 'control_flow_edge', 'permutation'],
65 'out_attrs': ['out', 'permutation'],
66 'data_attrs': ['fw_tensor_debug_info'],
67 'control_flow_edge': cf_flag
69 edge_list.append(edge)
73 def node_pb_arg(pb_extractor: callable):
74 return lambda node: pb_extractor(node.pb)
78 'TFCustomSubgraphCall': node_pb_arg(lambda pb: None),
79 'Transpose': node_pb_arg(tf_transpose_ext),
80 'LRN': node_pb_arg(tf_lrn_ext),
81 'Split': node_pb_arg(lambda pb: tf_split_ext(pb, tf_split_infer)),
82 'FusedBatchNorm': node_pb_arg(tf_fused_bn_extractor),
84 make_tf_eltwise(lambda a: np.maximum(0, np.minimum(a, 6)), attrs={'type': 'Clamp', 'min': 0, 'max': 6})),
85 'ExpandDims': node_pb_arg(tf_expand_dims_ext),
86 'ConcatV2': node_pb_arg(tf_concat_ext),
87 'MatMul': node_pb_arg(tf_matmul_ext),
88 'Pack': node_pb_arg(tf_pack_ext),
89 'Unpack': node_pb_arg(tf_unpack_ext),
90 'Prod': node_pb_arg(tf_reduce_prod_ext),
91 'Const': node_pb_arg(tf_const_ext),
92 'Placeholder': node_pb_arg(tf_placeholder_ext),
93 'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
95 make_tf_eltwise(lambda a, b: a + b, attrs={'type': 'Eltwise', 'operation': 'sum', 'can_be_bias': True})),
96 'Mul': node_pb_arg(make_tf_eltwise(lambda a, b: a * b, attrs={'type': 'Eltwise', 'operation': 'mul'})),
97 'Rsqrt': node_pb_arg(make_tf_eltwise(lambda v: np.reciprocal(np.sqrt(v)),
98 attrs={'type': 'Power', 'power': -0.5, 'scale': 1, 'shift': 0})),
99 'Neg': node_pb_arg(make_tf_eltwise(lambda v: -v, attrs={'type': 'Power', 'power': 1, 'scale': -1, 'shift': 0})),
100 'Sub': node_pb_arg(make_tf_eltwise(lambda a, b: a - b)),
101 'RealDiv': node_pb_arg(make_tf_eltwise(lambda a, b: a / b, attrs={'op': 'Div'})),
102 'Relu': node_pb_arg(make_tf_eltwise(lambda v: np.maximum(0, v), attrs={'type': 'ReLU'})), # 0 is an integer
103 'RandomUniform': node_pb_arg(tf_random_uniform_ext),
104 'Mean': node_pb_arg(tf_mean_ext),
105 'BiasAdd': node_pb_arg(tf_bias_add_ext),
106 'Reshape': node_pb_arg(tf_reshape_ext),
107 'Squeeze': node_pb_arg(tf_squeeze_ext),
108 'SpaceToBatchND': node_pb_arg(tf_space_to_batch_ext),
109 'BatchToSpaceND': node_pb_arg(tf_batch_to_space_ext),
110 'Square': node_pb_arg(make_tf_eltwise(lambda a: a * a)),
111 'Minimum': node_pb_arg(make_tf_eltwise(lambda a, b: np.minimum(a, b))), # can use clamp if one argument is const
112 'Maximum': node_pb_arg(make_tf_eltwise(lambda a, b: np.maximum(a, b), attrs={'type': 'Eltwise',
113 'operation': 'max'})),
114 'ReadVariableOp': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
115 'PlaceholderWithDefault': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True}))
119 def common_tf_fields(node: Node):
122 'name': node.pb.name,
124 'precision': 'FP32' # TODO use real precision derived from the model
128 def tf_op_extractor(node: Node, lowered_keys_map: dict):
129 # all required attributes for the 'TFCustomSubgraphCall' are set during their initialization
130 if (node.has('op') and node.op == 'TFCustomSubgraphCall') or (not node.has_valid('pb')):
131 return True, node.graph.node[node.id]
133 result = common_tf_fields(node)
134 node.graph.node[node.id].update(result)
136 op = result['op'].lower()
137 if op in lowered_keys_map:
138 op = lowered_keys_map[op]
139 assert op in tf_op_extractors
140 attrs = tf_op_extractors[op](node)
144 new_attrs = native_tf_node_extractor(node.pb)
145 new_attrs.update(result)
147 return supported, result