2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FoldAddV2Pass.h"
18 #include "PassTestGraphs.h"
20 #include <luci/IR/CircleNodes.h>
22 #include <gtest/gtest.h>
28 * Graph has an AddV2 Op with constant inputs
32 * [CircleConst] [CircleConst]
34 * [CircleCustom (AddV2)]
42 template <loco::DataType T> class FoldAddV2Test : public luci::ConstantFoldingAddTestGraph
45 FoldAddV2Test(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, T)
47 _addV2 = _g.nodes()->create<luci::CircleCustom>(2, 1);
48 _x = _g.nodes()->create<luci::CircleConst>();
49 _y = _g.nodes()->create<luci::CircleConst>();
50 _addV2_out = _g.nodes()->create<luci::CircleCustomOut>();
60 _addV2_out->shape(shape);
62 uint32_t num_elems = 1;
63 for (auto dim = shape.begin(); dim != shape.end(); dim++)
66 _x->size<T>(num_elems);
67 _y->size<T>(num_elems);
69 for (uint32_t i = 0; i < num_elems; i++)
75 _addV2->custom_code("AddV2");
76 _addV2->inputs(0, _x);
77 _addV2->inputs(1, _y);
78 _addV2_out->input(_addV2);
80 _addV2->name("addV2");
85 loco::Node *createFoldedPattern() override { return _addV2_out; }
87 virtual ~FoldAddV2Test() = default;
90 luci::CircleCustom *_addV2 = nullptr;
91 luci::CircleCustomOut *_addV2_out = nullptr;
92 luci::CircleConst *_x = nullptr;
93 luci::CircleConst *_y = nullptr;
96 class FoldS64AddV2Test : public FoldAddV2Test<loco::DataType::S64>, public ::testing::Test
99 FoldS64AddV2Test() : FoldAddV2Test<loco::DataType::S64>({3}) {}
101 virtual void SetUp() { init(); }
106 TEST(FoldAddV2PassTest, name)
108 luci::FoldAddV2Pass pass;
109 auto const name = pass.name();
110 ASSERT_NE(nullptr, name);
113 TEST_F(FoldS64AddV2Test, fold_addV2)
115 luci::FoldAddV2Pass pass;
116 while (pass.run(graph()))
119 auto folded_const = getFoldedPattern();
120 EXPECT_NE(nullptr, folded_const);
122 // Check type, shape, values of folded const
123 EXPECT_EQ(loco::DataType::S64, folded_const->dtype());
124 EXPECT_EQ(1, folded_const->rank());
125 EXPECT_EQ(3, folded_const->dim(0).value());
126 EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0));
127 EXPECT_EQ(4, folded_const->at<loco::DataType::S64>(1));
128 EXPECT_EQ(6, folded_const->at<loco::DataType::S64>(2));
131 TEST_F(FoldS64AddV2Test, input_type_mismatch_NEG)
133 _x->dtype(loco::DataType::S32);
135 luci::FoldAddV2Pass pass;
136 EXPECT_FALSE(pass.run(graph()));