2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseActivationFunctionPass.h"
19 #include <luci/IR/CircleNodes.h>
21 #include <luci/test/TestIOGraph.h>
23 #include <gtest/gtest.h>
28 using namespace luci::test;
31 * Simple graph for test
43 * [Conv1 + Activation func]
48 class ConvReluConvGraphlet
51 ConvReluConvGraphlet() = default;
53 void init(loco::Graph *g)
55 _conv1 = g->nodes()->create<luci::CircleConv2D>();
56 _conv2 = g->nodes()->create<luci::CircleConv2D>();
57 _relu = g->nodes()->create<luci::CircleRelu>();
58 _conv1_f = g->nodes()->create<luci::CircleConst>();
59 _conv1_b = g->nodes()->create<luci::CircleConst>();
60 _conv2_f = g->nodes()->create<luci::CircleConst>();
61 _conv2_b = g->nodes()->create<luci::CircleConst>();
63 _conv1->fusedActivationFunction(luci::FusedActFunc::NONE);
65 _conv1->name("conv1");
66 _conv2->name("conv2");
68 _conv1_f->name("conv1f");
69 _conv1_b->name("conv1b");
70 _conv2_f->name("conv2f");
71 _conv2_b->name("conv2b");
75 luci::CircleRelu *relu() { return _relu; }
76 luci::CircleConv2D *conv1() { return _conv1; }
77 luci::CircleConv2D *conv2() { return _conv2; }
80 luci::CircleConv2D *_conv1 = nullptr;
81 luci::CircleConv2D *_conv2 = nullptr;
82 luci::CircleRelu *_relu = nullptr;
83 luci::CircleConst *_conv1_f = nullptr;
84 luci::CircleConst *_conv1_b = nullptr;
85 luci::CircleConst *_conv2_f = nullptr;
86 luci::CircleConst *_conv2_b = nullptr;
89 class ConvTanhConvGraphlet
92 ConvTanhConvGraphlet() = default;
94 void init(loco::Graph *g)
96 _conv1 = g->nodes()->create<luci::CircleConv2D>();
97 _conv2 = g->nodes()->create<luci::CircleConv2D>();
98 _tanh = g->nodes()->create<luci::CircleTanh>();
99 _conv1_f = g->nodes()->create<luci::CircleConst>();
100 _conv1_b = g->nodes()->create<luci::CircleConst>();
101 _conv2_f = g->nodes()->create<luci::CircleConst>();
102 _conv2_b = g->nodes()->create<luci::CircleConst>();
104 _conv1->fusedActivationFunction(luci::FusedActFunc::NONE);
106 _conv1->name("conv1");
107 _conv2->name("conv2");
109 _conv1_f->name("conv1f");
110 _conv1_b->name("conv1b");
111 _conv2_f->name("conv2f");
112 _conv2_b->name("conv2b");
116 luci::CircleTanh *tanh() { return _tanh; }
117 luci::CircleConv2D *conv1() { return _conv1; }
118 luci::CircleConv2D *conv2() { return _conv2; }
121 luci::CircleConv2D *_conv1 = nullptr;
122 luci::CircleConv2D *_conv2 = nullptr;
123 luci::CircleTanh *_tanh = nullptr;
124 luci::CircleConst *_conv1_f = nullptr;
125 luci::CircleConst *_conv1_b = nullptr;
126 luci::CircleConst *_conv2_f = nullptr;
127 luci::CircleConst *_conv2_b = nullptr;
130 class FuseActTestGraph : public TestIOGraph, public ConvReluConvGraphlet
133 FuseActTestGraph() = default;
137 TestIOGraph::init({1}, {1});
138 ConvReluConvGraphlet::init(g());
140 _conv1->input(input());
141 _conv1->filter(_conv1_f);
142 _conv1->bias(_conv1_b);
144 _relu->features(_conv1);
146 _conv2->input(_relu);
147 _conv2->filter(_conv2_f);
148 _conv2->bias(_conv2_b);
150 output()->from(_conv2);
154 class FuseTanhActTestGraph : public TestIOGraph, public ConvTanhConvGraphlet
157 FuseTanhActTestGraph() = default;
161 TestIOGraph::init({1}, {1});
162 ConvTanhConvGraphlet::init(g());
164 _conv1->input(input());
165 _conv1->filter(_conv1_f);
166 _conv1->bias(_conv1_b);
170 _conv2->input(_tanh);
171 _conv2->filter(_conv2_f);
172 _conv2->bias(_conv2_b);
174 output()->from(_conv2);
178 class ConvHasMultiSuccGraph : public TestIOGraph, public ConvReluConvGraphlet
181 ConvHasMultiSuccGraph() = default;
185 TestIOGraph::init({1}, {1});
186 ConvReluConvGraphlet::init(g());
188 _conv1->input(input());
189 _conv1->filter(_conv1_f);
190 _conv1->bias(_conv1_b);
192 _relu->features(_conv1);
194 _conv2->input(_conv1);
195 _conv2->filter(_conv2_f);
196 _conv2->bias(_conv2_b);
198 output()->from(_relu); // We need to check from relu
202 // TODO use ::testing::Test
206 TEST(FuseActivationFunctionPassTest, name)
208 luci::FuseActivationFunctionPass pass;
209 auto const name = pass.name();
210 ASSERT_NE(nullptr, name);
213 TEST(FusePreActivationBatchNorm, fuse_activation_function)
216 luci::FuseActivationFunctionPass pass;
220 EXPECT_TRUE(pass.run(g.g()));
221 EXPECT_EQ(g.conv1(), g.conv2()->input());
224 TEST(FusePreActivationBatchNorm, fuse_activation_function_dup_relu)
227 luci::FuseActivationFunctionPass pass;
230 g.conv1()->fusedActivationFunction(luci::FusedActFunc::RELU);
232 EXPECT_TRUE(pass.run(g.g()));
233 EXPECT_EQ(g.conv1(), g.conv2()->input());
236 TEST(FusePreActivationBatchNorm, fuse_activation_function_mulsucc_NEG)
238 ConvHasMultiSuccGraph g;
239 luci::FuseActivationFunctionPass pass;
243 // Relu input Conv2D has multiple successors
244 EXPECT_FALSE(pass.run(g.g()));
247 TEST(FusePreActivationBatchNorm, fuse_activation_function_tanh_NEG)
250 luci::FuseActivationFunctionPass pass;
253 g.conv1()->fusedActivationFunction(luci::FusedActFunc::TANH);
255 // Relu input Conv2D already has activation function
256 EXPECT_FALSE(pass.run(g.g()));
259 TEST(FusePreActivationBatchNorm, fuse_tanh_NEG)
261 FuseTanhActTestGraph g;
262 luci::FuseActivationFunctionPass pass;
266 // Tanh should not be fused
267 // This can be changed when CONV+TANH is supported by luci-interpreter
268 EXPECT_FALSE(pass.run(g.g()));