Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / onnx / pooling_ext.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.front.extractor import FrontExtractorOp
22 from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_autopad
23 from mo.ops.pooling import Pooling
24 from mo.utils.error import Error
25
26
27 class AveragePoolFrontExtractor(FrontExtractorOp):
28     op = 'AveragePool'
29     enabled = True
30
31     @staticmethod
32     def extract(node):
33         attrs = common_onnx_pool_extractor(node)
34
35         Pooling.update_node_stat(node, attrs)
36         return __class__.enabled
37
38
39 class MaxPoolFrontExtractor(FrontExtractorOp):
40     op = 'MaxPool'
41     enabled = True
42
43     @staticmethod
44     def extract(node):
45         attrs = common_onnx_pool_extractor(node)
46
47         Pooling.update_node_stat(node, attrs)
48         return __class__.enabled
49
50
51 class GlobalAveragePoolFrontExtractor(FrontExtractorOp):
52     op = 'GlobalAveragePool'
53     enabled = True
54
55     @staticmethod
56     def extract(node):
57         attrs = common_onnx_pool_extractor(node)
58         attrs.update({'pooling_convention': 'full',
59                       'global_pool': True,
60                      })
61
62         Pooling.update_node_stat(node, attrs)
63         return __class__.enabled
64
65
66 class GlobalMaxPoolFrontExtractor(FrontExtractorOp):
67     op = 'GlobalMaxPool'
68     enabled = True
69
70     @staticmethod
71     def extract(node):
72         attrs = common_onnx_pool_extractor(node)
73         attrs.update({'pooling_convention': 'full',
74                       'global_pool': True,
75                      })
76
77         Pooling.update_node_stat(node, attrs)
78         return __class__.enabled
79
80
81 def common_onnx_pool_extractor(node):
82     pads = onnx_attr(node, 'pads', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
83
84     # Try to convert slightly incorrect models with insufficient pad parameters
85     if pads is not None and (pads.size == 2 or pads.size % 2 != 0):
86         log.warning(
87             'Node {} has pad = {} which is ill-formed -- it should consist of N%2==0 elements.'.format(node.name,
88                                                                                                        pads))
89         pads = np.concatenate([pads, pads])
90         log.warning('Extended pads to {}'.format(pads))
91
92     final_pads = None
93     if pads is not None:
94         assert len(pads) % 2 == 0
95         pads = pads.reshape([2, -1])
96         pads = np.transpose(pads)
97         final_pads = np.array([[0, 0], [0, 0], *[p for p in pads]], dtype=np.int64)
98
99     # Extract dilations attribute
100     # In case if dilations is not specified it will be set in default (1) in infer function
101     strides = onnx_attr(node, 'strides', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
102     final_strides = np.array([1, 1, *[x for x in strides]], dtype=np.int64) if strides is not None else None
103
104     kernel_shape = onnx_attr(node, 'kernel_shape', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
105     final_kernel_shape = np.array([1, 1, *[x for x in kernel_shape]], dtype=np.int64) if kernel_shape is not None else None
106
107     # exclude_pad = True only when count_include_pad == 0
108     exclude_pad = onnx_attr(node, 'count_include_pad', 'i', default=0) == 0
109
110     global_pooling = 0
111     if node.op in ['MaxPool', 'GlobalMaxPool']:
112         method = 'max'
113     elif node.op in ['AveragePool', 'GlobalAveragePool']:
114         method = 'avg'
115     else:
116         raise Error('Unsupported pooling op {}', node.op)
117
118     # TODO check if it is a correct choice for ONNX
119     pooling_convention = 'valid'  # for Caffe rounding type should be ceil
120     rt = 'floor'
121
122     auto_pad = onnx_attr(node, 'auto_pad', 's', default=None, dst_type=get_onnx_autopad)
123     if auto_pad:
124         rt = 'ceil'
125
126     attrs = {
127         'op': node.op,
128         'auto_pad': auto_pad,
129         'window': final_kernel_shape,
130         'stride': final_strides,
131         'pad': final_pads,
132         'pad_spatial_shape': np.array(pads, dtype=np.int64) if pads is not None else None,
133         'pool_method': method,
134         'exclude_pad': 'true' if exclude_pad else 'false',
135         'global_pool': global_pooling,
136         'output_spatial_shape': None,
137         'rounding_type': rt,
138
139         'spatial_dims': None,
140         'channel_dims': np.array([1], dtype=np.int64),
141         'batch_dims': np.array([0], dtype=np.int64),
142         'layout': 'NCHW',
143
144         'pooling_convention': pooling_convention
145     }
146     return attrs