Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / OperationValidator.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "OperationValidator.h"
18
19 #include <typeinfo>
20
21 #include "ir/Graph.h"
22 #include "ir/operation/LowerInfo.h"
23
24 #include "util/logging.h"
25 #include "util/Utils.h"
26
27 #define OP_REQUIRES(EXP)                                                                         \
28   do                                                                                             \
29   {                                                                                              \
30     if (!(EXP))                                                                                  \
31       throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
32   } while (0)
33
34 namespace onert
35 {
36 namespace compiler
37 {
38
39 OperationValidator::OperationValidator(const ir::Graph &graph)
40     : _graph{graph}, _ctx{graph.operands()}, _current_op_seq_layout{ir::Layout::UNKNOWN}
41 {
42 }
43
44 void OperationValidator::checkUnaryOp(const ir::Operation &node)
45 {
46   const auto output_index{node.getOutputs().at(0)};
47   const auto input_index{node.getInputs().at(0)};
48
49   // Check if I/O types match
50   OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
51
52   if (_ctx.at(output_index).info().isDynamic())
53     return;
54
55   // Check if I/O shapes match
56   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
57 }
58
59 void OperationValidator::operator()()
60 {
61   // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
62   // creating Compiler
63   assert(_graph.subgraphs() == nullptr);
64
65   _current_op_seq_layout = _graph.layout();
66
67   _graph.operations().iterate(
68       [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
69 }
70
71 void OperationValidator::visit(const ir::operation::Abs &node) { checkUnaryOp(node); }
72
73 void OperationValidator::visit(const ir::operation::AvgPool2D &node)
74 {
75   const auto ofm_index{node.getOutputs().at(0)};
76   if (_ctx.at(ofm_index).info().isDynamic())
77     return;
78
79   const auto ifm_index{node.getInputs().at(ir::operation::AvgPool2D::Input::INPUT)};
80
81   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
82 }
83
84 void OperationValidator::visit(const ir::operation::BatchMatMul &node)
85 {
86   const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
87   const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
88   const auto out_index{node.getOutputs().at(0)};
89
90   // Constant lhs and rhs is not implemented yet
91   OP_REQUIRES(!_ctx.at(lhs_index).isConstant() && !_ctx.at(rhs_index).isConstant());
92
93   if (_ctx.at(out_index).info().isDynamic())
94     return;
95
96   OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4);
97   OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4);
98   OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2);
99   OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2);
100 }
101
102 void OperationValidator::visit(const ir::operation::BatchToSpaceND &node)
103 {
104   const auto ofm_index{node.getOutputs().at(0)};
105   if (_ctx.at(ofm_index).info().isDynamic())
106     return;
107
108   const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
109   const auto block_size_index{
110       node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
111
112   const auto frontend_layout = _current_op_seq_layout;
113   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
114   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
115
116   // All requirement as per NNAPI specification.
117   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
118   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
119   OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
120
121   OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
122
123   OP_REQUIRES(_ctx.at(block_size_index).isConstant());
124
125   OP_REQUIRES(input_shape.C == output_shape.C);
126 }
127
128 void OperationValidator::visit(const ir::operation::Cast &node)
129 {
130   const auto output_index{node.getOutputs().at(0)};
131   if (_ctx.at(output_index).info().isDynamic())
132     return;
133
134   const auto input_index{node.getInputs().at(0)};
135
136   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
137 }
138
139 void OperationValidator::visit(const ir::operation::Comparison &node)
140 {
141   const auto output_index{node.getOutputs().at(0)};
142   // This validator does not check shape. So checking isDynamic() is skipped.
143
144   const auto lhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT0)};
145   const auto rhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT1)};
146
147   OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
148   OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::BOOL8);
149 }
150
151 void OperationValidator::visit(const ir::operation::Softmax &node)
152 {
153   VERBOSE(Softmax) << "Configure SOFTMAX operation" << std::endl;
154
155   const auto output_index{node.getOutputs().at(0)};
156   if (_ctx.at(output_index).info().isDynamic())
157     return;
158
159   const auto input_index{node.getInputs().at(0)};
160
161   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
162 }
163
164 void OperationValidator::visit(const ir::operation::InstanceNorm &node)
165 {
166   const auto ofm_index{node.getOutputs().at(0)};
167   if (_ctx.at(ofm_index).info().isDynamic())
168     return;
169
170   const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
171   const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
172   const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
173
174   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
175   OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape());
176   OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1);
177   OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1);
178 }
179
180 void OperationValidator::visit(const ir::operation::Permute &node)
181 {
182   VERBOSE(Permute) << "Configure Permute operation" << std::endl;
183
184   const auto output_index{node.getOutputs().at(0)};
185   if (_ctx.at(output_index).info().isDynamic())
186     return;
187
188   const auto input_index{node.getInputs().at(0)};
189
190   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
191 }
192
193 void OperationValidator::visit(const ir::operation::Reduce &node)
194 {
195   VERBOSE(Permute) << "Configure " + node.name() + " operation" << std::endl;
196
197   const auto output_index{node.getOutputs().at(0)};
198   if (_ctx.at(output_index).info().isDynamic())
199     return;
200
201   const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
202   const auto input_shape = _ctx.at(input_index).shape();
203   const auto output_shape = _ctx.at(output_index).shape();
204
205   OP_REQUIRES(input_shape.rank() <= 4);
206   OP_REQUIRES(output_shape.rank() <= input_shape.rank());
207
208   // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
209   // supports cases reducing height and width or reducing depth.
210   // TODO We have to support all cases of dimensions up to 4.
211   // For correct permuting, we have to set output's shape to be equal in dimension position of the
212   // input. But the positions of the same dimensions in the input and output may be set differently.
213   // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
214   // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
215   // extend it in 4 dimensions, it should be {1,1,3,5}.
216   // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
217   // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
218   // next operation is not desired.
219   if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
220   {
221     if (output_shape.rank() == 2)
222     {
223       // Reducing HW
224       OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
225                   input_shape.dim(3) == output_shape.dim(1));
226     }
227     else if (output_shape.rank() == 3)
228     {
229       // Reducing C or
230       // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
231       OP_REQUIRES((input_shape.dim(0) == output_shape.dim(0) &&
232                    input_shape.dim(1) == output_shape.dim(1) &&
233                    input_shape.dim(2) == output_shape.dim(2)) ||
234                   (input_shape.dim(0) == output_shape.dim(0) &&
235                    (input_shape.dim(1) == output_shape.dim(1) ||
236                     input_shape.dim(2) == output_shape.dim(1)) &&
237                    input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
238     }
239   }
240 }
241
242 void OperationValidator::visit(const ir::operation::Transpose &node)
243 {
244   const auto output_index{node.getOutputs().at(0)};
245   if (_ctx.at(output_index).info().isDynamic())
246     return;
247
248   const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
249   const auto &perm{node.param().perm};
250
251   const auto &output_shape = _ctx.at(output_index).shape();
252   const auto &input_shape = _ctx.at(input_index).shape();
253
254   OP_REQUIRES(input_shape.rank() == static_cast<int>(perm.size()));
255   OP_REQUIRES(input_shape.rank() == output_shape.rank());
256 }
257
258 void OperationValidator::visit(const ir::operation::RNN &node)
259 {
260   // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
261   // TODO Support dynamic rnn
262   const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
263   if (_ctx.at(output_index).info().isDynamic())
264     return;
265
266   const auto hidden_state_out_index{
267       node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)};
268
269   const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
270   const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
271   const auto recurrent_weights_index{
272       node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)};
273   const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
274   const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
275
276   const auto batch_size = _ctx.at(output_index).shape().dim(0);
277   const auto num_units = _ctx.at(output_index).shape().dim(1);
278
279   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 &&
280               _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
281               _ctx.at(input_index).shape().rank() == 2 &&
282               _ctx.at(weights_index).shape().rank() == 2 &&
283               _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
284               _ctx.at(hidden_state_in_index).shape().rank() == 2);
285   OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1);
286
287   OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) &&
288               batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
289               batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
290   OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
291
292   OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) &&
293               num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
294               num_units == _ctx.at(bias_index).shape().dim(0));
295   OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) &&
296               num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
297               num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
298               num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
299 }
300
301 void OperationValidator::visit(const ir::operation::Round &node) { checkUnaryOp(node); }
302
303 void OperationValidator::visit(const ir::operation::SpaceToBatchND &node)
304 {
305   const auto ofm_index{node.getOutputs().at(0)};
306   if (_ctx.at(ofm_index).info().isDynamic())
307     return;
308
309   const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
310   const auto block_size_index{
311       node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
312   const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
313
314   const auto frontend_layout = _current_op_seq_layout;
315   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
316   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
317
318   // All requirement as per NNAPI specification.
319   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
320   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
321   OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
322   OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2);
323
324   OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
325   OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2);
326   OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2);
327
328   OP_REQUIRES(_ctx.at(block_size_index).isConstant());
329   OP_REQUIRES(_ctx.at(paddings_index).isConstant());
330
331   OP_REQUIRES(input_shape.C == output_shape.C);
332 }
333
334 void OperationValidator::visit(const ir::operation::SpaceToDepth &node)
335 {
336   const auto ofm_index{node.getOutputs().at(0)};
337   if (_ctx.at(ofm_index).info().isDynamic())
338     return;
339
340   const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
341
342   const auto frontend_layout = _current_op_seq_layout;
343   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
344   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
345   const auto block_size = node.param().block_size;
346
347   // All assertions as per NNAPI specification.
348   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
349   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
350   OP_REQUIRES((block_size >= 1) && (input_shape.H % block_size == 0) &&
351               (input_shape.W % block_size == 0));
352   OP_REQUIRES(input_shape.N == output_shape.N);
353   OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
354 }
355
356 void OperationValidator::visit(const ir::operation::EmbeddingLookup &node)
357 {
358   const auto output_index{node.getOutputs().at(0)};
359   const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
360   const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
361
362   const auto &output_obj = _ctx.at(output_index);
363   const auto &lookups_obj = _ctx.at(lookups_index);
364   const auto &values_obj = _ctx.at(values_index);
365
366   // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
367   // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
368   {
369     OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32);
370
371     if (_ctx.at(output_index).info().isDynamic())
372       return;
373
374     const auto &output_shape = output_obj.shape();
375     const auto &lookups_shape = lookups_obj.shape();
376     const auto &values_shape = values_obj.shape();
377
378     OP_REQUIRES(lookups_shape.rank() == 1);
379     OP_REQUIRES(values_shape.rank() >= 2);
380
381     // output should be a n-D tensor with the same rank and shape as the values tensor, except for
382     // the first dimension which has the same size as lookups' only dimension.
383     OP_REQUIRES(output_shape.rank() == values_shape.rank());
384     OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
385     for (int n = 1; n < output_shape.rank(); ++n)
386     {
387       OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
388     }
389   }
390 }
391
392 void OperationValidator::visit(const ir::operation::Exp &node) { checkUnaryOp(node); }
393
394 void OperationValidator::visit(const ir::operation::ExpandDims &node)
395 {
396   const auto output_index{node.getOutputs().at(0)};
397   const auto input_index{node.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
398   const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
399
400   OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
401   OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32);
402
403   if (_ctx.at(axis_index).info().isDynamic())
404     return;
405   OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
406 }
407
408 void OperationValidator::visit(const ir::operation::Floor &node) { checkUnaryOp(node); }
409
410 void OperationValidator::visit(const ir::operation::HashtableLookup &node)
411 {
412   const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
413   const auto hits_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::HITS)};
414
415   const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
416   const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
417   const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
418
419   const auto &output_obj = _ctx.at(output_index);
420   const auto &hits_obj = _ctx.at(hits_index);
421
422   const auto &lookups_obj = _ctx.at(lookups_index);
423   const auto &keys_obj = _ctx.at(keys_index);
424   const auto &values_obj = _ctx.at(values_index);
425
426   OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32);
427   OP_REQUIRES(keys_obj.typeInfo().type() == ir::DataType::INT32);
428   OP_REQUIRES(hits_obj.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
429
430   if (_ctx.at(output_index).info().isDynamic())
431     return;
432
433   const auto &output_shape = output_obj.shape();
434   const auto &lookups_shape = lookups_obj.shape();
435   const auto &keys_shape = keys_obj.shape();
436   const auto &values_shape = values_obj.shape();
437
438   OP_REQUIRES(values_shape.rank() == output_shape.rank());
439   OP_REQUIRES(lookups_shape.rank() == 1);
440   OP_REQUIRES(keys_shape.rank() == 1);
441   OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
442   OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
443 }
444
445 void OperationValidator::visit(const ir::operation::TransposeConv &node)
446 {
447   // param check
448   OP_REQUIRES((node.param().padding.type == ir::PaddingType::SAME) ||
449               (node.param().padding.type == ir::PaddingType::VALID));
450
451   // shape check
452   const auto ofm_index{node.getOutputs().at(0)};
453   if (_ctx.at(ofm_index).info().isDynamic())
454     return;
455
456   const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
457   const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
458
459   // Only 4D tensors are supported
460   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
461   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank());
462   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank());
463
464   const auto frontend_layout = _current_op_seq_layout;
465   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
466   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
467   // The kernel has only IHWO layout on frontend
468   // So ker_shape is treated here below
469   // I -> N
470   // H -> H
471   // W -> W
472   // O -> C
473   const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC);
474
475   OP_REQUIRES(ifm_shape.N == ofm_shape.N);
476   OP_REQUIRES(ifm_shape.C == ker_shape.C);
477   OP_REQUIRES(ker_shape.N == ofm_shape.C);
478 }
479
480 void OperationValidator::visit(const ir::operation::Gather &node)
481 {
482   const auto ofm_index{node.getOutputs().at(0)};
483   if (_ctx.at(ofm_index).info().isDynamic())
484     return;
485
486   const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
487   const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
488
489   const auto ifm_shape = _ctx.at(ifm_index).shape();
490   const auto indices_shape = _ctx.at(indices_index).shape();
491   const auto ofm_shape = _ctx.at(ofm_index).shape();
492
493   OP_REQUIRES(ifm_shape.rank() <= 4);
494   OP_REQUIRES(indices_shape.rank() <= 3);
495   OP_REQUIRES(ofm_shape.rank() <= 4);
496 }
497
498 void OperationValidator::visit(const ir::operation::Dequantize &node)
499 {
500   const auto output_index{node.getOutputs().at(0)};
501
502   const auto input_index{node.getInputs().at(ir::operation::Dequantize::Input::INPUT)};
503
504   OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
505   OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::FLOAT32);
506
507   if (_ctx.at(output_index).info().isDynamic())
508     return;
509   OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
510   OP_REQUIRES(_ctx.at(input_index).shape() == _ctx.at(output_index).shape());
511 }
512
513 void OperationValidator::visit(const ir::operation::DepthToSpace &node)
514 {
515   // param check
516   int32_t block_size = node.param().block_size;
517
518   OP_REQUIRES(block_size > 0);
519
520   // shape check
521   const auto output_index{node.getOutputs().at(0)};
522   if (_ctx.at(output_index).info().isDynamic())
523     return;
524
525   const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
526
527   const auto frontend_layout = _current_op_seq_layout;
528   const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout);
529   const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout);
530
531   OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
532   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
533
534   {
535     OP_REQUIRES(output_shape.N == input_shape.N);
536     OP_REQUIRES(output_shape.H == input_shape.H * block_size);
537     OP_REQUIRES(output_shape.W == input_shape.W * block_size);
538     OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
539     OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
540   }
541 }
542
543 void OperationValidator::visit(const ir::operation::Pack &node)
544 {
545   // param check
546   const auto num{node.param().num};
547   const auto axis{node.param().axis};
548   OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
549
550   const auto output_index{node.getOutputs().at(0)};
551   if (_ctx.at(output_index).info().isDynamic())
552     return;
553
554   // shape check
555   const auto &output_shape = _ctx.at(output_index).shape();
556   const auto output_rank = static_cast<int32_t>(output_shape.rank());
557
558   const auto input1_index{node.getInputs().at(0)};
559   const auto input_shape = _ctx.at(input1_index).shape();
560
561   OP_REQUIRES(axis >= -output_rank && axis < output_rank);
562   for (const auto &index : node.getInputs())
563   {
564     OP_REQUIRES(input_shape == _ctx.at(index).shape());
565   }
566 }
567
568 void OperationValidator::visit(const ir::operation::LSTM &node)
569 {
570   // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
571   // TODO Support dynamic rnn
572   const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
573   if (_ctx.at(output_index).info().isDynamic())
574     return;
575
576   const auto scratch_buffer_index{
577       node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
578   const auto output_state_out_index{
579       node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
580   const auto cell_state_out_index{
581       node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
582
583   const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
584   const auto input_to_input_weights_index{
585       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
586   const auto input_to_forget_weights_index{
587       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
588   const auto input_to_cell_weights_index{
589       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
590   const auto input_to_output_weights_index{
591       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
592   const auto recurrent_to_input_weights_index{
593       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
594   const auto recurrent_to_forget_weights_index{
595       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
596   const auto recurrent_to_cell_weights_index{
597       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
598   const auto recurrent_to_output_weights_index{
599       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
600   const auto cell_to_input_weights_index{
601       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)};
602   const auto cell_to_forget_weights_index{
603       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)};
604   const auto cell_to_output_weights_index{
605       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)};
606   const auto input_gate_bias_index{
607       node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)};
608   const auto forget_gate_bias_index{
609       node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
610   const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
611   const auto output_gate_bias_index{
612       node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
613   const auto projection_weights_index{
614       node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)};
615   const auto projection_bias_index{
616       node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)};
617   const auto output_state_in_index{
618       node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
619   const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
620
621   OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2 &&
622               _ctx.at(output_state_out_index).shape().rank() == 2 &&
623               _ctx.at(cell_state_out_index).shape().rank() == 2 &&
624               _ctx.at(output_index).shape().rank() == 2 &&
625               _ctx.at(input_index).shape().rank() == 2 &&
626               _ctx.at(input_to_input_weights_index).shape().rank() == 2 &&
627               _ctx.at(input_to_forget_weights_index).shape().rank() == 2 &&
628               _ctx.at(input_to_cell_weights_index).shape().rank() == 2 &&
629               _ctx.at(input_to_output_weights_index).shape().rank() == 2 &&
630               _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2 &&
631               _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
632               _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
633               _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
634               _ctx.at(projection_weights_index).shape().rank() == 2 &&
635               _ctx.at(output_state_in_index).shape().rank() == 2 &&
636               _ctx.at(cell_state_in_index).shape().rank() == 2);
637
638   OP_REQUIRES(_ctx.at(cell_to_input_weights_index).shape().rank() == 1 &&
639               _ctx.at(cell_to_forget_weights_index).shape().rank() == 1 &&
640               _ctx.at(cell_to_output_weights_index).shape().rank() == 1 &&
641               _ctx.at(input_gate_bias_index).shape().rank() == 1 &&
642               _ctx.at(forget_gate_bias_index).shape().rank() == 1 &&
643               _ctx.at(cell_bias_index).shape().rank() == 1 &&
644               _ctx.at(output_gate_bias_index).shape().rank() == 1 &&
645               _ctx.at(projection_bias_index).shape().rank() == 1);
646
647   // CIFG assertion
648   OP_REQUIRES((_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 &&
649                _ctx.at(input_to_input_weights_index).shape().dim(1) == 0 &&
650                _ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
651                _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0 &&
652                _ctx.at(input_gate_bias_index).shape().dim(0) == 0 &&
653                _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) ||
654               (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
655                _ctx.at(input_to_input_weights_index).shape().dim(1) != 0 &&
656                _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
657                _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0 &&
658                _ctx.at(input_gate_bias_index).shape().dim(0) != 0));
659
660   // Peephole assertion
661   OP_REQUIRES((_ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0 &&
662                _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0) ||
663               (_ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0 &&
664                _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0));
665
666   bool has_input_to_input_weights = _ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
667                                     _ctx.at(input_to_input_weights_index).shape().dim(1) != 0;
668   bool has_recurrent_to_input_weights =
669       _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
670       _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
671   bool has_input_gate_bias = _ctx.at(input_gate_bias_index).shape().dim(0) != 0;
672   bool has_cell_to_input_weights = _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0;
673   bool has_cell_to_forget_weights = _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0;
674   bool has_cell_to_output_weights = _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0;
675   bool has_projection_weights = _ctx.at(projection_weights_index).shape().dim(0) != 0 &&
676                                 _ctx.at(projection_weights_index).shape().dim(1) != 0;
677   bool has_projection_bias = _ctx.at(projection_bias_index).shape().dim(0);
678
679   // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
680   // true: no CIFG
681   // false: CIFG
682   bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
683
684   // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
685   // true: peephole
686   // false: no peephole
687   bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
688
689   // NOTE The projection weights may have data but the projection bias may not.
690   bool has_projection_param = has_projection_weights;
691
692   const auto batch_size = _ctx.at(input_index).shape().dim(0);
693   OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) &&
694               batch_size == _ctx.at(cell_state_in_index).shape().dim(0) &&
695               batch_size == _ctx.at(scratch_buffer_index).shape().dim(0) &&
696               batch_size == _ctx.at(output_state_out_index).shape().dim(0) &&
697               batch_size == _ctx.at(cell_state_out_index).shape().dim(0) &&
698               batch_size == _ctx.at(output_index).shape().dim(0));
699
700   const auto input_size = _ctx.at(input_index).shape().dim(1);
701   OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) &&
702               input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) &&
703               input_size == _ctx.at(input_to_output_weights_index).shape().dim(1));
704
705   const auto num_units = _ctx.at(cell_state_out_index).shape().dim(1);
706   OP_REQUIRES(num_units == _ctx.at(input_to_forget_weights_index).shape().dim(0) &&
707               num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) &&
708               num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) &&
709               num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) &&
710               num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) &&
711               num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) &&
712               num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) &&
713               num_units == _ctx.at(cell_bias_index).shape().dim(0) &&
714               num_units == _ctx.at(output_gate_bias_index).shape().dim(0) &&
715               num_units == _ctx.at(cell_state_in_index).shape().dim(1) &&
716               (((num_units * 3) == _ctx.at(scratch_buffer_index).shape().dim(1)) ||
717                ((num_units * 4) == _ctx.at(scratch_buffer_index).shape().dim(1))));
718
719   const auto output_size = _ctx.at(output_index).shape().dim(1);
720   OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) &&
721               output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) &&
722               output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) &&
723               output_size == _ctx.at(output_state_in_index).shape().dim(1) &&
724               output_size == _ctx.at(output_state_out_index).shape().dim(1));
725
726   if (has_cifg_param)
727   {
728     OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1));
729     OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) &&
730                 num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) &&
731                 (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
732                  _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* non-peephole */) &&
733                 num_units == _ctx.at(input_gate_bias_index).shape().dim(0));
734     OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1));
735     OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
736                 has_input_gate_bias);
737     if (has_cell_to_input_weights)
738     {
739       // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
740       OP_REQUIRES(has_peephole_param);
741     }
742     OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
743   }
744   else
745   {
746     OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
747   }
748
749   if (has_peephole_param)
750   {
751     OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) &&
752                 num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) &&
753                 (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
754                  _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
755   }
756
757   if (has_projection_param)
758   {
759     OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1));
760     OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0));
761     if (has_projection_bias)
762     {
763       OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0));
764     }
765   }
766 }
767
768 void OperationValidator::visit(const ir::operation::L2Normalization &node)
769 {
770   const auto ofm_index{node.getOutputs().at(0)};
771   if (_ctx.at(ofm_index).info().isDynamic())
772     return;
773
774   const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
775
776   auto ifm_shape = _ctx.at(ifm_index).shape();
777   auto ofm_shape = _ctx.at(ofm_index).shape();
778
779   OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
780
781   for (auto i = 0; i < ifm_shape.rank(); i++)
782   {
783     OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
784   }
785 }
786
787 void OperationValidator::visit(const ir::operation::Unpack &node)
788 {
789   const auto num{node.param().num};
790   OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
791   const auto axis{node.param().axis};
792
793   const auto output_index{node.getInputs().at(0)};
794   if (_ctx.at(output_index).info().isDynamic())
795     return;
796
797   const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
798
799   const auto &input_shape = _ctx.at(input_index).shape();
800   const auto input_rank = static_cast<int32_t>(input_shape.rank());
801
802   OP_REQUIRES(axis >= -input_rank && axis < input_rank);
803 }
804
805 void OperationValidator::visit(const ir::operation::Pad &node)
806 {
807   const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
808   OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32);
809
810   const auto output_index{node.getInputs().at(0)};
811   if (_ctx.at(output_index).info().isDynamic())
812     return;
813
814   const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
815
816   const auto &pad_shape = _ctx.at(pad_index).shape();
817   const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank());
818
819   OP_REQUIRES(pad_shape.rank() == 2);
820   OP_REQUIRES(pad_shape.dim(0) == input_rank);
821   OP_REQUIRES(pad_shape.dim(1) == 2);
822   OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
823 }
824
825 void OperationValidator::visit(const ir::operation::Min &node)
826 {
827   const auto output_index{node.getOutputs().at(0)};
828   // This validator does not check shape. So checking isDynamic() is skipped.
829
830   const auto lhs_index{node.getInputs().at(ir::operation::Min::Input::LHS)};
831   const auto rhs_index{node.getInputs().at(ir::operation::Min::Input::RHS)};
832
833   OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
834   OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type());
835 }
836
837 void OperationValidator::visit(const ir::operation::Max &node)
838 {
839   const auto output_index{node.getOutputs().at(0)};
840   // This validator does not check shape. So checking isDynamic() is skipped.
841
842   const auto lhs_index{node.getInputs().at(ir::operation::Max::Input::LHS)};
843   const auto rhs_index{node.getInputs().at(ir::operation::Max::Input::RHS)};
844
845   OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
846   OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type());
847 }
848
849 void OperationValidator::visit(const ir::operation::Select &node)
850 {
851   const auto output_index{node.getOutputs().at(0)};
852   // This validator does not check shape. So checking isDynamic() is skipped.
853
854   const auto condition_index{node.getInputs().at(ir::operation::Select::Input::CONDITION)};
855   const auto input_true_index{node.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
856   const auto input_false_index{node.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
857   UNUSED_RELEASE(output_index);
858   UNUSED_RELEASE(input_true_index);
859   UNUSED_RELEASE(input_false_index);
860
861   OP_REQUIRES(_ctx.at(condition_index).typeInfo().type() == ir::DataType::BOOL8);
862 }
863
864 void OperationValidator::visit(const ir::operation::StridedSlice &node)
865 {
866   const auto output_index{node.getOutputs().at(0)};
867   const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
868   const auto starts_index{node.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
869   const auto ends_index{node.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
870   const auto strides_index{node.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
871
872   UNUSED_RELEASE(starts_index);
873   UNUSED_RELEASE(ends_index);
874   UNUSED_RELEASE(strides_index);
875
876   OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
877
878   if (_ctx.at(output_index).info().isDynamic())
879     return;
880
881   OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
882 }
883
884 void OperationValidator::visit(const ir::operation::Split &node)
885 {
886   const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
887
888   if (_ctx.at(input_index).info().isDynamic())
889     return;
890
891   const auto num_splits = node.param().num_splits;
892   const auto input_rank = _ctx.at(input_index).shape().rank();
893   const auto axis = node.param().axis < 0 ? node.param().axis + input_rank : node.param().axis;
894
895   OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
896   OP_REQUIRES(axis >= 0 && axis < input_rank);
897   OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
898
899   OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
900 }
901
902 void OperationValidator::visit(const ir::operation::Cos &node) { checkUnaryOp(node); }
903
904 void OperationValidator::visit(const ir::operation::Sin &node) { checkUnaryOp(node); }
905
906 void OperationValidator::visit(const ir::operation::RSQRT &node) { checkUnaryOp(node); }
907
908 void OperationValidator::visit(const ir::operation::Shape &node)
909 {
910   const auto output_index{node.getOutputs().at(0)};
911   if (_ctx.at(output_index).info().isDynamic())
912     return;
913
914   const auto input_index{node.getInputs().at(0)};
915   UNUSED_RELEASE(input_index);
916   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
917 }
918
919 void OperationValidator::visit(const ir::operation::ResizeBilinear &node)
920 {
921   const auto output_index{node.getOutputs().at(0)};
922   const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
923
924   if (_ctx.at(output_index).info().isDynamic())
925   {
926     return;
927   }
928   OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
929   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
930
931   auto align_corners = node.param().align_corners;
932   auto half_pixel_centers = node.param().half_pixel_centers;
933
934   OP_REQUIRES(!align_corners || !half_pixel_centers);
935 }
936
937 void OperationValidator::visit(const ir::operation::Reverse &node)
938 {
939   const auto output_index{node.getOutputs().at(0)};
940   const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
941   const auto axis_index{node.getInputs().at(ir::operation::Reverse::Input::AXIS)};
942
943   OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32);
944   OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
945
946   if (_ctx.at(output_index).info().isDynamic())
947     return;
948   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
949 }
950
951 void OperationValidator::visit(const ir::operation::If &)
952 {
953   // TODO Add to validate with subgraphs
954 }
955
956 void OperationValidator::visit(const ir::operation::While &node)
957 {
958   // This validator does not check shape. So checking isDynamic() is skipped.
959
960   OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
961   // TODO Add to validate with subgraphs
962 }
963
964 void OperationValidator::visit(const ir::operation::Neg &node) { checkUnaryOp(node); }
965
966 void OperationValidator::visit(const ir::operation::Log &node) { checkUnaryOp(node); }
967
968 void OperationValidator::visit(const ir::operation::LogicalNot &node) { checkUnaryOp(node); }
969
970 void OperationValidator::visit(const ir::operation::SquaredDifference &node)
971 {
972   const auto output_index{node.getOutputs().at(0)};
973   const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
974   const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
975
976   // Check for Type equivalence
977   OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(lhs_index).typeInfo().type());
978   OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
979
980   // Check for dimension constraints
981   if (_ctx.at(output_index).info().isDynamic())
982     return;
983
984   auto output_shape = _ctx.at(output_index).shape();
985   auto lhs_shape = _ctx.at(lhs_index).shape();
986   auto rhs_shape = _ctx.at(rhs_index).shape();
987   // Check for output rank
988   OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
989   auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
990
991   for (int idx = 1; idx <= min_rank; idx++)
992   {
993     int l_idx = lhs_shape.rank() - idx;
994     int r_idx = rhs_shape.rank() - idx;
995     int out_idx = output_shape.rank() - idx;
996
997     OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
998
999     auto l_dims = lhs_shape.dim(l_idx);
1000     auto r_dims = rhs_shape.dim(r_idx);
1001     auto out_dims = output_shape.dim(out_idx);
1002
1003     OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
1004                 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
1005   }
1006   auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
1007   for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
1008   {
1009     int out_idx = output_shape.rank() - idx;
1010     int tmp_idx = tmp_shape.rank() - idx;
1011
1012     OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
1013                 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
1014   }
1015 }
1016 void OperationValidator::visit(const ir::operation::Tile &node)
1017 {
1018   const auto output_index{node.getOutputs().at(0)};
1019   if (_ctx.at(output_index).info().isDynamic())
1020     return;
1021
1022   const auto input_index{node.getInputs().at(0)};
1023   const auto multiple_index{node.getInputs().at(1)};
1024
1025   OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1);
1026   OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank());
1027   OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
1028 }
1029
1030 void OperationValidator::visit(const ir::operation::LogicalOr &node)
1031 {
1032   const auto output_index{node.getOutputs().at(0)};
1033   const auto lhs_index{node.getInputs().at(0)};
1034   const auto rhs_index{node.getInputs().at(1)};
1035
1036   OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
1037   OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type());
1038 }
1039
1040 void OperationValidator::visit(const ir::operation::Range &node)
1041 {
1042   const auto output_index{node.getOutputs().at(0)};
1043   const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
1044   const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
1045   const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
1046
1047   // Check for dimension constraints
1048   if (_ctx.at(output_index).info().isDynamic())
1049     return;
1050
1051   OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0);
1052   OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0);
1053   OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0);
1054 }
1055
1056 void OperationValidator::visit(const ir::operation::MatrixBandPart &node)
1057 {
1058   const auto output_index{node.getOutputs().at(0)};
1059   const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
1060   const auto num_lower_index{
1061       node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)};
1062   const auto num_upper_index{
1063       node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
1064
1065   // Check for dimension constraints
1066   if (_ctx.at(output_index).info().isDynamic())
1067     return;
1068
1069   OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2);     // input must be more than 2 dim matrix
1070   OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
1071   OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
1072 }
1073
1074 void OperationValidator::visit(const ir::operation::LogSoftmax &node)
1075 {
1076   VERBOSE(LogSoftmax) << "Configure LOGSOFTMAX operation" << std::endl;
1077
1078   const auto output_index{node.getOutputs().at(0)};
1079   if (_ctx.at(output_index).info().isDynamic())
1080     return;
1081
1082   const auto input_index{node.getInputs().at(0)};
1083
1084   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
1085 }
1086
1087 void OperationValidator::visit(const ir::operation::Quantize &node)
1088 {
1089   VERBOSE(Quantize) << "Configure Quantize operation" << std::endl;
1090
1091   OP_REQUIRES(node.getInputs().size() == 1);
1092   OP_REQUIRES(node.getOutputs().size() == 1);
1093
1094   const auto input_index{node.getInputs().at(0)};
1095   const auto output_index{node.getOutputs().at(0)};
1096
1097   OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::FLOAT32);
1098
1099   if (_ctx.at(output_index).info().isDynamic())
1100     return;
1101
1102   OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
1103
1104   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
1105 }
1106 } // namespace compiler
1107 } // namespace onert