094dbc0d5d901c501e5e0c5995f30ebc648ab2b8
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / OperationValidator.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "OperationValidator.h"
18
19 #include "ir/Graph.h"
20 #include "util/logging.h"
21
22 #define OP_REQUIRES(EXP)                                                                         \
23   do                                                                                             \
24   {                                                                                              \
25     if (!(EXP))                                                                                  \
26       throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
27   } while (0)
28
29 namespace onert
30 {
31 namespace ir
32 {
33
34 OperationValidator::OperationValidator(const Graph &graph)
35   : _operations{graph.operations()}, _operands{graph.operands()}
36 {
37 }
38
39 void OperationValidator::operator()()
40 {
41   _operations.iterate([&](const OperationIndex &, const Operation &node) { node.accept(*this); });
42 }
43
44 DataType OperationValidator::operandType(const OperandIndex &idx)
45 {
46   return _operands.at(idx).typeInfo().type();
47 }
48
49 bool OperationValidator::isConstant(const OperandIndex &idx)
50 {
51   return _operands.at(idx).isConstant();
52 }
53
54 bool OperationValidator::isSameType(const OperandIndex &idx1, const OperandIndex &idx2)
55 {
56   return operandType(idx1) == operandType(idx2);
57 }
58
59 bool OperationValidator::isSameQuantParam(const OperandIndex &idx1, const OperandIndex &idx2)
60 {
61   if (_operands.at(idx1).typeInfo().scale() != _operands.at(idx2).typeInfo().scale())
62     return false;
63
64   if (_operands.at(idx1).typeInfo().zero_point() != _operands.at(idx2).typeInfo().zero_point())
65     return false;
66
67   return true;
68 }
69
70 bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &type)
71 {
72   return operandType(idx) == type;
73 }
74
75 bool OperationValidator::isValidType(const OperandIndex &idx,
76                                      std::initializer_list<DataType> valid_types)
77 {
78   for (auto type_to_check : valid_types)
79   {
80     if (isValidType(idx, type_to_check))
81     {
82       return true;
83     }
84   }
85
86   return false;
87 }
88
89 void OperationValidator::visit(const operation::AddN &node)
90 {
91   const auto output_index(node.getOutputs().at(0));
92
93   int size = node.getInputs().size();
94   for (int i = 0; i < size; i++)
95   {
96     const auto input_index(node.getInputs().at(i));
97     OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32}));
98     OP_REQUIRES(isSameType(input_index, output_index));
99   }
100 }
101
102 void OperationValidator::visit(const operation::ArgMinMax &node)
103 {
104   const auto input_index(node.getInputs().at(operation::ArgMinMax::Input::INPUT));
105   const auto axis_index(node.getInputs().at(operation::ArgMinMax::Input::AXIS));
106   const auto output_index(node.getOutputs().at(0));
107   const auto output_type = node.param().output_type;
108
109   OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::UINT8,
110                                         DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
111   OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
112   OP_REQUIRES(isValidType(output_index, {DataType::INT32, DataType::INT64}));
113   OP_REQUIRES(isValidType(output_index, output_type));
114 }
115
116 void OperationValidator::visit(const operation::BatchMatMul &node)
117 {
118   const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS));
119   const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS));
120   const auto output_index(node.getOutputs().at(0));
121
122   // Constant lhs and rhs is not implemented yet
123   OP_REQUIRES(!isConstant(lhs_index) && !isConstant(rhs_index));
124
125   // Allow hybrid quantization (lhs: float / rhs: qint8 / out: float)
126   OP_REQUIRES(isValidType(lhs_index, {DataType::FLOAT32, DataType::QUANT_INT8_ASYMM}));
127   OP_REQUIRES(isSameType(lhs_index, rhs_index) ||
128               ((operandType(lhs_index) == DataType::FLOAT32) &&
129                (operandType(rhs_index) == DataType::QUANT_INT8_ASYMM)));
130   OP_REQUIRES(isSameType(lhs_index, output_index));
131 }
132
133 void OperationValidator::visit(const operation::BatchToSpaceND &node)
134 {
135   const auto input_index{node.getInputs().at(operation::BatchToSpaceND::Input::INPUT)};
136   const auto output_index{node.getOutputs().at(0)};
137
138   OP_REQUIRES(isSameType(input_index, output_index));
139 }
140
141 void OperationValidator::visit(const operation::BinaryArithmetic &node)
142 {
143   const auto output_index{node.getOutputs().at(0)};
144   const auto lhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::LHS)};
145   const auto rhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::RHS)};
146
147   OP_REQUIRES(isSameType(lhs_index, rhs_index));
148   OP_REQUIRES(isSameType(lhs_index, output_index));
149 }
150
151 void OperationValidator::visit(const operation::Comparison &node)
152 {
153   const auto output_index{node.getOutputs().at(0)};
154
155   const auto lhs_index{node.getInputs().at(operation::Comparison::Input::INPUT0)};
156   const auto rhs_index{node.getInputs().at(operation::Comparison::Input::INPUT1)};
157
158   OP_REQUIRES(isSameType(lhs_index, rhs_index));
159   OP_REQUIRES(isValidType(output_index, DataType::BOOL8));
160 }
161
162 void OperationValidator::visit(const operation::Concat &node)
163 {
164   const auto output_index{node.getOutputs().at(0)};
165
166   for (auto input_index : node.getInputs())
167   {
168     OP_REQUIRES(isSameType(input_index, output_index));
169
170     // Int8 quantization requires same scale and zero point
171     if (isValidType(output_index, DataType::QUANT_INT8_ASYMM))
172     {
173       OP_REQUIRES(isSameQuantParam(input_index, output_index));
174     }
175   }
176 }
177
178 void OperationValidator::visit(const operation::Conv2D &node)
179 {
180   const auto input_index{node.getInputs().at(operation::Conv2D::Input::INPUT)};
181   const auto kernel_index{node.getInputs().at(operation::Conv2D::Input::KERNEL)};
182   const auto output_index{node.getOutputs().at(0)};
183
184   uint32_t stride_horizontal = node.param().stride.horizontal;
185   uint32_t stride_vertical = node.param().stride.vertical;
186   uint32_t dilation_width = node.param().dilation.width_factor;
187   uint32_t dilation_height = node.param().dilation.height_factor;
188
189   OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
190   OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
191   OP_REQUIRES(isSameType(input_index, output_index));
192
193   if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM)
194   {
195     for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points())
196       OP_REQUIRES(zeropoint == 0);
197   }
198 }
199
200 void OperationValidator::visit(const operation::DepthToSpace &node)
201 {
202   const auto input_index{node.getInputs().at(operation::DepthToSpace::Input::INPUT)};
203   const auto output_index{node.getOutputs().at(0)};
204
205   int32_t block_size = node.param().block_size;
206
207   OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::INT64,
208                                         DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
209   OP_REQUIRES(isSameType(input_index, output_index));
210
211   OP_REQUIRES(block_size > 0);
212 }
213
214 void OperationValidator::visit(const operation::DetectionPostProcess &node)
215 {
216   auto param = node.param();
217
218   // FIXME: number of classes should be 1 for now.
219   OP_REQUIRES(param.num_classes == 1);
220 }
221
222 void OperationValidator::visit(const operation::DepthwiseConv2D &node)
223 {
224   const auto input_index{node.getInputs().at(operation::DepthwiseConv2D::Input::INPUT)};
225   const auto kernel_index{node.getInputs().at(operation::DepthwiseConv2D::Input::KERNEL)};
226   const auto output_index{node.getOutputs().at(0)};
227
228   uint32_t stride_horizontal = node.param().stride.horizontal;
229   uint32_t stride_vertical = node.param().stride.vertical;
230   uint32_t dilation_width = node.param().dilation.width_factor;
231   uint32_t dilation_height = node.param().dilation.height_factor;
232
233   OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
234   OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
235   OP_REQUIRES(isSameType(input_index, output_index));
236
237   if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM)
238   {
239     for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points())
240       OP_REQUIRES(zeropoint == 0);
241   }
242 }
243
244 void OperationValidator::visit(const operation::ElementwiseActivation &node)
245 {
246   const auto output_index{node.getOutputs().at(0)};
247   const auto input_index{node.getInputs().at(0)};
248
249   // Check if I/O types match
250   OP_REQUIRES(isSameType(output_index, input_index));
251
252   switch (node.param().op_type)
253   {
254     case operation::ElementwiseActivation::Type::ELU:
255       OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
256       break;
257     case operation::ElementwiseActivation::Type::LEAKY_RELU:
258       OP_REQUIRES(
259         isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
260                                   DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
261       break;
262     case operation::ElementwiseActivation::Type::LOGISTIC:
263       OP_REQUIRES(
264         isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
265                                   DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
266       break;
267     case operation::ElementwiseActivation::Type::RELU:
268       OP_REQUIRES(isValidType(
269         input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
270       break;
271     case operation::ElementwiseActivation::Type::TANH:
272       OP_REQUIRES(
273         isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
274                                   DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
275       break;
276   }
277 }
278
279 void OperationValidator::visit(const operation::ElementwiseBinary &node)
280 {
281   const auto output_index{node.getOutputs().at(0)};
282   const auto lhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::LHS)};
283   const auto rhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::RHS)};
284
285   OP_REQUIRES(isSameType(lhs_index, rhs_index));
286   OP_REQUIRES(isSameType(lhs_index, output_index));
287
288   const auto op_type = node.param().op_type;
289   if (op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND ||
290       op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR)
291   {
292     OP_REQUIRES(isValidType(lhs_index, DataType::BOOL8));
293   }
294 }
295
296 void OperationValidator::visit(const operation::ElementwiseUnary &node)
297 {
298   const auto output_index{node.getOutputs().at(0)};
299   const auto input_index{node.getInputs().at(operation::ElementwiseUnary::Input::INPUT)};
300
301   // Check if I/O types match
302   if (node.param().op_type == operation::ElementwiseUnary::Type::DEQUANTIZE)
303   {
304     // NNAPI allow QUANT_INT8_SYMM type input
305     OP_REQUIRES(isValidType(input_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_SYMM,
306                                           DataType::QUANT_INT8_ASYMM}));
307     OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
308   }
309   else if (node.param().op_type == operation::ElementwiseUnary::Type::QUANTIZE)
310   {
311     OP_REQUIRES(isValidType(
312       input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
313     OP_REQUIRES(
314       isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
315   }
316   else if (node.param().op_type == operation::ElementwiseUnary::Type::FLOOR)
317   {
318     OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
319     OP_REQUIRES(isSameType(output_index, input_index));
320   }
321   else if (node.param().op_type != operation::ElementwiseUnary::Type::CAST)
322   {
323     OP_REQUIRES(isSameType(output_index, input_index));
324   }
325 }
326
327 void OperationValidator::visit(const operation::EmbeddingLookup &node)
328 {
329   const auto lookups_index{node.getInputs().at(operation::EmbeddingLookup::Input::LOOKUPS)};
330   const auto values_index{node.getInputs().at(operation::EmbeddingLookup::Input::VALUES)};
331   const auto output_index{node.getOutputs().at(0)};
332
333   OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
334
335   // TFLite: Allow hybrid type - value table & output
336   // NNAPI: Require same value table and output type
337   OP_REQUIRES(
338     isSameType(values_index, output_index) ||
339     (isValidType(output_index, DataType::FLOAT32) &&
340      (isValidType(values_index, {DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT8_SYMM}))));
341 }
342
343 void OperationValidator::visit(const operation::ExpandDims &node)
344 {
345   const auto output_index{node.getOutputs().at(0)};
346   const auto input_index{node.getInputs().at(operation::ExpandDims::Input::INPUT)};
347   const auto axis_index{node.getInputs().at(operation::ExpandDims::Input::AXIS)};
348
349   OP_REQUIRES(isSameType(output_index, input_index));
350   OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
351 }
352
353 void OperationValidator::visit(const operation::Fill &node)
354 {
355   const auto output_index{node.getOutputs().at(0)};
356   const auto input_index{node.getInputs().at(operation::Fill::Input::SHAPE)};
357   const auto value_index{node.getInputs().at(operation::Fill::Input::VALUE)};
358
359   OP_REQUIRES(isSameType(output_index, value_index));
360   OP_REQUIRES(isValidType(input_index, {DataType::INT32, DataType::INT64}));
361   OP_REQUIRES(isValidType(output_index,
362                           {DataType::FLOAT32, DataType::INT32, DataType::INT64, DataType::BOOL8}));
363 }
364
365 void OperationValidator::visit(const operation::HashtableLookup &node)
366 {
367   const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)};
368   const auto lookups_index{node.getInputs().at(operation::HashtableLookup::Input::LOOKUPS)};
369   const auto keys_index{node.getInputs().at(operation::HashtableLookup::Input::KEYS)};
370
371   OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
372   OP_REQUIRES(isValidType(keys_index, DataType::INT32));
373   OP_REQUIRES(isValidType(hits_index, DataType::QUANT_UINT8_ASYMM));
374 }
375
376 void OperationValidator::visit(const operation::Pack &node)
377 {
378   const auto num{node.param().num};
379
380   OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
381 }
382
383 void OperationValidator::visit(const operation::Pad &node)
384 {
385   const auto output_index{node.getOutputs().at(0)};
386   const auto input_index{node.getInputs().at(operation::Pad::Input::INPUT)};
387   const auto pad_index{node.getInputs().at(operation::Pad::Input::PAD)};
388   bool isQuantType =
389     isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM});
390   bool isPadV2 = node.getInputs().size() == 3 ? true : false;
391
392   OP_REQUIRES(isValidType(pad_index, DataType::INT32));
393   OP_REQUIRES(isSameType(input_index, output_index));
394
395   if (isQuantType)
396     OP_REQUIRES(isSameQuantParam(input_index, output_index));
397
398   if (isPadV2)
399   {
400     const auto value_index{node.getInputs().at(operation::Pad::Input::VALUE)};
401     const bool cond_same = isSameType(input_index, value_index);
402     const bool cond_same_quant = (!isQuantType || isSameQuantParam(input_index, value_index));
403     const auto input_t = operandType(input_index);
404     const auto value_t = operandType(value_index);
405     // NNAPI accepts this case. scale and zeroPoint are assumed to be the same as in input0.
406     const bool cond_quant8 =
407       ((input_t == DataType::QUANT_UINT8_ASYMM || input_t == DataType::QUANT_INT8_ASYMM) &&
408        value_t == DataType::INT32);
409     OP_REQUIRES((cond_same && cond_same_quant) || cond_quant8);
410   }
411 }
412
413 void OperationValidator::visit(const operation::Rank &node)
414 {
415   const auto output_index{node.getOutputs().at(0)};
416
417   OP_REQUIRES(isValidType(output_index, DataType::INT32));
418 }
419
420 void OperationValidator::visit(const operation::ResizeBilinear &node)
421 {
422   auto align_corners = node.param().align_corners;
423   auto half_pixel_centers = node.param().half_pixel_centers;
424
425   OP_REQUIRES(!align_corners || !half_pixel_centers);
426 }
427
428 void OperationValidator::visit(const operation::Reverse &node)
429 {
430   const auto output_index{node.getOutputs().at(0)};
431   const auto input_index{node.getInputs().at(operation::Reverse::Input::INPUT)};
432   const auto axis_index{node.getInputs().at(operation::Reverse::Input::AXIS)};
433
434   OP_REQUIRES(isValidType(axis_index, DataType::INT32));
435   OP_REQUIRES(isSameType(output_index, input_index));
436 }
437
438 void OperationValidator::visit(const operation::Select &node)
439 {
440   const auto condition_index{node.getInputs().at(operation::Select::Input::CONDITION)};
441   const auto input_true_index{node.getInputs().at(operation::Select::Input::INPUT_TRUE)};
442   const auto input_false_index{node.getInputs().at(operation::Select::Input::INPUT_FALSE)};
443
444   OP_REQUIRES(isValidType(condition_index, DataType::BOOL8));
445   OP_REQUIRES(isSameType(input_true_index, input_false_index));
446 }
447
448 void OperationValidator::visit(const operation::Shape &node)
449 {
450   const auto output_index{node.getOutputs().at(0)};
451
452   OP_REQUIRES(isValidType(output_index, {DataType::UINT32, DataType::INT32, DataType::INT64}));
453 }
454
455 void OperationValidator::visit(const operation::Slice &node)
456 {
457   const auto begins_index{node.getInputs().at(operation::Slice::BEGINS)};
458   const auto sizes_index{node.getInputs().at(operation::Slice::SIZES)};
459
460   OP_REQUIRES(isValidType(begins_index, {DataType::INT32, DataType::INT64}));
461   OP_REQUIRES(isSameType(begins_index, sizes_index));
462 }
463
464 void OperationValidator::visit(const operation::Softmax &node)
465 {
466   const auto output_index{node.getOutputs().at(0)};
467   const auto input_index{node.getInputs().at(operation::Softmax::INPUT)};
468
469   OP_REQUIRES(isSameType(input_index, output_index));
470   OP_REQUIRES(isValidType(
471     output_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
472 }
473
474 void OperationValidator::visit(const operation::SpaceToBatchND &node)
475 {
476   const auto block_size_index{node.getInputs().at(operation::SpaceToBatchND::Input::BLOCK_SIZE)};
477   const auto paddings_index{node.getInputs().at(operation::SpaceToBatchND::Input::PADDINGS)};
478
479   // Non-constant block_size and padding is not implemented yet
480   OP_REQUIRES(isConstant(block_size_index));
481   OP_REQUIRES(isConstant(paddings_index));
482 }
483
484 void OperationValidator::visit(const operation::SpaceToDepth &node)
485 {
486   const auto block_size = node.param().block_size;
487   OP_REQUIRES(block_size >= 1);
488 }
489
490 void OperationValidator::visit(const operation::Split &node)
491 {
492   const auto num_splits = node.param().num_splits;
493
494   OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
495   OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
496 }
497
498 void OperationValidator::visit(const operation::SquaredDifference &node)
499 {
500   const auto output_index{node.getOutputs().at(0)};
501   const auto lhs_index{node.getInputs().at(operation::SquaredDifference::Input::LHS)};
502   const auto rhs_index{node.getInputs().at(operation::SquaredDifference::Input::RHS)};
503
504   OP_REQUIRES(isSameType(output_index, lhs_index));
505   OP_REQUIRES(isSameType(lhs_index, rhs_index));
506 }
507
508 void OperationValidator::visit(const operation::StatelessRandomUniform &node)
509 {
510   const auto output_index{node.getOutputs().at(0)};
511   const auto shape_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SHAPE)};
512   const auto seed_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SEED)};
513
514   OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
515   OP_REQUIRES(isValidType(shape_index, DataType::INT32));
516   OP_REQUIRES(isValidType(seed_index, DataType::INT32));
517 }
518
519 void OperationValidator::visit(const operation::StridedSlice &node)
520 {
521   const auto output_index{node.getOutputs().at(0)};
522   const auto input_index{node.getInputs().at(operation::StridedSlice::Input::INPUT)};
523
524   OP_REQUIRES(isSameType(output_index, input_index));
525 }
526
527 void OperationValidator::visit(const operation::TransposeConv &node)
528 {
529   OP_REQUIRES((node.param().padding.type == PaddingType::SAME) ||
530               (node.param().padding.type == PaddingType::VALID));
531 }
532
533 void OperationValidator::visit(const operation::Unpack &node)
534 {
535   const auto num{node.param().num};
536   OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
537 }
538
539 void OperationValidator::visit(const operation::While &node)
540 {
541   OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
542 }
543
544 } // namespace ir
545 } // namespace onert