2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
19 #include "helpers/CreateCircleConst.h"
21 #include <luci/test/TestIOGraph.h>
22 #include <luci/IR/CircleNodes.h>
24 #include <gtest/gtest.h>
29 using namespace luci::test;
32 * Simple graph for test
36 * [IFM1] [IFM2] [BIAS]
45 * [BatchMatMul] [BIAS]
55 FCGraphlet() = default;
56 virtual ~FCGraphlet() = default;
58 void init(loco::Graph *g, const ShapeU32 r_shape, const float bv)
60 _tr_y = g->nodes()->create<luci::CircleTranspose>();
62 std::vector<int32_t> tr_val = {1, 0};
63 _tr_y->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_val));
65 _fc = g->nodes()->create<luci::CircleFullyConnected>();
68 _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
69 _fc->dtype(loco::DataType::FLOAT32);
71 auto l = _fc->dim(_fc->rank() - 1).value();
72 std::vector<float> bias_val(l, bv);
73 _fc->bias(luci::create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
78 luci::CircleFullyConnected *fc() { return _fc; }
81 luci::CircleFullyConnected *_fc = nullptr;
82 luci::CircleTranspose *_tr_y = nullptr;
83 luci::CircleInput *_x = nullptr;
84 luci::CircleInput *_y = nullptr;
87 struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphlet
90 virtual ~FCGraph() = default;
91 void init(const ShapeU32 x_shape, const ShapeU32 y_shape, const ShapeU32 r_shape, const float bv)
93 TestIsGraphlet<2>::init(g(), {x_shape, y_shape});
94 TestOGraphlet::init(g(), r_shape);
97 FCGraphlet::init(g(), r_shape, bv);
102 class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test
106 luci::ReplaceNonConstFCWithBatchMatMulPass pass;
111 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test)
113 g.init({2, 3}, {2, 3}, {2, 2}, 0.0f);
115 auto ret = pass.run(g.g());
116 EXPECT_EQ(true, ret);
118 auto res = dynamic_cast<luci::CircleReshape *>(g.output()->from());
119 EXPECT_NE(nullptr, res);
122 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test)
124 g.init({2, 3}, {2, 3}, {2, 2}, 1.0f);
126 auto ret = pass.run(g.g());
127 EXPECT_EQ(true, ret);
129 auto mm = dynamic_cast<luci::CircleAdd *>(g.output()->from());
130 EXPECT_NE(nullptr, mm);
133 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, wrong_op_NEG)
137 auto inp = g.nodes()->create<luci::CircleInput>();
138 auto relu = g.nodes()->create<luci::CircleRelu>();
141 luci::ReplaceNonConstFCWithBatchMatMulPass pass;
142 auto changed = pass.run(&g);
144 EXPECT_EQ(false, changed);