Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ExpandBroadcastConstPass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ExpandBroadcastConstPass.h"
18 #include "PassTestGraphs.h"
19
20 #include <luci/IR/CircleNodes.h>
21
22 #include <gtest/gtest.h>
23
24 namespace
25 {
26
27 class ExpandBroadcastConstTest : public ::testing::Test
28 {
29 public:
30   ExpandBroadcastConstTest()
31   {
32     _x = _g.nodes()->create<luci::CircleInput>();
33     _y = _g.nodes()->create<luci::CircleConst>();
34     _add = _g.nodes()->create<luci::CircleAdd>();
35     _output = _g.nodes()->create<luci::CircleOutput>();
36
37     auto graph_input = _g.inputs()->create();
38     graph_input->dtype(loco::DataType::FLOAT32);
39     graph_input->shape({1, H, W, D});
40     _x->index(graph_input->index());
41     _x->dtype(graph_input->dtype());
42     _x->shape({1, H, W, D});
43
44     auto graph_output = _g.outputs()->create();
45     graph_output->dtype(loco::DataType::FLOAT32);
46     graph_output->shape({1, H, W, D});
47     _output->index(graph_output->index());
48     _output->dtype(graph_output->dtype());
49     _output->shape({1, H, W, D});
50
51     _y->dtype(loco::DataType::FLOAT32);
52     _y->shape({1, H, W, 1});
53     _y->size<loco::DataType::FLOAT32>(16);
54
55     _add->dtype(loco::DataType::FLOAT32);
56     _add->fusedActivationFunction(luci::FusedActFunc::NONE);
57     _add->x(_x);
58     _add->y(_y);
59     _add->shape({1, H, W, D});
60
61     _output->from(_add);
62
63     _x->name("input");
64     _output->name("output");
65   }
66
67 protected:
68   uint32_t const H = 4;
69   uint32_t const W = 4;
70   uint32_t const D = 3;
71
72 protected:
73   loco::Graph _g;
74   luci::CircleAdd *_add = nullptr;
75   luci::CircleInput *_x = nullptr;
76   luci::CircleConst *_y = nullptr;
77   luci::CircleOutput *_output = nullptr;
78 };
79
80 } // namespace
81
82 TEST_F(ExpandBroadcastConstTest, name)
83 {
84   luci::ExpandBroadcastConstPass pass;
85   auto const name = pass.name();
86   ASSERT_NE(nullptr, name);
87 }
88
89 TEST_F(ExpandBroadcastConstTest, remove_broadcast)
90 {
91   for (uint32_t i = 0; i < H * W; ++i)
92     _y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);
93
94   luci::ExpandBroadcastConstPass pass;
95   ASSERT_TRUE(pass.run(&_g));
96
97   auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
98   ASSERT_NE(broadcasted_const, nullptr);
99
100   EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
101   EXPECT_EQ(broadcasted_const->dim(1).value(), H);
102   EXPECT_EQ(broadcasted_const->dim(2).value(), W);
103   EXPECT_EQ(broadcasted_const->dim(3).value(), D);
104   EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
105
106   for (uint32_t i = 0; i < H * W; ++i)
107   {
108     for (uint32_t d = 0; d < D; ++d)
109     {
110       EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i + H * W * d),
111                   static_cast<float>(i), std::numeric_limits<float>::min());
112     }
113   }
114 }
115
116 TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors)
117 {
118   auto const circle_sqrt = _g.nodes()->create<luci::CircleSqrt>();
119   circle_sqrt->dtype(loco::DataType::FLOAT32);
120   circle_sqrt->shape({1, H, W, 1});
121   circle_sqrt->x(_y);
122
123   luci::ExpandBroadcastConstPass pass;
124   ASSERT_TRUE(pass.run(&_g));
125
126   auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
127   auto original_const = dynamic_cast<luci::CircleConst *>(circle_sqrt->x());
128
129   ASSERT_NE(broadcasted_const, nullptr);
130   EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
131   EXPECT_EQ(broadcasted_const->dim(3).value(), D);
132   EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
133
134   // Check if another successor's node was left intact
135   ASSERT_NE(original_const, nullptr);
136   EXPECT_EQ(original_const->dtype(), loco::DataType::FLOAT32);
137   EXPECT_EQ(original_const->dim(3).value(), 1);
138   EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), H * W * 1);
139 }
140
141 TEST_F(ExpandBroadcastConstTest, broadcast_impossible_NEG)
142 {
143   _y->shape({1, H, W, 2});
144   _y->size<loco::DataType::FLOAT32>(H * W * (D - 1));
145
146   luci::ExpandBroadcastConstPass pass;
147   ASSERT_FALSE(pass.run(&_g));
148 }