2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
19 #include <luci/IR/CircleNodes.h>
21 #include <luci/test/TestIOGraph.h>
23 #include <gtest/gtest.h>
28 using namespace luci::test;
30 // TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
32 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
33 const std::vector<uint32_t> &shape,
34 const std::vector<T> &values)
36 auto node = g->nodes()->create<luci::CircleConst>();
38 node->rank(shape.size());
41 for (uint32_t i = 0; i < shape.size(); ++i)
43 node->dim(i) = shape.at(i);
46 node->shape_status(luci::ShapeStatus::VALID);
48 #define INIT_VALUES(DT) \
50 node->size<DT>(size); \
51 for (uint32_t i = 0; i < values.size(); ++i) \
52 node->at<DT>(i) = values[i]; \
57 case loco::DataType::U8:
58 INIT_VALUES(loco::DataType::U8);
60 case loco::DataType::S16:
61 INIT_VALUES(loco::DataType::S16);
63 case loco::DataType::S32:
64 INIT_VALUES(loco::DataType::S32);
66 case loco::DataType::FLOAT32:
67 INIT_VALUES(loco::DataType::FLOAT32)
70 INTERNAL_EXN("create_const_node called with unsupported type");
77 * Simple graph for test
87 * [FC w/ Relu] (bias updated)
93 FCAddGraphlet() = default;
95 void init(loco::Graph *g)
97 std::vector<float> weights_val(16 * 4);
98 _fc_f = create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
100 std::vector<float> bias_val(16);
101 _fc_b = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
103 _fc = g->nodes()->create<luci::CircleFullyConnected>();
106 _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
107 _fc->dtype(loco::DataType::FLOAT32);
111 std::vector<float> addition_val;
112 for (uint32_t i = 0; i < 16; i++)
113 addition_val.push_back(static_cast<float>(i));
114 _add_c = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
116 _add = g->nodes()->create<luci::CircleAdd>();
119 _add->fusedActivationFunction(luci::FusedActFunc::RELU);
120 _add->dtype(loco::DataType::FLOAT32);
121 _add->shape({1, 16});
126 luci::CircleFullyConnected *fc() { return _fc; }
129 luci::CircleFullyConnected *_fc = nullptr;
130 luci::CircleAdd *_add = nullptr;
131 luci::CircleConst *_fc_f = nullptr;
132 luci::CircleConst *_fc_b = nullptr;
133 luci::CircleConst *_add_c = nullptr;
136 class FuseAddWithFCTestGraph : public TestIOGraph, public FCAddGraphlet
139 FuseAddWithFCTestGraph() = default;
143 TestIOGraph::init({1, 4}, {1, 16});
144 FCAddGraphlet::init(g());
148 output()->from(_add);
152 class FuseAddWithFullyConnectedPassTest : public ::testing::Test
155 FuseAddWithFCTestGraph g;
156 luci::FuseAddWithFullyConnectedPass pass;
161 TEST_F(FuseAddWithFullyConnectedPassTest, simple_test)
165 auto ret = pass.run(g.g());
166 EXPECT_EQ(true, ret);
168 auto fc = dynamic_cast<luci::CircleFullyConnected *>(g.output()->from());
169 EXPECT_NE(nullptr, fc);
171 auto bias = loco::must_cast<luci::CircleConst *>(g.fc()->bias());
172 for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
174 EXPECT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));