2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.front.onnx.upsample_ext import UpsampleFrontExtractor
20 from mo.utils.unittest.graph import build_graph
21 from mo.graph.graph import Node
22 from mo.utils.error import Error
23 from mo.utils.unittest.extractors import BaseExtractorsTestingClass
26 class UpsampleONNXExtractorTest(BaseExtractorsTestingClass):
28 def _create_node(attrs: dict):
29 pb = onnx.helper.make_node("Upsample", ["X"], ["Y"], **attrs)
30 graph = build_graph({'node_0': {'pb': pb}}, [])
31 return Node(graph, 'node_0')
35 # Commonly used attributes in the tests
36 # Each test takes these ones and then adds/modifies/deletes particular fields
38 # test input ONNX attributes
44 # reference output Node attributes
47 resample_type='caffe.ResampleParameter.NEAREST',
55 node = __class__._create_node(inp)
56 UpsampleFrontExtractor.extract(node)
59 def _match(self, out, ref):
64 def test_all_valid_default(self):
65 inp, ref = self._base_attrs()
66 out = self._extract(inp)
69 def test_invalid_mode(self):
70 inp, ref = self._base_attrs()
71 inp['mode'] = 'invalid_mode'
72 with self.assertRaisesRegex(Error, '.*decoding Upsample.*supported modes.*'):
73 out = self._extract(inp)
75 def test_unsupported_linear(self):
76 inp, ref = self._base_attrs()
77 inp['mode'] = 'linear'
78 with self.assertRaisesRegex(Error, '.*Only nearest is supported.*'):
79 out = self._extract(inp)
81 def test_unsupported_scale(self):
82 inp, ref = self._base_attrs()
83 inp['scales'] = [2.0, 2.0]
84 with self.assertRaisesRegex(Error, '.*Only scale_width and scale_height are supported.*'):
85 out = self._extract(inp)
87 def test_missing_width_scale(self):
88 inp, ref = self._base_attrs()
89 del inp['width_scale']
90 with self.assertRaisesRegex(Error, '.*One/both of widths_scale.*and height_scale.*is not defined.*'):
91 out = self._extract(inp)
93 def test_missing_height_scale(self):
94 inp, ref = self._base_attrs()
95 del inp['height_scale']
96 with self.assertRaisesRegex(Error, '.*One/both of widths_scale.*and height_scale.*is not defined.*'):
97 out = self._extract(inp)
99 def test_different_scales(self):
100 inp, ref = self._base_attrs()
101 inp['height_scale'] = 2.0
102 inp['width_scale'] = 3.0
103 with self.assertRaisesRegex(Error, '.*different widths_scale.*and height_scale.*not supported.*'):
104 out = self._extract(inp)