2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 import tensorflow as tf
22 from google.protobuf import text_format
24 from mo.front.extractor import node_defs_to_str
25 from mo.front.tf.extractors.utils import tf_dtype_extractor, tf_tensor_shape, get_tf_node_port
26 from mo.graph.graph import Node
27 from mo.utils.graph import node_incoming_neighbourhood, node_outcoming_neighbourhood
30 def tf_native_tf_node_infer(node: Node):
32 The infer function should be used to infer shape and data type of the TF operation not supported by IE.
33 :param node: node to infer.
36 log.info('Called "tf_native_tf_node_infer" for node "{}"'.format(node.id))
38 # create a sub-graph only to make inference. The sub-graph contains desired node and it's inputs neighbourhood of
39 # depth 10. The number 10 is quite big to be sure that determine_data_type function will be able to identify the
40 # data type of input tensors, but not too huge to contain the whole graph.
41 # Also the sub-graph contains names of the output nodes of the node to perform native infer.
42 nodes_to_extract = node_incoming_neighbourhood(node.graph, node.id, 10) + node_outcoming_neighbourhood(node.graph,
44 tmp_graph = node.graph.create_sub_graph_copy(nodes_to_extract)
46 tmp_node_attrs = tmp_graph.node[node.id]
47 tmp_node = Node(tmp_graph, node.id)
49 # node attributes that are required by 'infer_subgraph_output_nodes' function
50 lists_to_init = ['input_nodes_names', 'output_tensors_names', 'nodes_order', 'internal_output_node_name',
53 for item in lists_to_init:
54 tmp_node_attrs[item] = list()
55 tmp_node_attrs['pbs'] = {tmp_node.name: tmp_node.pb}
56 tmp_node_attrs['nodes_order'].append(tmp_node.id)
57 for ind in range(len(tmp_node.out_edges())):
58 tmp_node_attrs['output_tensors_names'].append(tmp_node.id + ":" + str(ind))
60 tf_subgraph_infer(tmp_node)
61 # the shape and value has been inferred and saved to the tmp_node's out nodes attribute. Let's copy it back!
62 for tmp_out_port, tmp_out_node in tmp_node.out_nodes().items():
63 if tmp_out_node.value is not None:
64 node.out_node(tmp_out_port).value = np.array(tmp_out_node.value)
65 if tmp_out_node.shape is not None:
66 node.out_node(tmp_out_port).shape = np.array(tmp_out_node.shape)
67 if tmp_out_node.data_type is not None:
68 node.out_node(tmp_out_port).data_type = tmp_out_node.data_type
69 # lets cleanup the temporary graph
73 def generate_feed_dict(graph: tf.Graph, node: Node):
75 The first value in the return tuple is True if all inputs for the node has constant values.
76 The second returned value is mapping of placeholder tensor to the numpy arrays with the values for these
78 :param graph: the TensorFlow Graph to generate feed dictionary to.
79 :param node: the node which represents TensorFlow sub-graph of operations.
80 :return: pair where the first element is a flag that specifies that all node inputs are constants and a dictionary
81 where key is the input Tensor object and the value is the tensor value.
85 for in_data_node_name, edge_attrs in node.get_inputs():
86 if 'control_flow_edge' in edge_attrs and edge_attrs['control_flow_edge']:
88 value = node.in_node(edge_attrs['in']).value
91 placeholder_pb = node['pbs'][edge_attrs['placeholder_name']]
92 value = np.ones(shape=tf_tensor_shape(placeholder_pb.attr['shape'].shape),
93 dtype=tf_dtype_extractor(placeholder_pb.attr['dtype'].type))
94 feed_dict[graph.get_tensor_by_name(edge_attrs['placeholder_name'] + ":0")] = value
95 return all_constants, feed_dict
98 def get_subgraph_output_tensors(node: Node):
100 Infer output shapes of the node. The function uses TF to infer the values of output tensors and then getting tensor
102 TODO: try to not infer values but just infer the output tensors shapes.
103 :param node: sub-graph node to infer.
104 :return: pair where the first element is a flag that specifies that all node inputs are constants and a dictionary
105 where key is the output port and the value is the tensor value.
108 graph_def = tf.GraphDef()
109 text_format.Merge(node_defs_to_str(node), graph_def)
110 tf.reset_default_graph()
112 sess = tf.Session(graph=graph)
113 with graph.as_default(): # pylint: disable=not-context-manager
114 with sess.as_default(): # pylint: disable=not-context-manager
115 tf.import_graph_def(graph_def, name='')
116 all_constants, feed_dict = generate_feed_dict(graph, node)
117 for out_port, out_tensor_name in enumerate(node['output_tensors_names']):
118 if not match('.*:\d+', out_tensor_name):
119 out_tensor_name = out_tensor_name + ":" + str(out_port)
120 result_tensor = sess.run(graph.get_tensor_by_name(out_tensor_name), feed_dict=feed_dict)
121 result[out_port] = result_tensor
122 return all_constants, result
125 def tf_subgraph_infer(node: Node):
127 Infer output shapes of the node using TF to infer the values of output tensors and then getting tensor shapes.
128 If all inputs of the node are constants then the node's attribute 'value' is updated also.
129 :param node: sub-graph node to infer. The function updates 'shape' and 'data_type' attributes of the node.
132 # TODO: try to not infer values but just infer the output tensors shapes.
133 add_placeholders_to_subgraph(node)
135 all_constants, output_tensors = get_subgraph_output_tensors(node)
136 for out_port, tensor_value in output_tensors.items():
137 out_node = node.out_node(out_port)
138 out_node.shape = np.array([dim for dim in tensor_value.shape])
139 out_node.data_type = tensor_value.dtype
140 log.debug("Inferred shape of the output tensor with index '{}' of the node '{}': '{}'".format(str(out_port),
144 out_node.value = tensor_value
147 def add_node_def_to_subgraph(subgraph_node: Node, node_def: tf.NodeDef, name: str = None, position: int = 0,
148 is_input: bool = False):
150 Adds NodeDef definition of the node to the internal structures of the sub-graph's_node object that represents a
151 sub-graph of operations.
152 :param subgraph_node: the node that represents sub-graph where new node should be added.
153 :param node_def: the NodeDef (TF operation, variable or constant) to be added to the sub-graph.
154 :param name: name how to save added node. Default value is None which means take name from the NodeDef.
155 :param position: position in the GraphDef where to put the NodeDef. Default value is 0.
156 :param is_input: flag that specifies whether the node is input for the sub-graph. Default value is False.
159 name = name or node_def.name
160 assert (name not in subgraph_node['pbs'].keys())
162 subgraph_node['input_nodes_names'].append(name)
163 subgraph_node['pbs'][node_def.name] = node_def
164 subgraph_node['nodes_order'].insert(position, name)
167 def determine_data_type(node: Node):
169 Tries to determine data type of the node. The input node could be either data or op node. If we don't know the data
170 type of the node then we recursively check the first parent of the node.
171 :param node: node to determine data type.
172 :return: data type of the node output in the numpy format.
174 if node.has_and_set('data_type'):
175 return node.data_type
176 if node.has_and_set('kind') and node.kind == 'op':
177 if node.has_and_set('pb'):
178 if 'dtype' in node.pb.attr:
179 return tf_dtype_extractor(node.pb.attr['dtype'].type)
180 if 'T' in node.pb.attr:
181 return tf_dtype_extractor(node.pb.attr['T'].type)
182 if node.has_and_set('kind') and node.kind == 'data':
183 if 'value' in node and node.value is not None:
184 return node.value.dtype
185 if len(node.in_nodes()) != 0: # try to guess data type from the first parent
186 return determine_data_type(node.in_node(0))
187 log.error('Failed to determine data type for node "{}"'.format(node.name))
191 def add_placeholders_to_subgraph(node: Node):
193 Adds placeholders to the node's list of protobufs based on input nodes to the subgraph (the value of
194 'internal_input_node_name' property).
195 The function also updates input tensors for nodes which consume output of nodes that were replaced with
197 :param node: the node to add placeholders to.
200 inputs_replacements = list()
201 for index, (in_data_node, edge_attrs) in enumerate(node.get_sorted_inputs()):
202 if 'control_flow_edge' in edge_attrs and edge_attrs['control_flow_edge']:
205 if 'internal_input_node_name' in edge_attrs.keys():
206 input_tensor_name = edge_attrs['internal_input_node_name']
208 input_tensor_name = node['pb'].input[index]
210 input_node_name, port = get_tf_node_port(input_tensor_name)
212 placeholder_name = placeholder_name_for_node(input_node_name, port)
213 edge_attrs['placeholder_name'] = placeholder_name
214 in_node = node.in_node(index)
216 assert in_node.shape is not None
218 if placeholder_name not in node['pbs'].keys():
219 placeholder = tf.placeholder(determine_data_type(in_node), in_node.shape, placeholder_name)
220 inputs_replacements.append((input_tensor_name, placeholder_name))
221 add_node_def_to_subgraph(node, placeholder.op.node_def, is_input=True)
222 log.debug("Added placeholder with name '{}'".format(placeholder_name))
224 # update initial input names to a transposed ones
225 for old_input_tensor_name, new_name in inputs_replacements:
226 update_input_in_pbs(node, old_input_tensor_name, new_name)
229 def update_input_in_pbs(node: Node, old_input_tensor_name: str, new_input_name: str):
231 The function replaces all inputs with name 'old_input_tensor_name' with a
232 new input with name 'new_input_name'. This transformation is applied
233 for all NodeDef objects in the 'pbs' list.
235 log.debug("update_input_in_pbs: replace input '%s' with input '%s'" % (old_input_tensor_name, new_input_name))
236 old_input_tensor_name_without_port = old_input_tensor_name.split(":")[0]
237 for pb in node['pbs'].values():
238 if hasattr(pb, 'input'):
239 for ind in range(len(pb.input)):
240 if pb.input[ind] == old_input_tensor_name or pb.input[ind] == old_input_tensor_name_without_port:
241 pb.input[ind] = new_input_name
242 log.debug("Replacing input '{}' of the node '{}' with placeholder '{}'".format(ind, pb.name,
246 def placeholder_name_for_node(node_name: str, output_port: int):
247 return node_name + "_port_" + str(output_port) + "_ie_placeholder"