2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "moco/Support/TFShapeInferenceHelper.h"
19 #include <loco/Service/ShapeInference.h>
21 #include <oops/InternalExn.h>
28 // TODO Use codes in loco and remove duplicate broadcast_shape() and related
30 * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
34 * auto expanded_tensor_shape = expand(tensor_shape).to(N);
36 class TensorShapeExpander
39 TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
45 loco::TensorShape to(uint32_t output_rank)
47 auto const &input_shape = _shape;
48 uint32_t const input_rank = input_shape.rank();
50 assert(input_rank <= output_rank && "Cannot shrink rank");
51 uint32_t const axis_shift = output_rank - input_rank;
53 loco::TensorShape output_shape;
55 output_shape.rank(output_rank);
56 for (uint32_t axis = 0; axis < output_rank; ++axis)
58 output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
65 const loco::TensorShape _shape;
69 * @breif Expand shape x and y to same rank by align right and filling with 1
71 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
73 auto x_rank = x.rank();
74 auto y_rank = y.rank();
79 TensorShapeExpander x_exp(x);
80 TensorShapeExpander y_exp(y);
82 auto xy_rank = std::max(x_rank, y_rank);
84 x = x_rank > y_rank ? x : x_exp.to(xy_rank);
85 y = y_rank > x_rank ? y : y_exp.to(xy_rank);
89 * @breif Returns shape of expanded dimension of input x and y having same rank
91 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
93 assert(x.rank() == y.rank());
97 loco::TensorShape output_shape;
99 output_shape.rank(rank);
100 for (uint32_t axis = 0; axis < rank; ++axis)
102 assert(x.dim(axis).known() && y.dim(axis).known());
104 auto x_dim = x.dim(axis).value();
105 auto y_dim = y.dim(axis).value();
107 // each dimension of x and y should be same or one must be 1 if different
108 if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
110 // TODO may need to refine message
111 INTERNAL_EXN("ShapeInference: Input shapes don't match");
114 output_shape.dim(axis) = std::max(x_dim, y_dim);
125 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
130 expand_rank(x_match, y_match);
132 auto output_shape = expand_dimension(x_match, y_match);
142 loco::NodeShape node_shape(const loco::Node *node)
144 loco::NodeShape nodeshape; // default domain is Unknown
146 if (loco::shape_known(node))
148 nodeshape = loco::shape_get(node);
154 bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape)
156 nodeshape = node_shape(node);
157 return (nodeshape.domain() != loco::Domain::Unknown);
160 loco::TensorShape as_tensor_shape(const loco::FeatureShape &feature_shape,
161 const TFDataLayout &data_layout)
163 loco::TensorShape tensor_shape;
165 tensor_shape.rank(4);
166 if (data_layout == "NHWC")
168 tensor_shape.dim(0) = feature_shape.count();
169 tensor_shape.dim(1) = feature_shape.height();
170 tensor_shape.dim(2) = feature_shape.width();
171 tensor_shape.dim(3) = feature_shape.depth();
173 else if (data_layout == "NCHW")
175 tensor_shape.dim(0) = feature_shape.count();
176 tensor_shape.dim(1) = feature_shape.depth();
177 tensor_shape.dim(2) = feature_shape.height();
178 tensor_shape.dim(3) = feature_shape.width();
182 // TODO support for other data_layout if needed
183 INTERNAL_EXN_V("ShapeInference: Unknown data_format", data_layout);
189 loco::FeatureShape as_feature_shape(const loco::NodeShape &nodeshape,
190 const TFDataLayout &data_layout)
192 if (nodeshape.domain() == loco::Domain::Feature)
193 return nodeshape.as<loco::FeatureShape>();
195 loco::FeatureShape feature_shape;
197 // only convert from tensor to feature
198 if (nodeshape.domain() != loco::Domain::Tensor)
200 INTERNAL_EXN("ShapeInference: Invalid shape information");
203 loco::TensorShape tensor_shape = nodeshape.as<loco::TensorShape>();
205 if (tensor_shape.rank() != 4)
207 INTERNAL_EXN("ShapeInference: Rank is not 4");
210 if (data_layout == "NHWC")
212 feature_shape.count() = tensor_shape.dim(0);
213 feature_shape.height() = tensor_shape.dim(1);
214 feature_shape.width() = tensor_shape.dim(2);
215 feature_shape.depth() = tensor_shape.dim(3);
217 else if (data_layout == "NCHW")
219 feature_shape.count() = tensor_shape.dim(0);
220 feature_shape.depth() = tensor_shape.dim(1);
221 feature_shape.height() = tensor_shape.dim(2);
222 feature_shape.width() = tensor_shape.dim(3);
226 // TODO support for other data_layout if needed
227 INTERNAL_EXN_V("ShapeInference: Unknown data_format", data_layout);
230 return feature_shape;
238 PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape)
240 PlaneShape plane_shape;
242 plane_shape.height = feature_shape.height();
243 plane_shape.width = feature_shape.width();
248 FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
250 return FeatureShapeUpdater{&feature_shape};
259 * @brief Class to represent TensorFlow "data_format" attr.
261 enum class DataLayout
267 DataLayout as_data_layout(const std::string &tf_layout_str)
269 if (tf_layout_str == "NHWC")
270 return DataLayout::NHWC;
271 else if (tf_layout_str == "NCHW")
272 return DataLayout::NCHW;
274 /// @note data layout tag in TensorFlow is 'data_format'
275 INTERNAL_EXN_V("ShapeInference: Unknown data_format", tf_layout_str);
283 loco::Stride<2> stride_of(const TFStrides &strides, const TFDataLayout &datalayout)
285 loco::Stride<2> stride;
287 auto data_layout = as_data_layout(datalayout);
288 if (data_layout == DataLayout::NHWC)
290 stride.vertical(strides[1]);
291 stride.horizontal(strides[2]);
293 else if (data_layout == DataLayout::NCHW)
295 stride.vertical(strides[2]);
296 stride.horizontal(strides[3]);
300 // TODO add more datalayout supports if needed
301 INTERNAL_EXN("ShapeInference: Unknown data_format");
307 loco::Window<2> window_of(const TFKSize &ksize, const TFDataLayout &datalayout)
309 loco::Window<2> window;
311 auto data_layout = as_data_layout(datalayout);
312 if (data_layout == DataLayout::NHWC)
314 window.vertical(ksize[1]);
315 window.horizontal(ksize[2]);
317 else if (data_layout == DataLayout::NCHW)
319 window.vertical(ksize[2]);
320 window.horizontal(ksize[3]);
324 // TODO add more datalayout supports if needed
325 INTERNAL_EXN("ShapeInference: Unknown data_format");
331 loco::Window<2> window_of(const loco::TensorShape &shape, const TFDataLayout &datalayout)
333 loco::Window<2> window;
335 if (datalayout == "HWIO")
337 window.vertical(shape.dim(0).value());
338 window.horizontal(shape.dim(1).value());
340 else if (datalayout == "HWCM")
342 window.vertical(shape.dim(0).value());
343 window.horizontal(shape.dim(1).value());
347 // TODO add more datalayout supports if needed
348 INTERNAL_EXN_V("ShapeInference: Unknown data_format", datalayout);