Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / CustomSubgraphCall.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import copy
18 import logging as log
19
20 import numpy as np
21
22 from mo.front.common.layout import nhwc_to_nchw_permute
23 from mo.front.common.partial_infer.utils import int64_array
24 from mo.front.extractor import update_ie_fields
25 from mo.graph.graph import Graph
26 from mo.graph.graph import Node, add_opoutput
27 from mo.middle.replacement import MiddleReplacementPattern
28
29 nchw_to_nhwc_constant_name = 'IE_NCHW_TO_NHWC'
30 nhwc_to_nchw_constant_name = 'IE_NHWC_TO_NCHW'
31
32
33 class CustomSubgraphCall(MiddleReplacementPattern):
34     enabled = True
35     force_clean_up = True
36     graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
37
38     def run_after(self):
39         from extensions.middle.pass_separator import PreMiddleStart
40         return [PreMiddleStart]
41
42     def run_before(self):
43         from extensions.middle.pass_separator import MiddleStart
44         return [MiddleStart]
45
46     @staticmethod
47     def update_placeholders(graph: Graph):
48         """
49         Iterates over all nodes of the graph, find all TF sub-graph call operations and updates placeholders shapes and adds
50         transpose operation if necessary.
51         :param graph: graph to operate on
52         :return: None
53         """
54         for node_name in graph.nodes():
55             node = Node(graph, node_name)
56             if node.kind == 'op' and node.has_valid('op') and node.op == 'TFCustomSubgraphCall':
57                 CustomSubgraphCall.update_placeholder_shape_and_add_transpose(node)
58
59     @staticmethod
60     def update_placeholder_shape_and_add_transpose(node: Node):
61         """
62         The function changes placeholders shapes from NHWC to NCHW format and add transpose operations if needed.
63         :param node: node to operate on.
64         :return: None
65         """
66         import tensorflow as tf
67         from mo.front.common.layout import convert_shape, nhwc_to_nchw_permute, nchw_to_nhwc_permute
68         from mo.front.tf.extractors.utils import tf_tensor_shape
69         from mo.front.tf.partial_infer.tf import add_node_def_to_subgraph, update_input_in_pbs
70
71         tf.reset_default_graph()
72
73         inputs_replacements = list()
74
75         # transpose permutation constant
76         nchw_to_nhwc_constant = tf.constant(nchw_to_nhwc_permute, dtype=tf.int32, name=nchw_to_nhwc_constant_name)
77         nhwc_to_nchw_constant = tf.constant(nhwc_to_nchw_permute, dtype=tf.int32, name=nhwc_to_nchw_constant_name)
78
79         for placeholder_name in node['input_nodes_names']:
80             # dummy node which we can refer to as input in the transpose for the output node
81             # dummy node should be unique for each placeholder
82             dummy_node = tf.constant(value=[[[[1]]]], dtype=tf.float32, name='random_dummy_name_' + placeholder_name)
83
84             placeholder = node['pbs'][placeholder_name]
85             cur_shape = tf_tensor_shape(placeholder.attr['shape'].shape)
86             if len(cur_shape) == 4:  # TODO think about better check that transpose is required
87                 nchw_shape = convert_shape(cur_shape, nhwc_to_nchw_permute)
88                 for ind in range(len(cur_shape)):
89                     placeholder.attr['shape'].shape.dim[ind].size = nchw_shape[ind]
90                 transpose_name = placeholder.name + '_transpose'
91                 transpose = tf.transpose(dummy_node, nchw_to_nhwc_constant, transpose_name)  # NCHW -> NHWC
92
93                 # add transpose operations to GraphDef after placeholders
94                 add_node_def_to_subgraph(node, transpose.op.node_def, transpose_name, len(node['input_nodes_names']))
95                 inputs_replacements.append((placeholder.name, transpose_name))
96                 inputs_replacements.append((dummy_node.name, placeholder.name))
97                 node['real_input_dims'].append(nchw_shape)
98             else:
99                 node['real_input_dims'].append(cur_shape)
100         add_node_def_to_subgraph(node, nchw_to_nhwc_constant.op.node_def)
101         add_node_def_to_subgraph(node, nhwc_to_nchw_constant.op.node_def)
102
103         # update initial input names to a transposed ones
104         for old_input_tensor_name, new_name in inputs_replacements:
105             update_input_in_pbs(node, old_input_tensor_name, new_name)
106
107     @staticmethod
108     def add_output_nodes_transposes(graph: Graph):
109         """
110         Iterates over all nodes of the graph, find all TF sub-graph call operations and adds Transpose operations to the
111         output nodes if they are 4D to covert output from NHWC to NCHW.
112         :param graph: graph to operate on
113         :return: None
114         """
115         for node_name in graph.nodes():
116             node = Node(graph, node_name)
117             if node.kind == 'op' and node.has_valid('op') and node.op == 'TFCustomSubgraphCall':
118                 CustomSubgraphCall.add_sub_graph_call_output_tensors_transposes(node)
119
120     @staticmethod
121     def make_shape_4d(shape: np.array):
122         """
123         Create 4D tensor from 1D, 2D or 3D by adding new dimensions of size 1.
124         :param shape: shape to extend.
125         :return: 4D tensor.
126         """
127         new_shape = int64_array(shape)
128         old_shape_len = len(shape)
129
130         for x in range(
131                 4 - old_shape_len):  # TODO think about proper way to add additional dimensions considering layout
132             if len(
133                     new_shape) <= 1:  # if the shape is 0D or 1D then we should add additional dimensions to batch dimension
134                 new_shape = np.insert(new_shape, 0, 1)
135             #            new_shape = np.array([1, shape[0], 1, 1])
136             else:
137                 new_shape = np.insert(new_shape, 1, 1)
138         return new_shape
139
140     @staticmethod
141     def add_reshape_before_op_node(graph: Graph, data_node_name: str, op_node_name: str, edge_attrs: dict):
142         """
143         Adds reshape operation which expands dimension of the specified data tensor to 4D.
144         :param graph: graph to operate on.
145         :param data_node_name: the name of the data node to be reshaped to 4D tensor.
146         :param op_node_name: name of the TFCustomSubgraphCall node which produces the tensor.
147         :param edge_attrs: edge attributes which should be preserved.
148         :return: None
149         """
150         data_node = Node(graph, data_node_name)
151
152         graph.remove_edge(data_node_name, op_node_name)
153
154         assert data_node['shape'] is not None
155
156         new_shape = CustomSubgraphCall.make_shape_4d(data_node['shape'])
157
158         # reshape shape data node
159         reshape_shape_data_node_name = graph.unique_id("Reshape_shape_")
160         graph.add_node(reshape_shape_data_node_name, kind='data', precision="FP32", name=reshape_shape_data_node_name,
161                        value=new_shape, shape=[1])
162
163         # reshape operation node
164         reshape_node_name = graph.unique_id("Reshape_")
165         graph.add_node(reshape_node_name, kind='op', precision="FP32", type='Reshape', name=reshape_node_name,
166                        op='Reshape',
167                        data_type=data_node['data_type'])
168         update_ie_fields(graph.node[reshape_node_name])
169
170         # reshaped data node
171         reshaped_value = None
172         if data_node['value'] is not None:
173             reshaped_value = np.reshape(data_node['value'], new_shape)
174         reshaped_data_node_name = graph.unique_id("reshaped_data_")
175         graph.add_node(reshaped_data_node_name, kind='data', precision="FP32", name=reshaped_data_node_name,
176                        shape=new_shape, value=reshaped_value, nchw_layout=True)
177
178         graph.add_edges_from([
179             (data_node_name, reshape_node_name, {'in': 0}),
180             (reshape_shape_data_node_name, reshape_node_name, {'in': 1}),
181             (reshape_node_name, reshaped_data_node_name, {'out': 0}),
182             (reshaped_data_node_name, op_node_name, edge_attrs)
183         ])
184
185     @staticmethod
186     def add_reshape_after_data_node(graph: Graph, data_node_name: str):
187         """
188         Adds reshape operation which changes shape of the tensor produced by TFSubgraphCall from 4D to real dimension
189         of the tensor. The data_node_name node contains real dimensions of the tensor but they will be changed in the
190         add_reshapes_for_tf_subgraph_calls function to a 4D because IE TF call layer supports output in 4D only.
191         :param graph: graph to operate on.
192         :param data_node_name: name of the data node to be reshaped to correct dimensions.
193         :return: None
194         """
195         data_node = Node(graph, data_node_name)
196
197         # if the data node was previously marked as output then we need to mark as output new reshaped data node
198         is_out_node = False
199         if len(data_node.out_nodes()) == 1 and data_node.out_node().has('op') and data_node.out_node().op == 'OpOutput':
200             is_out_node = True
201             graph.remove_node(data_node.out_node().id)
202
203         # save old consumers nodes with edge attributes
204         old_consumer_nodes_with_attrs = list()
205         for index, out_op in enumerate(data_node.out_nodes()):
206             edge_attrs = graph.get_edge_data(data_node_name, out_op.name)[0]
207             old_consumer_nodes_with_attrs.append((out_op.name, edge_attrs))
208
209         # remove old consumers from the data node
210         for out_op in list(data_node.out_nodes()):
211             graph.remove_edge(data_node_name, out_op.name)
212
213         # reshape operation node
214         reshape_node_name = graph.unique_id("Reshape_")
215         graph.add_node(reshape_node_name, kind='op', precision="FP32", type='Reshape', name=reshape_node_name,
216                        op='Reshape',
217                        data_type=data_node['data_type'])
218         update_ie_fields(graph.node[reshape_node_name])
219
220         # reshape shape data node
221         reshape_shape_data_node_name = graph.unique_id("Reshape_shape_")
222         graph.add_node(reshape_shape_data_node_name, kind='data', precision="FP32", name=reshape_shape_data_node_name,
223                        value=np.array(data_node['shape']), shape=[1])
224
225         # reshaped data node
226         reshaped_value = None
227         if data_node['value'] is not None:
228             reshaped_value = np.array(data_node['value'])
229         reshaped_data_node_name = graph.unique_id("reshaped_data_")
230         graph.add_node(reshaped_data_node_name, kind='data', precision="FP32", name=reshaped_data_node_name,
231                        shape=np.array(data_node['shape']), value=reshaped_value, nchw_layout=True)
232
233         if is_out_node:
234             add_opoutput(graph, reshaped_data_node_name, 0, False)
235
236         graph.add_edges_from([
237             (data_node_name, reshape_node_name, {'in': 0}),
238             (reshape_shape_data_node_name, reshape_node_name, {'in': 1}),
239             (reshape_node_name, reshaped_data_node_name, {'out': 0}),
240         ])
241
242         for out_node_name, edge_attrs in old_consumer_nodes_with_attrs:
243             graph.add_edges_from([
244                 (reshaped_data_node_name, out_node_name, edge_attrs)
245             ])
246
247     @staticmethod
248     def add_reshapes_for_tf_subgraph_calls(graph: Graph):
249         """
250         Input and output tensors of the TFCustomSubgraphCall must be 4D because IE layer accepts and produces only 4D
251         tensors. This function adds reshape operations where it is necessary.
252         :param graph: graph to operate on.
253         :return: None.
254         """
255         for src_node_name, dst_node_name, edge_attrs in list(graph.edges(data=True)):
256             src_node = Node(graph, src_node_name)
257             dst_node = Node(graph, dst_node_name)
258             if dst_node.kind == 'op' and dst_node.has_valid('type') and dst_node.type == 'TFCustomSubgraphCall' and \
259                     src_node.has_valid('shape') and len(src_node.shape) != 4:
260                 log.info("There is an data tensor of shape '{}' which goes into '{}' node".format(
261                     src_node.shape, dst_node.type))
262                 CustomSubgraphCall.add_reshape_before_op_node(graph, src_node_name, dst_node_name, edge_attrs)
263
264         for node_name in list(graph.nodes()):
265             node = Node(graph, node_name)
266             if node['kind'] == 'op' and node.has_and_set('type') and node.type == 'TFCustomSubgraphCall':
267                 for index, data_node in node.out_nodes().items():
268                     real_dims_count = len(data_node.shape)
269                     if real_dims_count != 4:
270                         log.info(
271                             "There is an data tensor of shape '{}' with real dims count '{}' which goes out of '{}' "
272                             "node".format(data_node.shape, real_dims_count, node.name))
273                         CustomSubgraphCall.add_reshape_after_data_node(graph, data_node.id)
274
275                         # need to update shape of the op so IE generates XML with 4D tensors
276                         out_shape = CustomSubgraphCall.make_shape_4d(data_node['shape'])
277
278                         data_node['shape'] = out_shape
279
280     @staticmethod
281     def add_sub_graph_call_output_tensors_transposes(node: Node):
282         """
283         Adds transpose operations to the output nodes if they are 4D to change layout from NCHW to NHWC.
284         :param node: the node to add transposes to the output nodes to.
285         :return: None
286         """
287         import tensorflow as tf
288         from mo.front.tf.partial_infer.tf import get_subgraph_output_tensors, add_node_def_to_subgraph
289         _, output_tensors = get_subgraph_output_tensors(node)
290
291         # transpose permutation constant
292         nhwc_to_nchw_constant = tf.constant(nhwc_to_nchw_permute, dtype=tf.int32, name=nhwc_to_nchw_constant_name)
293
294         # dummy node which we can refer to as input in the transpose for the output node
295         dummy_node = tf.constant(value=[[[[1]]]], dtype=tf.float32, name='random_dummy_name')
296
297         new_out_tensor_names = list()
298         for out_tensor_name in node['output_tensors_names']:
299             out_name, out_port = out_tensor_name.split(':')
300             if len(output_tensors[
301                        int(out_port)].shape) == 4:  # TODO think about better check whether transpose is required
302                 out_transpose_name = out_name + '_port_' + out_port + '_transpose'
303                 transpose = tf.transpose(dummy_node, nhwc_to_nchw_constant, name=out_transpose_name)
304
305                 # starting from TF 1.8 it is not possible to modify the "node_def" of the "tf.op", so we create a copy,
306                 # update it and use further
307                 new_input_names = transpose.op.node_def.input[:]
308                 new_input_names[0] = out_tensor_name
309                 new_node_def = copy.deepcopy(transpose.op.node_def)
310                 new_node_def.input[:] = new_input_names
311                 add_node_def_to_subgraph(node, new_node_def, position=len(node['nodes_order']))
312                 new_out_tensor_names.append(out_transpose_name)
313             else:
314                 new_out_tensor_names.append(out_tensor_name)
315
316         # update output tensor names with transposes operations
317         node['output_tensors_names'] = new_out_tensor_names
318
319     def find_and_replace_pattern(self, graph: Graph):
320         CustomSubgraphCall.update_placeholders(graph)
321         CustomSubgraphCall.add_output_nodes_transposes(graph)
322         CustomSubgraphCall.add_reshapes_for_tf_subgraph_calls(graph)