2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from defusedxml.minidom import parseString
19 from xml.etree.ElementTree import Element, SubElement, tostring
21 from mo.graph.graph import *
22 from mo.utils.unsupported_ops import UnsupportedOps
23 from mo.utils.utils import refer_to_faq_msg
24 from mo.utils.version import get_version
27 def serialize_constants(graph: Graph, bin_file_name:str, data_type=np.float32):
29 Found all data constants that has output edges with 'bin' attribute.
30 Serialize content for such constants to a binary file with name bin_file_name in
31 raw format. Save offset and length of serialized area in the file as 'offset' and 'size'
32 attributes of data node.
35 @graph: input graph with op and data nodes
36 @bin_file_name: path to file to write blobs to
37 @data_type: numpy data type to convert all blob elemnts to
41 with open(bin_file_name, 'wb') as bin_file:
42 serialize_constants_recursively(graph, bin_file, data_type, bin_hashes)
45 def serialize_constants_recursively(graph: Graph, bin_file, data_type, bin_hashes):
46 nodes = sorted(graph.nodes())
48 node = Node(graph, node)
50 if node.kind == 'data' and node.value is not None and any('bin' in d for u, v, d in graph.out_edges(node.node, data=True)):
52 blob_hash = hashlib.sha512(blob.tobytes()).hexdigest()
54 if blob_hash in bin_hashes and np.array_equal(blob, bin_hashes[blob_hash]['blob']):
55 graph.node[node.node]['offset'] = bin_hashes[blob_hash]['offset']
56 graph.node[node.node]['size'] = bin_hashes[blob_hash]['size']
58 start = bin_file.tell()
62 graph.node[node.node]['offset'] = start
63 graph.node[node.node]['size'] = end - start
65 bin_hashes[blob_hash] = {'offset': graph.node[node.node]['offset'],
66 'size': graph.node[node.node]['size'], 'blob': blob}
68 assert (blob.dtype.itemsize * np.prod(node.shape) == end - start)
71 "Detected binary for graph: '{}', node: '{}', id: {}, shape: '{}', offset: '{}', size: '{}'".format(
72 graph, node.soft_get('name'), node.id, node.shape, node.offset, node.size))
74 # separate loop for sub-graph to dump them after all blobs for more natural blob offset ordering
75 # TODO: implement strict order for all blobs in entier IR
77 node = Node(graph, node)
78 # Dump blobs recursively if sub-graphs are present in the node
79 if node.has_valid('sub_graphs'):
80 for sub_graph_attr_name in node.sub_graphs:
81 sub_graph = node[sub_graph_attr_name]
82 serialize_constants_recursively(sub_graph, bin_file, data_type, bin_hashes)
85 def serialize_mean_image(bin_file_name: str, mean_data=[]):
86 with open(bin_file_name, 'ab') as bin_file:
89 for x in range(len(mean_data)):
90 start = bin_file.tell()
91 bin_file.write(mean_data[x][:])
93 mean_offset.append(start)
94 mean_size.append(end - start)
96 return mean_offset, mean_size
99 def xml_shape(shape: np.ndarray, element: Element):
101 dim = SubElement(element, 'dim')
103 raise Error('The value "{}" for shape is less or equal to 0. May be the input shape of the topology is '
106 raise Error('The value "{}" for shape is not integer.'.format(d))
107 if not isinstance(d, np.int64):
108 log.warning('The element of shape is not np.int64 value. Converting the value "{}" to integer'.format(d))
113 def xml_ports(node: Node, element: Element, edges: Element):
115 inputs = None # will create input section only if at least one input is available
116 for u, d in node.get_sorted_inputs():
117 if 'bin' not in d and ('xml_skip' not in d or not d['xml_skip']):
119 inputs = SubElement(element, 'input')
120 port = SubElement(inputs, 'port')
121 port.set('id', str(d['in']))
122 assert node.graph.node[u]['shape'] is not None, 'Input shape is not calculated properly for node {}'.format(
124 xml_shape(node.graph.node[u]['shape'], port)
125 # u is a data node that has a single producer, let's find it
126 assert (node.graph.node[u]['kind'] == 'data')
127 in_nodes = list(node.graph.in_edges(u, data=True))
128 assert (len(in_nodes) <= 1)
129 if len(in_nodes) == 1:
130 src, _, out_attrs = in_nodes[0]
131 edge = SubElement(edges, 'edge')
132 edge.set('from-layer', str(src))
133 edge.set('from-port', str(out_attrs['out']))
134 edge.set('to-layer', str(node.node))
135 edge.set('to-port', str(d['in']))
139 for v, d in node.get_sorted_outputs():
140 if 'xml_skip' not in d or not d['xml_skip']:
142 outputs = SubElement(element, 'output')
143 port = SubElement(outputs, 'port')
144 port.set('id', str(d['out']))
145 assert node.graph.node[v][
146 'shape'] is not None, 'Output shape is not calculated properly for node {}'.format(
148 xml_shape(node.graph.node[v]['shape'], port)
151 def xml_consts(graph: Graph, node: Node, element: Element):
152 blobs = None # sub-element that will be created on-demand
153 for u, d in node.get_sorted_inputs():
156 blobs = SubElement(element, 'blobs')
157 const = SubElement(blobs, d['bin'])
159 const.set('offset', str(graph.node[u]['offset']))
160 const.set('size', str(graph.node[u]['size']))
161 except Exception as e:
162 raise Error('Unable to access binary attributes ("offset" and/or "size") '
163 'for blobs for node {}. Details: {}'.format(node.soft_get('name'), e))
166 def soft_get(node, attr):
167 ''' If node has soft_get callable member, returns node.soft_get(attr), else return <SUB-ELEMENT> '''
168 return node.soft_get(attr) if hasattr(node, 'soft_get') and callable(node.soft_get) else '<SUB-ELEMENT>'
171 def serialize_element(
175 parent_element: Element,
179 name, attrs, subelements = schema
180 element = SubElement(parent_element, name)
182 if isinstance(attr, tuple):
185 if callable(attr[1]):
186 value = attr[1](node)
188 value = node[attr[1]] if attr[1] in node else None
189 except TypeError as e:
190 raise Error('Unable to extract {} from layer {}', key, soft_get(node, 'name')) from e
191 except Exception as e:
193 'Cannot emit value for attribute {} for layer {}. '
194 'Internal attribute template: {}.',
196 soft_get(node, 'name'),
199 elif isinstance(attr, dict):
200 node_attrs = node.graph.node[node.id] if isinstance(node, Node) else node
201 for key in attr.keys():
202 if key in node_attrs:
203 for k, v in node_attrs[key].items():
204 element.set(k, str(v))
208 value = node[attr] if attr in node else None
209 if value is not None:
210 element.set(key, str(value))
211 serialize_node_attributes(graph, node, subelements, element, edges, unsupported)
212 if len(element.attrib) == 0 and len(element.getchildren()) == 0:
213 parent_element.remove(element)
216 def serialize_meta_list(graph, node, schema, element, edges, unsupported):
217 _, list_accessor, sub_schema = schema
218 items = list_accessor(node) # this is a list of dictionary-like objects
220 serialize_node_attributes(graph, item, [sub_schema], element, edges, unsupported)
223 def serialize_node_attributes(
224 graph: Graph, # the current network graph
225 node, # dictionry-like object that should be serialized
227 parent_element: Element,
233 if not isinstance(s, tuple):
236 # TODO make sure that edges are generated regardless of the existence of @ports
237 xml_ports(node, parent_element, edges)
238 except Exception as e:
239 raise Error(('Unable to create ports for node with id {}. ' +
240 refer_to_faq_msg(3)).format(node.id)) from e
242 xml_consts(graph, node, parent_element)
244 log.warning('Unknown xml schema tag: {}'.format(s))
248 serialize_meta_list(graph, node, s, parent_element, edges, unsupported)
249 elif name == '@network':
250 serialize_network(node[s[1]], parent_element, unsupported)
252 serialize_element(graph, node, s, parent_element, edges, unsupported)
253 except Exception as e:
255 'Error while emitting attributes for layer {} (id = {}). '
256 'It usually means that there is unsupported pattern around this node or unsupported combination of attributes.',
257 soft_get(node, 'name'),
262 def create_pre_process_block_for_image(net: Element, ref_layer_names: list, mean_offset: tuple,
264 pre_process = SubElement(net, 'pre-process')
265 pre_process.set('mean-precision', 'FP32') # TODO: to think about need to output FP16 mean values
266 # TODO: extend it for several inputs
267 pre_process.set('reference-layer-name', ref_layer_names[0])
268 for idx in range(len(mean_size)):
269 channel_xml = SubElement(pre_process, 'channel')
270 channel_xml.set('id', str(idx))
271 mean_xml = SubElement(channel_xml, 'mean')
272 mean_xml.set('offset', str(mean_offset[idx]))
273 mean_xml.set('size', str(mean_size[idx]))
276 def create_pre_process_block(net, ref_layer_name, means, scales=None):
278 Generates the pre-process block for the IR XML
280 net: root XML element
281 ref_layer_name: name of the layer where it is referenced to
282 means: tuple of values
283 scales: tuple of values
286 pre-process XML element
288 pre_process = SubElement(net, 'pre-process')
289 pre_process.set('reference-layer-name', ref_layer_name)
291 for idx in range(len(means)):
292 channel_xml = SubElement(pre_process, 'channel')
293 channel_xml.set('id', str(idx))
295 mean_xml = SubElement(channel_xml, 'mean')
296 mean_xml.set('value', str(means[idx]))
299 scale_xml = SubElement(channel_xml, 'scale')
300 scale_xml.set('value', str(scales[idx]))
305 def add_quantization_statistics(graph, net_element):
306 if 'statistics' in graph.graph:
307 stats = SubElement(net_element, 'statistics')
308 for tensor, interval in graph.graph['statistics'].items():
309 layer = SubElement(stats, 'layer')
310 name = SubElement(layer, 'name')
312 min = SubElement(layer, 'min')
313 min.text = interval['min']
314 max = SubElement(layer, 'max')
315 max.text = interval['max']
316 log.info('Statistics were inserted to IR')
319 def add_meta_data(net: Element, meta_info: dict):
320 meta = SubElement(net, 'meta_data')
321 SubElement(meta, 'MO_version').set('value', get_version())
322 parameters = SubElement(meta, 'cli_parameters')
323 [SubElement(parameters, str(key)).set('value', str(meta_info[key])) for key in sorted(meta_info.keys()) if
325 SubElement(parameters, 'unset').set('unset_cli_parameters', ', '.join(sorted(meta_info['unset'])))
328 def serialize_network(graph, net_element, unsupported):
329 layers = SubElement(net_element, 'layers')
330 edges = SubElement(net_element, 'edges')
333 nodes = sorted(graph.nodes())
335 node = Node(graph, node)
336 if not node.has('IE'):
338 if node.kind == 'op' and (not node.has('type') or node.type is None):
339 unsupported.add(node)
342 serialize_node_attributes(graph, node, node.IE, layers, edges, unsupported)
344 raise Error(str(e).replace('<SUB-ELEMENT>', '{} (id = {})'.format(node.soft_get('name'), node.id))) from e
347 def generate_ie_ir(graph: Graph, file_name: str, input_names: tuple = (), mean_offset: tuple = (),
348 mean_size: tuple = (), meta_info: dict = dict()):
350 Extracts IE/IR attributes from kind='op' nodes in three ways:
351 (1) node.IE xml scheme that set correspondance from existing attributes to generated xml elements
352 (2) input/output edges that don't have 'bin' attributes are transformed to input/output ports
353 (3) input edges that has 'bin' attributes are handled in special way like weights/biases
356 graph: nx graph with FW-independent model
357 file_name: name of the resulting IR
358 input_names: names of input layers of the topology to add mean file to
359 input_name: name of the layer which is referenced from pre-processing block if any
360 mean_values: tuple of mean values for channels in RGB order
361 scale_values: tuple of mean values for channels in RGB order
362 mean_offset: offset in binary file, where mean file values start
363 mean_size: size of the mean file
366 net.set('name', graph.name)
367 net.set('version', str((graph.graph['ir_version'])))
368 net.set('batch', '1') # TODO substitute real batches here (is it a number or is it an index?)
370 if mean_size or mean_offset:
371 create_pre_process_block_for_image(net, input_names, mean_offset, mean_size)
373 if 'mean_values' in graph.graph.keys():
374 for input_name, values in graph.graph['mean_values'].items():
375 create_pre_process_block(net, input_name, values)
377 unsupported = UnsupportedOps(graph)
379 serialize_network(graph, net, unsupported)
380 add_quantization_statistics(graph, net)
381 add_meta_data(net, meta_info)
382 xml_string = tostring(net)
383 xml_doc = parseString(xml_string)
384 pretty_xml_as_string = xml_doc.toprettyxml()
385 if len(unsupported.unsupported):
386 log.debug('Partially correct IR XML:\n{}'.format(pretty_xml_as_string))
387 unsupported.report(log.error, "List of operations that cannot be converted to Inference Engine IR:")
388 raise Error('Part of the nodes was not converted to IR. Stopped. ' +
389 refer_to_faq_msg(24))
390 with open(file_name, 'w') as file:
391 file.write(pretty_xml_as_string)
394 def port_renumber(graph: Graph):
395 for node in list(graph.nodes()):
396 node = Node(graph, node)
397 if node.kind == 'op':
399 for u, d in node.get_sorted_inputs():
402 for v, d in node.get_sorted_outputs():