1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include <initializer_list>
26 #include <ngraph/runtime/host_tensor.hpp>
27 #include "backend.hpp"
28 #include "int_backend_visibility.hpp"
29 #include "ngraph/ops.hpp"
30 #include "ngraph/runtime/aligned_buffer.hpp"
31 #include "ngraph/runtime/reference/abs.hpp"
32 #include "ngraph/runtime/reference/acos.hpp"
33 #include "ngraph/runtime/reference/any.hpp"
34 #include "ngraph/runtime/reference/asin.hpp"
35 #include "ngraph/runtime/reference/atan.hpp"
36 #include "ngraph/runtime/reference/atan2.hpp"
37 #include "ngraph/runtime/reference/avg_pool.hpp"
38 #include "ngraph/runtime/reference/batch_norm.hpp"
39 #include "ngraph/runtime/reference/broadcast.hpp"
40 #include "ngraph/runtime/reference/ceiling.hpp"
41 #include "ngraph/runtime/reference/concat.hpp"
42 #include "ngraph/runtime/reference/constant.hpp"
43 #include "ngraph/runtime/reference/convert.hpp"
44 #include "ngraph/runtime/reference/convolution.hpp"
45 #include "ngraph/runtime/reference/cos.hpp"
46 #include "ngraph/runtime/reference/cosh.hpp"
47 #include "ngraph/runtime/reference/ctc_loss.hpp"
48 #include "ngraph/runtime/reference/cum_sum.hpp"
49 #include "ngraph/runtime/reference/dequantize.hpp"
50 #include "ngraph/runtime/reference/detection_output.hpp"
51 #include "ngraph/runtime/reference/dot.hpp"
52 #include "ngraph/runtime/reference/elu.hpp"
53 #include "ngraph/runtime/reference/embedding_bag_offsets_sum.hpp"
54 #include "ngraph/runtime/reference/embedding_bag_packed_sum.hpp"
55 #include "ngraph/runtime/reference/embedding_segments_sum.hpp"
56 #include "ngraph/runtime/reference/erf.hpp"
57 #include "ngraph/runtime/reference/exp.hpp"
58 #include "ngraph/runtime/reference/extract_image_patches.hpp"
59 #include "ngraph/runtime/reference/floor.hpp"
60 #include "ngraph/runtime/reference/gather.hpp"
61 #include "ngraph/runtime/reference/gather_nd.hpp"
62 #include "ngraph/runtime/reference/log.hpp"
63 #include "ngraph/runtime/reference/lrn.hpp"
64 #include "ngraph/runtime/reference/matmul.hpp"
65 #include "ngraph/runtime/reference/max.hpp"
66 #include "ngraph/runtime/reference/max_pool.hpp"
67 #include "ngraph/runtime/reference/min.hpp"
68 #include "ngraph/runtime/reference/negate.hpp"
69 #include "ngraph/runtime/reference/not.hpp"
70 #include "ngraph/runtime/reference/one_hot.hpp"
71 #include "ngraph/runtime/reference/pad.hpp"
72 #include "ngraph/runtime/reference/product.hpp"
73 #include "ngraph/runtime/reference/quantize.hpp"
74 #include "ngraph/runtime/reference/relu.hpp"
75 #include "ngraph/runtime/reference/replace_slice.hpp"
76 #include "ngraph/runtime/reference/reshape.hpp"
77 #include "ngraph/runtime/reference/result.hpp"
78 #include "ngraph/runtime/reference/reverse.hpp"
79 #include "ngraph/runtime/reference/reverse_sequence.hpp"
80 #include "ngraph/runtime/reference/round.hpp"
81 #include "ngraph/runtime/reference/scatter_nd_update.hpp"
82 #include "ngraph/runtime/reference/scatter_update.hpp"
83 #include "ngraph/runtime/reference/select.hpp"
84 #include "ngraph/runtime/reference/sigmoid.hpp"
85 #include "ngraph/runtime/reference/sign.hpp"
86 #include "ngraph/runtime/reference/sin.hpp"
87 #include "ngraph/runtime/reference/sinh.hpp"
88 #include "ngraph/runtime/reference/softmax.hpp"
89 #include "ngraph/runtime/reference/sqrt.hpp"
90 #include "ngraph/runtime/reference/sum.hpp"
91 #include "ngraph/runtime/reference/tan.hpp"
92 #include "ngraph/runtime/reference/tanh.hpp"
93 #include "ngraph/runtime/reference/topk.hpp"
94 #include "ngraph/runtime/tensor.hpp"
95 #include "op/avg_pool.hpp"
96 #include "op/convolution.hpp"
97 #include "op/group_conv.hpp"
99 NGRAPH_SUPPRESS_DEPRECATED_START
105 namespace interpreter
110 // This expands the op list in op_tbl.hpp into a list of enumerations that look like
117 #define NGRAPH_OP(NAME, NAMESPACE) ID_SUFFIX(NAME),
118 #include "opset_int_tbl.hpp"
122 } // namespace interpreter
123 } // namespace runtime
124 } // namespace ngraph
126 class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : public Executable
128 friend class INTBackend;
131 INTExecutable(const std::shared_ptr<Function>& function,
132 bool enable_performance_collection = false);
134 bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
135 const std::vector<std::shared_ptr<Tensor>>& inputs) override;
137 void set_nan_check(bool enable);
139 std::vector<PerformanceCounter> get_performance_data() const override;
141 std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index) override;
143 std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index) override;
145 std::vector<std::shared_ptr<runtime::Tensor>>
146 create_input_tensor(size_t input_index, size_t pipeline_depth) override;
148 std::vector<std::shared_ptr<runtime::Tensor>>
149 create_output_tensor(size_t output_index, size_t pipeline_depth) override;
152 std::shared_ptr<ngraph::op::Parameter> get_parameter(size_t index) const;
153 std::shared_ptr<ngraph::op::Result> get_result(size_t index) const;
154 int get_alignment() const { return 64; }
155 bool m_is_compiled = false;
156 bool m_nan_check_enabled = false;
157 bool m_performance_counters_enabled = false;
158 std::shared_ptr<Function> m_function;
159 std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
160 std::vector<std::shared_ptr<Node>> m_nodes;
161 std::set<std::string> m_unsupported_op_name_list;
163 static OP_TYPEID get_typeid(const Node& node);
165 static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
166 const Node* op = nullptr);
168 virtual void generate_calls(const element::Type& type,
170 const std::vector<std::shared_ptr<HostTensor>>& outputs,
171 const std::vector<std::shared_ptr<HostTensor>>& inputs);
173 template <typename T>
174 void op_engine(const Node& node,
175 const std::vector<std::shared_ptr<HostTensor>>& out,
176 const std::vector<std::shared_ptr<HostTensor>>& args)
178 // We want to check that every OP_TYPEID enumeration is included in the list.
179 // These GCC flags enable compile-time checking so that if an enumeration
180 // is not in the list an error is generated.
181 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
182 #pragma GCC diagnostic push
183 #pragma GCC diagnostic error "-Wswitch"
184 #pragma GCC diagnostic error "-Wswitch-enum"
186 switch (get_typeid(node))
190 size_t element_count = shape_size(node.get_output_shape(0));
192 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
195 case OP_TYPEID::Acos:
197 size_t element_count = shape_size(node.get_output_shape(0));
199 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
204 const op::Any* any = static_cast<const op::Any*>(&node);
205 reference::any(args[0]->get_data_ptr<const char>(),
206 out[0]->get_data_ptr<char>(),
207 node.get_input_shape(0),
208 any->get_reduction_axes(),
212 case OP_TYPEID::Asin:
214 size_t element_count = shape_size(node.get_output_shape(0));
216 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
219 case OP_TYPEID::Atan:
221 size_t element_count = shape_size(node.get_output_shape(0));
223 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
228 const op::Elu* elu_node = static_cast<const op::Elu*>(&node);
230 size_t element_count = shape_size(node.get_output_shape(0));
231 reference::elu<T>(args[0]->get_data_ptr<const T>(),
232 out[0]->get_data_ptr<T>(),
234 elu_node->get_alpha());
237 case OP_TYPEID::AvgPool:
239 const op::v0::AvgPool* avg_pool = static_cast<const op::v0::AvgPool*>(&node);
241 reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
242 out[0]->get_data_ptr<T>(),
243 node.get_input_shape(0),
244 node.get_output_shape(0),
245 avg_pool->get_window_shape(),
246 avg_pool->get_window_movement_strides(),
247 avg_pool->get_padding_below(),
248 avg_pool->get_padding_above(),
249 avg_pool->get_include_padding_in_avg_computation());
252 case OP_TYPEID::BatchNormInference:
254 const ngraph::op::BatchNormInference* bn =
255 static_cast<const ngraph::op::BatchNormInference*>(&node);
256 reference::batch_norm_inference<T>(bn->get_eps_value(),
257 args[0]->get_data_ptr<const T>(),
258 args[1]->get_data_ptr<const T>(),
259 args[2]->get_data_ptr<const T>(),
260 args[3]->get_data_ptr<const T>(),
261 args[4]->get_data_ptr<const T>(),
262 out[0]->get_data_ptr<T>(),
263 node.get_input_shape(2));
266 case OP_TYPEID::BroadcastLike: break;
267 case OP_TYPEID::Ceiling:
269 size_t element_count = shape_size(node.get_output_shape(0));
270 reference::ceiling<T>(
271 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
274 case OP_TYPEID::Convert:
276 // const op::Convert* c = static_cast<const op::Convert*>(&node);
277 element::Type type = node.get_element_type();
278 std::stringstream ss;
279 size_t element_count = shape_size(node.get_output_shape(0));
282 case element::Type_t::boolean:
283 reference::convert_to_bool<T>(
284 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
286 case element::Type_t::f32:
287 reference::convert<T>(
288 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
290 case element::Type_t::f64:
291 reference::convert<T>(args[0]->get_data_ptr<const T>(),
292 out[0]->get_data_ptr<double>(),
295 case element::Type_t::i8:
296 reference::convert<T>(args[0]->get_data_ptr<const T>(),
297 out[0]->get_data_ptr<int8_t>(),
300 case element::Type_t::i16:
301 reference::convert<T>(args[0]->get_data_ptr<const T>(),
302 out[0]->get_data_ptr<int16_t>(),
305 case element::Type_t::i32:
306 reference::convert<T>(args[0]->get_data_ptr<const T>(),
307 out[0]->get_data_ptr<int32_t>(),
310 case element::Type_t::i64:
311 reference::convert<T>(args[0]->get_data_ptr<const T>(),
312 out[0]->get_data_ptr<int64_t>(),
315 case element::Type_t::u8:
316 reference::convert<T>(args[0]->get_data_ptr<const T>(),
317 out[0]->get_data_ptr<uint8_t>(),
320 case element::Type_t::u16:
321 reference::convert<T>(args[0]->get_data_ptr<const T>(),
322 out[0]->get_data_ptr<uint16_t>(),
325 case element::Type_t::u32:
326 reference::convert<T>(args[0]->get_data_ptr<const T>(),
327 out[0]->get_data_ptr<uint32_t>(),
330 case element::Type_t::u64:
331 reference::convert<T>(args[0]->get_data_ptr<const T>(),
332 out[0]->get_data_ptr<uint64_t>(),
335 case element::Type_t::undefined:
336 case element::Type_t::dynamic:
337 case element::Type_t::u1:
338 case element::Type_t::bf16:
339 case element::Type_t::f16:
340 ss << "unsupported element type " << type << " op Convert";
341 throw std::runtime_error(ss.str());
345 case OP_TYPEID::Convolution:
347 const op::v0::Convolution* c = static_cast<const op::v0::Convolution*>(&node);
348 reference::convolution<T>(args[0]->get_data_ptr<const T>(),
349 args[1]->get_data_ptr<const T>(),
350 out[0]->get_data_ptr<T>(),
351 node.get_input_shape(0),
352 node.get_input_shape(1),
353 node.get_output_shape(0),
354 c->get_window_movement_strides(),
355 c->get_window_dilation_strides(),
356 c->get_padding_below(),
357 c->get_padding_above(),
358 c->get_data_dilation_strides());
362 case OP_TYPEID::ConvolutionBackpropData:
364 // Note that args[1] and args[0] are switched here from the usual order.
365 const op::v0::ConvolutionBackpropData* c =
366 static_cast<const op::v0::ConvolutionBackpropData*>(&node);
367 reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
368 args[0]->get_data_ptr<const T>(),
369 out[0]->get_data_ptr<T>(),
370 c->get_input_shape(1),
371 c->get_input_shape(0),
372 c->get_data_batch_shape(),
373 c->get_data_dilation_strides_forward(),
374 c->get_window_dilation_strides_forward(),
375 c->compute_backward_delta_out_pad_below(),
376 c->compute_backward_delta_out_pad_above(),
377 c->get_window_movement_strides_forward());
382 size_t element_count = shape_size(node.get_output_shape(0));
384 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
387 case OP_TYPEID::Cosh:
389 size_t element_count = shape_size(node.get_output_shape(0));
391 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
394 case OP_TYPEID::CTCLoss_v4:
396 const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
397 auto t_int = node.get_input_element_type(1);
398 if (t_int == element::i32)
400 reference::CTCLoss<T, int32_t>(
401 args[0]->get_data_ptr<const T>(),
402 ctc_loss->get_input_shape(0),
403 args[1]->get_data_ptr<const int32_t>(),
404 args[2]->get_data_ptr<const int32_t>(),
405 args[3]->get_data_ptr<const int32_t>(),
406 args.size() > 4 ? args[4]->get_data_ptr<const int32_t>() : nullptr,
407 ctc_loss->get_preprocess_collapse_repeated(),
408 ctc_loss->get_ctc_merge_repeated(),
409 ctc_loss->get_unique(),
410 out[0]->get_data_ptr<T>());
412 else if (t_int == element::i64)
414 reference::CTCLoss<T, int64_t>(
415 args[0]->get_data_ptr<const T>(),
416 ctc_loss->get_input_shape(0),
417 args[1]->get_data_ptr<const int64_t>(),
418 args[2]->get_data_ptr<const int64_t>(),
419 args[3]->get_data_ptr<const int64_t>(),
420 args.size() > 4 ? args[4]->get_data_ptr<const int64_t>() : nullptr,
421 ctc_loss->get_preprocess_collapse_repeated(),
422 ctc_loss->get_ctc_merge_repeated(),
423 ctc_loss->get_unique(),
424 out[0]->get_data_ptr<T>());
428 case OP_TYPEID::CumSum:
430 const op::CumSum* cumsum = static_cast<const op::CumSum*>(&node);
431 auto axis_et = node.get_input_element_type(1);
432 if (axis_et == element::i32)
434 reference::cumsum<T, int32_t>(args[0]->get_data_ptr<const T>(),
435 args[1]->get_data_ptr<const int32_t>(),
436 out[0]->get_data_ptr<T>(),
437 node.get_input_shape(0),
438 cumsum->is_exclusive(),
439 cumsum->is_reverse());
441 else if (axis_et == element::i64)
443 reference::cumsum<T, int64_t>(args[0]->get_data_ptr<const T>(),
444 args[1]->get_data_ptr<const int64_t>(),
445 out[0]->get_data_ptr<T>(),
446 node.get_input_shape(0),
447 cumsum->is_exclusive(),
448 cumsum->is_reverse());
452 case OP_TYPEID::Dequantize:
454 const op::Dequantize* dequantize = static_cast<const op::Dequantize*>(&node);
455 auto type = dequantize->get_element_type();
457 if (type == element::f32)
459 reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
460 args[1]->get_data_ptr<const float>(),
461 args[2]->get_data_ptr<const T>(),
462 out[0]->get_data_ptr<float>(),
463 node.get_input_shape(0),
464 node.get_input_shape(1),
465 dequantize->get_axes());
467 else if (type == element::f64)
469 reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
470 args[1]->get_data_ptr<const double>(),
471 args[2]->get_data_ptr<const T>(),
472 out[0]->get_data_ptr<double>(),
473 node.get_input_shape(0),
474 node.get_input_shape(1),
475 dequantize->get_axes());
479 std::stringstream ss;
480 ss << "unsupported element type " << type << " op Dequantize";
481 throw std::runtime_error(ss.str());
488 const op::Dot* dot = static_cast<const op::Dot*>(&node);
490 reference::dot(args[0]->get_data_ptr<const T>(),
491 args[1]->get_data_ptr<const T>(),
492 out[0]->get_data_ptr<T>(),
493 node.get_input_shape(0),
494 node.get_input_shape(1),
495 node.get_output_shape(0),
496 dot->get_reduction_axes_count());
499 case OP_TYPEID::EmbeddingBagOffsetsSum_v3:
501 const op::EmbeddingBagOffsetsSum* embed =
502 static_cast<const op::EmbeddingBagOffsetsSum*>(&node);
503 auto indicesType = embed->input(1).get_element_type();
504 size_t indices_num = shape_size(embed->get_input_shape(1));
506 if (indicesType == element::u64 || indicesType == element::i64)
508 reference::embeddingBagOffsetsSum<T, size_t>(
509 args[0]->get_data_ptr<const T>(),
510 args[1]->get_data_ptr<const size_t>(),
511 args[2]->get_data_ptr<const size_t>(),
512 args.size() > 3 ? args[3]->get_data_ptr<const size_t>() : nullptr,
513 args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
514 out[0]->get_data_ptr<T>(),
518 else if (indicesType == element::u32 || indicesType == element::i32)
520 reference::embeddingBagOffsetsSum<T, unsigned>(
521 args[0]->get_data_ptr<const T>(),
522 args[1]->get_data_ptr<const unsigned>(),
523 args[2]->get_data_ptr<const unsigned>(),
524 args.size() > 3 ? args[3]->get_data_ptr<const unsigned>() : nullptr,
525 args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
526 out[0]->get_data_ptr<T>(),
532 throw ngraph_error(std::string("Unsupported index type ") +
533 indicesType.c_type_string() +
534 std::string(" in EmbeddingBagOffsetsSum"));
538 case OP_TYPEID::EmbeddingBagPackedSum_v3:
540 const op::EmbeddingBagPackedSum* embed =
541 static_cast<const op::EmbeddingBagPackedSum*>(&node);
542 auto indicesType = embed->input(1).get_element_type();
544 if (indicesType == element::u64 || indicesType == element::i64)
546 reference::embeddingBagPackedSum<T, size_t>(
547 args[0]->get_data_ptr<const T>(),
548 args[1]->get_data_ptr<const size_t>(),
549 args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
550 out[0]->get_data_ptr<T>(),
551 embed->get_input_shape(1),
554 else if (indicesType == element::u32 || indicesType == element::i32)
556 reference::embeddingBagPackedSum<T, unsigned>(
557 args[0]->get_data_ptr<const T>(),
558 args[1]->get_data_ptr<const unsigned>(),
559 args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
560 out[0]->get_data_ptr<T>(),
561 embed->get_input_shape(1),
566 throw ngraph_error(std::string("Unsupported index type ") +
567 indicesType.c_type_string() +
568 std::string(" in EmbeddingBagPackedSum"));
572 case OP_TYPEID::EmbeddingSegmentsSum_v3:
574 const op::EmbeddingSegmentsSum* embed =
575 static_cast<const op::EmbeddingSegmentsSum*>(&node);
576 auto indicesType = embed->input(1).get_element_type();
577 size_t indices_num = shape_size(embed->get_input_shape(1));
579 if (indicesType == element::u64 || indicesType == element::i64)
581 reference::embeddingSegmentsSum<T, size_t>(
582 args[0]->get_data_ptr<const T>(),
583 args[1]->get_data_ptr<const size_t>(),
584 args[2]->get_data_ptr<const size_t>(),
585 args.size() > 4 ? args[4]->get_data_ptr<const size_t>() : nullptr,
586 args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
587 out[0]->get_data_ptr<T>(),
588 embed->get_input_shape(0),
589 embed->get_input_shape(1),
592 else if (indicesType == element::u32 || indicesType == element::i32)
594 reference::embeddingSegmentsSum<T, unsigned>(
595 args[0]->get_data_ptr<const T>(),
596 args[1]->get_data_ptr<const unsigned>(),
597 args[2]->get_data_ptr<const unsigned>(),
598 args.size() > 4 ? args[4]->get_data_ptr<const unsigned>() : nullptr,
599 args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
600 out[0]->get_data_ptr<T>(),
601 embed->get_input_shape(0),
602 embed->get_input_shape(1),
607 throw ngraph_error(std::string("Unsupported index type ") +
608 indicesType.c_type_string() +
609 std::string(" in EmbeddingSegmentsSum"));
615 size_t element_count = shape_size(node.get_output_shape(0));
617 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
620 case OP_TYPEID::ExtractImagePatches_v3:
622 const op::ExtractImagePatches* extImgPatches =
623 static_cast<const op::ExtractImagePatches*>(&node);
624 reference::extractImagePatches<T, size_t>(extImgPatches,
625 args[0]->get_data_ptr<const T>(),
626 out[0]->get_data_ptr<T>(),
627 extImgPatches->get_input_shape(0),
628 extImgPatches->get_shape());
633 size_t element_count = shape_size(node.get_output_shape(0));
635 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
638 #ifdef INTERPRETER_USE_HYBRID
639 case OP_TYPEID::FunctionCall:
641 auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
642 auto backend = f->get_backend();
643 auto executable = f->get_executable();
645 std::vector<std::shared_ptr<Tensor>> outputs;
646 std::vector<std::shared_ptr<Tensor>> inputs;
647 for (const std::shared_ptr<HostTensor>& t : out)
649 auto backend_tensor = backend->create_tensor(
650 t->get_element_type(), t->get_shape(), t->get_data_ptr());
651 outputs.push_back(backend_tensor);
653 for (const std::shared_ptr<HostTensor>& t : args)
655 auto backend_tensor = backend->create_tensor(
656 t->get_element_type(), t->get_shape(), t->get_data_ptr());
657 inputs.push_back(backend_tensor);
659 executable->call(outputs, inputs);
663 case OP_TYPEID::Floor:
665 size_t element_count = shape_size(node.get_output_shape(0));
667 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
670 case OP_TYPEID::GatherND:
672 if (node.get_input_element_type(1) == element::i64)
674 reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
675 args[1]->get_data_ptr<int64_t>(),
676 out[0]->get_data_ptr<T>(),
677 node.get_input_shape(0),
678 node.get_input_shape(1),
679 node.get_output_shape(0));
681 else if (node.get_input_element_type(1) == element::i32)
683 reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
684 args[1]->get_data_ptr<int32_t>(),
685 out[0]->get_data_ptr<T>(),
686 node.get_input_shape(0),
687 node.get_input_shape(1),
688 node.get_output_shape(0));
692 throw ngraph_error("Unexpected type");
698 size_t element_count = shape_size(node.get_output_shape(0));
700 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
705 const op::LRN* lrn = static_cast<const op::LRN*>(&node);
706 reference::lrn<T>(args[0]->get_data_ptr<const T>(),
707 lrn->get_reduction_axes(),
708 out[0]->get_data_ptr<T>(),
709 node.get_input_shape(0),
716 case OP_TYPEID::Negative:
718 size_t element_count = shape_size(node.get_output_shape(0));
719 reference::negate<T>(
720 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
723 case OP_TYPEID::LogicalNot_v1:
726 size_t element_count = shape_size(node.get_output_shape(0));
727 reference::logical_not(
728 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
731 case OP_TYPEID::OneHot:
733 const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
734 reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
735 out[0]->get_data_ptr<T>(),
736 node.get_input_shape(0),
737 node.get_output_shape(0),
738 oh->get_one_hot_axis());
741 case OP_TYPEID::Parameter: break;
742 case OP_TYPEID::Quantize:
744 const op::Quantize* quantize = static_cast<const op::Quantize*>(&node);
745 auto type = quantize->get_element_type();
747 if (type == element::u8)
749 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
750 args[1]->get_data_ptr<const T>(),
751 args[2]->get_data_ptr<const uint8_t>(),
752 out[0]->get_data_ptr<uint8_t>(),
753 node.get_input_shape(0),
754 node.get_input_shape(1),
755 quantize->get_axes(),
756 quantize->get_round_mode());
758 else if (type == element::i8)
760 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
761 args[1]->get_data_ptr<const T>(),
762 args[2]->get_data_ptr<const int8_t>(),
763 out[0]->get_data_ptr<int8_t>(),
764 node.get_input_shape(0),
765 node.get_input_shape(1),
766 quantize->get_axes(),
767 quantize->get_round_mode());
769 else if (type == element::i32)
771 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
772 args[1]->get_data_ptr<const T>(),
773 args[2]->get_data_ptr<const int32_t>(),
774 out[0]->get_data_ptr<int32_t>(),
775 node.get_input_shape(0),
776 node.get_input_shape(1),
777 quantize->get_axes(),
778 quantize->get_round_mode());
782 std::stringstream ss;
783 ss << "unsupported element type " << type << " op Quantize";
784 throw std::runtime_error(ss.str());
790 case OP_TYPEID::QuantizedConvolution:
792 const op::QuantizedConvolution* qc =
793 static_cast<const op::QuantizedConvolution*>(&node);
795 auto input_element_type = qc->get_input_element_type(0);
796 auto filter_element_type = qc->get_input_element_type(1);
797 auto output_element_type = qc->get_output_element_type(0);
799 if (input_element_type == element::u8 && filter_element_type == element::i8 &&
800 output_element_type == element::i8)
802 reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
803 args[0]->get_data_ptr<const uint8_t>(),
804 args[1]->get_data_ptr<const int8_t>(),
805 out[0]->get_data_ptr<int8_t>(),
806 node.get_input_shape(0),
807 node.get_input_shape(1),
808 node.get_output_shape(0),
809 qc->get_window_movement_strides(),
810 qc->get_window_dilation_strides(),
811 qc->get_padding_below(),
812 qc->get_padding_above(),
813 qc->get_data_dilation_strides(),
814 args[2]->get_data_ptr<const float>(),
815 args[3]->get_data_ptr<const uint8_t>(),
816 args[4]->get_data_ptr<const float>(),
817 args[5]->get_data_ptr<const int8_t>(),
818 args[6]->get_data_ptr<const float>(),
819 args[7]->get_data_ptr<const int8_t>());
821 else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
822 output_element_type == element::u8)
824 reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
825 args[0]->get_data_ptr<const uint8_t>(),
826 args[1]->get_data_ptr<const uint8_t>(),
827 out[0]->get_data_ptr<uint8_t>(),
828 node.get_input_shape(0),
829 node.get_input_shape(1),
830 node.get_output_shape(0),
831 qc->get_window_movement_strides(),
832 qc->get_window_dilation_strides(),
833 qc->get_padding_below(),
834 qc->get_padding_above(),
835 qc->get_data_dilation_strides(),
836 args[2]->get_data_ptr<const float>(),
837 args[3]->get_data_ptr<const uint8_t>(),
838 args[4]->get_data_ptr<const float>(),
839 args[5]->get_data_ptr<const uint8_t>(),
840 args[6]->get_data_ptr<const float>(),
841 args[7]->get_data_ptr<const uint8_t>());
843 else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
844 output_element_type == element::i32)
846 reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
847 args[0]->get_data_ptr<const uint8_t>(),
848 args[1]->get_data_ptr<const int8_t>(),
849 out[0]->get_data_ptr<int32_t>(),
850 node.get_input_shape(0),
851 node.get_input_shape(1),
852 node.get_output_shape(0),
853 qc->get_window_movement_strides(),
854 qc->get_window_dilation_strides(),
855 qc->get_padding_below(),
856 qc->get_padding_above(),
857 qc->get_data_dilation_strides(),
858 args[2]->get_data_ptr<const float>(),
859 args[3]->get_data_ptr<const uint8_t>(),
860 args[4]->get_data_ptr<const float>(),
861 args[5]->get_data_ptr<const int8_t>(),
862 args[6]->get_data_ptr<const float>(),
863 args[7]->get_data_ptr<const int32_t>());
865 else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
866 output_element_type == element::i32)
868 reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
869 args[0]->get_data_ptr<const uint8_t>(),
870 args[1]->get_data_ptr<const uint8_t>(),
871 out[0]->get_data_ptr<int32_t>(),
872 node.get_input_shape(0),
873 node.get_input_shape(1),
874 node.get_output_shape(0),
875 qc->get_window_movement_strides(),
876 qc->get_window_dilation_strides(),
877 qc->get_padding_below(),
878 qc->get_padding_above(),
879 qc->get_data_dilation_strides(),
880 args[2]->get_data_ptr<const float>(),
881 args[3]->get_data_ptr<const uint8_t>(),
882 args[4]->get_data_ptr<const float>(),
883 args[5]->get_data_ptr<const uint8_t>(),
884 args[6]->get_data_ptr<const float>(),
885 args[7]->get_data_ptr<const int32_t>());
889 std::stringstream ss;
890 ss << "unsupported element type";
891 throw std::runtime_error(ss.str());
897 case OP_TYPEID::QuantizedDot:
899 const op::QuantizedDot* qd = static_cast<const op::QuantizedDot*>(&node);
901 auto input0_element_type = qd->get_input_element_type(0);
902 auto input1_element_type = qd->get_input_element_type(1);
903 auto output_element_type = qd->get_output_element_type(0);
905 if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
906 output_element_type == element::i8)
908 reference::dot<uint8_t, int8_t, int8_t, int32_t>(
909 args[0]->get_data_ptr<const uint8_t>(),
910 args[1]->get_data_ptr<const int8_t>(),
911 out[0]->get_data_ptr<int8_t>(),
912 node.get_input_shape(0),
913 node.get_input_shape(1),
914 node.get_output_shape(0),
916 args[2]->get_data_ptr<const float>(),
917 args[3]->get_data_ptr<const uint8_t>(),
918 args[4]->get_data_ptr<const float>(),
919 args[5]->get_data_ptr<const int8_t>(),
920 args[6]->get_data_ptr<const float>(),
921 args[7]->get_data_ptr<const int8_t>());
923 else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
924 output_element_type == element::u8)
926 reference::dot<uint8_t, uint8_t, uint8_t, int32_t>(
927 args[0]->get_data_ptr<const uint8_t>(),
928 args[1]->get_data_ptr<const uint8_t>(),
929 out[0]->get_data_ptr<uint8_t>(),
930 node.get_input_shape(0),
931 node.get_input_shape(1),
932 node.get_output_shape(0),
934 args[2]->get_data_ptr<const float>(),
935 args[3]->get_data_ptr<const uint8_t>(),
936 args[4]->get_data_ptr<const float>(),
937 args[5]->get_data_ptr<const uint8_t>(),
938 args[6]->get_data_ptr<const float>(),
939 args[7]->get_data_ptr<const uint8_t>());
941 else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
942 output_element_type == element::i32)
944 reference::dot<uint8_t, uint8_t, int32_t, int32_t>(
945 args[0]->get_data_ptr<const uint8_t>(),
946 args[1]->get_data_ptr<const uint8_t>(),
947 out[0]->get_data_ptr<int32_t>(),
948 node.get_input_shape(0),
949 node.get_input_shape(1),
950 node.get_output_shape(0),
952 args[2]->get_data_ptr<const float>(),
953 args[3]->get_data_ptr<const uint8_t>(),
954 args[4]->get_data_ptr<const float>(),
955 args[5]->get_data_ptr<const uint8_t>(),
956 args[6]->get_data_ptr<const float>(),
957 args[7]->get_data_ptr<const int32_t>());
959 else if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
960 output_element_type == element::i32)
962 reference::dot<uint8_t, int8_t, int32_t, int32_t>(
963 args[0]->get_data_ptr<const uint8_t>(),
964 args[1]->get_data_ptr<const int8_t>(),
965 out[0]->get_data_ptr<int32_t>(),
966 node.get_input_shape(0),
967 node.get_input_shape(1),
968 node.get_output_shape(0),
970 args[2]->get_data_ptr<const float>(),
971 args[3]->get_data_ptr<const uint8_t>(),
972 args[4]->get_data_ptr<const float>(),
973 args[5]->get_data_ptr<const int8_t>(),
974 args[6]->get_data_ptr<const float>(),
975 args[7]->get_data_ptr<const int32_t>());
979 std::stringstream ss;
980 ss << "unsupported element type";
981 throw std::runtime_error(ss.str());
986 case OP_TYPEID::Relu:
988 size_t element_count = shape_size(node.get_output_shape(0));
990 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
993 case OP_TYPEID::ReplaceSlice:
995 const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
996 reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
997 args[1]->get_data_ptr<const T>(),
998 out[0]->get_data_ptr<T>(),
999 node.get_input_shape(1),
1000 slice->get_lower_bounds(),
1001 slice->get_upper_bounds(),
1002 slice->get_strides(),
1003 node.get_output_shape(0));
1006 case OP_TYPEID::Reverse:
1008 const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
1009 reference::reverse(args[0]->get_data_ptr<const char>(),
1010 out[0]->get_data_ptr<char>(),
1011 node.get_input_shape(0),
1012 node.get_output_shape(0),
1013 reverse->get_reversed_axes(),
1014 args[0]->get_element_type().size());
1017 case OP_TYPEID::ReverseSequence:
1019 const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
1021 if (node.get_input_element_type(1) == element::i32)
1023 reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
1024 out[0]->get_data_ptr<T>(),
1025 node.get_input_shape(0),
1026 reverse->get_batch_axis(),
1027 reverse->get_sequence_axis(),
1028 args[1]->get_data_ptr<const int32_t>());
1032 throw ngraph_error("only int32 indices are supported");
1036 case OP_TYPEID::Round:
1038 size_t element_count = shape_size(node.get_output_shape(0));
1039 reference::round<T>(
1040 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1043 case OP_TYPEID::Select:
1045 size_t element_count = shape_size(node.get_output_shape(0));
1046 reference::select<T>(args[0]->get_data_ptr<const char>(),
1047 args[1]->get_data_ptr<const T>(),
1048 args[2]->get_data_ptr<const T>(),
1049 out[0]->get_data_ptr<T>(),
1053 case OP_TYPEID::Sigmoid:
1055 size_t element_count = shape_size(node.get_output_shape(0));
1056 reference::sigmoid<T>(
1057 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1060 case OP_TYPEID::Sign:
1062 size_t element_count = shape_size(node.get_output_shape(0));
1064 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1067 case OP_TYPEID::Sin:
1069 size_t element_count = shape_size(node.get_output_shape(0));
1071 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1074 case OP_TYPEID::Sinh:
1076 size_t element_count = shape_size(node.get_output_shape(0));
1078 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1081 case OP_TYPEID::Sqrt:
1083 size_t element_count = shape_size(node.get_output_shape(0));
1085 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1088 case OP_TYPEID::Tan:
1090 size_t element_count = shape_size(node.get_output_shape(0));
1092 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1095 case OP_TYPEID::Tanh:
1097 size_t element_count = shape_size(node.get_output_shape(0));
1099 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1102 case OP_TYPEID::TopK:
1104 const op::TopK* topk = static_cast<const op::TopK*>(&node);
1105 if (node.get_output_element_type(0) == element::i64)
1107 reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
1108 out[0]->get_data_ptr<int64_t>(),
1109 out[1]->get_data_ptr<T>(),
1110 node.get_input_shape(0),
1111 node.get_output_shape(0),
1112 topk->get_top_k_axis(),
1114 topk->get_compute_max(),
1117 else if (node.get_output_element_type(0) == element::i32)
1119 reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
1120 out[0]->get_data_ptr<int32_t>(),
1121 out[1]->get_data_ptr<T>(),
1122 node.get_input_shape(0),
1123 node.get_output_shape(0),
1124 topk->get_top_k_axis(),
1126 topk->get_compute_max(),
1131 throw ngraph_error("Unexpected type");
1135 case OP_TYPEID::DetectionOutput_v0:
1137 const op::DetectionOutput* detOut = static_cast<const op::DetectionOutput*>(&node);
1138 reference::referenceDetectionOutput<T> refDetOut(
1139 detOut->get_attrs(), node.get_input_shape(0), node.get_input_shape(2));
1140 if (node.get_input_size() == 3)
1142 refDetOut.run(args[0]->get_data_ptr<const T>(),
1143 args[1]->get_data_ptr<const T>(),
1144 args[2]->get_data_ptr<const T>(),
1147 out[0]->get_data_ptr<T>());
1149 else if (node.get_input_size() == 5)
1151 refDetOut.run(args[0]->get_data_ptr<const T>(),
1152 args[1]->get_data_ptr<const T>(),
1153 args[2]->get_data_ptr<const T>(),
1154 args[3]->get_data_ptr<const T>(),
1155 args[4]->get_data_ptr<const T>(),
1156 out[0]->get_data_ptr<T>());
1160 throw ngraph_error("DetectionOutput layer supports only 3 or 5 inputs");
1165 case OP_TYPEID::ScatterNDUpdate_v3:
1167 const op::ScatterNDUpdate* scatterNDUpd =
1168 static_cast<const op::v3::ScatterNDUpdate*>(&node);
1169 auto idxType = scatterNDUpd->get_input_element_type(1);
1170 if (idxType == element::i32)
1172 reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
1173 args[1]->get_data_ptr<const int32_t>(),
1174 args[2]->get_data_ptr<const T>(),
1175 out[0]->get_data_ptr<T>(),
1176 node.get_input_shape(0),
1177 node.get_input_shape(1),
1178 node.get_input_shape(2));
1180 else if (idxType == element::i64)
1182 reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
1183 args[1]->get_data_ptr<const int64_t>(),
1184 args[2]->get_data_ptr<const T>(),
1185 out[0]->get_data_ptr<T>(),
1186 node.get_input_shape(0),
1187 node.get_input_shape(1),
1188 node.get_input_shape(2));
1193 "ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
1198 case OP_TYPEID::ScatterUpdate_v3:
1200 const op::v3::ScatterUpdate* scatterUpd =
1201 static_cast<const op::v3::ScatterUpdate*>(&node);
1203 if (scatterUpd->get_input_element_type(3) != element::i64)
1205 "ScatterNDUpdate layer support only i64 'axis' input precision!");
1207 auto idxType = scatterUpd->get_input_element_type(1);
1208 if (idxType == element::i32)
1210 reference::scatterUpdate<T, int32_t, int64_t>(
1211 args[0]->get_data_ptr<const T>(),
1212 args[1]->get_data_ptr<const int32_t>(),
1213 args[2]->get_data_ptr<const T>(),
1214 args[3]->get_data_ptr<const int64_t>(),
1215 out[0]->get_data_ptr<T>(),
1216 node.get_input_shape(0),
1217 node.get_input_shape(1),
1218 node.get_input_shape(2));
1220 else if (idxType == element::i64)
1222 reference::scatterUpdate<T, int64_t, int64_t>(
1223 args[0]->get_data_ptr<const T>(),
1224 args[1]->get_data_ptr<const int64_t>(),
1225 args[2]->get_data_ptr<const T>(),
1226 args[3]->get_data_ptr<const int64_t>(),
1227 out[0]->get_data_ptr<T>(),
1228 node.get_input_shape(0),
1229 node.get_input_shape(1),
1230 node.get_input_shape(2));
1235 "ScatterUpdate layer support only i32 and i64 'indices' input precision!");
1241 // Fused Ops are not supported in interpreter. They need to be decomposed before execution
1242 case OP_TYPEID::DepthToSpace:
1243 case OP_TYPEID::FakeQuantize:
1244 case OP_TYPEID::Gather:
1245 case OP_TYPEID::Gelu:
1246 case OP_TYPEID::GRN:
1247 case OP_TYPEID::GroupConvolution:
1248 case OP_TYPEID::GroupConvolutionBackpropData:
1249 case OP_TYPEID::GRUCell:
1250 case OP_TYPEID::HardSigmoid:
1251 case OP_TYPEID::Interpolate:
1252 case OP_TYPEID::LSTMCell:
1253 case OP_TYPEID::LSTMSequence:
1254 case OP_TYPEID::MVN:
1255 case OP_TYPEID::NormalizeL2:
1256 case OP_TYPEID::PRelu:
1257 case OP_TYPEID::RNNCell:
1258 case OP_TYPEID::Selu:
1259 case OP_TYPEID::ShuffleChannels:
1260 case OP_TYPEID::SpaceToDepth:
1261 case OP_TYPEID::Split:
1262 case OP_TYPEID::SquaredDifference:
1263 case OP_TYPEID::StopGradient:
1264 case OP_TYPEID::TensorIterator:
1265 case OP_TYPEID::Tile:
1266 case OP_TYPEID::UnknownOp:
1267 throw unsupported_op("Unsupported op '" + node.description() + "'");
1268 case OP_TYPEID::Add:
1269 case OP_TYPEID::Broadcast:
1270 case OP_TYPEID::Clamp:
1271 case OP_TYPEID::Concat:
1272 case OP_TYPEID::Constant:
1273 case OP_TYPEID::Divide:
1274 case OP_TYPEID::Equal:
1275 case OP_TYPEID::Greater:
1276 case OP_TYPEID::GreaterEq:
1277 case OP_TYPEID::Less:
1278 case OP_TYPEID::LessEq:
1279 case OP_TYPEID::LessEqual_v1:
1280 case OP_TYPEID::LogicalAnd_v1:
1281 case OP_TYPEID::LogicalOr_v1:
1282 case OP_TYPEID::LogicalXor_v1:
1283 case OP_TYPEID::MatMul:
1284 case OP_TYPEID::Max:
1285 case OP_TYPEID::Maximum:
1286 case OP_TYPEID::Min:
1287 case OP_TYPEID::Minimum:
1288 case OP_TYPEID::Multiply:
1289 case OP_TYPEID::NonZero_v3:
1290 case OP_TYPEID::NotEqual:
1292 case OP_TYPEID::Pad:
1293 case OP_TYPEID::Power:
1294 case OP_TYPEID::Product:
1295 case OP_TYPEID::Range:
1296 case OP_TYPEID::Reshape:
1297 case OP_TYPEID::Result:
1298 case OP_TYPEID::ShapeOf_v3:
1299 case OP_TYPEID::ShapeOf:
1300 case OP_TYPEID::Softmax:
1301 case OP_TYPEID::Squeeze:
1302 case OP_TYPEID::Sum:
1303 case OP_TYPEID::Subtract:
1304 case OP_TYPEID::Unsqueeze:
1305 case OP_TYPEID::Xor:
1306 case OP_TYPEID::Slice:
1307 // These ops are handled by op evaluators so nothing to do
1309 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
1310 #pragma GCC diagnostic pop
1316 NGRAPH_SUPPRESS_DEPRECATED_END