Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / OperationValidator.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "OperationValidator.h"
18
19 #include "ir/Graph.h"
20
21 #define OP_REQUIRES(EXP)                                                                         \
22   do                                                                                             \
23   {                                                                                              \
24     if (!(EXP))                                                                                  \
25       throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
26   } while (0)
27
28 namespace onert
29 {
30 namespace ir
31 {
32
33 OperationValidator::OperationValidator(const Graph &graph)
34     : _operations{graph.operations()}, _operands{graph.operands()}
35 {
36 }
37
38 void OperationValidator::operator()()
39 {
40   _operations.iterate([&](const OperationIndex &, const Operation &node) { node.accept(*this); });
41 }
42
43 DataType OperationValidator::operandType(const OperandIndex &idx)
44 {
45   return _operands.at(idx).typeInfo().type();
46 }
47
48 bool OperationValidator::isConstant(const OperandIndex &idx)
49 {
50   return _operands.at(idx).isConstant();
51 }
52
53 bool OperationValidator::isSameType(const OperandIndex &idx1, const OperandIndex &idx2)
54 {
55   return operandType(idx1) == operandType(idx2);
56 }
57
58 bool OperationValidator::isSameQuantParam(const OperandIndex &idx1, const OperandIndex &idx2)
59 {
60   if (_operands.at(idx1).typeInfo().scale() != _operands.at(idx2).typeInfo().scale())
61     return false;
62
63   if (_operands.at(idx1).typeInfo().offset() != _operands.at(idx2).typeInfo().offset())
64     return false;
65
66   return true;
67 }
68
69 bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &type)
70 {
71   return operandType(idx) == type;
72 }
73
74 bool OperationValidator::isValidType(const OperandIndex &idx,
75                                      std::initializer_list<DataType> valid_types)
76 {
77   for (auto type_to_check : valid_types)
78   {
79     if (isValidType(idx, type_to_check))
80     {
81       return true;
82     }
83   }
84
85   return false;
86 }
87
88 void OperationValidator::visit(const operation::AddN &node)
89 {
90   const auto output_index(node.getOutputs().at(0));
91
92   int size = node.getInputs().size();
93   for (int i = 0; i < size; i++)
94   {
95     const auto input_index(node.getInputs().at(i));
96     OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32}));
97     OP_REQUIRES(isSameType(input_index, output_index));
98   }
99 }
100
101 void OperationValidator::visit(const operation::ArgMinMax &node)
102 {
103   const auto input_index(node.getInputs().at(operation::ArgMinMax::Input::INPUT));
104   const auto axis_index(node.getInputs().at(operation::ArgMinMax::Input::AXIS));
105   const auto output_index(node.getOutputs().at(0));
106   const auto output_type = node.param().output_type;
107
108   OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::UINT8,
109                                         DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
110   OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
111   OP_REQUIRES(isValidType(output_index, {DataType::INT32, DataType::INT64}));
112   OP_REQUIRES(isValidType(output_index, output_type));
113 }
114
115 void OperationValidator::visit(const operation::BatchMatMul &node)
116 {
117   const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS));
118   const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS));
119   const auto output_index(node.getOutputs().at(0));
120
121   // Constant lhs and rhs is not implemented yet
122   OP_REQUIRES(!isConstant(lhs_index) && !isConstant(rhs_index));
123
124   // Allow hybrid quantization (lhs: float / rhs: qint8 / out: float)
125   OP_REQUIRES(isValidType(lhs_index, {DataType::FLOAT32, DataType::QUANT_INT8_ASYMM}));
126   OP_REQUIRES(isSameType(lhs_index, rhs_index) ||
127               ((operandType(lhs_index) == DataType::FLOAT32) &&
128                (operandType(rhs_index) == DataType::QUANT_INT8_ASYMM)));
129   OP_REQUIRES(isSameType(lhs_index, output_index));
130 }
131
132 void OperationValidator::visit(const operation::BatchToSpaceND &node)
133 {
134   const auto input_index{node.getInputs().at(operation::BatchToSpaceND::Input::INPUT)};
135   const auto output_index{node.getOutputs().at(0)};
136
137   OP_REQUIRES(isSameType(input_index, output_index));
138 }
139
140 void OperationValidator::visit(const operation::BinaryArithmetic &node)
141 {
142   const auto output_index{node.getOutputs().at(0)};
143   const auto lhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::LHS)};
144   const auto rhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::RHS)};
145
146   OP_REQUIRES(isSameType(lhs_index, rhs_index));
147   OP_REQUIRES(isSameType(lhs_index, output_index));
148 }
149
150 void OperationValidator::visit(const operation::Comparison &node)
151 {
152   const auto output_index{node.getOutputs().at(0)};
153
154   const auto lhs_index{node.getInputs().at(operation::Comparison::Input::INPUT0)};
155   const auto rhs_index{node.getInputs().at(operation::Comparison::Input::INPUT1)};
156
157   OP_REQUIRES(isSameType(lhs_index, rhs_index));
158   OP_REQUIRES(isValidType(output_index, DataType::BOOL8));
159 }
160
161 void OperationValidator::visit(const operation::Concat &node)
162 {
163   const auto output_index{node.getOutputs().at(0)};
164
165   for (auto input_index : node.getInputs())
166   {
167     OP_REQUIRES(isSameType(input_index, output_index));
168
169     // Int8 quantization requires same scale and zero point
170     if (isValidType(output_index, DataType::QUANT_INT8_ASYMM))
171     {
172       OP_REQUIRES(isSameQuantParam(input_index, output_index));
173     }
174   }
175 }
176
177 void OperationValidator::visit(const operation::Conv2D &node)
178 {
179   const auto input_index{node.getInputs().at(operation::Conv2D::Input::INPUT)};
180   const auto output_index{node.getOutputs().at(0)};
181
182   uint32_t stride_horizontal = node.param().stride.horizontal;
183   uint32_t stride_vertical = node.param().stride.vertical;
184   uint32_t dilation_width = node.param().dilation.width_factor;
185   uint32_t dilation_height = node.param().dilation.height_factor;
186
187   OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
188   OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
189   OP_REQUIRES(isSameType(input_index, output_index));
190 }
191
192 void OperationValidator::visit(const operation::DepthToSpace &node)
193 {
194   const auto input_index{node.getInputs().at(operation::DepthToSpace::Input::INPUT)};
195   const auto output_index{node.getOutputs().at(0)};
196
197   int32_t block_size = node.param().block_size;
198
199   OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::INT64,
200                                         DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
201   OP_REQUIRES(isSameType(input_index, output_index));
202
203   OP_REQUIRES(block_size > 0);
204 }
205
206 void OperationValidator::visit(const operation::DepthwiseConv2D &node)
207 {
208   const auto input_index{node.getInputs().at(operation::DepthwiseConv2D::Input::INPUT)};
209   const auto output_index{node.getOutputs().at(0)};
210
211   uint32_t stride_horizontal = node.param().stride.horizontal;
212   uint32_t stride_vertical = node.param().stride.vertical;
213   uint32_t dilation_width = node.param().dilation.width_factor;
214   uint32_t dilation_height = node.param().dilation.height_factor;
215
216   OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
217   OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
218   OP_REQUIRES(isSameType(input_index, output_index));
219 }
220
221 void OperationValidator::visit(const operation::ElementwiseActivation &node)
222 {
223   const auto output_index{node.getOutputs().at(0)};
224   const auto input_index{node.getInputs().at(0)};
225
226   // Check if I/O types match
227   OP_REQUIRES(isSameType(output_index, input_index));
228
229   switch (node.param().op_type)
230   {
231     case operation::ElementwiseActivation::Type::ELU:
232       OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
233       break;
234     case operation::ElementwiseActivation::Type::LEAKY_RELU:
235       OP_REQUIRES(
236           isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
237                                     DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
238       break;
239     case operation::ElementwiseActivation::Type::LOGISTIC:
240       OP_REQUIRES(
241           isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
242                                     DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
243       break;
244     case operation::ElementwiseActivation::Type::RELU:
245       OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
246                                             DataType::QUANT_INT8_ASYMM}));
247       break;
248     case operation::ElementwiseActivation::Type::TANH:
249       OP_REQUIRES(
250           isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
251                                     DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
252       break;
253   }
254 }
255
256 void OperationValidator::visit(const operation::ElementwiseBinary &node)
257 {
258   const auto output_index{node.getOutputs().at(0)};
259   const auto lhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::LHS)};
260   const auto rhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::RHS)};
261
262   OP_REQUIRES(isSameType(lhs_index, rhs_index));
263   OP_REQUIRES(isSameType(lhs_index, output_index));
264
265   const auto op_type = node.param().op_type;
266   if (op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND ||
267       op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR)
268   {
269     OP_REQUIRES(isValidType(lhs_index, DataType::BOOL8));
270   }
271 }
272
273 void OperationValidator::visit(const operation::ElementwiseUnary &node)
274 {
275   const auto output_index{node.getOutputs().at(0)};
276   const auto input_index{node.getInputs().at(operation::ElementwiseUnary::Input::INPUT)};
277
278   // Check if I/O types match
279   if (node.param().op_type == operation::ElementwiseUnary::Type::DEQUANTIZE)
280   {
281     // NNAPI allow QUANT_INT8_SYMM type input
282     OP_REQUIRES(isValidType(input_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_SYMM,
283                                           DataType::QUANT_INT8_ASYMM}));
284     OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
285   }
286   else if (node.param().op_type == operation::ElementwiseUnary::Type::QUANTIZE)
287   {
288     OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
289     OP_REQUIRES(isValidType(output_index, DataType::QUANT_UINT8_ASYMM));
290   }
291   else if (node.param().op_type == operation::ElementwiseUnary::Type::FLOOR)
292   {
293     OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
294     OP_REQUIRES(isSameType(output_index, input_index));
295   }
296   else if (node.param().op_type != operation::ElementwiseUnary::Type::CAST)
297   {
298     OP_REQUIRES(isSameType(output_index, input_index));
299   }
300 }
301
302 void OperationValidator::visit(const operation::EmbeddingLookup &node)
303 {
304   const auto lookups_index{node.getInputs().at(operation::EmbeddingLookup::Input::LOOKUPS)};
305   const auto values_index{node.getInputs().at(operation::EmbeddingLookup::Input::VALUES)};
306   const auto output_index{node.getOutputs().at(0)};
307
308   OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
309
310   // TFLite: Allow hybrid type - value table & output
311   // NNAPI: Require same value table and output type
312   OP_REQUIRES(
313       isSameType(values_index, output_index) ||
314       (isValidType(output_index, DataType::FLOAT32) &&
315        (isValidType(values_index, {DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT8_SYMM}))));
316 }
317
318 void OperationValidator::visit(const operation::ExpandDims &node)
319 {
320   const auto output_index{node.getOutputs().at(0)};
321   const auto input_index{node.getInputs().at(operation::ExpandDims::Input::INPUT)};
322   const auto axis_index{node.getInputs().at(operation::ExpandDims::Input::AXIS)};
323
324   OP_REQUIRES(isSameType(output_index, input_index));
325   OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
326 }
327
328 void OperationValidator::visit(const operation::Fill &node)
329 {
330   const auto output_index{node.getOutputs().at(0)};
331   const auto input_index{node.getInputs().at(operation::Fill::Input::SHAPE)};
332   const auto value_index{node.getInputs().at(operation::Fill::Input::VALUE)};
333
334   OP_REQUIRES(isSameType(output_index, value_index));
335   OP_REQUIRES(isValidType(input_index, {DataType::INT32, DataType::INT64}));
336   OP_REQUIRES(isValidType(output_index,
337                           {DataType::FLOAT32, DataType::INT32, DataType::INT64, DataType::BOOL8}));
338 }
339
340 void OperationValidator::visit(const operation::HashtableLookup &node)
341 {
342   const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)};
343   const auto lookups_index{node.getInputs().at(operation::HashtableLookup::Input::LOOKUPS)};
344   const auto keys_index{node.getInputs().at(operation::HashtableLookup::Input::KEYS)};
345
346   OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
347   OP_REQUIRES(isValidType(keys_index, DataType::INT32));
348   OP_REQUIRES(isValidType(hits_index, DataType::QUANT_UINT8_ASYMM));
349 }
350
351 void OperationValidator::visit(const operation::Pack &node)
352 {
353   const auto num{node.param().num};
354
355   OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
356 }
357
358 void OperationValidator::visit(const operation::Pad &node)
359 {
360   const auto pad_index{node.getInputs().at(operation::Pad::Input::PAD)};
361
362   OP_REQUIRES(isValidType(pad_index, DataType::INT32));
363 }
364
365 void OperationValidator::visit(const operation::Rank &node)
366 {
367   const auto output_index{node.getOutputs().at(0)};
368
369   OP_REQUIRES(isValidType(output_index, DataType::INT32));
370 }
371
372 void OperationValidator::visit(const operation::ResizeBilinear &node)
373 {
374   auto align_corners = node.param().align_corners;
375   auto half_pixel_centers = node.param().half_pixel_centers;
376
377   OP_REQUIRES(!align_corners || !half_pixel_centers);
378 }
379
380 void OperationValidator::visit(const operation::Reverse &node)
381 {
382   const auto output_index{node.getOutputs().at(0)};
383   const auto input_index{node.getInputs().at(operation::Reverse::Input::INPUT)};
384   const auto axis_index{node.getInputs().at(operation::Reverse::Input::AXIS)};
385
386   OP_REQUIRES(isValidType(axis_index, DataType::INT32));
387   OP_REQUIRES(isSameType(output_index, input_index));
388 }
389
390 void OperationValidator::visit(const operation::Select &node)
391 {
392   const auto condition_index{node.getInputs().at(operation::Select::Input::CONDITION)};
393   const auto input_true_index{node.getInputs().at(operation::Select::Input::INPUT_TRUE)};
394   const auto input_false_index{node.getInputs().at(operation::Select::Input::INPUT_FALSE)};
395
396   OP_REQUIRES(isValidType(condition_index, DataType::BOOL8));
397   OP_REQUIRES(isSameType(input_true_index, input_false_index));
398 }
399
400 void OperationValidator::visit(const operation::Shape &node)
401 {
402   const auto output_index{node.getOutputs().at(0)};
403
404   OP_REQUIRES(isValidType(output_index, {DataType::UINT32, DataType::INT32, DataType::INT64}));
405 }
406
407 void OperationValidator::visit(const operation::SpaceToBatchND &node)
408 {
409   const auto block_size_index{node.getInputs().at(operation::SpaceToBatchND::Input::BLOCK_SIZE)};
410   const auto paddings_index{node.getInputs().at(operation::SpaceToBatchND::Input::PADDINGS)};
411
412   // Non-constant block_size and padding is not implemented yet
413   OP_REQUIRES(isConstant(block_size_index));
414   OP_REQUIRES(isConstant(paddings_index));
415 }
416
417 void OperationValidator::visit(const operation::SpaceToDepth &node)
418 {
419   const auto block_size = node.param().block_size;
420   OP_REQUIRES(block_size >= 1);
421 }
422
423 void OperationValidator::visit(const operation::Split &node)
424 {
425   const auto num_splits = node.param().num_splits;
426
427   OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
428   OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
429 }
430
431 void OperationValidator::visit(const operation::SquaredDifference &node)
432 {
433   const auto output_index{node.getOutputs().at(0)};
434   const auto lhs_index{node.getInputs().at(operation::SquaredDifference::Input::LHS)};
435   const auto rhs_index{node.getInputs().at(operation::SquaredDifference::Input::RHS)};
436
437   OP_REQUIRES(isSameType(output_index, lhs_index));
438   OP_REQUIRES(isSameType(lhs_index, rhs_index));
439 }
440
441 void OperationValidator::visit(const operation::StridedSlice &node)
442 {
443   const auto output_index{node.getOutputs().at(0)};
444   const auto input_index{node.getInputs().at(operation::StridedSlice::Input::INPUT)};
445
446   OP_REQUIRES(isSameType(output_index, input_index));
447 }
448
449 void OperationValidator::visit(const operation::TransposeConv &node)
450 {
451   OP_REQUIRES((node.param().padding.type == PaddingType::SAME) ||
452               (node.param().padding.type == PaddingType::VALID));
453 }
454
455 void OperationValidator::visit(const operation::Unpack &node)
456 {
457   const auto num{node.param().num};
458   OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
459 }
460
461 void OperationValidator::visit(const operation::While &node)
462 {
463   OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
464 }
465
466 } // namespace compiler
467 } // namespace onert