Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FoldAddV2Pass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FoldAddV2Pass.h"
18 #include "PassTestGraphs.h"
19
20 #include <luci/IR/CircleNodes.h>
21
22 #include <gtest/gtest.h>
23
24 namespace
25 {
26
27 /**
28  *  Graph has an AddV2 Op with constant inputs
29  *
30  *    BEFORE
31  *
32  *    [CircleConst] [CircleConst]
33  *               |   |
34  *       [CircleCustom (AddV2)]
35  *                 |
36  *         [CircleCustomOut]
37  *
38  *    AFTER
39  *
40  *           [CircleConst]
41  */
42 template <loco::DataType T> class FoldAddV2Test : public luci::ConstantFoldingAddTestGraph
43 {
44 public:
45   FoldAddV2Test(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, T)
46   {
47     _addV2 = _g.nodes()->template create<luci::CircleCustom>(2, 1);
48     _x = _g.nodes()->template create<luci::CircleConst>();
49     _y = _g.nodes()->template create<luci::CircleConst>();
50     _addV2_out = _g.nodes()->template create<luci::CircleCustomOut>();
51
52     _addV2->dtype(T);
53     _x->dtype(T);
54     _y->dtype(T);
55     _addV2_out->dtype(T);
56
57     _addV2->shape(shape);
58     _x->shape(shape);
59     _y->shape(shape);
60     _addV2_out->shape(shape);
61
62     uint32_t num_elems = 1;
63     for (auto dim = shape.begin(); dim != shape.end(); dim++)
64       num_elems *= *dim;
65
66     _x->size<T>(num_elems);
67     _y->size<T>(num_elems);
68
69     for (uint32_t i = 0; i < num_elems; i++)
70     {
71       _x->at<T>(i) = i + 1;
72       _y->at<T>(i) = i + 1;
73     }
74
75     _addV2->custom_code("AddV2");
76     _addV2->inputs(0, _x);
77     _addV2->inputs(1, _y);
78     _addV2_out->input(_addV2);
79
80     _addV2->name("addV2");
81     _x->name("x");
82     _y->name("y");
83   }
84
85   loco::Node *createFoldedPattern() override { return _addV2_out; }
86
87   virtual ~FoldAddV2Test() = default;
88
89 protected:
90   luci::CircleCustom *_addV2 = nullptr;
91   luci::CircleCustomOut *_addV2_out = nullptr;
92   luci::CircleConst *_x = nullptr;
93   luci::CircleConst *_y = nullptr;
94 };
95
96 class FoldS64AddV2Test : public FoldAddV2Test<loco::DataType::S64>, public ::testing::Test
97 {
98 public:
99   FoldS64AddV2Test() : FoldAddV2Test<loco::DataType::S64>({3}) {}
100
101   virtual void SetUp() { init(); }
102 };
103
104 } // namespace
105
106 TEST(FoldAddV2PassTest, name)
107 {
108   luci::FoldAddV2Pass pass;
109   auto const name = pass.name();
110   ASSERT_NE(nullptr, name);
111 }
112
113 TEST_F(FoldS64AddV2Test, fold_addV2)
114 {
115   luci::FoldAddV2Pass pass;
116   while (pass.run(graph()))
117     ;
118
119   auto folded_const = getFoldedPattern();
120   EXPECT_NE(nullptr, folded_const);
121
122   // Check type, shape, values of folded const
123   EXPECT_EQ(loco::DataType::S64, folded_const->dtype());
124   EXPECT_EQ(1, folded_const->rank());
125   EXPECT_EQ(3, folded_const->dim(0).value());
126   EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0));
127   EXPECT_EQ(4, folded_const->at<loco::DataType::S64>(1));
128   EXPECT_EQ(6, folded_const->at<loco::DataType::S64>(2));
129 }
130
131 TEST_F(FoldS64AddV2Test, input_type_mismatch_NEG)
132 {
133   _x->dtype(loco::DataType::S32);
134
135   luci::FoldAddV2Pass pass;
136   EXPECT_FALSE(pass.run(graph()));
137 }