Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / YOLO.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from extensions.front.no_op_eraser import NoOpEraser
18 from extensions.front.standalone_const_eraser import StandaloneConstEraser
19 from extensions.ops.regionyolo import RegionYoloOp
20 from mo.front.tf.replacement import FrontReplacementFromConfigFileGeneral
21 from mo.graph.graph import Node, Graph
22 from mo.ops.output import Output
23 from mo.utils.error import Error
24
25
26 class YoloRegionAddon(FrontReplacementFromConfigFileGeneral):
27     """
28     Replaces all OpOutput nodes in graph with YoloRegion->OpOutput nodes chain.
29     YoloRegion node attributes are taken from configuration file
30     """
31     replacement_id = 'TFYOLO'
32
33     def run_after(self):
34         return [NoOpEraser, StandaloneConstEraser]
35
36     def transform_graph(self, graph: Graph, replacement_descriptions):
37         op_outputs = [n for n, d in graph.nodes(data=True) if 'op' in d and d['op'] == 'OpOutput']
38         for op_output in op_outputs:
39             last_node = Node(graph, op_output).in_node(0)
40             op_params = dict(name=last_node.id + '/YoloRegion', axis=1, end_axis=-1)
41             op_params.update(replacement_descriptions)
42             region_layer = RegionYoloOp(graph, op_params)
43             region_layer_node = region_layer.create_node([last_node])
44             # here we remove 'axis' from 'dim_attrs' to avoid permutation from axis = 1 to axis = 2
45             region_layer_node.dim_attrs.remove('axis')
46             Output(graph).create_node([region_layer_node])
47
48
49 class YoloV3RegionAddon(FrontReplacementFromConfigFileGeneral):
50     """
51     Replaces all OpOutput nodes in graph with YoloRegion->OpOutput nodes chain.
52     YoloRegion node attributes are taken from configuration file
53     """
54     replacement_id = 'TFYOLOV3'
55
56     def transform_graph(self, graph: Graph, replacement_descriptions):
57         graph.remove_nodes_from(graph.get_nodes_with_attributes(op='OpOutput'))
58         for input_node_name in replacement_descriptions['entry_points']:
59             if input_node_name not in graph.nodes():
60                 raise Error('TensorFlow YOLO V3 conversion mechanism was enabled. '
61                             'Entry points "{}" were provided in the configuration file. '
62                             'Entry points are nodes that feed YOLO Region layers. '
63                             'Node with name {} doesn\'t exist in the graph. '
64                             'Refer to documentation about converting YOLO models for more information.'.format(
65                     ', '.join(replacement_descriptions['entry_points']), input_node_name))
66             last_node = Node(graph, input_node_name).in_node(0)
67             op_params = dict(name=last_node.id + '/YoloRegion', axis=1, end_axis=-1, do_softmax=0)
68             op_params.update(replacement_descriptions)
69             region_layer_node = RegionYoloOp(graph, op_params).create_node([last_node])
70             # TODO: do we need change axis for further permutation
71             region_layer_node.dim_attrs.remove('axis')
72             Output(graph, {'name': region_layer_node.id + '/OpOutput'}).create_node([region_layer_node])