2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
19 #include <luci/test/TestIOGraph.h>
20 #include <luci/IR/CircleNodes.h>
22 #include <gtest/gtest.h>
27 using namespace luci::test;
29 // TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
31 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
32 const std::vector<uint32_t> &shape,
33 const std::vector<T> &values)
35 auto node = g->nodes()->create<luci::CircleConst>();
37 node->rank(shape.size());
40 for (uint32_t i = 0; i < shape.size(); ++i)
42 node->dim(i) = shape.at(i);
45 node->shape_status(luci::ShapeStatus::VALID);
47 #define INIT_VALUES(DT) \
49 node->size<DT>(size); \
50 for (uint32_t i = 0; i < values.size(); ++i) \
51 node->at<DT>(i) = values[i]; \
56 case loco::DataType::U8:
57 INIT_VALUES(loco::DataType::U8);
59 case loco::DataType::S16:
60 INIT_VALUES(loco::DataType::S16);
62 case loco::DataType::S32:
63 INIT_VALUES(loco::DataType::S32);
65 case loco::DataType::FLOAT32:
66 INIT_VALUES(loco::DataType::FLOAT32)
69 INTERNAL_EXN("create_const_node called with unsupported type");
76 * Simple graph for test
80 * [IFM1] [IFM2] [BIAS]
89 * [BatchMatMul] [BIAS]
99 FCGraphlet() = default;
100 virtual ~FCGraphlet() = default;
102 void init(loco::Graph *g, const ShapeU32 r_shape, const float bv)
104 _tr_y = g->nodes()->create<luci::CircleTranspose>();
106 std::vector<int32_t> tr_val = {1, 0};
107 _tr_y->perm(create_const_node(g, loco::DataType::S32, {2}, tr_val));
109 _fc = g->nodes()->create<luci::CircleFullyConnected>();
112 _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
113 _fc->dtype(loco::DataType::FLOAT32);
115 auto l = _fc->dim(_fc->rank() - 1).value();
116 std::vector<float> bias_val(l, bv);
117 _fc->bias(create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
122 luci::CircleFullyConnected *fc() { return _fc; }
125 luci::CircleFullyConnected *_fc = nullptr;
126 luci::CircleTranspose *_tr_y = nullptr;
127 luci::CircleInput *_x = nullptr;
128 luci::CircleInput *_y = nullptr;
131 struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphlet
134 virtual ~FCGraph() = default;
135 void init(const ShapeU32 x_shape, const ShapeU32 y_shape, const ShapeU32 r_shape, const float bv)
137 TestIsGraphlet<2>::init(g(), {x_shape, y_shape});
138 TestOGraphlet::init(g(), r_shape);
141 FCGraphlet::init(g(), r_shape, bv);
146 class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test
150 luci::ReplaceNonConstFCWithBatchMatMulPass pass;
155 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test)
157 g.init({2, 3}, {2, 3}, {2, 2}, 0.0f);
159 auto ret = pass.run(g.g());
160 EXPECT_EQ(true, ret);
162 auto res = dynamic_cast<luci::CircleReshape *>(g.output()->from());
163 EXPECT_NE(nullptr, res);
166 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test)
168 g.init({2, 3}, {2, 3}, {2, 2}, 1.0f);
170 auto ret = pass.run(g.g());
171 EXPECT_EQ(true, ret);
173 auto mm = dynamic_cast<luci::CircleAdd *>(g.output()->from());
174 EXPECT_NE(nullptr, mm);
177 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, wrong_op_NEG)
181 auto inp = g.nodes()->create<luci::CircleInput>();
182 auto relu = g.nodes()->create<luci::CircleRelu>();
185 luci::ReplaceNonConstFCWithBatchMatMulPass pass;
186 auto changed = pass.run(&g);
188 EXPECT_EQ(false, changed);