2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
22 from mo.front.common.layout import nhwc_to_nchw_permute
23 from mo.front.common.partial_infer.utils import int64_array
24 from mo.front.extractor import update_ie_fields
25 from mo.graph.graph import Graph
26 from mo.graph.graph import Node, add_opoutput
27 from mo.middle.replacement import MiddleReplacementPattern
29 nchw_to_nhwc_constant_name = 'IE_NCHW_TO_NHWC'
30 nhwc_to_nchw_constant_name = 'IE_NHWC_TO_NCHW'
33 class CustomSubgraphCall(MiddleReplacementPattern):
36 graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
39 from extensions.middle.pass_separator import PreMiddleStart
40 return [PreMiddleStart]
43 from extensions.middle.pass_separator import MiddleStart
47 def update_placeholders(graph: Graph):
49 Iterates over all nodes of the graph, find all TF sub-graph call operations and updates placeholders shapes and adds
50 transpose operation if necessary.
51 :param graph: graph to operate on
54 for node_name in graph.nodes():
55 node = Node(graph, node_name)
56 if node.kind == 'op' and node.has_valid('op') and node.op == 'TFCustomSubgraphCall':
57 CustomSubgraphCall.update_placeholder_shape_and_add_transpose(node)
60 def update_placeholder_shape_and_add_transpose(node: Node):
62 The function changes placeholders shapes from NHWC to NCHW format and add transpose operations if needed.
63 :param node: node to operate on.
66 import tensorflow as tf
67 from mo.front.common.layout import convert_shape, nhwc_to_nchw_permute, nchw_to_nhwc_permute
68 from mo.front.tf.extractors.utils import tf_tensor_shape
69 from mo.front.tf.partial_infer.tf import add_node_def_to_subgraph, update_input_in_pbs
71 tf.reset_default_graph()
73 inputs_replacements = list()
75 # transpose permutation constant
76 nchw_to_nhwc_constant = tf.constant(nchw_to_nhwc_permute, dtype=tf.int32, name=nchw_to_nhwc_constant_name)
77 nhwc_to_nchw_constant = tf.constant(nhwc_to_nchw_permute, dtype=tf.int32, name=nhwc_to_nchw_constant_name)
79 for placeholder_name in node['input_nodes_names']:
80 # dummy node which we can refer to as input in the transpose for the output node
81 # dummy node should be unique for each placeholder
82 dummy_node = tf.constant(value=[[[[1]]]], dtype=tf.float32, name='random_dummy_name_' + placeholder_name)
84 placeholder = node['pbs'][placeholder_name]
85 cur_shape = tf_tensor_shape(placeholder.attr['shape'].shape)
86 if len(cur_shape) == 4: # TODO think about better check that transpose is required
87 nchw_shape = convert_shape(cur_shape, nhwc_to_nchw_permute)
88 for ind in range(len(cur_shape)):
89 placeholder.attr['shape'].shape.dim[ind].size = nchw_shape[ind]
90 transpose_name = placeholder.name + '_transpose'
91 transpose = tf.transpose(dummy_node, nchw_to_nhwc_constant, transpose_name) # NCHW -> NHWC
93 # add transpose operations to GraphDef after placeholders
94 add_node_def_to_subgraph(node, transpose.op.node_def, transpose_name, len(node['input_nodes_names']))
95 inputs_replacements.append((placeholder.name, transpose_name))
96 inputs_replacements.append((dummy_node.name, placeholder.name))
97 node['real_input_dims'].append(nchw_shape)
99 node['real_input_dims'].append(cur_shape)
100 add_node_def_to_subgraph(node, nchw_to_nhwc_constant.op.node_def)
101 add_node_def_to_subgraph(node, nhwc_to_nchw_constant.op.node_def)
103 # update initial input names to a transposed ones
104 for old_input_tensor_name, new_name in inputs_replacements:
105 update_input_in_pbs(node, old_input_tensor_name, new_name)
108 def add_output_nodes_transposes(graph: Graph):
110 Iterates over all nodes of the graph, find all TF sub-graph call operations and adds Transpose operations to the
111 output nodes if they are 4D to covert output from NHWC to NCHW.
112 :param graph: graph to operate on
115 for node_name in graph.nodes():
116 node = Node(graph, node_name)
117 if node.kind == 'op' and node.has_valid('op') and node.op == 'TFCustomSubgraphCall':
118 CustomSubgraphCall.add_sub_graph_call_output_tensors_transposes(node)
121 def make_shape_4d(shape: np.array):
123 Create 4D tensor from 1D, 2D or 3D by adding new dimensions of size 1.
124 :param shape: shape to extend.
127 new_shape = int64_array(shape)
128 old_shape_len = len(shape)
131 4 - old_shape_len): # TODO think about proper way to add additional dimensions considering layout
133 new_shape) <= 1: # if the shape is 0D or 1D then we should add additional dimensions to batch dimension
134 new_shape = np.insert(new_shape, 0, 1)
135 # new_shape = np.array([1, shape[0], 1, 1])
137 new_shape = np.insert(new_shape, 1, 1)
141 def add_reshape_before_op_node(graph: Graph, data_node_name: str, op_node_name: str, edge_attrs: dict):
143 Adds reshape operation which expands dimension of the specified data tensor to 4D.
144 :param graph: graph to operate on.
145 :param data_node_name: the name of the data node to be reshaped to 4D tensor.
146 :param op_node_name: name of the TFCustomSubgraphCall node which produces the tensor.
147 :param edge_attrs: edge attributes which should be preserved.
150 data_node = Node(graph, data_node_name)
152 graph.remove_edge(data_node_name, op_node_name)
154 assert data_node['shape'] is not None
156 new_shape = CustomSubgraphCall.make_shape_4d(data_node['shape'])
158 # reshape shape data node
159 reshape_shape_data_node_name = graph.unique_id("Reshape_shape_")
160 graph.add_node(reshape_shape_data_node_name, kind='data', precision="FP32", name=reshape_shape_data_node_name,
161 value=new_shape, shape=[1])
163 # reshape operation node
164 reshape_node_name = graph.unique_id("Reshape_")
165 graph.add_node(reshape_node_name, kind='op', precision="FP32", type='Reshape', name=reshape_node_name,
167 data_type=data_node['data_type'])
168 update_ie_fields(graph.node[reshape_node_name])
171 reshaped_value = None
172 if data_node['value'] is not None:
173 reshaped_value = np.reshape(data_node['value'], new_shape)
174 reshaped_data_node_name = graph.unique_id("reshaped_data_")
175 graph.add_node(reshaped_data_node_name, kind='data', precision="FP32", name=reshaped_data_node_name,
176 shape=new_shape, value=reshaped_value, nchw_layout=True)
178 graph.add_edges_from([
179 (data_node_name, reshape_node_name, {'in': 0}),
180 (reshape_shape_data_node_name, reshape_node_name, {'in': 1}),
181 (reshape_node_name, reshaped_data_node_name, {'out': 0}),
182 (reshaped_data_node_name, op_node_name, edge_attrs)
186 def add_reshape_after_data_node(graph: Graph, data_node_name: str):
188 Adds reshape operation which changes shape of the tensor produced by TFSubgraphCall from 4D to real dimension
189 of the tensor. The data_node_name node contains real dimensions of the tensor but they will be changed in the
190 add_reshapes_for_tf_subgraph_calls function to a 4D because IE TF call layer supports output in 4D only.
191 :param graph: graph to operate on.
192 :param data_node_name: name of the data node to be reshaped to correct dimensions.
195 data_node = Node(graph, data_node_name)
197 # if the data node was previously marked as output then we need to mark as output new reshaped data node
199 if len(data_node.out_nodes()) == 1 and data_node.out_node().has('op') and data_node.out_node().op == 'OpOutput':
201 graph.remove_node(data_node.out_node().id)
203 # save old consumers nodes with edge attributes
204 old_consumer_nodes_with_attrs = list()
205 for index, out_op in enumerate(data_node.out_nodes()):
206 edge_attrs = graph.get_edge_data(data_node_name, out_op.name)[0]
207 old_consumer_nodes_with_attrs.append((out_op.name, edge_attrs))
209 # remove old consumers from the data node
210 for out_op in list(data_node.out_nodes()):
211 graph.remove_edge(data_node_name, out_op.name)
213 # reshape operation node
214 reshape_node_name = graph.unique_id("Reshape_")
215 graph.add_node(reshape_node_name, kind='op', precision="FP32", type='Reshape', name=reshape_node_name,
217 data_type=data_node['data_type'])
218 update_ie_fields(graph.node[reshape_node_name])
220 # reshape shape data node
221 reshape_shape_data_node_name = graph.unique_id("Reshape_shape_")
222 graph.add_node(reshape_shape_data_node_name, kind='data', precision="FP32", name=reshape_shape_data_node_name,
223 value=np.array(data_node['shape']), shape=[1])
226 reshaped_value = None
227 if data_node['value'] is not None:
228 reshaped_value = np.array(data_node['value'])
229 reshaped_data_node_name = graph.unique_id("reshaped_data_")
230 graph.add_node(reshaped_data_node_name, kind='data', precision="FP32", name=reshaped_data_node_name,
231 shape=np.array(data_node['shape']), value=reshaped_value, nchw_layout=True)
234 add_opoutput(graph, reshaped_data_node_name, 0, False)
236 graph.add_edges_from([
237 (data_node_name, reshape_node_name, {'in': 0}),
238 (reshape_shape_data_node_name, reshape_node_name, {'in': 1}),
239 (reshape_node_name, reshaped_data_node_name, {'out': 0}),
242 for out_node_name, edge_attrs in old_consumer_nodes_with_attrs:
243 graph.add_edges_from([
244 (reshaped_data_node_name, out_node_name, edge_attrs)
248 def add_reshapes_for_tf_subgraph_calls(graph: Graph):
250 Input and output tensors of the TFCustomSubgraphCall must be 4D because IE layer accepts and produces only 4D
251 tensors. This function adds reshape operations where it is necessary.
252 :param graph: graph to operate on.
255 for src_node_name, dst_node_name, edge_attrs in list(graph.edges(data=True)):
256 src_node = Node(graph, src_node_name)
257 dst_node = Node(graph, dst_node_name)
258 if dst_node.kind == 'op' and dst_node.has_valid('type') and dst_node.type == 'TFCustomSubgraphCall' and \
259 src_node.has_valid('shape') and len(src_node.shape) != 4:
260 log.info("There is an data tensor of shape '{}' which goes into '{}' node".format(
261 src_node.shape, dst_node.type))
262 CustomSubgraphCall.add_reshape_before_op_node(graph, src_node_name, dst_node_name, edge_attrs)
264 for node_name in list(graph.nodes()):
265 node = Node(graph, node_name)
266 if node['kind'] == 'op' and node.has_and_set('type') and node.type == 'TFCustomSubgraphCall':
267 for index, data_node in node.out_nodes().items():
268 real_dims_count = len(data_node.shape)
269 if real_dims_count != 4:
271 "There is an data tensor of shape '{}' with real dims count '{}' which goes out of '{}' "
272 "node".format(data_node.shape, real_dims_count, node.name))
273 CustomSubgraphCall.add_reshape_after_data_node(graph, data_node.id)
275 # need to update shape of the op so IE generates XML with 4D tensors
276 out_shape = CustomSubgraphCall.make_shape_4d(data_node['shape'])
278 data_node['shape'] = out_shape
281 def add_sub_graph_call_output_tensors_transposes(node: Node):
283 Adds transpose operations to the output nodes if they are 4D to change layout from NCHW to NHWC.
284 :param node: the node to add transposes to the output nodes to.
287 import tensorflow as tf
288 from mo.front.tf.partial_infer.tf import get_subgraph_output_tensors, add_node_def_to_subgraph
289 _, output_tensors = get_subgraph_output_tensors(node)
291 # transpose permutation constant
292 nhwc_to_nchw_constant = tf.constant(nhwc_to_nchw_permute, dtype=tf.int32, name=nhwc_to_nchw_constant_name)
294 # dummy node which we can refer to as input in the transpose for the output node
295 dummy_node = tf.constant(value=[[[[1]]]], dtype=tf.float32, name='random_dummy_name')
297 new_out_tensor_names = list()
298 for out_tensor_name in node['output_tensors_names']:
299 out_name, out_port = out_tensor_name.split(':')
300 if len(output_tensors[
301 int(out_port)].shape) == 4: # TODO think about better check whether transpose is required
302 out_transpose_name = out_name + '_port_' + out_port + '_transpose'
303 transpose = tf.transpose(dummy_node, nhwc_to_nchw_constant, name=out_transpose_name)
305 # starting from TF 1.8 it is not possible to modify the "node_def" of the "tf.op", so we create a copy,
306 # update it and use further
307 new_input_names = transpose.op.node_def.input[:]
308 new_input_names[0] = out_tensor_name
309 new_node_def = copy.deepcopy(transpose.op.node_def)
310 new_node_def.input[:] = new_input_names
311 add_node_def_to_subgraph(node, new_node_def, position=len(node['nodes_order']))
312 new_out_tensor_names.append(out_transpose_name)
314 new_out_tensor_names.append(out_tensor_name)
316 # update output tensor names with transposes operations
317 node['output_tensors_names'] = new_out_tensor_names
319 def find_and_replace_pattern(self, graph: Graph):
320 CustomSubgraphCall.update_placeholders(graph)
321 CustomSubgraphCall.add_output_nodes_transposes(graph)
322 CustomSubgraphCall.add_reshapes_for_tf_subgraph_calls(graph)