Publishing R3
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / Preprocessor.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 import networkx as nx
19
20 from extensions.front.sub import Sub
21 from extensions.front.tf.Pack import Pack
22 from mo.front.subgraph_matcher import SubgraphMatch
23 from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
24 from mo.graph.graph import create_edge, Node
25 from mo.utils.error import Error
26
27
28 class PreprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
29     """
30     The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
31     to applying mean/scaling values are kept.
32     """
33     replacement_id = 'PreprocessorReplacement'
34
35     def run_before(self):
36         return [Pack, Sub]
37
38     def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
39         new_nodes_to_remove = match.matched_nodes_names()
40         # do not remove nodes that perform input image scaling and mean value subtraction
41         for node_to_keep in ('Preprocessor/sub', 'Preprocessor/sub/y', 'Preprocessor/mul', 'Preprocessor/mul/x'):
42             if node_to_keep in new_nodes_to_remove:
43                 new_nodes_to_remove.remove(node_to_keep)
44         return new_nodes_to_remove
45
46     def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
47         print('WARNING: the "{}" is a legacy replacer that will be removed in the future release. Please, consider '
48               'using replacers defined in the "extensions/front/tf/ObjectDetectionAPI.py"'.format(self.replacement_id))
49         log.debug('PreprocessorReplacement: matched_nodes = {}'.format(match.matched_nodes_names()))
50
51         sub_node = match.output_node(0)[0]
52         if not sub_node.has('op') or sub_node.op != 'Sub':
53             raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
54                         'not created with TensorFlow Object Detection API.')
55
56         mul_node = None
57         if sub_node.in_node(0).has('op') and sub_node.in_node(0).op == 'Mul':
58             log.info('There is image scaling node in the Preprocessor block.')
59             mul_node = sub_node.in_node(0)
60
61         config_attrs = match.custom_replacement_desc.custom_attributes
62         preprocessed_image_height_width = self.get_preprocessed_image_size_from_model(graph)
63         if preprocessed_image_height_width is None:
64             if 'preprocessed_image_width' not in config_attrs or 'preprocessed_image_height' not in config_attrs:
65                 raise Error('Failed to determine the pre-processed image size from the original TensorFlow graph. '
66                             'Please, specify "preprocessed_image_width" and "preprocessed_image_height" in the '
67                             'topology replacement configuration file in the "custom_attributes" section of the '
68                             '"PreprocessorReplacement" replacer. This value is defined in the configuration file '
69                             'samples/configs/*.config of the model in the Object Detection model zoo as '
70                             '"min_dimension".')
71             else:
72                 graph.graph['preprocessed_image_width'] = config_attrs['preprocessed_image_width']
73                 graph.graph['preprocessed_image_height'] = config_attrs['preprocessed_image_height']
74         else:
75             graph.graph['preprocessed_image_height'] = preprocessed_image_height_width[0]
76             graph.graph['preprocessed_image_width'] = preprocessed_image_height_width[1]
77
78         initial_input_node_name = 'image_tensor'
79         if initial_input_node_name not in graph.nodes():
80             raise Error('Input node "{}" of the graph is not found. Do not run the Model Optimizer with '
81                         '"--input" command line parameter.'.format(initial_input_node_name))
82         placeholder_node = Node(graph, initial_input_node_name)
83
84         if placeholder_node.shape[0] != 1 and placeholder_node.shape[0] != -1:
85             raise Error('The faster R-CNN model support batch size 1 only.')
86         placeholder_node.shape[0] = 1  # batch size 1 is supported only
87         placeholder_node.shape[1] = graph.graph['preprocessed_image_height']
88         placeholder_node.shape[2] = graph.graph['preprocessed_image_width']
89
90         to_float_node = placeholder_node.out_node(0)
91         if not to_float_node.has('op') or to_float_node.op != 'Cast':
92             raise Error('The output of the "{}" is not Cast operation. Cannot apply replacer.'.format(
93                 initial_input_node_name))
94
95         # connect to_float_node directly with node performing scale on mean value subtraction
96         if mul_node is None:
97             create_edge(to_float_node, sub_node, 0, 0)
98         else:
99             create_edge(to_float_node, mul_node, 0, 1)
100
101         print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
102               ' applicable) are kept.')
103         return {}
104
105     @staticmethod
106     def get_preprocessed_image_size_from_model(graph: nx.MultiDiGraph):
107         """
108         The function looks for nodes in the Preprocessor block with specific names for resized image shape. If one of
109         the nodes exist return the desired size. If nodes do not exist then return None.
110         :param graph: graph to operate on.
111         :return: the tuple with height and width of the preprocessed image.
112         """
113         preprocess_resize_to_range_size_node_name = 'Preprocessor/map/while/ResizeToRange/Const'
114         preprocess_resize_bilinear_node_name = 'Preprocessor/map/while/ResizeImage/ResizeBilinear'
115         result = None
116         if preprocess_resize_to_range_size_node_name in graph.nodes():
117             preprocess_size_node = Node(graph, preprocess_resize_to_range_size_node_name)
118             result = (int(preprocess_size_node.value.item()), int(preprocess_size_node.value.item()))
119         elif preprocess_resize_bilinear_node_name in graph.nodes():
120             preprocess_size_node = Node(graph, preprocess_resize_bilinear_node_name)
121             result = (int(preprocess_size_node.in_node(1).value[0]), int(preprocess_size_node.in_node(1).value[1]))
122         return result