Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / pooling_ext_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from extensions.front.tf.pooling_ext import AvgPoolFrontExtractor, MaxPoolFrontExtractor
20 from mo.utils.unittest.extractors import PB, BaseExtractorsTestingClass
21
22
23 class PoolingExtractorTest(BaseExtractorsTestingClass):
24     @classmethod
25     def setUpClass(cls):
26         cls.strides = [1, 2, 3, 4]
27         cls.ksize = [1, 3, 3, 1]
28         cls.patcher = 'mo.ops.pooling.Pooling.infer'
29
30     def test_pool_defaults(self):
31         pb = PB({'attr': {
32             'data_format': PB({
33                 's': b"NHWC"
34             }),
35             'strides': PB({
36                 'list': PB({
37                     "i": self.strides
38                 })
39             }),
40             'ksize': PB({
41                 'list': PB({"i": self.ksize})
42             }),
43             'padding': PB({
44                 's': b'VALID'
45             })
46         }})
47         self.expected = {
48             'pad': None,  # will be inferred when input shape is known
49             'pad_spatial_shape': None,
50             'type': 'Pooling',
51             'exclude_pad': 'true',
52         }
53         node = PB({'pb': pb})
54         AvgPoolFrontExtractor.extract(node)
55         self.res = node
56         self.res["infer"](None)
57         self.call_args = self.infer_mock.call_args
58         self.expected_call_args = (None, None)
59         self.compare()
60
61     def test_avg_pool_nhwc(self):
62         pb = PB({'attr': {
63             'data_format': PB({
64                 's': b"NHWC"
65             }),
66             'strides': PB({
67                 'list': PB({"i": self.strides})
68             }),
69             'ksize': PB({
70                 'list': PB({"i": self.ksize})
71             }),
72             'padding': PB({
73                 's': b'VALID'
74             })
75         }})
76         self.expected = {
77             'window': np.array(self.ksize, dtype=np.int8),
78             'spatial_dims': [1, 2],
79             'stride': np.array(self.strides, dtype=np.int8),
80             'pool_method': "avg",
81         }
82         node = PB({'pb': pb})
83         AvgPoolFrontExtractor.extract(node)
84         self.res = node
85         self.res["infer"](None)
86         self.call_args = self.infer_mock.call_args
87         self.expected_call_args = (None, "avg")
88         self.compare()
89
90     def test_avg_pool_nchw(self):
91         pb = PB({'attr': {
92             'data_format': PB({
93                 's': b"NCHW"
94             }),
95             'strides': PB({
96                 'list': PB({
97                     "i": self.strides
98                 })
99             }),
100             'ksize': PB({
101                 'list': PB({
102                     "i": self.ksize
103                 })
104             }),
105             'padding': PB({
106                 's': b'VALID'
107             })
108         }})
109         self.expected = {
110             'window': np.array(self.ksize, dtype=np.int8),
111             'spatial_dims': [2, 3],
112             'stride': np.array(self.strides, dtype=np.int8),
113             'pool_method': "avg",
114         }
115         node = PB({'pb': pb})
116         AvgPoolFrontExtractor.extract(node)
117         self.res = node
118         self.res["infer"](None)
119         self.call_args = self.infer_mock.call_args
120         self.expected_call_args = (None, "avg")
121         self.compare()
122
123     def test_max_pool_nhwc(self):
124         pb = PB({'attr': {
125             'data_format': PB({
126                 's': b"NHWC"
127             }),
128             'strides': PB({
129                 'list': PB({
130                     "i": self.strides
131                 })
132             }),
133             'ksize': PB({
134                 'list': PB({
135                     "i": self.ksize
136                 })
137             }),
138             'padding': PB({
139                 's': b'VALID'
140             })
141         }})
142         self.expected = {
143             'window': np.array(self.ksize, dtype=np.int8),
144             'spatial_dims': [1, 2],
145             'stride': np.array(self.strides, dtype=np.int64),
146             'pool_method': "max",
147         }
148         node = PB({'pb': pb})
149         MaxPoolFrontExtractor.extract(node)
150         self.res = node
151         self.res["infer"](None)
152         self.call_args = self.infer_mock.call_args
153         self.expected_call_args = (None, "max")
154         self.compare()