2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FoldDepthwiseConv2DPass.h"
18 #include "PassTestGraphs.h"
20 #include <luci/IR/CircleNodes.h>
22 #include <gtest/gtest.h>
28 * Graph has an DepthwiseConv2D Op with constant inputs
32 * [CircleConst] [CircleConst]
34 * [CircleDepthwiseConv2D]
40 class FoldDepthwiseConv2DTest : public luci::ConstantFoldingTestGraph, public ::testing::Test
43 FoldDepthwiseConv2DTest() : luci::ConstantFoldingTestGraph({1, 4, 4, 1}, loco::DataType::FLOAT32)
45 _dconv = _g.nodes()->create<luci::CircleDepthwiseConv2D>();
46 _dconv_input = _g.nodes()->create<luci::CircleConst>();
47 _dconv_filter = _g.nodes()->create<luci::CircleConst>();
48 _dconv_bias = _g.nodes()->create<luci::CircleConst>();
50 _dconv->dtype(loco::DataType::FLOAT32);
51 _dconv->padding(luci::Padding::VALID);
52 _dconv->fusedActivationFunction(luci::FusedActFunc::NONE);
53 _dconv->input(_dconv_input);
54 _dconv->filter(_dconv_filter);
55 _dconv->bias(_dconv_bias);
56 _dconv->shape({1, 4, 4, 1});
57 _dconv->stride()->h(1);
58 _dconv->stride()->w(1);
59 _dconv->depthMultiplier(1);
61 _dconv_input->dtype(loco::DataType::FLOAT32);
62 _dconv_input->shape({1, 4, 4, 1});
63 _dconv_input->size<loco::DataType::FLOAT32>(16);
65 _dconv_filter->dtype(loco::DataType::FLOAT32);
66 _dconv_filter->shape({1, 1, 1, 1});
67 _dconv_filter->size<loco::DataType::FLOAT32>(1);
69 _dconv_bias->dtype(loco::DataType::FLOAT32);
70 _dconv_bias->shape({1});
71 _dconv_bias->size<loco::DataType::FLOAT32>(1);
73 _output->from(_dconv);
80 loco::Node *createFoldedPattern() final { return nullptr; }
83 luci::CircleConst *getFoldedPattern() final
85 return loco::must_cast<luci::CircleConst *>(_output->from());
89 luci::CircleDepthwiseConv2D *_dconv = nullptr;
90 luci::CircleConst *_dconv_input = nullptr;
91 luci::CircleConst *_dconv_filter = nullptr;
92 luci::CircleConst *_dconv_bias = nullptr;
97 TEST(FoldDepthwiseConv2DPass, name)
99 luci::FoldDepthwiseConv2DPass pass;
100 auto const name = pass.name();
101 ASSERT_NE(nullptr, name);
104 TEST_F(FoldDepthwiseConv2DTest, fold_depthwise_conv2d)
106 for (uint32_t i = 0; i < 16; ++i)
107 _dconv_input->at<loco::DataType::FLOAT32>(i) = 0.5;
108 _dconv_filter->at<loco::DataType::FLOAT32>(0) = 0.5;
110 luci::FoldDepthwiseConv2DPass pass;
111 ASSERT_TRUE(pass.run(&_g));
113 auto folded_const = getFoldedPattern();
114 EXPECT_EQ(folded_const->dtype(), loco::DataType::FLOAT32);
115 EXPECT_NEAR(folded_const->at<loco::DataType::FLOAT32>(0), 0.25,
116 std::numeric_limits<float>::min());
117 EXPECT_NEAR(folded_const->at<loco::DataType::FLOAT32>(15), 0.25,
118 std::numeric_limits<float>::min());
121 TEST_F(FoldDepthwiseConv2DTest, fold_non_constant_NEG)
123 _dconv->input(_input);
125 luci::FoldDepthwiseConv2DPass pass;
126 ASSERT_FALSE(pass.run(&_g));