3797e3ccc5052eca3df272f2d47fa0266ff25790
[platform/core/ml/nnfw.git] / compiler / mir / unittests / ShapeRange.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "gtest/gtest.h"
18 #include "mir/ShapeRange.h"
19
20 using namespace mir;
21
22 namespace
23 {
24
25 struct ParamType
26 {
27   int32_t actual_length;
28   Shape shape;
29
30   template <typename... Args>
31   explicit ParamType(int32_t actual_len, Args &&... args)
32     : actual_length(actual_len), shape({static_cast<int32_t>(args)...})
33   {
34   }
35 };
36
37 class ShapeIteratorTest : public ::testing::TestWithParam<ParamType>
38 {
39 };
40
41 TEST_P(ShapeIteratorTest, ElementCount)
42 {
43   Shape sh(GetParam().shape);
44   ShapeRange r(sh);
45
46   int32_t cnt = 0;
47   for (auto &idx : r)
48   {
49     (void)idx;
50     cnt++;
51   }
52
53   ASSERT_EQ(cnt, GetParam().actual_length);
54 }
55
56 std::vector<ParamType> test_data{ParamType{6, 1, 2, 3}, ParamType{16, 2, 2, 4},
57                                  ParamType{1, 1, 1, 1, 1, 1}, ParamType{5, 5, 1, 1, 1, 1, 1}};
58
59 INSTANTIATE_TEST_CASE_P(SimpleInput, ShapeIteratorTest, ::testing::ValuesIn(test_data));
60
61 TEST(ShapeRange, Contains)
62 {
63   const int h = 2;
64   const int w = 3;
65   Shape shape{static_cast<int32_t>(h), static_cast<int32_t>(w)};
66   ShapeRange range(shape);
67   Index index{0, 0, 0, 0};
68   for (int32_t row = -2; row < h + 1; ++row)
69     for (int32_t col = -2; col < w + 1; ++col)
70     {
71       Index idx{row, col};
72       if (row < 0 || row >= h || col < 0 || col >= w)
73         ASSERT_FALSE(range.contains(idx));
74       else
75         ASSERT_TRUE(range.contains(idx));
76     }
77 }
78 } // namespace