2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ExpandBroadcastConstPass.h"
19 #include <luci/IR/CircleNodes.h>
22 #include <type_traits>
27 luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::CircleNode *successor)
31 if (successor->rank() != node->rank())
34 std::vector<uint32_t> broadcast_dims;
35 for (uint32_t dim = 0; dim < node->rank(); ++dim)
37 if (node->dim(dim) == successor->dim(dim))
40 if (node->dim(dim) == 1)
41 broadcast_dims.push_back(dim);
44 if (broadcast_dims.size() != 1 || broadcast_dims.back() != node->rank() - 1)
46 WARN(l) << "NYI: Only depth broadcast removal is supported";
50 auto constant = node->graph()->nodes()->create<luci::CircleConst>();
51 constant->name(node->name());
52 constant->dtype(node->dtype());
53 constant->rank(node->rank());
54 constant->shape_status(luci::ShapeStatus::VALID);
56 uint32_t node_size = node->size<loco::DataType::FLOAT32>();
57 uint32_t constant_size = 1;
58 for (uint32_t i = 0; i < successor->rank(); ++i)
60 constant->dim(i).set(successor->dim(i).value());
61 constant_size *= constant->dim(i).value();
63 constant->size<loco::DataType::FLOAT32>(constant_size);
65 auto const node_data = &node->at<loco::DataType::FLOAT32>(0);
66 auto const constant_data = &constant->at<loco::DataType::FLOAT32>(0);
68 auto const successor_depth = successor->dim(successor->rank() - 1).value();
69 for (uint32_t d = 0; d < successor_depth; ++d)
70 std::copy(node_data, node_data + node_size, constant_data + d * node_size);
75 template <typename N> bool expand_node_input(luci::CircleConst *node, luci::CircleNode *successor)
77 static_assert(std::is_base_of<luci::CircleNode, N>::value,
78 "Successor node should have CircleNode base");
80 auto const successor_node = loco::must_cast<N *>(successor);
81 auto const successor_x = loco::must_cast<luci::CircleNode *>(successor_node->x());
82 auto const successor_y = loco::must_cast<luci::CircleNode *>(successor_node->y());
84 luci::CircleConst *expanded_const;
86 if (node == successor_x)
88 expanded_const = create_expanded_constant(node, successor_y);
90 if (expanded_const == nullptr)
93 successor_node->x(expanded_const);
95 else if (node == successor_y)
97 expanded_const = create_expanded_constant(node, successor_x);
99 if (expanded_const == nullptr)
102 successor_node->y(expanded_const);
109 * Expand constants following broadcasting rules for binary input nodes (Add, Mul, etc.)
113 * [CircleInput] [CircleConst (H x W x 1)]
119 * [CircleInput] [CircleConst (H x W x D)]
123 bool expand_broadcast_const(luci::CircleConst *node)
125 if (node->dtype() != loco::DataType::FLOAT32)
126 return false; // Unsupported data type
128 bool changed = false;
130 for (auto successor : loco::succs(node))
132 auto const circle_successor = loco::must_cast<luci::CircleNode *>(successor);
133 switch (circle_successor->opcode())
135 case luci::CircleOpcode::ADD:
136 if (expand_node_input<luci::CircleAdd>(node, circle_successor))
139 case luci::CircleOpcode::MUL:
140 if (expand_node_input<luci::CircleMul>(node, circle_successor))
143 case luci::CircleOpcode::DIV:
144 if (expand_node_input<luci::CircleDiv>(node, circle_successor))
148 break; // Unsupported successor node
161 * Broadcast expanding for Const nodes
163 bool ExpandBroadcastConstPass::run(loco::Graph *g)
165 bool changed = false;
166 for (auto node : loco::active_nodes(loco::output_nodes(g)))
168 auto const_node = dynamic_cast<luci::CircleConst *>(node);
169 if (const_node == nullptr)
172 if (expand_broadcast_const(const_node))