2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.extractor import FrontExtractorOp
22 from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_autopad
23 from mo.ops.pooling import Pooling
24 from mo.utils.error import Error
27 class AveragePoolFrontExtractor(FrontExtractorOp):
33 attrs = common_onnx_pool_extractor(node)
35 Pooling.update_node_stat(node, attrs)
36 return __class__.enabled
39 class MaxPoolFrontExtractor(FrontExtractorOp):
45 attrs = common_onnx_pool_extractor(node)
47 Pooling.update_node_stat(node, attrs)
48 return __class__.enabled
51 class GlobalAveragePoolFrontExtractor(FrontExtractorOp):
52 op = 'GlobalAveragePool'
57 attrs = common_onnx_pool_extractor(node)
58 attrs.update({'pooling_convention': 'full',
62 Pooling.update_node_stat(node, attrs)
63 return __class__.enabled
66 class GlobalMaxPoolFrontExtractor(FrontExtractorOp):
72 attrs = common_onnx_pool_extractor(node)
73 attrs.update({'pooling_convention': 'full',
77 Pooling.update_node_stat(node, attrs)
78 return __class__.enabled
81 def common_onnx_pool_extractor(node):
82 pads = onnx_attr(node, 'pads', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
84 # Try to convert slightly incorrect models with insufficient pad parameters
85 if pads is not None and (pads.size == 2 or pads.size % 2 != 0):
87 'Node {} has pad = {} which is ill-formed -- it should consist of N%2==0 elements.'.format(node.name,
89 pads = np.concatenate([pads, pads])
90 log.warning('Extended pads to {}'.format(pads))
94 assert len(pads) % 2 == 0
95 pads = pads.reshape([2, -1])
96 pads = np.transpose(pads)
97 final_pads = np.array([[0, 0], [0, 0], *[p for p in pads]], dtype=np.int64)
99 # Extract dilations attribute
100 # In case if dilations is not specified it will be set in default (1) in infer function
101 strides = onnx_attr(node, 'strides', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
102 final_strides = np.array([1, 1, *[x for x in strides]], dtype=np.int64) if strides is not None else None
104 kernel_shape = onnx_attr(node, 'kernel_shape', 'ints', default=None, dst_type=lambda x: np.array(x, dtype=np.int64))
105 final_kernel_shape = np.array([1, 1, *[x for x in kernel_shape]], dtype=np.int64) if kernel_shape is not None else None
107 # exclude_pad = True only when count_include_pad == 0
108 exclude_pad = onnx_attr(node, 'count_include_pad', 'i', default=0) == 0
111 if node.op in ['MaxPool', 'GlobalMaxPool']:
113 elif node.op in ['AveragePool', 'GlobalAveragePool']:
116 raise Error('Unsupported pooling op {}', node.op)
118 # TODO check if it is a correct choice for ONNX
119 pooling_convention = 'valid' # for Caffe rounding type should be ceil
122 auto_pad = onnx_attr(node, 'auto_pad', 's', default=None, dst_type=get_onnx_autopad)
128 'auto_pad': auto_pad,
129 'window': final_kernel_shape,
130 'stride': final_strides,
132 'pad_spatial_shape': np.array(pads, dtype=np.int64) if pads is not None else None,
133 'pool_method': method,
134 'exclude_pad': 'true' if exclude_pad else 'false',
135 'global_pool': global_pooling,
136 'output_spatial_shape': None,
139 'spatial_dims': None,
140 'channel_dims': np.array([1], dtype=np.int64),
141 'batch_dims': np.array([0], dtype=np.int64),
144 'pooling_convention': pooling_convention