2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "moco/Pass/Passes/ConstantFoldStridedSlice.h"
19 #include "ConstantFoldHelper.h"
20 #include "TensorSliceEnumerator.h"
22 #include <moco/IR/Nodes/TFStridedSlice.h>
23 #include <moco/IR/Nodes/TFConst.h>
25 #include <moco/Support/NodeAs.h>
26 #include <moco/Support/TFShapeInferenceHelper.h>
28 #include <oops/UserExn.h>
36 loco::TensorShape calc_output_shape(moco::TFStridedSlice *node)
38 auto const_input = loco::must_cast<moco::TFConst *>(node->input());
39 auto const_begin = loco::must_cast<moco::TFConst *>(node->begin());
40 auto const_end = loco::must_cast<moco::TFConst *>(node->end());
41 auto input_rank = const_input->rank();
42 auto output_rank = input_rank;
43 loco::TensorShape output_shape_range;
45 output_shape_range.rank(input_rank);
46 for (uint32_t r = 0; r < input_rank; ++r)
48 // TODO apply begin/end mask
49 // TODO apply ellipsis mask
51 auto end = const_end->at<loco::DataType::S32>(r);
52 auto begin = const_begin->at<loco::DataType::S32>(r);
53 auto size = end - begin;
54 output_shape_range.dim(r).set(size);
57 loco::TensorShape output_tensor_shape;
58 if (node->shrink_axis_mask() != 0)
60 for (uint32_t rs = 0; rs < input_rank; ++rs)
62 int32_t bit = 1 << rs;
63 int32_t mask = node->shrink_axis_mask();
66 // shrink one dimension
67 assert(output_rank > 0);
68 output_rank = output_rank - 1;
71 output_tensor_shape.rank(output_rank);
72 for (uint32_t rs = 0, rd = 0; rs < input_rank; ++rs)
74 int32_t bit = 1 << rs;
75 int32_t mask = node->shrink_axis_mask();
76 if ((bit & mask) == 0)
79 output_tensor_shape.dim(rd).set(output_shape_range.dim(rs).value());
82 // else this dimension is shrink-ed
87 output_tensor_shape = output_shape_range;
90 return output_tensor_shape;
93 moco::u32v_t vector_from_const(moco::TFConst *tfconst)
97 auto rank = tfconst->rank();
99 auto dim = tfconst->dim(0).value();
102 for (uint32_t r = 0; r < dim; ++r)
104 auto val = tfconst->at<loco::DataType::S32>(r);
111 moco::u32v_t operator-(const moco::u32v_t &lhs, const moco::u32v_t &rhs)
113 assert(lhs.size() == rhs.size());
116 res.resize(lhs.size());
117 for (uint32_t r = 0; r < lhs.size(); r++)
119 res.at(r) = lhs.at(r) - rhs.at(r);
124 template <typename T> T tfconst_at(const moco::TFConst *tfconst, const moco::u32v_t &pos);
126 template <> int32_t tfconst_at<int32_t>(const moco::TFConst *tfconst, const moco::u32v_t &pos)
128 uint32_t rank = tfconst->rank();
129 assert(rank == pos.size());
130 uint32_t element = 0;
131 for (uint32_t r = 0; r < rank; ++r)
133 uint32_t dim = tfconst->dim(r).value();
134 element = element * dim + pos.at(r);
136 return tfconst->at<loco::DataType::S32>(element);
139 template <> float tfconst_at<float>(const moco::TFConst *tfconst, const moco::u32v_t &pos)
141 uint32_t rank = tfconst->rank();
142 assert(rank == pos.size());
143 uint32_t element = 0;
144 for (uint32_t r = 0; r < rank; ++r)
146 uint32_t dim = tfconst->dim(r).value();
147 element = element * dim + pos.at(r);
149 return tfconst->at<loco::DataType::FLOAT32>(element);
152 void tfconst_at(moco::TFConst *tfconst, const moco::u32v_t &pos, int32_t value)
154 // tfconst->rank() can be smaller than pos.size()
155 // i.e., tfconst: shape[3] and pos[0,1]
156 // where shape[3] is output result shape
157 // [0,1] is position of input const
158 uint32_t rank = pos.size();
159 uint32_t element = 0;
160 for (uint32_t r = 0; r < rank; ++r)
162 // this is like expand the shape from [3] to [1,3] to use same formula as in reading
163 uint32_t dim = tfconst->rank() < r ? tfconst->dim(r).value() : 1;
164 element = element * dim + pos.at(r);
167 tfconst->at<loco::DataType::S32>(element) = value;
170 void tfconst_at(moco::TFConst *tfconst, const moco::u32v_t &pos, float value)
172 uint32_t rank = pos.size();
173 uint32_t element = 0;
174 for (uint32_t r = 0; r < rank; ++r)
176 uint32_t dim = tfconst->rank() < r ? tfconst->dim(r).value() : 1;
177 element = element * dim + pos.at(r);
180 tfconst->at<loco::DataType::FLOAT32>(element) = value;
183 bool constantfold_stridedslice(moco::TFStridedSlice *node)
185 auto const_input = dynamic_cast<moco::TFConst *>(node->input());
186 if (const_input == nullptr)
188 // input is not TFConst, there's nothing to do
192 // TODO support full mask features: see import codes also
193 assert(node->begin_mask() == 0);
194 assert(node->end_mask() == 0);
195 assert(node->ellipsis_mask() == 0);
196 assert(node->shrink_axis_mask() == 1);
198 // TODO support other dtypes
199 assert(const_input->dtype() == loco::DataType::S32 ||
200 const_input->dtype() == loco::DataType::FLOAT32);
202 auto const_begin = dynamic_cast<moco::TFConst *>(node->begin());
203 auto const_end = dynamic_cast<moco::TFConst *>(node->end());
204 auto const_strides = dynamic_cast<moco::TFConst *>(node->strides());
205 if (const_begin == nullptr || const_end == nullptr || const_strides == nullptr)
210 // NOTE need shape but cannot depend on shape inference service module
211 auto tensor_shape = calc_output_shape(node);
212 auto input_shape = moco::tensor_shape(const_input);
214 auto graph = node->graph();
216 // Create our target TFConst node with shape from begin~end/strides
217 auto const_sliced = moco::new_const(graph, tensor_shape, const_input->dtype());
219 // Copy sliced elements using TensorSliceEnumerator
220 moco::TensorSliceEnumerator etor;
221 auto v_begin = vector_from_const(const_begin);
222 auto v_end = vector_from_const(const_end);
223 moco::u32v_t v_cursor;
224 moco::u32v_t v_offset;
226 etor.shape(input_shape);
230 for (etor.start(); etor.valid(); etor.advance())
232 v_cursor = etor.cursor();
233 v_offset = v_cursor - v_begin;
235 if (const_input->dtype() == loco::DataType::S32)
237 int32_t value = tfconst_at<int32_t>(const_input, v_cursor);
238 tfconst_at(const_sliced, v_offset, value);
240 else if (const_input->dtype() == loco::DataType::FLOAT32)
242 float value = tfconst_at<float>(const_input, v_cursor);
243 tfconst_at(const_sliced, v_offset, value);
248 loco::replace(node).with(const_sliced);
259 * @note This will Replace TFStridedSlice with TFConst when 'input' is TFConst
262 * A --- TFStridedSlice --- C
265 * A --- TFStridedSlice
267 * TFConst ---------- C
269 * A,B : inputs of TFStridedSlice
270 * C : a node that uses TFStridedSlice as an input
271 * TFStridedSlice is disconnected from C
272 * Nodes are drawn multiple times to simplify the diagram
274 * Only limit set of inputs are supported for now
276 bool ConstantFoldStridedSlice::run(loco::Graph *graph)
278 bool changed = false;
279 for (auto node : loco::active_nodes(loco::output_nodes(graph)))
281 if (auto sslice_node = as<moco::TFStridedSlice>(node))
283 if (constantfold_stridedslice(sslice_node))