2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.front.tf.conv_ext import Conv2DFrontExtractor, DepthwiseConv2dNativeFrontExtractor
20 from mo.utils.unittest.extractors import PB, BaseExtractorsTestingClass
23 class ConvExtractorTest(BaseExtractorsTestingClass):
26 cls.strides = [1, 2, 3, 4]
27 cls.dilations = [1, 1, 1, 1]
29 def test_conv_2d_defaults(self):
30 node = PB({'pb': PB({'attr': {
35 'list': PB({"i": self.strides})
41 'list': PB({"i": [1, 1, 1, 1]})
46 'dilation': np.array([1, 1, 1, 1], dtype=np.int8),
47 'type': 'Convolution',
50 Conv2DFrontExtractor.extract(node)
52 self.expected_call_args = (None, False)
55 def test_conv2d_nhwc(self):
56 node = PB({'pb': PB({'attr': {
61 'list': PB({"i": self.strides})
67 'list': PB({"i": [1, 1, 1, 1]})
71 # spatial_dims = [1, 2] will be detected in infer function
74 "input_feature_channel": 2,
75 "output_feature_channel": 3,
76 'dilation': np.array([1, 1, 1, 1], dtype=np.int8),
77 'stride': np.array(self.strides, dtype=np.int8),
79 Conv2DFrontExtractor.extract(node)
81 self.expected_call_args = (None, False)
84 def test_conv2d_nchw(self):
85 node = PB({'pb': PB({'attr': {
90 'list': PB({"i": self.strides})
96 'list': PB({"i": [1, 1, 1, 1]})
100 # spatial_dims = [2, 3] will be detected in infer function
103 "input_feature_channel": 2,
104 "output_feature_channel": 3,
105 'dilation': np.array([1, 1, 1, 1], dtype=np.int8),
106 'stride': np.array(self.strides, dtype=np.int8),
108 Conv2DFrontExtractor.extract(node)
110 self.expected_call_args = (None, False)
113 def test_conv2d_depthwise(self):
114 node = PB({'pb': PB({'attr': {
119 'list': PB({"i": self.strides}),
122 'list': PB({"i": self.dilations}),
129 # spatial_dims = [1, 2] will be detected in infer function
132 "input_feature_channel": 2,
133 "output_feature_channel": 2,
134 'dilation': np.array([1, 1, 1, 1], dtype=np.int8),
135 'stride': np.array(self.strides, dtype=np.int8),
137 DepthwiseConv2dNativeFrontExtractor.extract(node)
139 self.expected_call_args = (None, True)