2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ShapeValidator.h"
22 #include "ir/operation/LowerInfo.h"
24 #include "util/logging.h"
25 #include "util/Utils.h"
27 #define OP_REQUIRES(EXP) \
31 throw std::runtime_error("ShapeValidator failed at line " + std::to_string(__LINE__)); \
39 ShapeValidator::ShapeValidator(const ir::Graph &graph)
40 : _graph{graph}, _ctx{graph.operands()}, _current_layout{ir::Layout::UNKNOWN}
44 void ShapeValidator::checkUnaryOp(const ir::Operation &node)
46 const auto output_index{node.getOutputs().at(0)};
47 const auto input_index{node.getInputs().at(0)};
49 if (_ctx.at(output_index).info().isDynamic())
52 // Check if I/O shapes match
53 OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
56 void ShapeValidator::operator()()
58 // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
60 assert(_graph.subgraphs() == nullptr);
62 _current_layout = _graph.layout();
64 _graph.operations().iterate(
65 [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
68 void ShapeValidator::visit(const ir::operation::BatchMatMul &node)
70 const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
71 const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
72 const auto out_index{node.getOutputs().at(0)};
74 if (_ctx.at(out_index).info().isDynamic())
77 OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4);
78 OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4);
79 OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2);
80 OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2);
83 void ShapeValidator::visit(const ir::operation::BatchToSpaceND &node)
85 const auto ofm_index{node.getOutputs().at(0)};
86 if (_ctx.at(ofm_index).info().isDynamic())
89 const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
90 const auto block_size_index{
91 node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
93 const auto frontend_layout = _current_layout;
94 const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
95 const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
97 // All requirement as per NNAPI specification.
98 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
99 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
100 OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
102 OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
104 if (node.getInputs().size() != 2)
106 const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
107 OP_REQUIRES(_ctx.at(crops_index).shape().rank() == 2);
108 OP_REQUIRES(_ctx.at(crops_index).shape().dim(0) == (_ctx.at(ifm_index).shape().rank() - 2));
109 OP_REQUIRES(_ctx.at(crops_index).shape().dim(1) == 2);
112 OP_REQUIRES(input_shape.C == output_shape.C);
115 void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node)
117 const auto ofm_index{node.getOutputs().at(0)};
118 if (_ctx.at(ofm_index).info().isDynamic())
121 const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
122 const auto weight_scales_index{
123 node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_SCALES)};
124 const auto weight_binary_index{
125 node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_BINARY)};
126 const auto weight_cluster_index{
127 node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
128 // const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
130 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 2);
131 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 2);
132 OP_REQUIRES(_ctx.at(weight_scales_index).shape().rank() == 1);
133 OP_REQUIRES(_ctx.at(weight_binary_index).shape().rank() == 2);
134 OP_REQUIRES(_ctx.at(weight_cluster_index).shape().rank() == 2);
136 OP_REQUIRES(_ctx.at(ifm_index).shape().dim(1) == _ctx.at(ofm_index).shape().dim(1));
138 OP_REQUIRES(_ctx.at(weight_cluster_index).shape().dim(0) > 0);
139 OP_REQUIRES(_ctx.at(weight_cluster_index).shape().dim(1) == 2);
141 // more shape validation will be done inside kernel.
143 // TODO Check bias dimension (can be null tensor)
146 void ShapeValidator::visit(const ir::operation::BCQGather &node)
148 const auto ofm_index{node.getOutputs().at(0)};
149 if (_ctx.at(ofm_index).info().isDynamic())
152 const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
153 const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
154 const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
155 const auto input_clusters_index{
156 node.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
158 OP_REQUIRES(_ctx.at(indices_index).shape().rank() <= 2); // TODO : support rank up to 4 or more
159 OP_REQUIRES(_ctx.at(input_binary_index).shape().rank() == 2);
160 OP_REQUIRES(_ctx.at(input_scales_index).shape().rank() == 1);
161 OP_REQUIRES(_ctx.at(input_clusters_index).shape().rank() == 2);
163 OP_REQUIRES(_ctx.at(input_clusters_index).shape().dim(0) > 0);
164 OP_REQUIRES(_ctx.at(input_clusters_index).shape().dim(1) == 2);
166 // more shape validation will be done inside kernel.
169 void ShapeValidator::visit(const ir::operation::Comparison &)
171 // TODO Shape validation of comparison
174 void ShapeValidator::visit(const ir::operation::Softmax &node)
176 const auto output_index{node.getOutputs().at(0)};
177 if (_ctx.at(output_index).info().isDynamic())
180 const auto input_index{node.getInputs().at(0)};
182 OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
185 void ShapeValidator::visit(const ir::operation::InstanceNorm &node)
187 const auto ofm_index{node.getOutputs().at(0)};
188 if (_ctx.at(ofm_index).info().isDynamic())
191 const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
192 const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
193 const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
195 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
196 OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape());
197 OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1);
198 OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1);
201 void ShapeValidator::visit(const ir::operation::Pool2D &node)
203 const auto ofm_index{node.getOutputs().at(0)};
204 if (_ctx.at(ofm_index).info().isDynamic())
207 const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
209 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
212 void ShapeValidator::visit(const ir::operation::Permute &node)
214 const auto output_index{node.getOutputs().at(0)};
215 if (_ctx.at(output_index).info().isDynamic())
218 const auto input_index{node.getInputs().at(0)};
220 OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
223 void ShapeValidator::visit(const ir::operation::Reduce &node)
225 const auto output_index{node.getOutputs().at(0)};
226 if (_ctx.at(output_index).info().isDynamic())
229 const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
230 const auto input_shape = _ctx.at(input_index).shape();
231 const auto output_shape = _ctx.at(output_index).shape();
233 OP_REQUIRES(input_shape.rank() <= 4);
234 OP_REQUIRES(output_shape.rank() <= input_shape.rank());
236 // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
237 // supports cases reducing height and width or reducing depth.
238 // TODO We have to support all cases of dimensions up to 4.
239 // For correct permuting, we have to set output's shape to be equal in dimension position of the
240 // input. But the positions of the same dimensions in the input and output may be set differently.
241 // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
242 // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
243 // extend it in 4 dimensions, it should be {1,1,3,5}.
244 // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
245 // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
246 // next operation is not desired.
247 if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
249 if (output_shape.rank() == 2)
252 OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
253 input_shape.dim(3) == output_shape.dim(1));
255 else if (output_shape.rank() == 3)
258 // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
259 OP_REQUIRES((input_shape.dim(0) == output_shape.dim(0) &&
260 input_shape.dim(1) == output_shape.dim(1) &&
261 input_shape.dim(2) == output_shape.dim(2)) ||
262 (input_shape.dim(0) == output_shape.dim(0) &&
263 (input_shape.dim(1) == output_shape.dim(1) ||
264 input_shape.dim(2) == output_shape.dim(1)) &&
265 input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
270 void ShapeValidator::visit(const ir::operation::Transpose &node)
272 const auto output_index{node.getOutputs().at(0)};
273 if (_ctx.at(output_index).info().isDynamic())
276 const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
277 const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
279 const auto &output_shape = _ctx.at(output_index).shape();
280 const auto &input_shape = _ctx.at(input_index).shape();
282 OP_REQUIRES(_ctx.at(perm_index).shape().num_elements() == 0 ||
283 input_shape.rank() == static_cast<int>(_ctx.at(perm_index).shape().num_elements()));
284 OP_REQUIRES(input_shape.rank() == output_shape.rank());
287 void ShapeValidator::visit(const ir::operation::RNN &node)
289 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
290 // TODO Support dynamic rnn
291 const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
292 if (_ctx.at(output_index).info().isDynamic())
295 const auto hidden_state_out_index{
296 node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)};
298 const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
299 const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
300 const auto recurrent_weights_index{
301 node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)};
302 const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
303 const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
305 const auto batch_size = _ctx.at(output_index).shape().dim(0);
306 const auto num_units = _ctx.at(output_index).shape().dim(1);
308 OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 &&
309 _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
310 _ctx.at(input_index).shape().rank() == 2 &&
311 _ctx.at(weights_index).shape().rank() == 2 &&
312 _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
313 _ctx.at(hidden_state_in_index).shape().rank() == 2);
314 OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1);
316 OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) &&
317 batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
318 batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
319 OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
321 OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) &&
322 num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
323 num_units == _ctx.at(bias_index).shape().dim(0));
324 OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) &&
325 num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
326 num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
327 num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
330 void ShapeValidator::visit(const ir::operation::SpaceToBatchND &node)
332 const auto ofm_index{node.getOutputs().at(0)};
333 if (_ctx.at(ofm_index).info().isDynamic())
336 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
337 const auto block_size_index{
338 node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
339 const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
341 const auto frontend_layout = _current_layout;
342 const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
343 const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
345 // All requirement as per NNAPI specification.
346 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
347 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
348 OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
349 OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2);
351 OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
352 OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2);
353 OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2);
355 OP_REQUIRES(input_shape.C == output_shape.C);
358 void ShapeValidator::visit(const ir::operation::SpaceToDepth &node)
360 const auto ofm_index{node.getOutputs().at(0)};
361 if (_ctx.at(ofm_index).info().isDynamic())
364 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
366 const auto frontend_layout = _current_layout;
367 const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
368 const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
369 const auto block_size = node.param().block_size;
371 // All assertions as per NNAPI specification.
372 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
373 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
374 OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
375 OP_REQUIRES(input_shape.N == output_shape.N);
376 OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
379 void ShapeValidator::visit(const ir::operation::ElementwiseActivation &node) { checkUnaryOp(node); }
381 void ShapeValidator::visit(const ir::operation::ElementwiseBinary &)
383 // TODO Shape validation of ElementwiseBinary
386 void ShapeValidator::visit(const ir::operation::ElementwiseUnary &node)
388 const auto output_index{node.getOutputs().at(0)};
389 const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
391 if (_ctx.at(output_index).info().isDynamic())
394 OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
397 void ShapeValidator::visit(const ir::operation::EmbeddingLookup &node)
399 const auto output_index{node.getOutputs().at(0)};
400 const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
401 const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
403 const auto &output_obj = _ctx.at(output_index);
404 const auto &lookups_obj = _ctx.at(lookups_index);
405 const auto &values_obj = _ctx.at(values_index);
407 // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
408 // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
410 if (_ctx.at(output_index).info().isDynamic())
413 const auto &output_shape = output_obj.shape();
414 const auto &lookups_shape = lookups_obj.shape();
415 const auto &values_shape = values_obj.shape();
417 OP_REQUIRES(lookups_shape.rank() == 1);
418 OP_REQUIRES(values_shape.rank() >= 2);
420 // output should be a n-D tensor with the same rank and shape as the values tensor, except for
421 // the first dimension which has the same size as lookups' only dimension.
422 OP_REQUIRES(output_shape.rank() == values_shape.rank());
423 OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
424 for (int n = 1; n < output_shape.rank(); ++n)
426 OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
431 void ShapeValidator::visit(const ir::operation::ExpandDims &node)
433 const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
435 if (_ctx.at(axis_index).info().isDynamic())
437 OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
440 void ShapeValidator::visit(const ir::operation::HashtableLookup &node)
442 const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
443 const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
444 const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
445 const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
447 const auto &output_obj = _ctx.at(output_index);
448 const auto &lookups_obj = _ctx.at(lookups_index);
449 const auto &keys_obj = _ctx.at(keys_index);
450 const auto &values_obj = _ctx.at(values_index);
452 if (_ctx.at(output_index).info().isDynamic())
455 const auto &output_shape = output_obj.shape();
456 const auto &lookups_shape = lookups_obj.shape();
457 const auto &keys_shape = keys_obj.shape();
458 const auto &values_shape = values_obj.shape();
460 OP_REQUIRES(values_shape.rank() == output_shape.rank());
461 OP_REQUIRES(lookups_shape.rank() == 1);
462 OP_REQUIRES(keys_shape.rank() == 1);
463 OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
464 OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
467 void ShapeValidator::visit(const ir::operation::TransposeConv &node)
470 const auto ofm_index{node.getOutputs().at(0)};
471 if (_ctx.at(ofm_index).info().isDynamic())
474 const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
475 const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
477 // Only 4D tensors are supported
478 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
479 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank());
480 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank());
482 const auto frontend_layout = _current_layout;
483 const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
484 const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
485 // The kernel has only IHWO layout on frontend
486 // So ker_shape is treated here below
491 const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC);
493 OP_REQUIRES(ifm_shape.N == ofm_shape.N);
494 OP_REQUIRES(ifm_shape.C == ker_shape.C);
495 OP_REQUIRES(ker_shape.N == ofm_shape.C);
498 void ShapeValidator::visit(const ir::operation::Gather &node)
500 const auto ofm_index{node.getOutputs().at(0)};
501 if (_ctx.at(ofm_index).info().isDynamic())
504 const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
505 const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
507 const auto ifm_shape = _ctx.at(ifm_index).shape();
508 const auto indices_shape = _ctx.at(indices_index).shape();
509 const auto ofm_shape = _ctx.at(ofm_index).shape();
511 OP_REQUIRES(ifm_shape.rank() <= 4);
512 OP_REQUIRES(indices_shape.rank() <= 3);
513 OP_REQUIRES(ofm_shape.rank() <= 4);
516 void ShapeValidator::visit(const ir::operation::DepthToSpace &node)
518 int32_t block_size = node.param().block_size;
521 const auto output_index{node.getOutputs().at(0)};
522 if (_ctx.at(output_index).info().isDynamic())
525 const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
527 const auto frontend_layout = _current_layout;
528 const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout);
529 const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout);
531 OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
532 OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
535 OP_REQUIRES(output_shape.N == input_shape.N);
536 OP_REQUIRES(output_shape.H == input_shape.H * block_size);
537 OP_REQUIRES(output_shape.W == input_shape.W * block_size);
538 OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
539 OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
543 void ShapeValidator::visit(const ir::operation::Pack &node)
545 const auto axis{node.param().axis};
546 const auto output_index{node.getOutputs().at(0)};
547 if (_ctx.at(output_index).info().isDynamic())
551 const auto &output_shape = _ctx.at(output_index).shape();
552 const auto output_rank = static_cast<int32_t>(output_shape.rank());
554 const auto input1_index{node.getInputs().at(0)};
555 const auto input_shape = _ctx.at(input1_index).shape();
557 OP_REQUIRES(axis >= -output_rank && axis < output_rank);
558 for (const auto &index : node.getInputs())
560 OP_REQUIRES(input_shape == _ctx.at(index).shape());
564 void ShapeValidator::visit(const ir::operation::LSTM &node)
566 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
567 // TODO Support dynamic rnn
568 const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
569 if (_ctx.at(output_index).info().isDynamic())
572 const auto scratch_buffer_index{
573 node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
574 const auto output_state_out_index{
575 node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
576 const auto cell_state_out_index{
577 node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
579 const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
580 const auto input_to_input_weights_index{
581 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
582 const auto input_to_forget_weights_index{
583 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
584 const auto input_to_cell_weights_index{
585 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
586 const auto input_to_output_weights_index{
587 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
588 const auto recurrent_to_input_weights_index{
589 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
590 const auto recurrent_to_forget_weights_index{
591 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
592 const auto recurrent_to_cell_weights_index{
593 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
594 const auto recurrent_to_output_weights_index{
595 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
596 const auto cell_to_input_weights_index{
597 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
598 const auto cell_to_forget_weights_index{
599 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
600 const auto cell_to_output_weights_index{
601 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
602 const auto input_gate_bias_index{
603 node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
604 const auto forget_gate_bias_index{
605 node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
606 const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
607 const auto output_gate_bias_index{
608 node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
609 const auto projection_weights_index{
610 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
611 const auto projection_bias_index{
612 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
613 const auto output_state_in_index{
614 node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
615 const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
617 OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
618 for (int i = 0; i < _ctx.at(input_index).shape().rank() - 1; ++i)
620 OP_REQUIRES(_ctx.at(input_index).shape().dim(i) == _ctx.at(output_index).shape().dim(i));
623 (_ctx.at(output_index).shape().rank() == 2 || _ctx.at(output_index).shape().rank() == 3) &&
624 (_ctx.at(input_index).shape().rank() == 2 || _ctx.at(input_index).shape().rank() == 3) &&
625 (!_ctx.exist(input_to_input_weights_index) ||
626 _ctx.at(input_to_input_weights_index).shape().rank() == 2) &&
627 _ctx.at(input_to_forget_weights_index).shape().rank() == 2 &&
628 _ctx.at(input_to_cell_weights_index).shape().rank() == 2 &&
629 _ctx.at(input_to_output_weights_index).shape().rank() == 2 &&
630 (!_ctx.exist(recurrent_to_input_weights_index) ||
631 _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
632 _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
633 _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
634 _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
635 (!_ctx.exist(projection_weights_index) ||
636 _ctx.at(projection_weights_index).shape().rank() == 2) &&
637 _ctx.at(output_state_in_index).shape().rank() == 2 &&
638 _ctx.at(cell_state_in_index).shape().rank() == 2);
641 (!_ctx.exist(cell_to_input_weights_index) ||
642 _ctx.at(cell_to_input_weights_index).shape().rank() == 1) &&
643 (!_ctx.exist(cell_to_forget_weights_index) ||
644 _ctx.at(cell_to_forget_weights_index).shape().rank() == 1) &&
645 (!_ctx.exist(cell_to_output_weights_index) ||
646 _ctx.at(cell_to_output_weights_index).shape().rank() == 1) &&
647 (!_ctx.exist(input_gate_bias_index) || _ctx.at(input_gate_bias_index).shape().rank() == 1) &&
648 _ctx.at(forget_gate_bias_index).shape().rank() == 1 &&
649 _ctx.at(cell_bias_index).shape().rank() == 1 &&
650 _ctx.at(output_gate_bias_index).shape().rank() == 1 &&
651 (!_ctx.exist(projection_bias_index) || _ctx.at(projection_bias_index).shape().rank() == 1));
655 ((!_ctx.exist(input_to_input_weights_index) ||
656 (_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 &&
657 _ctx.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
658 (!_ctx.exist(recurrent_to_input_weights_index) ||
659 (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
660 _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
661 (!_ctx.exist(input_gate_bias_index) || _ctx.at(input_gate_bias_index).shape().dim(0) == 0) &&
662 (!_ctx.exist(cell_to_input_weights_index) ||
663 _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
664 ((_ctx.exist(input_to_input_weights_index) &&
665 (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
666 _ctx.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
667 (_ctx.exist(recurrent_to_input_weights_index) &&
668 (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
669 _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
670 (_ctx.exist(input_gate_bias_index) && _ctx.at(input_gate_bias_index).shape().dim(0) != 0)));
672 // Peephole assertion
673 OP_REQUIRES(((!_ctx.exist(cell_to_forget_weights_index) ||
674 _ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
675 (!_ctx.exist(cell_to_output_weights_index) ||
676 _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
677 ((_ctx.exist(cell_to_forget_weights_index) &&
678 _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
679 (_ctx.exist(cell_to_output_weights_index) &&
680 _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0)));
682 bool has_input_to_input_weights = _ctx.exist(input_to_input_weights_index) &&
683 (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
684 _ctx.at(input_to_input_weights_index).shape().dim(1) != 0);
685 bool has_recurrent_to_input_weights =
686 _ctx.exist(recurrent_to_input_weights_index) &&
687 (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
688 _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
689 bool has_input_gate_bias =
690 _ctx.exist(input_gate_bias_index) && _ctx.at(input_gate_bias_index).shape().dim(0) != 0;
691 bool has_cell_to_input_weights = _ctx.exist(cell_to_input_weights_index) &&
692 _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0;
693 bool has_cell_to_forget_weights = _ctx.exist(cell_to_forget_weights_index) &&
694 _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0;
695 bool has_cell_to_output_weights = _ctx.exist(cell_to_output_weights_index) &&
696 _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0;
697 bool has_projection_weights = _ctx.exist(projection_weights_index) &&
698 (_ctx.at(projection_weights_index).shape().dim(0) != 0 &&
699 _ctx.at(projection_weights_index).shape().dim(1) != 0);
700 bool has_projection_bias =
701 _ctx.exist(projection_bias_index) && _ctx.at(projection_bias_index).shape().dim(0) != 0;
703 // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
706 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
708 // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
710 // false: no peephole
711 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
713 // NOTE The projection weights may have data but the projection bias may not.
714 bool has_projection_param = has_projection_weights;
716 const auto batch_size = (_ctx.at(input_index).shape().rank() == 3 && node.param().time_major)
717 ? _ctx.at(input_index).shape().dim(1)
718 : _ctx.at(input_index).shape().dim(0);
719 OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) &&
720 batch_size == _ctx.at(cell_state_in_index).shape().dim(0));
722 const auto input_size = _ctx.at(input_index).shape().dim(_ctx.at(input_index).shape().rank() - 1);
723 OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) &&
724 input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) &&
725 input_size == _ctx.at(input_to_output_weights_index).shape().dim(1));
727 const auto num_units = _ctx.at(input_to_output_weights_index).shape().dim(0);
728 OP_REQUIRES(num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) &&
729 num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) &&
730 num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) &&
731 num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) &&
732 num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) &&
733 num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) &&
734 num_units == _ctx.at(cell_bias_index).shape().dim(0) &&
735 num_units == _ctx.at(output_gate_bias_index).shape().dim(0) &&
736 num_units == _ctx.at(cell_state_in_index).shape().dim(1));
738 const auto output_size =
739 _ctx.at(output_index).shape().dim(_ctx.at(output_index).shape().rank() - 1);
740 OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) &&
741 output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) &&
742 output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) &&
743 output_size == _ctx.at(output_state_in_index).shape().dim(1));
747 OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1));
748 OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) &&
749 num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) &&
750 ((_ctx.exist(cell_to_input_weights_index) &&
751 num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0)) ||
752 (!_ctx.exist(cell_to_input_weights_index) ||
753 _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
754 num_units == _ctx.at(input_gate_bias_index).shape().dim(0));
755 OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1));
756 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
757 has_input_gate_bias);
758 if (has_cell_to_input_weights)
760 // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
761 OP_REQUIRES(has_peephole_param);
763 if (_ctx.exist(scratch_buffer_index))
764 OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
768 if (_ctx.exist(scratch_buffer_index))
769 OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
772 if (has_peephole_param)
774 OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) &&
775 num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) &&
776 (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
777 _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
780 if (has_projection_param)
782 OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1));
783 OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0));
784 if (has_projection_bias)
786 OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0));
790 if (_ctx.exist(scratch_buffer_index))
792 OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2);
793 OP_REQUIRES(batch_size == _ctx.at(scratch_buffer_index).shape().dim(0));
796 if (_ctx.exist(output_state_out_index))
798 OP_REQUIRES(_ctx.at(output_state_out_index).shape().rank() == 2);
799 OP_REQUIRES(batch_size == _ctx.at(output_state_out_index).shape().dim(0));
800 OP_REQUIRES(output_size == _ctx.at(output_state_out_index).shape().dim(1));
803 if (_ctx.exist(cell_state_out_index))
805 OP_REQUIRES(_ctx.at(cell_state_out_index).shape().rank() == 2);
806 OP_REQUIRES(batch_size == _ctx.at(cell_state_out_index).shape().dim(0));
807 OP_REQUIRES(num_units == _ctx.at(cell_state_out_index).shape().dim(1));
811 void ShapeValidator::visit(const ir::operation::L2Normalization &node)
813 const auto ofm_index{node.getOutputs().at(0)};
814 if (_ctx.at(ofm_index).info().isDynamic())
817 const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
819 auto ifm_shape = _ctx.at(ifm_index).shape();
820 auto ofm_shape = _ctx.at(ofm_index).shape();
822 OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
824 for (auto i = 0; i < ifm_shape.rank(); i++)
826 OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
830 void ShapeValidator::visit(const ir::operation::Unpack &node)
832 const auto axis{node.param().axis};
833 const auto output_index{node.getInputs().at(0)};
834 if (_ctx.at(output_index).info().isDynamic())
837 const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
839 const auto &input_shape = _ctx.at(input_index).shape();
840 const auto input_rank = static_cast<int32_t>(input_shape.rank());
842 OP_REQUIRES(axis >= -input_rank && axis < input_rank);
845 void ShapeValidator::visit(const ir::operation::Pad &node)
847 const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
848 OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32);
850 const auto output_index{node.getInputs().at(0)};
851 if (_ctx.at(output_index).info().isDynamic())
854 const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
856 const auto &pad_shape = _ctx.at(pad_index).shape();
857 const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank());
859 OP_REQUIRES(pad_shape.rank() == 2);
860 OP_REQUIRES(pad_shape.dim(0) == input_rank);
861 OP_REQUIRES(pad_shape.dim(1) == 2);
862 OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
865 void ShapeValidator::visit(const ir::operation::Select &)
867 // TODO Shape validation of select
870 void ShapeValidator::visit(const ir::operation::StridedSlice &node)
872 const auto output_index{node.getOutputs().at(0)};
873 const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
875 if (_ctx.at(output_index).info().isDynamic())
878 OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
881 void ShapeValidator::visit(const ir::operation::Split &node)
883 const auto output_index{node.getOutputs().at(0)};
884 if (_ctx.at(output_index).info().isDynamic())
887 const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
888 const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
890 const auto num_splits = node.param().num_splits;
891 const auto input_rank = _ctx.at(input_index).shape().rank();
892 auto axis = *reinterpret_cast<const int32_t *>(_ctx.at(axis_index).data()->base());
893 axis = axis < 0 ? axis + input_rank : axis;
895 OP_REQUIRES(axis >= 0 && axis < input_rank);
896 OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
899 void ShapeValidator::visit(const ir::operation::Shape &node)
901 const auto output_index{node.getOutputs().at(0)};
902 if (_ctx.at(output_index).info().isDynamic())
905 const auto input_index{node.getInputs().at(0)};
906 UNUSED_RELEASE(input_index);
907 OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
910 void ShapeValidator::visit(const ir::operation::ResizeBilinear &node)
912 const auto output_index{node.getOutputs().at(0)};
913 const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
915 if (_ctx.at(output_index).info().isDynamic())
919 OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
920 OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
923 void ShapeValidator::visit(const ir::operation::Reverse &node)
925 const auto output_index{node.getOutputs().at(0)};
926 const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
928 if (_ctx.at(output_index).info().isDynamic())
930 OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
933 void ShapeValidator::visit(const ir::operation::If &)
935 // TODO Add to validate with subgraphs
938 void ShapeValidator::visit(const ir::operation::While &)
940 // This validator does not check shape. So checking isDynamic() is skipped.
941 // TODO Add to validate with subgraphs
944 void ShapeValidator::visit(const ir::operation::SquaredDifference &node)
946 const auto output_index{node.getOutputs().at(0)};
947 const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
948 const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
950 // Check for dimension constraints
951 if (_ctx.at(output_index).info().isDynamic())
954 auto output_shape = _ctx.at(output_index).shape();
955 auto lhs_shape = _ctx.at(lhs_index).shape();
956 auto rhs_shape = _ctx.at(rhs_index).shape();
957 // Check for output rank
958 OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
959 auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
961 for (int idx = 1; idx <= min_rank; idx++)
963 int l_idx = lhs_shape.rank() - idx;
964 int r_idx = rhs_shape.rank() - idx;
965 int out_idx = output_shape.rank() - idx;
967 OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
969 auto l_dims = lhs_shape.dim(l_idx);
970 auto r_dims = rhs_shape.dim(r_idx);
971 auto out_dims = output_shape.dim(out_idx);
973 OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
974 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
976 auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
977 for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
979 int out_idx = output_shape.rank() - idx;
980 int tmp_idx = tmp_shape.rank() - idx;
982 OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
983 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
986 void ShapeValidator::visit(const ir::operation::Tile &node)
988 const auto output_index{node.getOutputs().at(0)};
989 if (_ctx.at(output_index).info().isDynamic())
992 const auto input_index{node.getInputs().at(0)};
993 const auto multiple_index{node.getInputs().at(1)};
995 OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1);
996 OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank());
997 OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
1000 void ShapeValidator::visit(const ir::operation::Range &node)
1002 const auto output_index{node.getOutputs().at(0)};
1003 const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
1004 const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
1005 const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
1007 // Check for dimension constraints
1008 if (_ctx.at(output_index).info().isDynamic())
1011 OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0);
1012 OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0);
1013 OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0);
1016 void ShapeValidator::visit(const ir::operation::MatrixBandPart &node)
1018 const auto output_index{node.getOutputs().at(0)};
1019 const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
1020 const auto num_lower_index{
1021 node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)};
1022 const auto num_upper_index{
1023 node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
1025 // Check for dimension constraints
1026 if (_ctx.at(output_index).info().isDynamic())
1029 OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
1030 OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
1031 OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
1034 void ShapeValidator::visit(const ir::operation::LogSoftmax &node)
1036 const auto output_index{node.getOutputs().at(0)};
1037 if (_ctx.at(output_index).info().isDynamic())
1040 const auto input_index{node.getInputs().at(0)};
1042 OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
1045 } // namespace compiler
1046 } // namespace onert