2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "OperationValidator.h"
22 #include "ir/operation/LowerInfo.h"
24 #include "util/logging.h"
25 #include "util/Utils.h"
27 #define OP_REQUIRES(EXP) \
31 throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
39 OperationValidator::OperationValidator(const ir::Graph &graph)
40 : _graph{graph}, _ctx{graph.operands()}, _current_op_seq_layout{ir::Layout::UNKNOWN}
44 void OperationValidator::checkUnaryOp(const ir::Operation &node)
46 const auto output_index{node.getOutputs().at(0)};
47 const auto input_index{node.getInputs().at(0)};
49 // Check if I/O types match
50 OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
52 if (_ctx.at(output_index).info().isDynamic())
55 // Check if I/O shapes match
56 OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
59 void OperationValidator::operator()()
61 // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
63 assert(_graph.subgraphs() == nullptr);
65 _current_op_seq_layout = _graph.layout();
67 _graph.operations().iterate(
68 [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
71 void OperationValidator::visit(const ir::operation::Abs &node) { checkUnaryOp(node); }
73 void OperationValidator::visit(const ir::operation::AvgPool2D &node)
75 const auto ofm_index{node.getOutputs().at(0)};
76 if (_ctx.at(ofm_index).info().isDynamic())
79 const auto ifm_index{node.getInputs().at(ir::operation::AvgPool2D::Input::INPUT)};
81 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
84 void OperationValidator::visit(const ir::operation::BatchMatMul &node)
86 const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
87 const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
88 const auto out_index{node.getOutputs().at(0)};
90 // Constant lhs and rhs is not implemented yet
91 OP_REQUIRES(!_ctx.at(lhs_index).isConstant() && !_ctx.at(rhs_index).isConstant());
93 if (_ctx.at(out_index).info().isDynamic())
96 OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4);
97 OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4);
98 OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2);
99 OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2);
102 void OperationValidator::visit(const ir::operation::BatchToSpaceND &node)
104 const auto ofm_index{node.getOutputs().at(0)};
105 if (_ctx.at(ofm_index).info().isDynamic())
108 const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
109 const auto block_size_index{
110 node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
112 const auto frontend_layout = _current_op_seq_layout;
113 const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
114 const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
116 // All requirement as per NNAPI specification.
117 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
118 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
119 OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
121 OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
123 OP_REQUIRES(_ctx.at(block_size_index).isConstant());
125 OP_REQUIRES(input_shape.C == output_shape.C);
128 void OperationValidator::visit(const ir::operation::Cast &node)
130 const auto output_index{node.getOutputs().at(0)};
131 if (_ctx.at(output_index).info().isDynamic())
134 const auto input_index{node.getInputs().at(0)};
136 OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
139 void OperationValidator::visit(const ir::operation::Comparison &node)
141 const auto output_index{node.getOutputs().at(0)};
142 // This validator does not check shape. So checking isDynamic() is skipped.
144 const auto lhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT0)};
145 const auto rhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT1)};
147 OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
148 OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::BOOL8);
151 void OperationValidator::visit(const ir::operation::Softmax &node)
153 VERBOSE(Softmax) << "Configure SOFTMAX operation" << std::endl;
155 const auto output_index{node.getOutputs().at(0)};
156 if (_ctx.at(output_index).info().isDynamic())
159 const auto input_index{node.getInputs().at(0)};
161 OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
164 void OperationValidator::visit(const ir::operation::InstanceNorm &node)
166 const auto ofm_index{node.getOutputs().at(0)};
167 if (_ctx.at(ofm_index).info().isDynamic())
170 const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
171 const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
172 const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
174 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
175 OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape());
176 OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1);
177 OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1);
180 void OperationValidator::visit(const ir::operation::Permute &node)
182 VERBOSE(Permute) << "Configure Permute operation" << std::endl;
184 const auto output_index{node.getOutputs().at(0)};
185 if (_ctx.at(output_index).info().isDynamic())
188 const auto input_index{node.getInputs().at(0)};
190 OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
193 void OperationValidator::visit(const ir::operation::Reduce &node)
195 VERBOSE(Permute) << "Configure " + node.name() + " operation" << std::endl;
197 const auto output_index{node.getOutputs().at(0)};
198 if (_ctx.at(output_index).info().isDynamic())
201 const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
202 const auto input_shape = _ctx.at(input_index).shape();
203 const auto output_shape = _ctx.at(output_index).shape();
205 OP_REQUIRES(input_shape.rank() <= 4);
206 OP_REQUIRES(output_shape.rank() <= input_shape.rank());
208 // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
209 // supports cases reducing height and width or reducing depth.
210 // TODO We have to support all cases of dimensions up to 4.
211 // For correct permuting, we have to set output's shape to be equal in dimension position of the
212 // input. But the positions of the same dimensions in the input and output may be set differently.
213 // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
214 // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
215 // extend it in 4 dimensions, it should be {1,1,3,5}.
216 // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
217 // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
218 // next operation is not desired.
219 if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
221 if (output_shape.rank() == 2)
224 OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
225 input_shape.dim(3) == output_shape.dim(1));
227 else if (output_shape.rank() == 3)
230 // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
231 OP_REQUIRES((input_shape.dim(0) == output_shape.dim(0) &&
232 input_shape.dim(1) == output_shape.dim(1) &&
233 input_shape.dim(2) == output_shape.dim(2)) ||
234 (input_shape.dim(0) == output_shape.dim(0) &&
235 (input_shape.dim(1) == output_shape.dim(1) ||
236 input_shape.dim(2) == output_shape.dim(1)) &&
237 input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
242 void OperationValidator::visit(const ir::operation::Transpose &node)
244 const auto output_index{node.getOutputs().at(0)};
245 if (_ctx.at(output_index).info().isDynamic())
248 const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
249 const auto &perm{node.param().perm};
251 const auto &output_shape = _ctx.at(output_index).shape();
252 const auto &input_shape = _ctx.at(input_index).shape();
254 OP_REQUIRES(input_shape.rank() == static_cast<int>(perm.size()));
255 OP_REQUIRES(input_shape.rank() == output_shape.rank());
258 void OperationValidator::visit(const ir::operation::RNN &node)
260 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
261 // TODO Support dynamic rnn
262 const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
263 if (_ctx.at(output_index).info().isDynamic())
266 const auto hidden_state_out_index{
267 node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)};
269 const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
270 const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
271 const auto recurrent_weights_index{
272 node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)};
273 const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
274 const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
276 const auto batch_size = _ctx.at(output_index).shape().dim(0);
277 const auto num_units = _ctx.at(output_index).shape().dim(1);
279 OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 &&
280 _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
281 _ctx.at(input_index).shape().rank() == 2 &&
282 _ctx.at(weights_index).shape().rank() == 2 &&
283 _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
284 _ctx.at(hidden_state_in_index).shape().rank() == 2);
285 OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1);
287 OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) &&
288 batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
289 batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
290 OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
292 OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) &&
293 num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
294 num_units == _ctx.at(bias_index).shape().dim(0));
295 OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) &&
296 num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
297 num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
298 num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
301 void OperationValidator::visit(const ir::operation::Round &node) { checkUnaryOp(node); }
303 void OperationValidator::visit(const ir::operation::SpaceToBatchND &node)
305 const auto ofm_index{node.getOutputs().at(0)};
306 if (_ctx.at(ofm_index).info().isDynamic())
309 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
310 const auto block_size_index{
311 node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
312 const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
314 const auto frontend_layout = _current_op_seq_layout;
315 const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
316 const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
318 // All requirement as per NNAPI specification.
319 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
320 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
321 OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
322 OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2);
324 OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
325 OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2);
326 OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2);
328 OP_REQUIRES(_ctx.at(block_size_index).isConstant());
329 OP_REQUIRES(_ctx.at(paddings_index).isConstant());
331 OP_REQUIRES(input_shape.C == output_shape.C);
334 void OperationValidator::visit(const ir::operation::SpaceToDepth &node)
336 const auto ofm_index{node.getOutputs().at(0)};
337 if (_ctx.at(ofm_index).info().isDynamic())
340 const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
342 const auto frontend_layout = _current_op_seq_layout;
343 const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
344 const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
345 const auto block_size = node.param().block_size;
347 // All assertions as per NNAPI specification.
348 OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
349 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
350 OP_REQUIRES((block_size >= 1) && (input_shape.H % block_size == 0) &&
351 (input_shape.W % block_size == 0));
352 OP_REQUIRES(input_shape.N == output_shape.N);
353 OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
356 void OperationValidator::visit(const ir::operation::EmbeddingLookup &node)
358 const auto output_index{node.getOutputs().at(0)};
359 const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
360 const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
362 const auto &output_obj = _ctx.at(output_index);
363 const auto &lookups_obj = _ctx.at(lookups_index);
364 const auto &values_obj = _ctx.at(values_index);
366 // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
367 // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
369 OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32);
371 if (_ctx.at(output_index).info().isDynamic())
374 const auto &output_shape = output_obj.shape();
375 const auto &lookups_shape = lookups_obj.shape();
376 const auto &values_shape = values_obj.shape();
378 OP_REQUIRES(lookups_shape.rank() == 1);
379 OP_REQUIRES(values_shape.rank() >= 2);
381 // output should be a n-D tensor with the same rank and shape as the values tensor, except for
382 // the first dimension which has the same size as lookups' only dimension.
383 OP_REQUIRES(output_shape.rank() == values_shape.rank());
384 OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
385 for (int n = 1; n < output_shape.rank(); ++n)
387 OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
392 void OperationValidator::visit(const ir::operation::Exp &node) { checkUnaryOp(node); }
394 void OperationValidator::visit(const ir::operation::ExpandDims &node)
396 const auto output_index{node.getOutputs().at(0)};
397 const auto input_index{node.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
398 const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
400 OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
401 OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32);
403 if (_ctx.at(axis_index).info().isDynamic())
405 OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
408 void OperationValidator::visit(const ir::operation::Floor &node) { checkUnaryOp(node); }
410 void OperationValidator::visit(const ir::operation::HashtableLookup &node)
412 const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
413 const auto hits_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::HITS)};
415 const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
416 const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
417 const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
419 const auto &output_obj = _ctx.at(output_index);
420 const auto &hits_obj = _ctx.at(hits_index);
422 const auto &lookups_obj = _ctx.at(lookups_index);
423 const auto &keys_obj = _ctx.at(keys_index);
424 const auto &values_obj = _ctx.at(values_index);
426 OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32);
427 OP_REQUIRES(keys_obj.typeInfo().type() == ir::DataType::INT32);
428 OP_REQUIRES(hits_obj.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
430 if (_ctx.at(output_index).info().isDynamic())
433 const auto &output_shape = output_obj.shape();
434 const auto &lookups_shape = lookups_obj.shape();
435 const auto &keys_shape = keys_obj.shape();
436 const auto &values_shape = values_obj.shape();
438 OP_REQUIRES(values_shape.rank() == output_shape.rank());
439 OP_REQUIRES(lookups_shape.rank() == 1);
440 OP_REQUIRES(keys_shape.rank() == 1);
441 OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
442 OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
445 void OperationValidator::visit(const ir::operation::TransposeConv &node)
448 OP_REQUIRES((node.param().padding.type == ir::PaddingType::SAME) ||
449 (node.param().padding.type == ir::PaddingType::VALID));
452 const auto ofm_index{node.getOutputs().at(0)};
453 if (_ctx.at(ofm_index).info().isDynamic())
456 const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
457 const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
459 // Only 4D tensors are supported
460 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
461 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank());
462 OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank());
464 const auto frontend_layout = _current_op_seq_layout;
465 const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
466 const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
467 // The kernel has only IHWO layout on frontend
468 // So ker_shape is treated here below
473 const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC);
475 OP_REQUIRES(ifm_shape.N == ofm_shape.N);
476 OP_REQUIRES(ifm_shape.C == ker_shape.C);
477 OP_REQUIRES(ker_shape.N == ofm_shape.C);
480 void OperationValidator::visit(const ir::operation::Gather &node)
482 const auto ofm_index{node.getOutputs().at(0)};
483 if (_ctx.at(ofm_index).info().isDynamic())
486 const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
487 const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
489 const auto ifm_shape = _ctx.at(ifm_index).shape();
490 const auto indices_shape = _ctx.at(indices_index).shape();
491 const auto ofm_shape = _ctx.at(ofm_index).shape();
493 OP_REQUIRES(ifm_shape.rank() <= 4);
494 OP_REQUIRES(indices_shape.rank() <= 3);
495 OP_REQUIRES(ofm_shape.rank() <= 4);
498 void OperationValidator::visit(const ir::operation::Dequantize &node)
500 const auto output_index{node.getOutputs().at(0)};
502 const auto input_index{node.getInputs().at(ir::operation::Dequantize::Input::INPUT)};
504 OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
505 OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::FLOAT32);
507 if (_ctx.at(output_index).info().isDynamic())
509 OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
510 OP_REQUIRES(_ctx.at(input_index).shape() == _ctx.at(output_index).shape());
513 void OperationValidator::visit(const ir::operation::DepthToSpace &node)
516 int32_t block_size = node.param().block_size;
518 OP_REQUIRES(block_size > 0);
521 const auto output_index{node.getOutputs().at(0)};
522 if (_ctx.at(output_index).info().isDynamic())
525 const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
527 const auto frontend_layout = _current_op_seq_layout;
528 const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout);
529 const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout);
531 OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
532 OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
535 OP_REQUIRES(output_shape.N == input_shape.N);
536 OP_REQUIRES(output_shape.H == input_shape.H * block_size);
537 OP_REQUIRES(output_shape.W == input_shape.W * block_size);
538 OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
539 OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
543 void OperationValidator::visit(const ir::operation::Pack &node)
546 const auto num{node.param().num};
547 const auto axis{node.param().axis};
548 OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
550 const auto output_index{node.getOutputs().at(0)};
551 if (_ctx.at(output_index).info().isDynamic())
555 const auto &output_shape = _ctx.at(output_index).shape();
556 const auto output_rank = static_cast<int32_t>(output_shape.rank());
558 const auto input1_index{node.getInputs().at(0)};
559 const auto input_shape = _ctx.at(input1_index).shape();
561 OP_REQUIRES(axis >= -output_rank && axis < output_rank);
562 for (const auto &index : node.getInputs())
564 OP_REQUIRES(input_shape == _ctx.at(index).shape());
568 void OperationValidator::visit(const ir::operation::LSTM &node)
570 // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
571 // TODO Support dynamic rnn
572 const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
573 if (_ctx.at(output_index).info().isDynamic())
576 const auto scratch_buffer_index{
577 node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
578 const auto output_state_out_index{
579 node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
580 const auto cell_state_out_index{
581 node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
583 const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
584 const auto input_to_input_weights_index{
585 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
586 const auto input_to_forget_weights_index{
587 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
588 const auto input_to_cell_weights_index{
589 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
590 const auto input_to_output_weights_index{
591 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
592 const auto recurrent_to_input_weights_index{
593 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
594 const auto recurrent_to_forget_weights_index{
595 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
596 const auto recurrent_to_cell_weights_index{
597 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
598 const auto recurrent_to_output_weights_index{
599 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
600 const auto cell_to_input_weights_index{
601 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)};
602 const auto cell_to_forget_weights_index{
603 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)};
604 const auto cell_to_output_weights_index{
605 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)};
606 const auto input_gate_bias_index{
607 node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)};
608 const auto forget_gate_bias_index{
609 node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
610 const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
611 const auto output_gate_bias_index{
612 node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
613 const auto projection_weights_index{
614 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)};
615 const auto projection_bias_index{
616 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)};
617 const auto output_state_in_index{
618 node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
619 const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
621 OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2 &&
622 _ctx.at(output_state_out_index).shape().rank() == 2 &&
623 _ctx.at(cell_state_out_index).shape().rank() == 2 &&
624 _ctx.at(output_index).shape().rank() == 2 &&
625 _ctx.at(input_index).shape().rank() == 2 &&
626 _ctx.at(input_to_input_weights_index).shape().rank() == 2 &&
627 _ctx.at(input_to_forget_weights_index).shape().rank() == 2 &&
628 _ctx.at(input_to_cell_weights_index).shape().rank() == 2 &&
629 _ctx.at(input_to_output_weights_index).shape().rank() == 2 &&
630 _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2 &&
631 _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
632 _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
633 _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
634 _ctx.at(projection_weights_index).shape().rank() == 2 &&
635 _ctx.at(output_state_in_index).shape().rank() == 2 &&
636 _ctx.at(cell_state_in_index).shape().rank() == 2);
638 OP_REQUIRES(_ctx.at(cell_to_input_weights_index).shape().rank() == 1 &&
639 _ctx.at(cell_to_forget_weights_index).shape().rank() == 1 &&
640 _ctx.at(cell_to_output_weights_index).shape().rank() == 1 &&
641 _ctx.at(input_gate_bias_index).shape().rank() == 1 &&
642 _ctx.at(forget_gate_bias_index).shape().rank() == 1 &&
643 _ctx.at(cell_bias_index).shape().rank() == 1 &&
644 _ctx.at(output_gate_bias_index).shape().rank() == 1 &&
645 _ctx.at(projection_bias_index).shape().rank() == 1);
648 OP_REQUIRES((_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 &&
649 _ctx.at(input_to_input_weights_index).shape().dim(1) == 0 &&
650 _ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
651 _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0 &&
652 _ctx.at(input_gate_bias_index).shape().dim(0) == 0 &&
653 _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) ||
654 (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
655 _ctx.at(input_to_input_weights_index).shape().dim(1) != 0 &&
656 _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
657 _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0 &&
658 _ctx.at(input_gate_bias_index).shape().dim(0) != 0));
660 // Peephole assertion
661 OP_REQUIRES((_ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0 &&
662 _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0) ||
663 (_ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0 &&
664 _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0));
666 bool has_input_to_input_weights = _ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
667 _ctx.at(input_to_input_weights_index).shape().dim(1) != 0;
668 bool has_recurrent_to_input_weights =
669 _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
670 _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
671 bool has_input_gate_bias = _ctx.at(input_gate_bias_index).shape().dim(0) != 0;
672 bool has_cell_to_input_weights = _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0;
673 bool has_cell_to_forget_weights = _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0;
674 bool has_cell_to_output_weights = _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0;
675 bool has_projection_weights = _ctx.at(projection_weights_index).shape().dim(0) != 0 &&
676 _ctx.at(projection_weights_index).shape().dim(1) != 0;
677 bool has_projection_bias = _ctx.at(projection_bias_index).shape().dim(0);
679 // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
682 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
684 // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
686 // false: no peephole
687 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
689 // NOTE The projection weights may have data but the projection bias may not.
690 bool has_projection_param = has_projection_weights;
692 const auto batch_size = _ctx.at(input_index).shape().dim(0);
693 OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) &&
694 batch_size == _ctx.at(cell_state_in_index).shape().dim(0) &&
695 batch_size == _ctx.at(scratch_buffer_index).shape().dim(0) &&
696 batch_size == _ctx.at(output_state_out_index).shape().dim(0) &&
697 batch_size == _ctx.at(cell_state_out_index).shape().dim(0) &&
698 batch_size == _ctx.at(output_index).shape().dim(0));
700 const auto input_size = _ctx.at(input_index).shape().dim(1);
701 OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) &&
702 input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) &&
703 input_size == _ctx.at(input_to_output_weights_index).shape().dim(1));
705 const auto num_units = _ctx.at(cell_state_out_index).shape().dim(1);
706 OP_REQUIRES(num_units == _ctx.at(input_to_forget_weights_index).shape().dim(0) &&
707 num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) &&
708 num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) &&
709 num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) &&
710 num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) &&
711 num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) &&
712 num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) &&
713 num_units == _ctx.at(cell_bias_index).shape().dim(0) &&
714 num_units == _ctx.at(output_gate_bias_index).shape().dim(0) &&
715 num_units == _ctx.at(cell_state_in_index).shape().dim(1) &&
716 (((num_units * 3) == _ctx.at(scratch_buffer_index).shape().dim(1)) ||
717 ((num_units * 4) == _ctx.at(scratch_buffer_index).shape().dim(1))));
719 const auto output_size = _ctx.at(output_index).shape().dim(1);
720 OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) &&
721 output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) &&
722 output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) &&
723 output_size == _ctx.at(output_state_in_index).shape().dim(1) &&
724 output_size == _ctx.at(output_state_out_index).shape().dim(1));
728 OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1));
729 OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) &&
730 num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) &&
731 (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
732 _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* non-peephole */) &&
733 num_units == _ctx.at(input_gate_bias_index).shape().dim(0));
734 OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1));
735 OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
736 has_input_gate_bias);
737 if (has_cell_to_input_weights)
739 // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
740 OP_REQUIRES(has_peephole_param);
742 OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
746 OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
749 if (has_peephole_param)
751 OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) &&
752 num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) &&
753 (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
754 _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
757 if (has_projection_param)
759 OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1));
760 OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0));
761 if (has_projection_bias)
763 OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0));
768 void OperationValidator::visit(const ir::operation::L2Normalization &node)
770 const auto ofm_index{node.getOutputs().at(0)};
771 if (_ctx.at(ofm_index).info().isDynamic())
774 const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
776 auto ifm_shape = _ctx.at(ifm_index).shape();
777 auto ofm_shape = _ctx.at(ofm_index).shape();
779 OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
781 for (auto i = 0; i < ifm_shape.rank(); i++)
783 OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
787 void OperationValidator::visit(const ir::operation::Unpack &node)
789 const auto num{node.param().num};
790 OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
791 const auto axis{node.param().axis};
793 const auto output_index{node.getInputs().at(0)};
794 if (_ctx.at(output_index).info().isDynamic())
797 const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
799 const auto &input_shape = _ctx.at(input_index).shape();
800 const auto input_rank = static_cast<int32_t>(input_shape.rank());
802 OP_REQUIRES(axis >= -input_rank && axis < input_rank);
805 void OperationValidator::visit(const ir::operation::Pad &node)
807 const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
808 OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32);
810 const auto output_index{node.getInputs().at(0)};
811 if (_ctx.at(output_index).info().isDynamic())
814 const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
816 const auto &pad_shape = _ctx.at(pad_index).shape();
817 const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank());
819 OP_REQUIRES(pad_shape.rank() == 2);
820 OP_REQUIRES(pad_shape.dim(0) == input_rank);
821 OP_REQUIRES(pad_shape.dim(1) == 2);
822 OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
825 void OperationValidator::visit(const ir::operation::Min &node)
827 const auto output_index{node.getOutputs().at(0)};
828 // This validator does not check shape. So checking isDynamic() is skipped.
830 const auto lhs_index{node.getInputs().at(ir::operation::Min::Input::LHS)};
831 const auto rhs_index{node.getInputs().at(ir::operation::Min::Input::RHS)};
833 OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
834 OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type());
837 void OperationValidator::visit(const ir::operation::Max &node)
839 const auto output_index{node.getOutputs().at(0)};
840 // This validator does not check shape. So checking isDynamic() is skipped.
842 const auto lhs_index{node.getInputs().at(ir::operation::Max::Input::LHS)};
843 const auto rhs_index{node.getInputs().at(ir::operation::Max::Input::RHS)};
845 OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
846 OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type());
849 void OperationValidator::visit(const ir::operation::Select &node)
851 const auto output_index{node.getOutputs().at(0)};
852 // This validator does not check shape. So checking isDynamic() is skipped.
854 const auto condition_index{node.getInputs().at(ir::operation::Select::Input::CONDITION)};
855 const auto input_true_index{node.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
856 const auto input_false_index{node.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
857 UNUSED_RELEASE(output_index);
858 UNUSED_RELEASE(input_true_index);
859 UNUSED_RELEASE(input_false_index);
861 OP_REQUIRES(_ctx.at(condition_index).typeInfo().type() == ir::DataType::BOOL8);
864 void OperationValidator::visit(const ir::operation::StridedSlice &node)
866 const auto output_index{node.getOutputs().at(0)};
867 const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
868 const auto starts_index{node.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
869 const auto ends_index{node.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
870 const auto strides_index{node.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
872 UNUSED_RELEASE(starts_index);
873 UNUSED_RELEASE(ends_index);
874 UNUSED_RELEASE(strides_index);
876 OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
878 if (_ctx.at(output_index).info().isDynamic())
881 OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
884 void OperationValidator::visit(const ir::operation::Split &node)
886 const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
888 if (_ctx.at(input_index).info().isDynamic())
891 const auto num_splits = node.param().num_splits;
892 const auto input_rank = _ctx.at(input_index).shape().rank();
893 const auto axis = node.param().axis < 0 ? node.param().axis + input_rank : node.param().axis;
895 OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
896 OP_REQUIRES(axis >= 0 && axis < input_rank);
897 OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
899 OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
902 void OperationValidator::visit(const ir::operation::Cos &node) { checkUnaryOp(node); }
904 void OperationValidator::visit(const ir::operation::Sin &node) { checkUnaryOp(node); }
906 void OperationValidator::visit(const ir::operation::RSQRT &node) { checkUnaryOp(node); }
908 void OperationValidator::visit(const ir::operation::Shape &node)
910 const auto output_index{node.getOutputs().at(0)};
911 if (_ctx.at(output_index).info().isDynamic())
914 const auto input_index{node.getInputs().at(0)};
915 UNUSED_RELEASE(input_index);
916 OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
919 void OperationValidator::visit(const ir::operation::ResizeBilinear &node)
921 const auto output_index{node.getOutputs().at(0)};
922 const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
924 if (_ctx.at(output_index).info().isDynamic())
928 OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
929 OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
931 auto align_corners = node.param().align_corners;
932 auto half_pixel_centers = node.param().half_pixel_centers;
934 OP_REQUIRES(!align_corners || !half_pixel_centers);
937 void OperationValidator::visit(const ir::operation::Reverse &node)
939 const auto output_index{node.getOutputs().at(0)};
940 const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
941 const auto axis_index{node.getInputs().at(ir::operation::Reverse::Input::AXIS)};
943 OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32);
944 OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
946 if (_ctx.at(output_index).info().isDynamic())
948 OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
951 void OperationValidator::visit(const ir::operation::If &)
953 // TODO Add to validate with subgraphs
956 void OperationValidator::visit(const ir::operation::While &node)
958 // This validator does not check shape. So checking isDynamic() is skipped.
960 OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
961 // TODO Add to validate with subgraphs
964 void OperationValidator::visit(const ir::operation::Neg &node) { checkUnaryOp(node); }
966 void OperationValidator::visit(const ir::operation::Log &node) { checkUnaryOp(node); }
968 void OperationValidator::visit(const ir::operation::LogicalNot &node) { checkUnaryOp(node); }
970 void OperationValidator::visit(const ir::operation::SquaredDifference &node)
972 const auto output_index{node.getOutputs().at(0)};
973 const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
974 const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
976 // Check for Type equivalence
977 OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(lhs_index).typeInfo().type());
978 OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
980 // Check for dimension constraints
981 if (_ctx.at(output_index).info().isDynamic())
984 auto output_shape = _ctx.at(output_index).shape();
985 auto lhs_shape = _ctx.at(lhs_index).shape();
986 auto rhs_shape = _ctx.at(rhs_index).shape();
987 // Check for output rank
988 OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
989 auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
991 for (int idx = 1; idx <= min_rank; idx++)
993 int l_idx = lhs_shape.rank() - idx;
994 int r_idx = rhs_shape.rank() - idx;
995 int out_idx = output_shape.rank() - idx;
997 OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
999 auto l_dims = lhs_shape.dim(l_idx);
1000 auto r_dims = rhs_shape.dim(r_idx);
1001 auto out_dims = output_shape.dim(out_idx);
1003 OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
1004 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
1006 auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
1007 for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
1009 int out_idx = output_shape.rank() - idx;
1010 int tmp_idx = tmp_shape.rank() - idx;
1012 OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
1013 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
1016 void OperationValidator::visit(const ir::operation::Tile &node)
1018 const auto output_index{node.getOutputs().at(0)};
1019 if (_ctx.at(output_index).info().isDynamic())
1022 const auto input_index{node.getInputs().at(0)};
1023 const auto multiple_index{node.getInputs().at(1)};
1025 OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1);
1026 OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank());
1027 OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
1030 void OperationValidator::visit(const ir::operation::LogicalOr &node)
1032 const auto output_index{node.getOutputs().at(0)};
1033 const auto lhs_index{node.getInputs().at(0)};
1034 const auto rhs_index{node.getInputs().at(1)};
1036 OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
1037 OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type());
1040 void OperationValidator::visit(const ir::operation::Range &node)
1042 const auto output_index{node.getOutputs().at(0)};
1043 const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
1044 const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
1045 const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
1047 // Check for dimension constraints
1048 if (_ctx.at(output_index).info().isDynamic())
1051 OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0);
1052 OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0);
1053 OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0);
1056 void OperationValidator::visit(const ir::operation::MatrixBandPart &node)
1058 const auto output_index{node.getOutputs().at(0)};
1059 const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
1060 const auto num_lower_index{
1061 node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)};
1062 const auto num_upper_index{
1063 node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
1065 // Check for dimension constraints
1066 if (_ctx.at(output_index).info().isDynamic())
1069 OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
1070 OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
1071 OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
1074 void OperationValidator::visit(const ir::operation::LogSoftmax &node)
1076 VERBOSE(LogSoftmax) << "Configure LOGSOFTMAX operation" << std::endl;
1078 const auto output_index{node.getOutputs().at(0)};
1079 if (_ctx.at(output_index).info().isDynamic())
1082 const auto input_index{node.getInputs().at(0)};
1084 OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
1087 void OperationValidator::visit(const ir::operation::Quantize &node)
1089 VERBOSE(Quantize) << "Configure Quantize operation" << std::endl;
1091 OP_REQUIRES(node.getInputs().size() == 1);
1092 OP_REQUIRES(node.getOutputs().size() == 1);
1094 const auto input_index{node.getInputs().at(0)};
1095 const auto output_index{node.getOutputs().at(0)};
1097 OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::FLOAT32);
1099 if (_ctx.at(output_index).info().isDynamic())
1102 OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
1104 OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
1106 } // namespace compiler
1107 } // namespace onert