2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.caffe.extractors.utils import get_canonical_axis_index
22 from mo.front.common.layout import get_batch_dim, get_features_dim
23 from mo.graph.graph import Graph
24 from mo.middle.replacement import MiddleReplacementPattern
25 from mo.ops.pooling import Pooling
26 from mo.ops.power import Power
27 from mo.ops.reshape import Reshape
30 class ReduceReplacer(MiddleReplacementPattern):
34 supported_reduce_types = ['mean', 'max', 'sum']
43 from extensions.middle.pass_separator import MiddleStart
47 from extensions.middle.pass_separator import MiddleFinish
53 ('reduce', dict(kind='op', op='Reduce'))
58 def replace_pattern(self, graph: Graph, match: dict):
59 node = match['reduce']
60 if not node.has_valid('reduce_type') or node.reduce_type.lower() not in self.supported_reduce_types:
61 log.error("Reduce type {} is not supported for node {}".format(node.soft_get('reduce_type'), node.id))
64 reduce_type = node.reduce_type.lower()
65 if reduce_type not in self.pool_method_map:
66 log.error("Reduce type {} is not included in pool_method_map. Please update pool_method_map with new key "
67 "{}".format(reduce_type, reduce_type))
70 input_data = node.in_node()
71 output_data = node.out_node()
73 input_shape = node.in_node().shape
74 output_shape = node.out_node().shape
76 # normalize node.axis to exclude negative indices
77 node.axis = [get_canonical_axis_index(input_shape, a) for a in node.axis]
81 # Check that values in axis list are consecutive
82 for idx in range(1, len(axis)):
83 if axis[idx] != (axis[idx - 1] + 1):
84 log.error("Reduce with not consecutive axes {} is not supported ".format(axis))
87 layout = graph.graph['layout']
89 # So now we are sure that we can convert Reduce to appropriate operation
91 # 1. Calculate shape that will be used in reduction
92 reduction_dim = np.prod([input_shape[idx] for idx in axis])
93 begin_dims = np.array([input_shape[idx] for idx in range(axis[0])])
94 end_dim = np.prod([input_shape[idx] for idx in range(axis[-1] + 1, len(input_shape))])
96 # 2. Create reshape with appropriate shape
98 if len(begin_dims) > 2:
99 begin_dims = np.array([np.prod(begin_dims[0:-1]), begin_dims[-1]], dtype=np.int64)
101 # Expand begin_dims to 2
102 begin_dims = np.array(np.append(begin_dims, [1] * (2 - len(begin_dims))), dtype=np.int64)
103 reshape_shape = np.array([*begin_dims, reduction_dim, end_dim], dtype=np.int64)
104 pool_window = np.array([1, 1, reduction_dim, 1], dtype=np.int64)
105 elif layout == 'NHWC':
106 begin_dims = np.prod(begin_dims)
107 reshape_shape = np.array([begin_dims, reduction_dim, 1, end_dim], dtype=np.int64)
108 pool_window = np.array([1, reduction_dim, 1, 1], dtype=np.int64)
110 log.error('{} layout currently is not supported'.format(layout))
113 # 3. Reduce => Reshape->Pooling->Reshape
114 reshape_op = Reshape(graph, {'name': node.id + '/Reshape', 'dim': reshape_shape})
115 final_reshape_op = Reshape(graph, {'name': node.id + '/FinalReshape', 'dim': output_shape})
116 pooling_op = Pooling(graph,
117 dict(name=node.id + '/Pool',
119 output_spatial_shape=None,
120 batch_dims=np.array([get_batch_dim(layout, 4)], dtype=np.int64),
121 channel_dims=np.array([get_features_dim(layout, 4)], dtype=np.int64),
122 exclude_pad='false', pool_method=self.pool_method_map[reduce_type]))
124 graph.remove_edge(input_data.id, node.id)
125 graph.remove_edge(node.id, output_data.id)
127 final_reshape_op.create_node_with_data(
128 inputs=[pooling_op.create_node_with_data(
129 inputs=[reshape_op.create_node_with_data(
133 data_nodes=output_data)
135 # 4. If it is reduction with summation, we need to multiply by size of the reduction slice with Mul op
136 if reduce_type == 'sum':
137 output_data.in_node().insert_node_with_data_after(
140 {'name': node.name + '/Mul', 'scale': float(reduction_dim)}