Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Reverse.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Reverse.h"
19 #include "kernels/TestUtils.h"
20
21 namespace luci_interpreter
22 {
23 namespace kernels
24 {
25 namespace
26 {
27
28 using namespace testing;
29
30 template <typename T> class ReverseTest : public ::testing::Test
31 {
32 };
33
34 using DataTypes = ::testing::Types<float, uint8_t>;
35 TYPED_TEST_CASE(ReverseTest, DataTypes);
36
37 TYPED_TEST(ReverseTest, MultiDimensions)
38 {
39   // TypeParam
40   std::vector<TypeParam> input_data{1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
41                                     13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
42   Shape input_shape{4, 3, 2};
43   std::vector<int32_t> axis_data{1};
44   Shape axis_shape{1};
45
46   std::vector<TypeParam> output_data{5,  6,  3,  4,  1,  2,  11, 12, 9,  10, 7,  8,
47                                      17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20};
48   std::vector<int32_t> output_shape{4, 3, 2};
49
50   Tensor input_tensor = makeInputTensor<getElementType<TypeParam>()>(input_shape, input_data);
51   Tensor axis_tensor = makeInputTensor<DataType::S32>(axis_shape, axis_data);
52
53   Tensor output_tensor = makeOutputTensor(getElementType<TypeParam>());
54
55   Reverse kernel = Reverse(&input_tensor, &axis_tensor, &output_tensor);
56   kernel.configure();
57   kernel.execute();
58
59   EXPECT_THAT(extractTensorData<TypeParam>(output_tensor),
60               ::testing::ElementsAreArray(output_data));
61   EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
62 }
63
64 } // namespace
65 } // namespace kernels
66 } // namespace luci_interpreter