2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
19 #include "helpers/CreateCircleConst.h"
21 #include <luci/IR/CircleNodes.h>
23 #include <luci/test/TestIOGraph.h>
25 #include <gtest/gtest.h>
30 using namespace luci::test;
33 * Simple graph for test
43 * [FC w/ Relu] (bias updated)
49 FCAddGraphlet() = default;
51 void init(loco::Graph *g)
53 std::vector<float> weights_val(16 * 4);
54 _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
56 std::vector<float> bias_val(16);
57 _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
59 _fc = g->nodes()->create<luci::CircleFullyConnected>();
62 _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
63 _fc->dtype(loco::DataType::FLOAT32);
67 std::vector<float> addition_val;
68 for (uint32_t i = 0; i < 16; i++)
69 addition_val.push_back(static_cast<float>(i));
70 _add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
72 _add = g->nodes()->create<luci::CircleAdd>();
75 _add->fusedActivationFunction(luci::FusedActFunc::RELU);
76 _add->dtype(loco::DataType::FLOAT32);
82 luci::CircleFullyConnected *fc() { return _fc; }
87 assert(_fc != nullptr); // FIX_ME_UNLESS
89 auto new_fc = _fc->graph()->nodes()->create<luci::CircleFullyConnected>();
94 luci::CircleFullyConnected *_fc = nullptr;
95 luci::CircleAdd *_add = nullptr;
96 luci::CircleConst *_fc_f = nullptr;
97 luci::CircleConst *_fc_b = nullptr;
98 luci::CircleConst *_add_c = nullptr;
101 class FuseAddWithFCTestGraph : public TestIOGraph, public FCAddGraphlet
104 FuseAddWithFCTestGraph() = default;
108 TestIOGraph::init({1, 4}, {1, 16});
109 FCAddGraphlet::init(g());
113 output()->from(_add);
117 class FuseAddWithFullyConnectedPassTest : public ::testing::Test
120 FuseAddWithFCTestGraph g;
121 luci::FuseAddWithFullyConnectedPass pass;
126 TEST_F(FuseAddWithFullyConnectedPassTest, simple_test)
130 auto ret = pass.run(g.g());
131 EXPECT_EQ(true, ret);
133 auto fc = dynamic_cast<luci::CircleFullyConnected *>(g.output()->from());
134 EXPECT_NE(nullptr, fc);
136 auto bias = loco::must_cast<luci::CircleConst *>(g.fc()->bias());
137 for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
139 EXPECT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
143 TEST_F(FuseAddWithFullyConnectedPassTest, fm_bias_NEG)
147 // Bias is a feature map. Add is not fused.
150 auto ret = pass.run(g.g());
151 EXPECT_EQ(false, ret);