2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "OperationValidator.h"
20 #include "util/logging.h"
22 #define OP_REQUIRES(EXP) \
26 throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
34 OperationValidator::OperationValidator(const Graph &graph)
35 : _operations{graph.operations()}, _operands{graph.operands()}
39 void OperationValidator::operator()()
41 _operations.iterate([&](const OperationIndex &, const Operation &node) { node.accept(*this); });
44 DataType OperationValidator::operandType(const OperandIndex &idx)
46 return _operands.at(idx).typeInfo().type();
49 bool OperationValidator::isConstant(const OperandIndex &idx)
51 return _operands.at(idx).isConstant();
54 bool OperationValidator::isSameType(const OperandIndex &idx1, const OperandIndex &idx2)
56 return operandType(idx1) == operandType(idx2);
59 bool OperationValidator::isSameQuantParam(const OperandIndex &idx1, const OperandIndex &idx2)
61 if (_operands.at(idx1).typeInfo().scale() != _operands.at(idx2).typeInfo().scale())
64 if (_operands.at(idx1).typeInfo().zero_point() != _operands.at(idx2).typeInfo().zero_point())
70 bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &type)
72 return operandType(idx) == type;
75 bool OperationValidator::isValidType(const OperandIndex &idx,
76 std::initializer_list<DataType> valid_types)
78 for (auto type_to_check : valid_types)
80 if (isValidType(idx, type_to_check))
89 void OperationValidator::visit(const operation::AddN &node)
91 const auto output_index(node.getOutputs().at(0));
93 int size = node.getInputs().size();
94 for (int i = 0; i < size; i++)
96 const auto input_index(node.getInputs().at(i));
97 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32}));
98 OP_REQUIRES(isSameType(input_index, output_index));
102 void OperationValidator::visit(const operation::ArgMinMax &node)
104 const auto input_index(node.getInputs().at(operation::ArgMinMax::Input::INPUT));
105 const auto axis_index(node.getInputs().at(operation::ArgMinMax::Input::AXIS));
106 const auto output_index(node.getOutputs().at(0));
107 const auto output_type = node.param().output_type;
109 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::UINT8,
110 DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
111 OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
112 OP_REQUIRES(isValidType(output_index, {DataType::INT32, DataType::INT64}));
113 OP_REQUIRES(isValidType(output_index, output_type));
116 void OperationValidator::visit(const operation::BatchMatMul &node)
118 const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS));
119 const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS));
120 const auto output_index(node.getOutputs().at(0));
122 // Constant lhs and rhs is not implemented yet
123 OP_REQUIRES(!isConstant(lhs_index) && !isConstant(rhs_index));
125 // Allow hybrid quantization (lhs: float / rhs: qint8 / out: float)
126 OP_REQUIRES(isValidType(lhs_index, {DataType::FLOAT32, DataType::QUANT_INT8_ASYMM}));
127 OP_REQUIRES(isSameType(lhs_index, rhs_index) ||
128 ((operandType(lhs_index) == DataType::FLOAT32) &&
129 (operandType(rhs_index) == DataType::QUANT_INT8_ASYMM)));
130 OP_REQUIRES(isSameType(lhs_index, output_index));
133 void OperationValidator::visit(const operation::BatchToSpaceND &node)
135 const auto input_index{node.getInputs().at(operation::BatchToSpaceND::Input::INPUT)};
136 const auto output_index{node.getOutputs().at(0)};
138 OP_REQUIRES(isSameType(input_index, output_index));
141 void OperationValidator::visit(const operation::BinaryArithmetic &node)
143 const auto output_index{node.getOutputs().at(0)};
144 const auto lhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::LHS)};
145 const auto rhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::RHS)};
147 OP_REQUIRES(isSameType(lhs_index, rhs_index));
148 OP_REQUIRES(isSameType(lhs_index, output_index));
151 void OperationValidator::visit(const operation::Comparison &node)
153 const auto output_index{node.getOutputs().at(0)};
155 const auto lhs_index{node.getInputs().at(operation::Comparison::Input::INPUT0)};
156 const auto rhs_index{node.getInputs().at(operation::Comparison::Input::INPUT1)};
158 OP_REQUIRES(isSameType(lhs_index, rhs_index));
159 OP_REQUIRES(isValidType(output_index, DataType::BOOL8));
162 void OperationValidator::visit(const operation::Concat &node)
164 const auto output_index{node.getOutputs().at(0)};
166 for (auto input_index : node.getInputs())
168 OP_REQUIRES(isSameType(input_index, output_index));
170 // Int8 quantization requires same scale and zero point
171 if (isValidType(output_index, DataType::QUANT_INT8_ASYMM))
173 OP_REQUIRES(isSameQuantParam(input_index, output_index));
178 void OperationValidator::visit(const operation::Conv2D &node)
180 const auto input_index{node.getInputs().at(operation::Conv2D::Input::INPUT)};
181 const auto kernel_index{node.getInputs().at(operation::Conv2D::Input::KERNEL)};
182 const auto output_index{node.getOutputs().at(0)};
184 uint32_t stride_horizontal = node.param().stride.horizontal;
185 uint32_t stride_vertical = node.param().stride.vertical;
186 uint32_t dilation_width = node.param().dilation.width_factor;
187 uint32_t dilation_height = node.param().dilation.height_factor;
189 OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
190 OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
191 OP_REQUIRES(isSameType(input_index, output_index));
193 if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM)
195 for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points())
196 OP_REQUIRES(zeropoint == 0);
200 void OperationValidator::visit(const operation::DepthToSpace &node)
202 const auto input_index{node.getInputs().at(operation::DepthToSpace::Input::INPUT)};
203 const auto output_index{node.getOutputs().at(0)};
205 int32_t block_size = node.param().block_size;
207 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::INT64,
208 DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
209 OP_REQUIRES(isSameType(input_index, output_index));
211 OP_REQUIRES(block_size > 0);
214 void OperationValidator::visit(const operation::DetectionPostProcess &node)
216 auto param = node.param();
218 // FIXME: number of classes should be 1 for now.
219 OP_REQUIRES(param.num_classes == 1);
222 void OperationValidator::visit(const operation::DepthwiseConv2D &node)
224 const auto input_index{node.getInputs().at(operation::DepthwiseConv2D::Input::INPUT)};
225 const auto kernel_index{node.getInputs().at(operation::DepthwiseConv2D::Input::KERNEL)};
226 const auto output_index{node.getOutputs().at(0)};
228 uint32_t stride_horizontal = node.param().stride.horizontal;
229 uint32_t stride_vertical = node.param().stride.vertical;
230 uint32_t dilation_width = node.param().dilation.width_factor;
231 uint32_t dilation_height = node.param().dilation.height_factor;
233 OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
234 OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
235 OP_REQUIRES(isSameType(input_index, output_index));
237 if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM)
239 for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points())
240 OP_REQUIRES(zeropoint == 0);
244 void OperationValidator::visit(const operation::ElementwiseActivation &node)
246 const auto output_index{node.getOutputs().at(0)};
247 const auto input_index{node.getInputs().at(0)};
249 // Check if I/O types match
250 OP_REQUIRES(isSameType(output_index, input_index));
252 switch (node.param().op_type)
254 case operation::ElementwiseActivation::Type::ELU:
255 OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
257 case operation::ElementwiseActivation::Type::LEAKY_RELU:
259 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
260 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
262 case operation::ElementwiseActivation::Type::LOGISTIC:
264 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
265 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
267 case operation::ElementwiseActivation::Type::RELU:
268 OP_REQUIRES(isValidType(
269 input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
271 case operation::ElementwiseActivation::Type::TANH:
273 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
274 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
279 void OperationValidator::visit(const operation::ElementwiseBinary &node)
281 const auto output_index{node.getOutputs().at(0)};
282 const auto lhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::LHS)};
283 const auto rhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::RHS)};
285 OP_REQUIRES(isSameType(lhs_index, rhs_index));
286 OP_REQUIRES(isSameType(lhs_index, output_index));
288 const auto op_type = node.param().op_type;
289 if (op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND ||
290 op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR)
292 OP_REQUIRES(isValidType(lhs_index, DataType::BOOL8));
296 void OperationValidator::visit(const operation::ElementwiseUnary &node)
298 const auto output_index{node.getOutputs().at(0)};
299 const auto input_index{node.getInputs().at(operation::ElementwiseUnary::Input::INPUT)};
301 // Check if I/O types match
302 if (node.param().op_type == operation::ElementwiseUnary::Type::DEQUANTIZE)
304 // NNAPI allow QUANT_INT8_SYMM type input
305 OP_REQUIRES(isValidType(input_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_SYMM,
306 DataType::QUANT_INT8_ASYMM}));
307 OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
309 else if (node.param().op_type == operation::ElementwiseUnary::Type::QUANTIZE)
311 OP_REQUIRES(isValidType(
312 input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
314 isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
316 else if (node.param().op_type == operation::ElementwiseUnary::Type::FLOOR)
318 OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
319 OP_REQUIRES(isSameType(output_index, input_index));
321 else if (node.param().op_type != operation::ElementwiseUnary::Type::CAST)
323 OP_REQUIRES(isSameType(output_index, input_index));
327 void OperationValidator::visit(const operation::EmbeddingLookup &node)
329 const auto lookups_index{node.getInputs().at(operation::EmbeddingLookup::Input::LOOKUPS)};
330 const auto values_index{node.getInputs().at(operation::EmbeddingLookup::Input::VALUES)};
331 const auto output_index{node.getOutputs().at(0)};
333 OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
335 // TFLite: Allow hybrid type - value table & output
336 // NNAPI: Require same value table and output type
338 isSameType(values_index, output_index) ||
339 (isValidType(output_index, DataType::FLOAT32) &&
340 (isValidType(values_index, {DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT8_SYMM}))));
343 void OperationValidator::visit(const operation::ExpandDims &node)
345 const auto output_index{node.getOutputs().at(0)};
346 const auto input_index{node.getInputs().at(operation::ExpandDims::Input::INPUT)};
347 const auto axis_index{node.getInputs().at(operation::ExpandDims::Input::AXIS)};
349 OP_REQUIRES(isSameType(output_index, input_index));
350 OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
353 void OperationValidator::visit(const operation::Fill &node)
355 const auto output_index{node.getOutputs().at(0)};
356 const auto input_index{node.getInputs().at(operation::Fill::Input::SHAPE)};
357 const auto value_index{node.getInputs().at(operation::Fill::Input::VALUE)};
359 OP_REQUIRES(isSameType(output_index, value_index));
360 OP_REQUIRES(isValidType(input_index, {DataType::INT32, DataType::INT64}));
361 OP_REQUIRES(isValidType(output_index,
362 {DataType::FLOAT32, DataType::INT32, DataType::INT64, DataType::BOOL8}));
365 void OperationValidator::visit(const operation::HashtableLookup &node)
367 const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)};
368 const auto lookups_index{node.getInputs().at(operation::HashtableLookup::Input::LOOKUPS)};
369 const auto keys_index{node.getInputs().at(operation::HashtableLookup::Input::KEYS)};
371 OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
372 OP_REQUIRES(isValidType(keys_index, DataType::INT32));
373 OP_REQUIRES(isValidType(hits_index, DataType::QUANT_UINT8_ASYMM));
376 void OperationValidator::visit(const operation::Pack &node)
378 const auto num{node.param().num};
380 OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
383 void OperationValidator::visit(const operation::Pad &node)
385 const auto output_index{node.getOutputs().at(0)};
386 const auto input_index{node.getInputs().at(operation::Pad::Input::INPUT)};
387 const auto pad_index{node.getInputs().at(operation::Pad::Input::PAD)};
389 isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM});
390 bool isPadV2 = node.getInputs().size() == 3 ? true : false;
392 OP_REQUIRES(isValidType(pad_index, DataType::INT32));
393 OP_REQUIRES(isSameType(input_index, output_index));
396 OP_REQUIRES(isSameQuantParam(input_index, output_index));
400 const auto value_index{node.getInputs().at(operation::Pad::Input::VALUE)};
401 const bool cond_same = isSameType(input_index, value_index);
402 const bool cond_same_quant = (!isQuantType || isSameQuantParam(input_index, value_index));
403 const auto input_t = operandType(input_index);
404 const auto value_t = operandType(value_index);
405 // NNAPI accepts this case. scale and zeroPoint are assumed to be the same as in input0.
406 const bool cond_quant8 =
407 ((input_t == DataType::QUANT_UINT8_ASYMM || input_t == DataType::QUANT_INT8_ASYMM) &&
408 value_t == DataType::INT32);
409 OP_REQUIRES((cond_same && cond_same_quant) || cond_quant8);
413 void OperationValidator::visit(const operation::Rank &node)
415 const auto output_index{node.getOutputs().at(0)};
417 OP_REQUIRES(isValidType(output_index, DataType::INT32));
420 void OperationValidator::visit(const operation::ResizeBilinear &node)
422 auto align_corners = node.param().align_corners;
423 auto half_pixel_centers = node.param().half_pixel_centers;
425 OP_REQUIRES(!align_corners || !half_pixel_centers);
428 void OperationValidator::visit(const operation::Reverse &node)
430 const auto output_index{node.getOutputs().at(0)};
431 const auto input_index{node.getInputs().at(operation::Reverse::Input::INPUT)};
432 const auto axis_index{node.getInputs().at(operation::Reverse::Input::AXIS)};
434 OP_REQUIRES(isValidType(axis_index, DataType::INT32));
435 OP_REQUIRES(isSameType(output_index, input_index));
438 void OperationValidator::visit(const operation::Select &node)
440 const auto condition_index{node.getInputs().at(operation::Select::Input::CONDITION)};
441 const auto input_true_index{node.getInputs().at(operation::Select::Input::INPUT_TRUE)};
442 const auto input_false_index{node.getInputs().at(operation::Select::Input::INPUT_FALSE)};
444 OP_REQUIRES(isValidType(condition_index, DataType::BOOL8));
445 OP_REQUIRES(isSameType(input_true_index, input_false_index));
448 void OperationValidator::visit(const operation::Shape &node)
450 const auto output_index{node.getOutputs().at(0)};
452 OP_REQUIRES(isValidType(output_index, {DataType::UINT32, DataType::INT32, DataType::INT64}));
455 void OperationValidator::visit(const operation::Slice &node)
457 const auto begins_index{node.getInputs().at(operation::Slice::BEGINS)};
458 const auto sizes_index{node.getInputs().at(operation::Slice::SIZES)};
460 OP_REQUIRES(isValidType(begins_index, {DataType::INT32, DataType::INT64}));
461 OP_REQUIRES(isSameType(begins_index, sizes_index));
464 void OperationValidator::visit(const operation::Softmax &node)
466 const auto output_index{node.getOutputs().at(0)};
467 const auto input_index{node.getInputs().at(operation::Softmax::INPUT)};
469 OP_REQUIRES(isSameType(input_index, output_index));
470 OP_REQUIRES(isValidType(
471 output_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
474 void OperationValidator::visit(const operation::SpaceToBatchND &node)
476 const auto block_size_index{node.getInputs().at(operation::SpaceToBatchND::Input::BLOCK_SIZE)};
477 const auto paddings_index{node.getInputs().at(operation::SpaceToBatchND::Input::PADDINGS)};
479 // Non-constant block_size and padding is not implemented yet
480 OP_REQUIRES(isConstant(block_size_index));
481 OP_REQUIRES(isConstant(paddings_index));
484 void OperationValidator::visit(const operation::SpaceToDepth &node)
486 const auto block_size = node.param().block_size;
487 OP_REQUIRES(block_size >= 1);
490 void OperationValidator::visit(const operation::Split &node)
492 const auto num_splits = node.param().num_splits;
494 OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
495 OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
498 void OperationValidator::visit(const operation::SquaredDifference &node)
500 const auto output_index{node.getOutputs().at(0)};
501 const auto lhs_index{node.getInputs().at(operation::SquaredDifference::Input::LHS)};
502 const auto rhs_index{node.getInputs().at(operation::SquaredDifference::Input::RHS)};
504 OP_REQUIRES(isSameType(output_index, lhs_index));
505 OP_REQUIRES(isSameType(lhs_index, rhs_index));
508 void OperationValidator::visit(const operation::StatelessRandomUniform &node)
510 const auto output_index{node.getOutputs().at(0)};
511 const auto shape_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SHAPE)};
512 const auto seed_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SEED)};
514 OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
515 OP_REQUIRES(isValidType(shape_index, DataType::INT32));
516 OP_REQUIRES(isValidType(seed_index, DataType::INT32));
519 void OperationValidator::visit(const operation::StridedSlice &node)
521 const auto output_index{node.getOutputs().at(0)};
522 const auto input_index{node.getInputs().at(operation::StridedSlice::Input::INPUT)};
524 OP_REQUIRES(isSameType(output_index, input_index));
527 void OperationValidator::visit(const operation::TransposeConv &node)
529 OP_REQUIRES((node.param().padding.type == PaddingType::SAME) ||
530 (node.param().padding.type == PaddingType::VALID));
533 void OperationValidator::visit(const operation::Unpack &node)
535 const auto num{node.param().num};
536 OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
539 void OperationValidator::visit(const operation::While &node)
541 OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());