Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ExpandBroadcastConstPass.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ExpandBroadcastConstPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Log.h>
21
22 #include <type_traits>
23
24 namespace
25 {
26
27 luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::CircleNode *successor)
28 {
29   LOGGER(l);
30
31   if (successor->rank() != node->rank())
32     return nullptr;
33
34   std::vector<uint32_t> broadcast_dims;
35   for (uint32_t dim = 0; dim < node->rank(); ++dim)
36   {
37     if (node->dim(dim) == successor->dim(dim))
38       continue;
39
40     if (node->dim(dim) == 1)
41       broadcast_dims.push_back(dim);
42   }
43
44   if (broadcast_dims.size() != 1 || broadcast_dims.back() != node->rank() - 1)
45   {
46     WARN(l) << "NYI: Only depth broadcast removal is supported";
47     return nullptr;
48   }
49
50   auto constant = node->graph()->nodes()->create<luci::CircleConst>();
51   constant->name(node->name());
52   constant->dtype(node->dtype());
53   constant->rank(node->rank());
54   constant->shape_status(luci::ShapeStatus::VALID);
55
56   uint32_t node_size = node->size<loco::DataType::FLOAT32>();
57   uint32_t constant_size = 1;
58   for (uint32_t i = 0; i < successor->rank(); ++i)
59   {
60     constant->dim(i).set(successor->dim(i).value());
61     constant_size *= constant->dim(i).value();
62   }
63   constant->size<loco::DataType::FLOAT32>(constant_size);
64
65   auto const node_data = &node->at<loco::DataType::FLOAT32>(0);
66   auto const constant_data = &constant->at<loco::DataType::FLOAT32>(0);
67
68   auto const successor_depth = successor->dim(successor->rank() - 1).value();
69   for (uint32_t d = 0; d < successor_depth; ++d)
70     std::copy(node_data, node_data + node_size, constant_data + d * node_size);
71
72   return constant;
73 }
74
75 template <typename N> bool expand_node_input(luci::CircleConst *node, luci::CircleNode *successor)
76 {
77   static_assert(std::is_base_of<luci::CircleNode, N>::value,
78                 "Successor node should have CircleNode base");
79
80   auto const successor_node = loco::must_cast<N *>(successor);
81   auto const successor_x = loco::must_cast<luci::CircleNode *>(successor_node->x());
82   auto const successor_y = loco::must_cast<luci::CircleNode *>(successor_node->y());
83
84   luci::CircleConst *expanded_const;
85
86   if (node == successor_x)
87   {
88     expanded_const = create_expanded_constant(node, successor_y);
89
90     if (expanded_const == nullptr)
91       return false;
92
93     successor_node->x(expanded_const);
94   }
95   else if (node == successor_y)
96   {
97     expanded_const = create_expanded_constant(node, successor_x);
98
99     if (expanded_const == nullptr)
100       return false;
101
102     successor_node->y(expanded_const);
103   }
104
105   return true;
106 }
107
108 /**
109  * Expand constants following broadcasting rules for binary input nodes (Add, Mul, etc.)
110  *
111  *    BEFORE
112  *
113  *    [CircleInput] [CircleConst (H x W x 1)]
114  *               |     |
115  *             [CircleAdd]
116  *
117  *    AFTER
118  *
119  *    [CircleInput] [CircleConst (H x W x D)]
120  *               |     |
121  *             [CircleAdd]
122  */
123 bool expand_broadcast_const(luci::CircleConst *node)
124 {
125   if (node->dtype() != loco::DataType::FLOAT32)
126     return false; // Unsupported data type
127
128   bool changed = false;
129
130   for (auto successor : loco::succs(node))
131   {
132     auto const circle_successor = loco::must_cast<luci::CircleNode *>(successor);
133     switch (circle_successor->opcode())
134     {
135       case luci::CircleOpcode::ADD:
136         if (expand_node_input<luci::CircleAdd>(node, circle_successor))
137           changed = true;
138         break;
139       case luci::CircleOpcode::MUL:
140         if (expand_node_input<luci::CircleMul>(node, circle_successor))
141           changed = true;
142         break;
143       case luci::CircleOpcode::DIV:
144         if (expand_node_input<luci::CircleDiv>(node, circle_successor))
145           changed = true;
146         break;
147       default:
148         break; // Unsupported successor node
149     }
150   }
151
152   return changed;
153 }
154
155 } // namespace
156
157 namespace luci
158 {
159
160 /**
161  * Broadcast expanding for Const nodes
162  **/
163 bool ExpandBroadcastConstPass::run(loco::Graph *g)
164 {
165   bool changed = false;
166   for (auto node : loco::active_nodes(loco::output_nodes(g)))
167   {
168     auto const_node = dynamic_cast<luci::CircleConst *>(node);
169     if (const_node == nullptr)
170       continue;
171
172     if (expand_broadcast_const(const_node))
173       changed = true;
174   }
175   return changed;
176 }
177
178 } // namespace luci