2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from generator import generator
21 from mo.graph.graph import Node
22 from mo.ops.slice import Slice
23 from mo.utils.unittest.graph import build_graph
57 class TestSliceOp(unittest.TestCase):
58 def test_slice_infer_constant(self):
59 # Testing constant path case
60 graph = build_graph(nodes_attributes,
65 {'data_1': {'shape': np.array([4]), 'value': np.array([1, 3, 224, 224])},
66 'slice': {'start': np.array([1]), 'end': np.array([2])},
67 'size': {'value': np.array([1])},
68 'begin': {'value': np.array([1])}})
70 slice_node = Node(graph, 'slice')
71 Slice.infer(slice_node)
73 self.assertTrue(np.array_equal(slice_node.out_node().value, np.array([3])))
74 self.assertTrue(np.array_equal(slice_node.out_node().shape, np.array([1])))
75 self.assertTrue(np.array_equal(slice_node['slices'], np.array([slice(1, 2, 1)])))
77 def test_slice_infer_non_constant(self):
78 # Testing non-constant path case (when value in input is None)
79 # with multiply params
80 graph = build_graph(nodes_attributes,
85 {'data_1': {'shape': np.array([4, 5, 6])},
86 'slice': {'start': np.array([1, 2]),
87 'end': np.array([4, 3])},
88 'size': {'value': np.array([3, 1])},
89 'begin': {'value': np.array([1, 2])}})
91 slice_node = Node(graph, 'slice')
93 Slice.infer(slice_node)
94 self.assertTrue(np.array_equal(slice_node.out_node().value, None))
95 self.assertTrue(np.array_equal(slice_node.out_node().shape, np.array([3, 1, 6])))
96 self.assertTrue(np.array_equal(slice_node['slices'], np.array([slice(1, 4, 1), slice(2, 3, 1), slice(0, 6, 1)])))
98 def test_slice_infer_multiply_params(self):
99 # Test case when size[i] == -1 (that means all
100 # remaining elements in dimension i are included in the slice)
101 graph = build_graph(nodes_attributes,
102 [('data_1', 'slice'),
105 ('slice', 'data_2')],
106 {'data_1': {'shape': np.array([4, 5, 6])},
107 'slice': {'start': np.array([1, 2]),
108 'end': np.array([4, 1])},
109 'size': {'value': np.array([3, -1])},
110 'begin': {'value': np.array([1, 2])}})
112 slice_node = Node(graph, 'slice')
114 Slice.infer(slice_node)
115 self.assertTrue(np.array_equal(slice_node.out_node().value, None))
116 self.assertTrue(np.array_equal(slice_node.out_node().shape, np.array([3, 3, 6])))
117 self.assertTrue(np.array_equal(slice_node['slices'], np.array([slice(1, 4, 1), slice(2, 5, 1), slice(0, 6, 1)])))