2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/SubstituteSqueezeToReshapePass.h"
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Profile/CircleNodeOrigin.h>
26 * @brief return TRUE if all dim is known
27 * @note This pass can be applied even some of dimensions are unknown.
28 For now, do not consider about it and update logic later.
30 bool can_squeeze_shape(const luci::CircleNode *node)
32 for (uint32_t r = 0; r < node->rank(); ++r)
34 if (not node->dim(r).known())
41 * @brief return valid unsigned dim value from 0 ~ (rank-1)
42 * @note dim can be -rank to (rank-1)
44 uint32_t valid_unsigned_dim(uint32_t rank, int32_t dim)
46 int32_t irank = static_cast<int32_t>(rank);
47 return dim >= 0 ? static_cast<uint32_t>(dim) : static_cast<uint32_t>(irank + dim);
51 * @brief return TRUE if input dim is 1 for squeeze_dims values
53 bool is_valid_input(const luci::CircleNode *node, const std::vector<int32_t> &squeeze_dims)
55 auto rank = node->rank();
56 for (auto dim : squeeze_dims)
58 auto udim = valid_unsigned_dim(rank, dim);
59 if (node->dim(udim).value() != 1)
66 * @brief return shape vector from input
68 std::vector<uint32_t> node_shape(const luci::CircleNode *input)
70 std::vector<uint32_t> shape;
71 uint32_t rank = input->rank();
72 for (uint32_t r = 0; r < rank; ++r)
73 shape.push_back(input->dim(r).value());
79 * @brief copy quantparam of src to dst
81 void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src)
83 auto q = src->quantparam();
85 dst->quantparam(nullptr);
87 dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q));
91 * @brief return CircleConst ptr with values of new_shape
93 luci::CircleConst *create_shape_const(loco::Graph *graph, const std::vector<uint32_t> &new_shape)
95 // NOTE dim_size can be 0
96 uint32_t dim_size = static_cast<uint32_t>(new_shape.size());
98 auto shape_const = graph->nodes()->create<luci::CircleConst>();
101 shape_const->dtype(loco::DataType::S32);
104 shape_const->rank(1);
105 shape_const->dim(0).set(dim_size);
108 shape_const->rank(0);
109 shape_const->shape_status(luci::ShapeStatus::VALID);
112 shape_const->size<loco::DataType::S32>(dim_size);
113 for (uint32_t i = 0; i < dim_size; ++i)
114 shape_const->at<loco::DataType::S32>(i) = new_shape.at(i);
119 bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze)
121 assert(squeeze != nullptr);
123 auto input = loco::must_cast<luci::CircleNode *>(squeeze->input());
124 // we need input node shape and all dim should be known
125 if (input->shape_status() != luci::ShapeStatus::VALID)
127 if (not can_squeeze_shape(input))
130 // we will use squeeze shape for new shape
131 if (squeeze->shape_status() != luci::ShapeStatus::VALID)
134 auto squeeze_dims = squeeze->squeeze_dims();
135 if (not is_valid_input(input, squeeze_dims))
136 throw std::runtime_error("Invalid values in squeeze_dims: " + squeeze->name());
138 auto name = squeeze->name();
139 assert(name.length() > 0);
141 auto reshape_shape = node_shape(squeeze);
142 auto graph = squeeze->graph();
143 auto reshape = graph->nodes()->create<luci::CircleReshape>();
144 auto shape_const = create_shape_const(graph, reshape_shape);
145 copy_quantparam(reshape, squeeze);
146 reshape->name(name + "/Reshape");
147 luci::add_origin(reshape, luci::get_origin(squeeze));
148 shape_const->name(name + "/Reshape/shape");
151 reshape->tensor(input);
152 reshape->shape(shape_const);
153 replace(squeeze).with(reshape);
175 * [CircleNode] [CircleConst]
177 * [CircleSqueeze] [CircleReshape]
182 bool SubstituteSqueezeToReshapePass::run(loco::Graph *g)
184 bool changed = false;
185 for (auto node : loco::active_nodes(loco::output_nodes(g)))
187 if (auto squeeze = dynamic_cast<luci::CircleSqueeze *>(node))
189 if (substitute_squeeze_to_reshape(squeeze))