2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/QuantizeWeightsPass.h"
18 #include <luci/IR/CircleNodes.h>
20 #include <gtest/gtest.h>
24 struct QuantizeWeightsPassTest : public ::testing::Test
42 const int C = 3; // IC = OC
44 // graph input and output
45 auto graph_input = _g.inputs()->create();
46 auto graph_output = _g.outputs()->create();
49 auto input = _g.nodes()->create<luci::CircleInput>();
50 input->index(graph_input->index());
51 input->shape({N, H, W, C});
52 input->dtype(loco::DataType::FLOAT32);
56 auto conv = _g.nodes()->create<luci::CircleConv2D>();
58 auto bias = _g.nodes()->create<luci::CircleConst>();
59 bias->dtype(loco::DataType::FLOAT32);
61 bias->name("conv_bias");
63 auto weight = _g.nodes()->create<luci::CircleConst>();
64 weight->dtype(loco::DataType::FLOAT32);
65 weight->shape({C, H, W, C});
66 weight->size<loco::DataType::FLOAT32>(C * H * W * C);
68 conv->padding(luci::Padding::SAME);
69 conv->fusedActivationFunction(luci::FusedActFunc::NONE);
70 conv->dtype(loco::DataType::FLOAT32);
74 auto output = _g.nodes()->create<luci::CircleOutput>();
75 output->index(graph_output->index());
77 output->shape({N, H, W, C});
78 output->dtype(loco::DataType::FLOAT32);
79 output->name("output");
81 virtual void SetUp() { MakeGraph(); }
87 TEST_F(QuantizeWeightsPassTest, name)
89 luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::S8,
90 luci::QuantizationGranularity::ChannelWise);
91 auto const name = pass.name();
92 ASSERT_NE(nullptr, name);
95 TEST_F(QuantizeWeightsPassTest, name_ctx)
97 auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>();
99 ctx->input_model_dtype = loco::DataType::FLOAT32;
100 ctx->output_model_dtype = loco::DataType::S8;
101 ctx->granularity = luci::QuantizationGranularity::ChannelWise;
104 luci::QuantizeWeightsPass pass(std::move(ctx));
105 auto const name = pass.name();
106 ASSERT_NE(nullptr, name);
109 TEST_F(QuantizeWeightsPassTest, run_input_U8_NEG)
112 luci::QuantizeWeightsPass pass(loco::DataType::U8, loco::DataType::S8,
113 luci::QuantizationGranularity::ChannelWise);
114 EXPECT_THROW(pass.run(&_g), std::runtime_error);
117 TEST_F(QuantizeWeightsPassTest, run_output_f32_NEG)
120 luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::FLOAT32,
121 luci::QuantizationGranularity::ChannelWise);
122 EXPECT_THROW(pass.run(&_g), std::runtime_error);