Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / onnx / upsample_ext_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import onnx
18
19 from extensions.front.onnx.upsample_ext import UpsampleFrontExtractor
20 from mo.utils.unittest.graph import build_graph
21 from mo.graph.graph import Node
22 from mo.utils.error import Error
23 from mo.utils.unittest.extractors import BaseExtractorsTestingClass
24
25
26 class UpsampleONNXExtractorTest(BaseExtractorsTestingClass):
27     @staticmethod
28     def _create_node(attrs: dict):
29         pb = onnx.helper.make_node("Upsample", ["X"], ["Y"], **attrs)
30         graph = build_graph({'node_0': {'pb': pb}}, [])
31         return Node(graph, 'node_0')
32
33     @staticmethod
34     def _base_attrs():
35         # Commonly used attributes in the tests
36         # Each test takes these ones and then adds/modifies/deletes particular fields
37         return (
38             # test input ONNX attributes
39             dict(
40                 mode='nearest',
41                 width_scale=2.0,
42                 height_scale=2.0,
43             ),
44             # reference output Node attributes
45             dict(
46                 type='Resample',
47                 resample_type='caffe.ResampleParameter.NEAREST',
48                 factor=2,
49                 antialias=0,
50             )
51         )
52
53     @staticmethod
54     def _extract(inp):
55         node = __class__._create_node(inp)
56         UpsampleFrontExtractor.extract(node)
57         return node
58
59     def _match(self, out, ref):
60         self.res = out
61         self.expected = ref
62         self.compare()
63
64     def test_all_valid_default(self):
65         inp, ref = self._base_attrs()
66         out = self._extract(inp)
67         self._match(out, ref)
68
69     def test_invalid_mode(self):
70         inp, ref = self._base_attrs()
71         inp['mode'] = 'invalid_mode'
72         with self.assertRaisesRegex(Error, '.*decoding Upsample.*supported modes.*'):
73             out = self._extract(inp)
74
75     def test_unsupported_linear(self):
76         inp, ref = self._base_attrs()
77         inp['mode'] = 'linear'
78         with self.assertRaisesRegex(Error, '.*Only nearest is supported.*'):
79             out = self._extract(inp)
80
81     def test_unsupported_scale(self):
82         inp, ref = self._base_attrs()
83         inp['scales'] = [2.0, 2.0]
84         with self.assertRaisesRegex(Error, '.*Only scale_width and scale_height are supported.*'):
85             out = self._extract(inp)
86
87     def test_missing_width_scale(self):
88         inp, ref = self._base_attrs()
89         del inp['width_scale']
90         with self.assertRaisesRegex(Error, '.*One/both of widths_scale.*and height_scale.*is not defined.*'):
91             out = self._extract(inp)
92
93     def test_missing_height_scale(self):
94         inp, ref = self._base_attrs()
95         del inp['height_scale']
96         with self.assertRaisesRegex(Error, '.*One/both of widths_scale.*and height_scale.*is not defined.*'):
97             out = self._extract(inp)
98
99     def test_different_scales(self):
100         inp, ref = self._base_attrs()
101         inp['height_scale'] = 2.0
102         inp['width_scale'] = 3.0
103         with self.assertRaisesRegex(Error, '.*different widths_scale.*and height_scale.*not supported.*'):
104             out = self._extract(inp)