2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "OperationValidator.h"
21 #define OP_REQUIRES(EXP) \
25 throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
33 OperationValidator::OperationValidator(const Graph &graph)
34 : _operations{graph.operations()}, _operands{graph.operands()}
38 void OperationValidator::operator()()
40 _operations.iterate([&](const OperationIndex &, const Operation &node) { node.accept(*this); });
43 DataType OperationValidator::operandType(const OperandIndex &idx)
45 return _operands.at(idx).typeInfo().type();
48 bool OperationValidator::isConstant(const OperandIndex &idx)
50 return _operands.at(idx).isConstant();
53 bool OperationValidator::isSameType(const OperandIndex &idx1, const OperandIndex &idx2)
55 return operandType(idx1) == operandType(idx2);
58 bool OperationValidator::isSameQuantParam(const OperandIndex &idx1, const OperandIndex &idx2)
60 if (_operands.at(idx1).typeInfo().scale() != _operands.at(idx2).typeInfo().scale())
63 if (_operands.at(idx1).typeInfo().offset() != _operands.at(idx2).typeInfo().offset())
69 bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &type)
71 return operandType(idx) == type;
74 bool OperationValidator::isValidType(const OperandIndex &idx,
75 std::initializer_list<DataType> valid_types)
77 for (auto type_to_check : valid_types)
79 if (isValidType(idx, type_to_check))
88 void OperationValidator::visit(const operation::AddN &node)
90 const auto output_index(node.getOutputs().at(0));
92 int size = node.getInputs().size();
93 for (int i = 0; i < size; i++)
95 const auto input_index(node.getInputs().at(i));
96 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32}));
97 OP_REQUIRES(isSameType(input_index, output_index));
101 void OperationValidator::visit(const operation::ArgMinMax &node)
103 const auto input_index(node.getInputs().at(operation::ArgMinMax::Input::INPUT));
104 const auto axis_index(node.getInputs().at(operation::ArgMinMax::Input::AXIS));
105 const auto output_index(node.getOutputs().at(0));
106 const auto output_type = node.param().output_type;
108 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::UINT8,
109 DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
110 OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
111 OP_REQUIRES(isValidType(output_index, {DataType::INT32, DataType::INT64}));
112 OP_REQUIRES(isValidType(output_index, output_type));
115 void OperationValidator::visit(const operation::BatchMatMul &node)
117 const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS));
118 const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS));
119 const auto output_index(node.getOutputs().at(0));
121 // Constant lhs and rhs is not implemented yet
122 OP_REQUIRES(!isConstant(lhs_index) && !isConstant(rhs_index));
124 // Allow hybrid quantization (lhs: float / rhs: qint8 / out: float)
125 OP_REQUIRES(isValidType(lhs_index, {DataType::FLOAT32, DataType::QUANT_INT8_ASYMM}));
126 OP_REQUIRES(isSameType(lhs_index, rhs_index) ||
127 ((operandType(lhs_index) == DataType::FLOAT32) &&
128 (operandType(rhs_index) == DataType::QUANT_INT8_ASYMM)));
129 OP_REQUIRES(isSameType(lhs_index, output_index));
132 void OperationValidator::visit(const operation::BatchToSpaceND &node)
134 const auto input_index{node.getInputs().at(operation::BatchToSpaceND::Input::INPUT)};
135 const auto output_index{node.getOutputs().at(0)};
137 OP_REQUIRES(isSameType(input_index, output_index));
140 void OperationValidator::visit(const operation::BinaryArithmetic &node)
142 const auto output_index{node.getOutputs().at(0)};
143 const auto lhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::LHS)};
144 const auto rhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::RHS)};
146 OP_REQUIRES(isSameType(lhs_index, rhs_index));
147 OP_REQUIRES(isSameType(lhs_index, output_index));
150 void OperationValidator::visit(const operation::Comparison &node)
152 const auto output_index{node.getOutputs().at(0)};
154 const auto lhs_index{node.getInputs().at(operation::Comparison::Input::INPUT0)};
155 const auto rhs_index{node.getInputs().at(operation::Comparison::Input::INPUT1)};
157 OP_REQUIRES(isSameType(lhs_index, rhs_index));
158 OP_REQUIRES(isValidType(output_index, DataType::BOOL8));
161 void OperationValidator::visit(const operation::Concat &node)
163 const auto output_index{node.getOutputs().at(0)};
165 for (auto input_index : node.getInputs())
167 OP_REQUIRES(isSameType(input_index, output_index));
169 // Int8 quantization requires same scale and zero point
170 if (isValidType(output_index, DataType::QUANT_INT8_ASYMM))
172 OP_REQUIRES(isSameQuantParam(input_index, output_index));
177 void OperationValidator::visit(const operation::Conv2D &node)
179 const auto input_index{node.getInputs().at(operation::Conv2D::Input::INPUT)};
180 const auto output_index{node.getOutputs().at(0)};
182 uint32_t stride_horizontal = node.param().stride.horizontal;
183 uint32_t stride_vertical = node.param().stride.vertical;
184 uint32_t dilation_width = node.param().dilation.width_factor;
185 uint32_t dilation_height = node.param().dilation.height_factor;
187 OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
188 OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
189 OP_REQUIRES(isSameType(input_index, output_index));
192 void OperationValidator::visit(const operation::DepthToSpace &node)
194 const auto input_index{node.getInputs().at(operation::DepthToSpace::Input::INPUT)};
195 const auto output_index{node.getOutputs().at(0)};
197 int32_t block_size = node.param().block_size;
199 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::INT64,
200 DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
201 OP_REQUIRES(isSameType(input_index, output_index));
203 OP_REQUIRES(block_size > 0);
206 void OperationValidator::visit(const operation::DepthwiseConv2D &node)
208 const auto input_index{node.getInputs().at(operation::DepthwiseConv2D::Input::INPUT)};
209 const auto output_index{node.getOutputs().at(0)};
211 uint32_t stride_horizontal = node.param().stride.horizontal;
212 uint32_t stride_vertical = node.param().stride.vertical;
213 uint32_t dilation_width = node.param().dilation.width_factor;
214 uint32_t dilation_height = node.param().dilation.height_factor;
216 OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
217 OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
218 OP_REQUIRES(isSameType(input_index, output_index));
221 void OperationValidator::visit(const operation::ElementwiseActivation &node)
223 const auto output_index{node.getOutputs().at(0)};
224 const auto input_index{node.getInputs().at(0)};
226 // Check if I/O types match
227 OP_REQUIRES(isSameType(output_index, input_index));
229 switch (node.param().op_type)
231 case operation::ElementwiseActivation::Type::ELU:
232 OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
234 case operation::ElementwiseActivation::Type::LEAKY_RELU:
236 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
237 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
239 case operation::ElementwiseActivation::Type::LOGISTIC:
241 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
242 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
244 case operation::ElementwiseActivation::Type::RELU:
245 OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
246 DataType::QUANT_INT8_ASYMM}));
248 case operation::ElementwiseActivation::Type::TANH:
250 isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
251 DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
256 void OperationValidator::visit(const operation::ElementwiseBinary &node)
258 const auto output_index{node.getOutputs().at(0)};
259 const auto lhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::LHS)};
260 const auto rhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::RHS)};
262 OP_REQUIRES(isSameType(lhs_index, rhs_index));
263 OP_REQUIRES(isSameType(lhs_index, output_index));
265 const auto op_type = node.param().op_type;
266 if (op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND ||
267 op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR)
269 OP_REQUIRES(isValidType(lhs_index, DataType::BOOL8));
273 void OperationValidator::visit(const operation::ElementwiseUnary &node)
275 const auto output_index{node.getOutputs().at(0)};
276 const auto input_index{node.getInputs().at(operation::ElementwiseUnary::Input::INPUT)};
278 // Check if I/O types match
279 if (node.param().op_type == operation::ElementwiseUnary::Type::DEQUANTIZE)
281 // NNAPI allow QUANT_INT8_SYMM type input
282 OP_REQUIRES(isValidType(input_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_SYMM,
283 DataType::QUANT_INT8_ASYMM}));
284 OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
286 else if (node.param().op_type == operation::ElementwiseUnary::Type::QUANTIZE)
288 OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
289 OP_REQUIRES(isValidType(output_index, DataType::QUANT_UINT8_ASYMM));
291 else if (node.param().op_type == operation::ElementwiseUnary::Type::FLOOR)
293 OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
294 OP_REQUIRES(isSameType(output_index, input_index));
296 else if (node.param().op_type != operation::ElementwiseUnary::Type::CAST)
298 OP_REQUIRES(isSameType(output_index, input_index));
302 void OperationValidator::visit(const operation::EmbeddingLookup &node)
304 const auto lookups_index{node.getInputs().at(operation::EmbeddingLookup::Input::LOOKUPS)};
305 const auto values_index{node.getInputs().at(operation::EmbeddingLookup::Input::VALUES)};
306 const auto output_index{node.getOutputs().at(0)};
308 OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
310 // TFLite: Allow hybrid type - value table & output
311 // NNAPI: Require same value table and output type
313 isSameType(values_index, output_index) ||
314 (isValidType(output_index, DataType::FLOAT32) &&
315 (isValidType(values_index, {DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT8_SYMM}))));
318 void OperationValidator::visit(const operation::ExpandDims &node)
320 const auto output_index{node.getOutputs().at(0)};
321 const auto input_index{node.getInputs().at(operation::ExpandDims::Input::INPUT)};
322 const auto axis_index{node.getInputs().at(operation::ExpandDims::Input::AXIS)};
324 OP_REQUIRES(isSameType(output_index, input_index));
325 OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
328 void OperationValidator::visit(const operation::Fill &node)
330 const auto output_index{node.getOutputs().at(0)};
331 const auto input_index{node.getInputs().at(operation::Fill::Input::SHAPE)};
332 const auto value_index{node.getInputs().at(operation::Fill::Input::VALUE)};
334 OP_REQUIRES(isSameType(output_index, value_index));
335 OP_REQUIRES(isValidType(input_index, {DataType::INT32, DataType::INT64}));
336 OP_REQUIRES(isValidType(output_index,
337 {DataType::FLOAT32, DataType::INT32, DataType::INT64, DataType::BOOL8}));
340 void OperationValidator::visit(const operation::HashtableLookup &node)
342 const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)};
343 const auto lookups_index{node.getInputs().at(operation::HashtableLookup::Input::LOOKUPS)};
344 const auto keys_index{node.getInputs().at(operation::HashtableLookup::Input::KEYS)};
346 OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
347 OP_REQUIRES(isValidType(keys_index, DataType::INT32));
348 OP_REQUIRES(isValidType(hits_index, DataType::QUANT_UINT8_ASYMM));
351 void OperationValidator::visit(const operation::Pack &node)
353 const auto num{node.param().num};
355 OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
358 void OperationValidator::visit(const operation::Pad &node)
360 const auto pad_index{node.getInputs().at(operation::Pad::Input::PAD)};
362 OP_REQUIRES(isValidType(pad_index, DataType::INT32));
365 void OperationValidator::visit(const operation::Rank &node)
367 const auto output_index{node.getOutputs().at(0)};
369 OP_REQUIRES(isValidType(output_index, DataType::INT32));
372 void OperationValidator::visit(const operation::ResizeBilinear &node)
374 auto align_corners = node.param().align_corners;
375 auto half_pixel_centers = node.param().half_pixel_centers;
377 OP_REQUIRES(!align_corners || !half_pixel_centers);
380 void OperationValidator::visit(const operation::Reverse &node)
382 const auto output_index{node.getOutputs().at(0)};
383 const auto input_index{node.getInputs().at(operation::Reverse::Input::INPUT)};
384 const auto axis_index{node.getInputs().at(operation::Reverse::Input::AXIS)};
386 OP_REQUIRES(isValidType(axis_index, DataType::INT32));
387 OP_REQUIRES(isSameType(output_index, input_index));
390 void OperationValidator::visit(const operation::Select &node)
392 const auto condition_index{node.getInputs().at(operation::Select::Input::CONDITION)};
393 const auto input_true_index{node.getInputs().at(operation::Select::Input::INPUT_TRUE)};
394 const auto input_false_index{node.getInputs().at(operation::Select::Input::INPUT_FALSE)};
396 OP_REQUIRES(isValidType(condition_index, DataType::BOOL8));
397 OP_REQUIRES(isSameType(input_true_index, input_false_index));
400 void OperationValidator::visit(const operation::Shape &node)
402 const auto output_index{node.getOutputs().at(0)};
404 OP_REQUIRES(isValidType(output_index, {DataType::UINT32, DataType::INT32, DataType::INT64}));
407 void OperationValidator::visit(const operation::SpaceToBatchND &node)
409 const auto block_size_index{node.getInputs().at(operation::SpaceToBatchND::Input::BLOCK_SIZE)};
410 const auto paddings_index{node.getInputs().at(operation::SpaceToBatchND::Input::PADDINGS)};
412 // Non-constant block_size and padding is not implemented yet
413 OP_REQUIRES(isConstant(block_size_index));
414 OP_REQUIRES(isConstant(paddings_index));
417 void OperationValidator::visit(const operation::SpaceToDepth &node)
419 const auto block_size = node.param().block_size;
420 OP_REQUIRES(block_size >= 1);
423 void OperationValidator::visit(const operation::Split &node)
425 const auto num_splits = node.param().num_splits;
427 OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
428 OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
431 void OperationValidator::visit(const operation::SquaredDifference &node)
433 const auto output_index{node.getOutputs().at(0)};
434 const auto lhs_index{node.getInputs().at(operation::SquaredDifference::Input::LHS)};
435 const auto rhs_index{node.getInputs().at(operation::SquaredDifference::Input::RHS)};
437 OP_REQUIRES(isSameType(output_index, lhs_index));
438 OP_REQUIRES(isSameType(lhs_index, rhs_index));
441 void OperationValidator::visit(const operation::StridedSlice &node)
443 const auto output_index{node.getOutputs().at(0)};
444 const auto input_index{node.getInputs().at(operation::StridedSlice::Input::INPUT)};
446 OP_REQUIRES(isSameType(output_index, input_index));
449 void OperationValidator::visit(const operation::TransposeConv &node)
451 OP_REQUIRES((node.param().padding.type == PaddingType::SAME) ||
452 (node.param().padding.type == PaddingType::VALID));
455 void OperationValidator::visit(const operation::Unpack &node)
457 const auto num{node.param().num};
458 OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
461 void OperationValidator::visit(const operation::While &node)
463 OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
466 } // namespace compiler