Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / Reduce.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.front.caffe.extractors.utils import get_canonical_axis_index
22 from mo.front.common.layout import get_batch_dim, get_features_dim
23 from mo.graph.graph import Graph
24 from mo.middle.replacement import MiddleReplacementPattern
25 from mo.ops.pooling import Pooling
26 from mo.ops.power import Power
27 from mo.ops.reshape import Reshape
28
29
30 class ReduceReplacer(MiddleReplacementPattern):
31     op = "Reduce"
32     enabled = True
33
34     supported_reduce_types = ['mean', 'max', 'sum']
35
36     pool_method_map = {
37         'max': 'max',
38         'mean': 'avg',
39         'sum': 'avg'
40     }
41
42     def run_after(self):
43         from extensions.middle.pass_separator import MiddleStart
44         return [MiddleStart]
45
46     def run_before(self):
47         from extensions.middle.pass_separator import MiddleFinish
48         return [MiddleFinish]
49
50     def pattern(self):
51         return dict(
52             nodes=[
53                 ('reduce', dict(kind='op', op='Reduce'))
54             ],
55             edges=[]
56         )
57
58     def replace_pattern(self, graph: Graph, match: dict):
59         node = match['reduce']
60         if not node.has_valid('reduce_type') or node.reduce_type.lower() not in self.supported_reduce_types:
61             log.error("Reduce type {} is not supported for node {}".format(node.soft_get('reduce_type'), node.id))
62             return
63
64         reduce_type = node.reduce_type.lower()
65         if reduce_type not in self.pool_method_map:
66             log.error("Reduce type {} is not included in pool_method_map. Please update pool_method_map with new key "
67                       "{}".format(reduce_type, reduce_type))
68             return
69
70         input_data = node.in_node()
71         output_data = node.out_node()
72
73         input_shape = node.in_node().shape
74         output_shape = node.out_node().shape
75
76         # normalize node.axis to exclude negative indices
77         node.axis = [get_canonical_axis_index(input_shape, a) for a in node.axis]
78
79         axis = node.axis
80
81         # Check that values in axis list are consecutive
82         for idx in range(1, len(axis)):
83             if axis[idx] != (axis[idx - 1] + 1):
84                 log.error("Reduce with not consecutive axes {} is not supported ".format(axis))
85                 return
86
87         layout = graph.graph['layout']
88
89         # So now we are sure that we can convert Reduce to appropriate operation
90
91         # 1. Calculate shape that will be used in reduction
92         reduction_dim = np.prod([input_shape[idx] for idx in axis])
93         begin_dims = np.array([input_shape[idx] for idx in range(axis[0])])
94         end_dim = np.prod([input_shape[idx] for idx in range(axis[-1] + 1, len(input_shape))])
95
96         # 2. Create reshape with appropriate shape
97         if layout == 'NCHW':
98             if len(begin_dims) > 2:
99                 begin_dims = np.array([np.prod(begin_dims[0:-1]), begin_dims[-1]], dtype=np.int64)
100             else:
101                 # Expand begin_dims to 2
102                 begin_dims = np.array(np.append(begin_dims, [1] * (2 - len(begin_dims))), dtype=np.int64)
103             reshape_shape = np.array([*begin_dims, reduction_dim, end_dim], dtype=np.int64)
104             pool_window = np.array([1, 1, reduction_dim, 1], dtype=np.int64)
105         elif layout == 'NHWC':
106             begin_dims = np.prod(begin_dims)
107             reshape_shape = np.array([begin_dims, reduction_dim, 1, end_dim], dtype=np.int64)
108             pool_window = np.array([1, reduction_dim, 1, 1], dtype=np.int64)
109         else:
110             log.error('{} layout currently is not supported'.format(layout))
111             return
112
113         # 3. Reduce => Reshape->Pooling->Reshape
114         reshape_op = Reshape(graph, {'name': node.id + '/Reshape', 'dim': reshape_shape})
115         final_reshape_op = Reshape(graph, {'name': node.id + '/FinalReshape', 'dim': output_shape})
116         pooling_op = Pooling(graph,
117                              dict(name=node.id + '/Pool',
118                                   window=pool_window,
119                                   output_spatial_shape=None,
120                                   batch_dims=np.array([get_batch_dim(layout, 4)], dtype=np.int64),
121                                   channel_dims=np.array([get_features_dim(layout, 4)], dtype=np.int64),
122                                   exclude_pad='false', pool_method=self.pool_method_map[reduce_type]))
123
124         graph.remove_edge(input_data.id, node.id)
125         graph.remove_edge(node.id, output_data.id)
126
127         final_reshape_op.create_node_with_data(
128             inputs=[pooling_op.create_node_with_data(
129                 inputs=[reshape_op.create_node_with_data(
130                     inputs=[input_data]
131                 )]
132             )],
133             data_nodes=output_data)
134
135         # 4. If it is reduction with summation, we need to multiply by size of the reduction slice with Mul op
136         if reduce_type == 'sum':
137             output_data.in_node().insert_node_with_data_after(
138                 output_data,
139                 Power,
140                 {'name': node.name + '/Mul', 'scale': float(reduction_dim)}
141             )