da08e81fc8fce04fd5a100a042f252a751a89ff0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / OperationValidator.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "OperationValidator.h"
18
19 #include "ir/Graph.h"
20
21 #define OP_REQUIRES(EXP)                                                                         \
22   do                                                                                             \
23   {                                                                                              \
24     if (!(EXP))                                                                                  \
25       throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
26   } while (0)
27
28 namespace onert
29 {
30 namespace ir
31 {
32
33 OperationValidator::OperationValidator(const Graph &graph)
34     : _operations{graph.operations()}, _operands{graph.operands()}
35 {
36 }
37
38 void OperationValidator::operator()()
39 {
40   _operations.iterate([&](const OperationIndex &, const Operation &node) { node.accept(*this); });
41 }
42
43 DataType OperationValidator::operandType(const OperandIndex &idx)
44 {
45   return _operands.at(idx).typeInfo().type();
46 }
47
48 bool OperationValidator::isConstant(const OperandIndex &idx)
49 {
50   return _operands.at(idx).isConstant();
51 }
52
53 bool OperationValidator::isSameType(const OperandIndex &idx1, const OperandIndex &idx2)
54 {
55   return operandType(idx1) == operandType(idx2);
56 }
57
58 bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &type)
59 {
60   return operandType(idx) == type;
61 }
62
63 bool OperationValidator::isValidType(const OperandIndex &idx,
64                                      std::initializer_list<DataType> valid_types)
65 {
66   for (auto type_to_check : valid_types)
67   {
68     if (isValidType(idx, type_to_check))
69     {
70       return true;
71     }
72   }
73
74   return false;
75 }
76
77 void OperationValidator::visit(const operation::AddN &node)
78 {
79   int size = node.getInputs().size();
80   for (int i = 0; i < size; i++)
81   {
82     const auto input_index(node.getInputs().at(i));
83     OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32}));
84   }
85 }
86
87 void OperationValidator::visit(const operation::BatchMatMul &node)
88 {
89   const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS));
90   const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS));
91
92   // Constant lhs and rhs is not implemented yet
93   OP_REQUIRES(!isConstant(lhs_index) && !isConstant(rhs_index));
94 }
95
96 void OperationValidator::visit(const operation::BatchToSpaceND &node)
97 {
98   const auto block_size_index{node.getInputs().at(operation::BatchToSpaceND::Input::BLOCK_SIZE)};
99
100   // Non-constant block_size is not implemented yet
101   OP_REQUIRES(isConstant(block_size_index));
102 }
103
104 void OperationValidator::visit(const operation::BinaryArithmetic &node)
105 {
106   const auto output_index{node.getOutputs().at(0)};
107   const auto lhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::LHS)};
108   const auto rhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::RHS)};
109
110   OP_REQUIRES(isSameType(lhs_index, rhs_index));
111   OP_REQUIRES(isSameType(lhs_index, output_index));
112 }
113
114 void OperationValidator::visit(const operation::Comparison &node)
115 {
116   const auto output_index{node.getOutputs().at(0)};
117
118   const auto lhs_index{node.getInputs().at(operation::Comparison::Input::INPUT0)};
119   const auto rhs_index{node.getInputs().at(operation::Comparison::Input::INPUT1)};
120
121   OP_REQUIRES(isSameType(lhs_index, rhs_index));
122   OP_REQUIRES(isValidType(output_index, DataType::BOOL8));
123 }
124
125 void OperationValidator::visit(const operation::DepthToSpace &node)
126 {
127   int32_t block_size = node.param().block_size;
128
129   OP_REQUIRES(block_size > 0);
130 }
131
132 void OperationValidator::visit(const operation::DepthwiseConv2D &node)
133 {
134   const auto input_index{node.getInputs().at(operation::DepthwiseConv2D::Input::INPUT)};
135   const auto output_index{node.getOutputs().at(0)};
136
137   uint32_t stride_horizontal = node.param().stride.horizontal;
138   uint32_t stride_vertical = node.param().stride.vertical;
139   uint32_t dilation_width = node.param().dilation.width_factor;
140   uint32_t dilation_height = node.param().dilation.height_factor;
141
142   OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
143   OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
144   OP_REQUIRES(isSameType(input_index, output_index));
145 }
146
147 void OperationValidator::visit(const operation::ElementwiseActivation &node)
148 {
149   const auto output_index{node.getOutputs().at(0)};
150   const auto input_index{node.getInputs().at(0)};
151
152   // Check if I/O types match
153   OP_REQUIRES(isSameType(output_index, input_index));
154 }
155
156 void OperationValidator::visit(const operation::ElementwiseBinary &node)
157 {
158   const auto output_index{node.getOutputs().at(0)};
159   const auto lhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::LHS)};
160   const auto rhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::RHS)};
161
162   OP_REQUIRES(isSameType(lhs_index, rhs_index));
163   OP_REQUIRES(isSameType(lhs_index, output_index));
164 }
165
166 void OperationValidator::visit(const operation::ElementwiseUnary &node)
167 {
168   const auto output_index{node.getOutputs().at(0)};
169   const auto input_index{node.getInputs().at(operation::ElementwiseUnary::Input::INPUT)};
170
171   // Check if I/O types match
172   if (node.param().op_type == operation::ElementwiseUnary::Type::DEQUANTIZE)
173   {
174     // NNAPI allow QUANT_INT8_SYMM type input
175     OP_REQUIRES(isValidType(input_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_SYMM,
176                                           DataType::QUANT_INT8_ASYMM}));
177     OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
178   }
179   else if (node.param().op_type == operation::ElementwiseUnary::Type::QUANTIZE)
180   {
181     OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
182     OP_REQUIRES(isValidType(output_index, DataType::QUANT_UINT8_ASYMM));
183   }
184   else if (node.param().op_type == operation::ElementwiseUnary::Type::FLOOR)
185   {
186     OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
187     OP_REQUIRES(isSameType(output_index, input_index));
188   }
189   else if (node.param().op_type != operation::ElementwiseUnary::Type::CAST)
190   {
191     OP_REQUIRES(isSameType(output_index, input_index));
192   }
193 }
194
195 void OperationValidator::visit(const operation::EmbeddingLookup &node)
196 {
197   const auto lookups_index{node.getInputs().at(operation::EmbeddingLookup::Input::LOOKUPS)};
198
199   OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
200 }
201
202 void OperationValidator::visit(const operation::ExpandDims &node)
203 {
204   const auto output_index{node.getOutputs().at(0)};
205   const auto input_index{node.getInputs().at(operation::ExpandDims::Input::INPUT)};
206   const auto axis_index{node.getInputs().at(operation::ExpandDims::Input::AXIS)};
207
208   OP_REQUIRES(isSameType(output_index, input_index));
209   OP_REQUIRES(isValidType(axis_index, DataType::INT32));
210 }
211
212 void OperationValidator::visit(const operation::HashtableLookup &node)
213 {
214   const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)};
215   const auto lookups_index{node.getInputs().at(operation::HashtableLookup::Input::LOOKUPS)};
216   const auto keys_index{node.getInputs().at(operation::HashtableLookup::Input::KEYS)};
217
218   OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
219   OP_REQUIRES(isValidType(keys_index, DataType::INT32));
220   OP_REQUIRES(isValidType(hits_index, DataType::QUANT_UINT8_ASYMM));
221 }
222
223 void OperationValidator::visit(const operation::Pack &node)
224 {
225   const auto num{node.param().num};
226
227   OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
228 }
229
230 void OperationValidator::visit(const operation::Pad &node)
231 {
232   const auto pad_index{node.getInputs().at(operation::Pad::Input::PAD)};
233
234   OP_REQUIRES(isValidType(pad_index, DataType::INT32));
235 }
236
237 void OperationValidator::visit(const operation::Rank &node)
238 {
239   const auto output_index{node.getOutputs().at(0)};
240
241   OP_REQUIRES(isValidType(output_index, DataType::INT32));
242 }
243
244 void OperationValidator::visit(const operation::ResizeBilinear &node)
245 {
246   auto align_corners = node.param().align_corners;
247   auto half_pixel_centers = node.param().half_pixel_centers;
248
249   OP_REQUIRES(!align_corners || !half_pixel_centers);
250 }
251
252 void OperationValidator::visit(const operation::Reverse &node)
253 {
254   const auto output_index{node.getOutputs().at(0)};
255   const auto input_index{node.getInputs().at(operation::Reverse::Input::INPUT)};
256   const auto axis_index{node.getInputs().at(operation::Reverse::Input::AXIS)};
257
258   OP_REQUIRES(isValidType(axis_index, DataType::INT32));
259   OP_REQUIRES(isSameType(output_index, input_index));
260 }
261
262 void OperationValidator::visit(const operation::Select &node)
263 {
264   const auto condition_index{node.getInputs().at(operation::Select::Input::CONDITION)};
265   const auto input_true_index{node.getInputs().at(operation::Select::Input::INPUT_TRUE)};
266   const auto input_false_index{node.getInputs().at(operation::Select::Input::INPUT_FALSE)};
267
268   OP_REQUIRES(isValidType(condition_index, DataType::BOOL8));
269   OP_REQUIRES(isSameType(input_true_index, input_false_index));
270 }
271
272 void OperationValidator::visit(const operation::Shape &node)
273 {
274   const auto output_index{node.getOutputs().at(0)};
275
276   OP_REQUIRES(isValidType(output_index, {DataType::UINT32, DataType::INT32, DataType::INT64}));
277 }
278
279 void OperationValidator::visit(const operation::SpaceToBatchND &node)
280 {
281   const auto block_size_index{node.getInputs().at(operation::SpaceToBatchND::Input::BLOCK_SIZE)};
282   const auto paddings_index{node.getInputs().at(operation::SpaceToBatchND::Input::PADDINGS)};
283
284   // Non-constant block_size and padding is not implemented yet
285   OP_REQUIRES(isConstant(block_size_index));
286   OP_REQUIRES(isConstant(paddings_index));
287 }
288
289 void OperationValidator::visit(const operation::SpaceToDepth &node)
290 {
291   const auto block_size = node.param().block_size;
292   OP_REQUIRES(block_size >= 1);
293 }
294
295 void OperationValidator::visit(const operation::Split &node)
296 {
297   const auto num_splits = node.param().num_splits;
298
299   OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
300   OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
301 }
302
303 void OperationValidator::visit(const operation::SquaredDifference &node)
304 {
305   const auto output_index{node.getOutputs().at(0)};
306   const auto lhs_index{node.getInputs().at(operation::SquaredDifference::Input::LHS)};
307   const auto rhs_index{node.getInputs().at(operation::SquaredDifference::Input::RHS)};
308
309   OP_REQUIRES(isSameType(output_index, lhs_index));
310   OP_REQUIRES(isSameType(lhs_index, rhs_index));
311 }
312
313 void OperationValidator::visit(const operation::StridedSlice &node)
314 {
315   const auto output_index{node.getOutputs().at(0)};
316   const auto input_index{node.getInputs().at(operation::StridedSlice::Input::INPUT)};
317
318   OP_REQUIRES(isSameType(output_index, input_index));
319 }
320
321 void OperationValidator::visit(const operation::TransposeConv &node)
322 {
323   OP_REQUIRES((node.param().padding.type == PaddingType::SAME) ||
324               (node.param().padding.type == PaddingType::VALID));
325 }
326
327 void OperationValidator::visit(const operation::Unpack &node)
328 {
329   const auto num{node.param().num};
330   OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
331 }
332
333 void OperationValidator::visit(const operation::While &node)
334 {
335   OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
336 }
337
338 } // namespace compiler
339 } // namespace onert