Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / partial_infer / tf.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from re import match
19
20 import numpy as np
21 import tensorflow as tf
22 from google.protobuf import text_format
23
24 from mo.front.extractor import node_defs_to_str
25 from mo.front.tf.extractors.utils import tf_dtype_extractor, tf_tensor_shape, get_tf_node_port
26 from mo.graph.graph import Node
27 from mo.utils.graph import node_incoming_neighbourhood, node_outcoming_neighbourhood
28
29
30 def tf_native_tf_node_infer(node: Node):
31     """
32     The infer function should be used to infer shape and data type of the TF operation not supported by IE.
33     :param node: node to infer.
34     :return: None
35     """
36     log.info('Called "tf_native_tf_node_infer" for node "{}"'.format(node.id))
37
38     # create a sub-graph only to make inference. The sub-graph contains desired node and it's inputs neighbourhood of
39     # depth 10. The number 10 is quite big to be sure that determine_data_type function will be able to identify the
40     # data type of input tensors, but not too huge to contain the whole graph.
41     # Also the sub-graph contains names of the output nodes of the node to perform native infer.
42     nodes_to_extract = node_incoming_neighbourhood(node.graph, node.id, 10) + node_outcoming_neighbourhood(node.graph,
43                                                                                                            node.id, 1)
44     tmp_graph = node.graph.create_sub_graph_copy(nodes_to_extract)
45
46     tmp_node_attrs = tmp_graph.node[node.id]
47     tmp_node = Node(tmp_graph, node.id)
48
49     # node attributes that are required by 'infer_subgraph_output_nodes' function
50     lists_to_init = ['input_nodes_names', 'output_tensors_names', 'nodes_order', 'internal_output_node_name',
51                      'real_input_dims']
52
53     for item in lists_to_init:
54         tmp_node_attrs[item] = list()
55     tmp_node_attrs['pbs'] = {tmp_node.name: tmp_node.pb}
56     tmp_node_attrs['nodes_order'].append(tmp_node.id)
57     for ind in range(len(tmp_node.out_edges())):
58         tmp_node_attrs['output_tensors_names'].append(tmp_node.id + ":" + str(ind))
59
60     tf_subgraph_infer(tmp_node)
61     # the shape and value has been inferred and saved to the tmp_node's out nodes attribute. Let's copy it back!
62     for tmp_out_port, tmp_out_node in tmp_node.out_nodes().items():
63         if tmp_out_node.value is not None:
64             node.out_node(tmp_out_port).value = np.array(tmp_out_node.value)
65         if tmp_out_node.shape is not None:
66             node.out_node(tmp_out_port).shape = np.array(tmp_out_node.shape)
67         if tmp_out_node.data_type is not None:
68             node.out_node(tmp_out_port).data_type = tmp_out_node.data_type
69     # lets cleanup the temporary graph
70     tmp_graph.clear()
71
72
73 def generate_feed_dict(graph: tf.Graph, node: Node):
74     """
75     The first value in the return tuple is True if all inputs for the node has constant values.
76     The second returned value is mapping of placeholder tensor to the numpy arrays with the values for these
77     placeholders.
78     :param graph: the TensorFlow Graph to generate feed dictionary to.
79     :param node: the node which represents TensorFlow sub-graph of operations.
80     :return: pair where the first element is a flag that specifies that all node inputs are constants and a dictionary
81     where key is the input Tensor object and the value is the tensor value.
82     """
83     all_constants = True
84     feed_dict = dict()
85     for in_data_node_name, edge_attrs in node.get_inputs():
86         if 'control_flow_edge' in edge_attrs and edge_attrs['control_flow_edge']:
87             continue
88         value = node.in_node(edge_attrs['in']).value
89         if value is None:
90             all_constants = False
91             placeholder_pb = node['pbs'][edge_attrs['placeholder_name']]
92             value = np.ones(shape=tf_tensor_shape(placeholder_pb.attr['shape'].shape),
93                             dtype=tf_dtype_extractor(placeholder_pb.attr['dtype'].type))
94         feed_dict[graph.get_tensor_by_name(edge_attrs['placeholder_name'] + ":0")] = value
95     return all_constants, feed_dict
96
97
98 def get_subgraph_output_tensors(node: Node):
99     """
100     Infer output shapes of the node. The function uses TF to infer the values of output tensors and then getting tensor
101     shape.
102     TODO: try to not infer values but just infer the output tensors shapes.
103     :param node: sub-graph node to infer.
104     :return: pair where the first element is a flag that specifies that all node inputs are constants and a dictionary
105     where key is the output port and the value is the tensor value.
106     """
107     result = dict()
108     graph_def = tf.GraphDef()
109     text_format.Merge(node_defs_to_str(node), graph_def)
110     tf.reset_default_graph()
111     graph = tf.Graph()
112     sess = tf.Session(graph=graph)
113     with graph.as_default():  # pylint: disable=not-context-manager
114         with sess.as_default():  # pylint: disable=not-context-manager
115             tf.import_graph_def(graph_def, name='')
116             all_constants, feed_dict = generate_feed_dict(graph, node)
117             for out_port, out_tensor_name in enumerate(node['output_tensors_names']):
118                 if not match('.*:\d+', out_tensor_name):
119                     out_tensor_name = out_tensor_name + ":" + str(out_port)
120                 result_tensor = sess.run(graph.get_tensor_by_name(out_tensor_name), feed_dict=feed_dict)
121                 result[out_port] = result_tensor
122     return all_constants, result
123
124
125 def tf_subgraph_infer(node: Node):
126     """
127     Infer output shapes of the node using TF to infer the values of output tensors and then getting tensor shapes.
128     If all inputs of the node are constants then the node's attribute 'value' is updated also.
129     :param node: sub-graph node to infer. The function updates 'shape' and 'data_type' attributes of the node.
130     :return: None
131     """
132     # TODO: try to not infer values but just infer the output tensors shapes.
133     add_placeholders_to_subgraph(node)
134
135     all_constants, output_tensors = get_subgraph_output_tensors(node)
136     for out_port, tensor_value in output_tensors.items():
137         out_node = node.out_node(out_port)
138         out_node.shape = np.array([dim for dim in tensor_value.shape])
139         out_node.data_type = tensor_value.dtype
140         log.debug("Inferred shape of the output tensor with index '{}' of the node '{}': '{}'".format(str(out_port),
141                                                                                                       node.name,
142                                                                                                       out_node.shape))
143         if all_constants:
144             out_node.value = tensor_value
145
146
147 def add_node_def_to_subgraph(subgraph_node: Node, node_def: tf.NodeDef, name: str = None, position: int = 0,
148                              is_input: bool = False):
149     """
150     Adds NodeDef definition of the node to the internal structures of the sub-graph's_node object that represents a
151     sub-graph of operations.
152     :param subgraph_node: the node that represents sub-graph where new node should be added.
153     :param node_def: the NodeDef (TF operation, variable or constant) to be added to the sub-graph.
154     :param name: name how to save added node. Default value is None which means take name from the NodeDef.
155     :param position: position in the GraphDef where to put the NodeDef. Default value is 0.
156     :param is_input: flag that specifies whether the node is input for the sub-graph. Default value is False.
157     :return: None
158     """
159     name = name or node_def.name
160     assert (name not in subgraph_node['pbs'].keys())
161     if is_input:
162         subgraph_node['input_nodes_names'].append(name)
163     subgraph_node['pbs'][node_def.name] = node_def
164     subgraph_node['nodes_order'].insert(position, name)
165
166
167 def determine_data_type(node: Node):
168     """
169     Tries to determine data type of the node. The input node could be either data or op node. If we don't know the data
170     type of the node then we recursively check the first parent of the node.
171     :param node: node to determine data type.
172     :return: data type of the node output in the numpy format.
173     """
174     if node.has_and_set('data_type'):
175         return node.data_type
176     if node.has_and_set('kind') and node.kind == 'op':
177         if node.has_and_set('pb'):
178             if 'dtype' in node.pb.attr:
179                 return tf_dtype_extractor(node.pb.attr['dtype'].type)
180             if 'T' in node.pb.attr:
181                 return tf_dtype_extractor(node.pb.attr['T'].type)
182     if node.has_and_set('kind') and node.kind == 'data':
183         if 'value' in node and node.value is not None:
184             return node.value.dtype
185     if len(node.in_nodes()) != 0:  # try to guess data type from the first parent
186         return determine_data_type(node.in_node(0))
187     log.error('Failed to determine data type for node "{}"'.format(node.name))
188     return None
189
190
191 def add_placeholders_to_subgraph(node: Node):
192     """
193     Adds placeholders to the node's list of protobufs based on input nodes to the subgraph (the value of
194     'internal_input_node_name' property).
195     The function also updates input tensors for nodes which consume output of nodes that were replaced with
196     placeholders.
197     :param node: the node to add placeholders to.
198     :return: None
199     """
200     inputs_replacements = list()
201     for index, (in_data_node, edge_attrs) in enumerate(node.get_sorted_inputs()):
202         if 'control_flow_edge' in edge_attrs and edge_attrs['control_flow_edge']:
203             continue
204
205         if 'internal_input_node_name' in edge_attrs.keys():
206             input_tensor_name = edge_attrs['internal_input_node_name']
207         else:
208             input_tensor_name = node['pb'].input[index]
209
210         input_node_name, port = get_tf_node_port(input_tensor_name)
211
212         placeholder_name = placeholder_name_for_node(input_node_name, port)
213         edge_attrs['placeholder_name'] = placeholder_name
214         in_node = node.in_node(index)
215
216         assert in_node.shape is not None
217
218         if placeholder_name not in node['pbs'].keys():
219             placeholder = tf.placeholder(determine_data_type(in_node), in_node.shape, placeholder_name)
220             inputs_replacements.append((input_tensor_name, placeholder_name))
221             add_node_def_to_subgraph(node, placeholder.op.node_def, is_input=True)
222             log.debug("Added placeholder with name '{}'".format(placeholder_name))
223
224     # update initial input names to a transposed ones
225     for old_input_tensor_name, new_name in inputs_replacements:
226         update_input_in_pbs(node, old_input_tensor_name, new_name)
227
228
229 def update_input_in_pbs(node: Node, old_input_tensor_name: str, new_input_name: str):
230     """
231     The function replaces all inputs with name 'old_input_tensor_name' with a
232     new input with name 'new_input_name'. This transformation is applied
233     for all NodeDef objects in the 'pbs' list.
234     """
235     log.debug("update_input_in_pbs: replace input '%s' with input '%s'" % (old_input_tensor_name, new_input_name))
236     old_input_tensor_name_without_port = old_input_tensor_name.split(":")[0]
237     for pb in node['pbs'].values():
238         if hasattr(pb, 'input'):
239             for ind in range(len(pb.input)):
240                 if pb.input[ind] == old_input_tensor_name or pb.input[ind] == old_input_tensor_name_without_port:
241                     pb.input[ind] = new_input_name
242                     log.debug("Replacing input '{}' of the node '{}' with placeholder '{}'".format(ind, pb.name,
243                                                                                                    new_input_name))
244
245
246 def placeholder_name_for_node(node_name: str, output_port: int):
247     return node_name + "_port_" + str(output_port) + "_ie_placeholder"