2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from extensions.front.sub import Sub
21 from extensions.front.tf.Pack import Pack
22 from mo.front.subgraph_matcher import SubgraphMatch
23 from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
24 from mo.graph.graph import create_edge, Node
25 from mo.utils.error import Error
28 class PreprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
30 The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
31 to applying mean/scaling values are kept.
33 replacement_id = 'PreprocessorReplacement'
38 def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
39 new_nodes_to_remove = match.matched_nodes_names()
40 # do not remove nodes that perform input image scaling and mean value subtraction
41 for node_to_keep in ('Preprocessor/sub', 'Preprocessor/sub/y', 'Preprocessor/mul', 'Preprocessor/mul/x'):
42 if node_to_keep in new_nodes_to_remove:
43 new_nodes_to_remove.remove(node_to_keep)
44 return new_nodes_to_remove
46 def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
47 print('WARNING: the "{}" is a legacy replacer that will be removed in the future release. Please, consider '
48 'using replacers defined in the "extensions/front/tf/ObjectDetectionAPI.py"'.format(self.replacement_id))
49 log.debug('PreprocessorReplacement: matched_nodes = {}'.format(match.matched_nodes_names()))
51 sub_node = match.output_node(0)[0]
52 if not sub_node.has('op') or sub_node.op != 'Sub':
53 raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
54 'not created with TensorFlow Object Detection API.')
57 if sub_node.in_node(0).has('op') and sub_node.in_node(0).op == 'Mul':
58 log.info('There is image scaling node in the Preprocessor block.')
59 mul_node = sub_node.in_node(0)
61 config_attrs = match.custom_replacement_desc.custom_attributes
62 preprocessed_image_height_width = self.get_preprocessed_image_size_from_model(graph)
63 if preprocessed_image_height_width is None:
64 if 'preprocessed_image_width' not in config_attrs or 'preprocessed_image_height' not in config_attrs:
65 raise Error('Failed to determine the pre-processed image size from the original TensorFlow graph. '
66 'Please, specify "preprocessed_image_width" and "preprocessed_image_height" in the '
67 'topology replacement configuration file in the "custom_attributes" section of the '
68 '"PreprocessorReplacement" replacer. This value is defined in the configuration file '
69 'samples/configs/*.config of the model in the Object Detection model zoo as '
72 graph.graph['preprocessed_image_width'] = config_attrs['preprocessed_image_width']
73 graph.graph['preprocessed_image_height'] = config_attrs['preprocessed_image_height']
75 graph.graph['preprocessed_image_height'] = preprocessed_image_height_width[0]
76 graph.graph['preprocessed_image_width'] = preprocessed_image_height_width[1]
78 initial_input_node_name = 'image_tensor'
79 if initial_input_node_name not in graph.nodes():
80 raise Error('Input node "{}" of the graph is not found. Do not run the Model Optimizer with '
81 '"--input" command line parameter.'.format(initial_input_node_name))
82 placeholder_node = Node(graph, initial_input_node_name)
84 if placeholder_node.shape[0] != 1 and placeholder_node.shape[0] != -1:
85 raise Error('The faster R-CNN model support batch size 1 only.')
86 placeholder_node.shape[0] = 1 # batch size 1 is supported only
87 placeholder_node.shape[1] = graph.graph['preprocessed_image_height']
88 placeholder_node.shape[2] = graph.graph['preprocessed_image_width']
90 to_float_node = placeholder_node.out_node(0)
91 if not to_float_node.has('op') or to_float_node.op != 'Cast':
92 raise Error('The output of the "{}" is not Cast operation. Cannot apply replacer.'.format(
93 initial_input_node_name))
95 # connect to_float_node directly with node performing scale on mean value subtraction
97 create_edge(to_float_node, sub_node, 0, 0)
99 create_edge(to_float_node, mul_node, 0, 1)
101 print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
102 ' applicable) are kept.')
106 def get_preprocessed_image_size_from_model(graph: nx.MultiDiGraph):
108 The function looks for nodes in the Preprocessor block with specific names for resized image shape. If one of
109 the nodes exist return the desired size. If nodes do not exist then return None.
110 :param graph: graph to operate on.
111 :return: the tuple with height and width of the preprocessed image.
113 preprocess_resize_to_range_size_node_name = 'Preprocessor/map/while/ResizeToRange/Const'
114 preprocess_resize_bilinear_node_name = 'Preprocessor/map/while/ResizeImage/ResizeBilinear'
116 if preprocess_resize_to_range_size_node_name in graph.nodes():
117 preprocess_size_node = Node(graph, preprocess_resize_to_range_size_node_name)
118 result = (int(preprocess_size_node.value.item()), int(preprocess_size_node.value.item()))
119 elif preprocess_resize_bilinear_node_name in graph.nodes():
120 preprocess_size_node = Node(graph, preprocess_resize_bilinear_node_name)
121 result = (int(preprocess_size_node.in_node(1).value[0]), int(preprocess_size_node.in_node(1).value[1]))