2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ExpandBroadcastConstPass.h"
18 #include "PassTestGraphs.h"
20 #include <luci/IR/CircleNodes.h>
22 #include <gtest/gtest.h>
27 class ExpandBroadcastConstTest : public ::testing::Test
30 ExpandBroadcastConstTest()
32 _x = _g.nodes()->create<luci::CircleInput>();
33 _y = _g.nodes()->create<luci::CircleConst>();
34 _add = _g.nodes()->create<luci::CircleAdd>();
35 _output = _g.nodes()->create<luci::CircleOutput>();
37 auto graph_input = _g.inputs()->create();
38 graph_input->dtype(loco::DataType::FLOAT32);
39 graph_input->shape({1, H, W, D});
40 _x->index(graph_input->index());
41 _x->dtype(graph_input->dtype());
42 _x->shape({1, H, W, D});
44 auto graph_output = _g.outputs()->create();
45 graph_output->dtype(loco::DataType::FLOAT32);
46 graph_output->shape({1, H, W, D});
47 _output->index(graph_output->index());
48 _output->dtype(graph_output->dtype());
49 _output->shape({1, H, W, D});
51 _y->dtype(loco::DataType::FLOAT32);
52 _y->shape({1, H, W, 1});
53 _y->size<loco::DataType::FLOAT32>(16);
55 _add->dtype(loco::DataType::FLOAT32);
56 _add->fusedActivationFunction(luci::FusedActFunc::NONE);
59 _add->shape({1, H, W, D});
64 _output->name("output");
74 luci::CircleAdd *_add = nullptr;
75 luci::CircleInput *_x = nullptr;
76 luci::CircleConst *_y = nullptr;
77 luci::CircleOutput *_output = nullptr;
82 TEST_F(ExpandBroadcastConstTest, name)
84 luci::ExpandBroadcastConstPass pass;
85 auto const name = pass.name();
86 ASSERT_NE(nullptr, name);
89 TEST_F(ExpandBroadcastConstTest, remove_broadcast)
91 for (uint32_t i = 0; i < H * W; ++i)
92 _y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);
94 luci::ExpandBroadcastConstPass pass;
95 ASSERT_TRUE(pass.run(&_g));
97 auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
98 ASSERT_NE(broadcasted_const, nullptr);
100 EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
101 EXPECT_EQ(broadcasted_const->dim(1).value(), H);
102 EXPECT_EQ(broadcasted_const->dim(2).value(), W);
103 EXPECT_EQ(broadcasted_const->dim(3).value(), D);
104 EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
106 for (uint32_t i = 0; i < H * W; ++i)
108 for (uint32_t d = 0; d < D; ++d)
110 EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i + H * W * d),
111 static_cast<float>(i), std::numeric_limits<float>::min());
116 TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors)
118 auto const circle_sqrt = _g.nodes()->create<luci::CircleSqrt>();
119 circle_sqrt->dtype(loco::DataType::FLOAT32);
120 circle_sqrt->shape({1, H, W, 1});
123 luci::ExpandBroadcastConstPass pass;
124 ASSERT_TRUE(pass.run(&_g));
126 auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
127 auto original_const = dynamic_cast<luci::CircleConst *>(circle_sqrt->x());
129 ASSERT_NE(broadcasted_const, nullptr);
130 EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
131 EXPECT_EQ(broadcasted_const->dim(3).value(), D);
132 EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
134 // Check if another successor's node was left intact
135 ASSERT_NE(original_const, nullptr);
136 EXPECT_EQ(original_const->dtype(), loco::DataType::FLOAT32);
137 EXPECT_EQ(original_const->dim(3).value(), 1);
138 EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), H * W * 1);
141 TEST_F(ExpandBroadcastConstTest, broadcast_impossible_NEG)
143 _y->shape({1, H, W, 2});
144 _y->size<loco::DataType::FLOAT32>(H * W * (D - 1));
146 luci::ExpandBroadcastConstPass pass;
147 ASSERT_FALSE(pass.run(&_g));