2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
19 #include <luci/IR/CircleNodes.h>
21 #include <gtest/gtest.h>
27 * Simple graph for test
39 * [Node] [weights] [bias]
48 input = g.nodes()->create<luci::CircleInput>();
49 mul = g.nodes()->create<luci::CircleMul>();
50 gamma = g.nodes()->create<luci::CircleConst>();
51 add = g.nodes()->create<luci::CircleAdd>();
52 beta = g.nodes()->create<luci::CircleConst>();
53 output = g.nodes()->create<luci::CircleOutput>();
55 auto graph_input = g.inputs()->create();
56 input->index(graph_input->index());
57 auto graph_output = g.outputs()->create();
58 output->index(graph_output->index());
60 input->dtype(loco::DataType::FLOAT32);
61 mul->dtype(loco::DataType::FLOAT32);
62 gamma->dtype(loco::DataType::FLOAT32);
63 add->dtype(loco::DataType::FLOAT32);
64 beta->dtype(loco::DataType::FLOAT32);
65 output->dtype(loco::DataType::FLOAT32);
67 uint32_t channel_size = 16;
68 input->shape({1, 4, 4, channel_size});
69 mul->shape({1, 4, 4, channel_size});
70 gamma->shape({channel_size});
71 add->shape({1, 4, 4, channel_size});
72 beta->shape({channel_size});
73 output->shape({1, 4, 4, channel_size});
75 gamma->size<loco::DataType::FLOAT32>(channel_size);
76 beta->size<loco::DataType::FLOAT32>(channel_size);
77 for (uint32_t i = 0; i < channel_size; i++)
79 gamma->at<loco::DataType::FLOAT32>(i) = i;
80 beta->at<loco::DataType::FLOAT32>(i) = i;
92 luci::CircleInput *input = nullptr;
93 luci::CircleMul *mul = nullptr;
94 luci::CircleConst *gamma = nullptr;
95 luci::CircleAdd *add = nullptr;
96 luci::CircleConst *beta = nullptr;
97 luci::CircleOutput *output = nullptr;
102 TEST(ReplaceMulAddWithDepthwiseConv, simple)
106 luci::ReplaceMulAddWithDepthwiseConvPass pass;
107 while (pass.run(&g.g))
110 auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from());
111 EXPECT_NE(nullptr, dwconv);
113 uint32_t channel_size = 16;
114 auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter());
115 auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
116 EXPECT_NE(nullptr, weights);
117 EXPECT_EQ(4, weights->rank());
118 EXPECT_EQ(channel_size, weights->dim(3).value());
119 EXPECT_NE(nullptr, bias);
120 EXPECT_EQ(1, bias->rank());
121 EXPECT_EQ(channel_size, bias->dim(0).value());
123 for (int i = 0; i < channel_size; i++)
125 EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i));
126 EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
130 TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
133 // swap mul/add (changed to add->mul)
135 loco::replace(g.add).with(g.mul);
138 luci::ReplaceMulAddWithDepthwiseConvPass pass;
139 auto changed = pass.run(&g.g);
141 EXPECT_EQ(false, changed);