Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FoldCastPass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FoldCastPass.h"
18 #include "PassTestGraphs.h"
19
20 #include <luci/IR/CircleNodes.h>
21
22 #include <gtest/gtest.h>
23
24 namespace
25 {
26
27 template <loco::DataType FromT, loco::DataType ToT>
28 class FoldCastTest : public luci::ConstantFoldingAddTestGraph
29 {
30 public:
31   FoldCastTest(std::initializer_list<uint32_t> shape)
32     : luci::ConstantFoldingAddTestGraph(shape, ToT)
33   {
34     _cast = _g.nodes()->template create<luci::CircleCast>();
35     _x = _g.nodes()->template create<luci::CircleConst>();
36
37     _cast->dtype(ToT);
38     _x->dtype(FromT);
39
40     _cast->shape(shape);
41     _x->shape(shape);
42
43     uint32_t num_elems = 1;
44     for (auto dim = shape.begin(); dim != shape.end(); dim++)
45       num_elems *= *dim;
46
47     _x->size<FromT>(num_elems);
48     for (uint32_t i = 0; i < num_elems; i++)
49       _x->at<FromT>(i) = i + 1;
50
51     _cast->x(_x);
52
53     _cast->name("cast");
54     _x->name("x");
55   }
56
57   loco::Node *createFoldedPattern() override { return _cast; }
58
59 protected:
60   luci::CircleCast *_cast = nullptr;
61   luci::CircleConst *_x = nullptr;
62 };
63
64 /**
65  *  Graph that has a Cast Op with constant input
66  *
67  *    BEFORE
68  *
69  *         [CircleConst]
70  *               |
71  *            [Cast]
72  *
73  *    AFTER
74  *
75  *         [CircleConst]
76  *
77  */
78 class FoldS64ToS32CastTest : public FoldCastTest<loco::DataType::S64, loco::DataType::S32>,
79                              public ::testing::Test
80 {
81 public:
82   FoldS64ToS32CastTest() : FoldCastTest<loco::DataType::S64, loco::DataType::S32>({3}) {}
83
84   virtual void SetUp() { init(); }
85 };
86
87 } // namespace
88
89 TEST(FoldCastPassTest, name)
90 {
91   luci::FoldCastPass pass;
92   auto const name = pass.name();
93   ASSERT_NE(nullptr, name);
94 }
95
96 TEST_F(FoldS64ToS32CastTest, fold_cast_s64_to_s32)
97 {
98   luci::FoldCastPass pass;
99   while (pass.run(graph()))
100     ;
101
102   auto folded_const = getFoldedPattern();
103   EXPECT_NE(nullptr, folded_const);
104
105   // Check type, shape, values of folded const
106   EXPECT_EQ(loco::DataType::S32, folded_const->dtype());
107   EXPECT_EQ(1, folded_const->rank());
108   EXPECT_EQ(3, folded_const->dim(0).value());
109   EXPECT_EQ(1, folded_const->at<loco::DataType::S32>(0));
110   EXPECT_EQ(2, folded_const->at<loco::DataType::S32>(1));
111   EXPECT_EQ(3, folded_const->at<loco::DataType::S32>(2));
112 }