2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.front.tf.pooling_ext import AvgPoolFrontExtractor, MaxPoolFrontExtractor
20 from mo.utils.unittest.extractors import PB, BaseExtractorsTestingClass
23 class PoolingExtractorTest(BaseExtractorsTestingClass):
26 cls.strides = [1, 2, 3, 4]
27 cls.ksize = [1, 3, 3, 1]
28 cls.patcher = 'mo.ops.pooling.Pooling.infer'
30 def test_pool_defaults(self):
41 'list': PB({"i": self.ksize})
48 'pad': None, # will be inferred when input shape is known
49 'pad_spatial_shape': None,
51 'exclude_pad': 'true',
54 AvgPoolFrontExtractor.extract(node)
56 self.res["infer"](None)
57 self.call_args = self.infer_mock.call_args
58 self.expected_call_args = (None, None)
61 def test_avg_pool_nhwc(self):
67 'list': PB({"i": self.strides})
70 'list': PB({"i": self.ksize})
77 'window': np.array(self.ksize, dtype=np.int8),
78 'spatial_dims': [1, 2],
79 'stride': np.array(self.strides, dtype=np.int8),
83 AvgPoolFrontExtractor.extract(node)
85 self.res["infer"](None)
86 self.call_args = self.infer_mock.call_args
87 self.expected_call_args = (None, "avg")
90 def test_avg_pool_nchw(self):
110 'window': np.array(self.ksize, dtype=np.int8),
111 'spatial_dims': [2, 3],
112 'stride': np.array(self.strides, dtype=np.int8),
113 'pool_method': "avg",
115 node = PB({'pb': pb})
116 AvgPoolFrontExtractor.extract(node)
118 self.res["infer"](None)
119 self.call_args = self.infer_mock.call_args
120 self.expected_call_args = (None, "avg")
123 def test_max_pool_nhwc(self):
143 'window': np.array(self.ksize, dtype=np.int8),
144 'spatial_dims': [1, 2],
145 'stride': np.array(self.strides, dtype=np.int64),
146 'pool_method': "max",
148 node = PB({'pb': pb})
149 MaxPoolFrontExtractor.extract(node)
151 self.res["infer"](None)
152 self.call_args = self.infer_mock.call_args
153 self.expected_call_args = (None, "max")