1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include <initializer_list>
26 #include <ngraph/runtime/host_tensor.hpp>
27 #include "backend.hpp"
28 #include "int_backend_visibility.hpp"
29 #include "ngraph/ops.hpp"
30 #include "ngraph/runtime/aligned_buffer.hpp"
31 #include "ngraph/runtime/reference/abs.hpp"
32 #include "ngraph/runtime/reference/acos.hpp"
33 #include "ngraph/runtime/reference/asin.hpp"
34 #include "ngraph/runtime/reference/atan.hpp"
35 #include "ngraph/runtime/reference/atan2.hpp"
36 #include "ngraph/runtime/reference/avg_pool.hpp"
37 #include "ngraph/runtime/reference/batch_norm.hpp"
38 #include "ngraph/runtime/reference/broadcast.hpp"
39 #include "ngraph/runtime/reference/ceiling.hpp"
40 #include "ngraph/runtime/reference/concat.hpp"
41 #include "ngraph/runtime/reference/constant.hpp"
42 #include "ngraph/runtime/reference/convert.hpp"
43 #include "ngraph/runtime/reference/convolution.hpp"
44 #include "ngraph/runtime/reference/cos.hpp"
45 #include "ngraph/runtime/reference/cosh.hpp"
46 #include "ngraph/runtime/reference/ctc_greedy_decoder.hpp"
47 #include "ngraph/runtime/reference/ctc_loss.hpp"
48 #include "ngraph/runtime/reference/cum_sum.hpp"
49 #include "ngraph/runtime/reference/detection_output.hpp"
50 #include "ngraph/runtime/reference/dot.hpp"
51 #include "ngraph/runtime/reference/elu.hpp"
52 #include "ngraph/runtime/reference/embedding_bag_offsets_sum.hpp"
53 #include "ngraph/runtime/reference/embedding_bag_packed_sum.hpp"
54 #include "ngraph/runtime/reference/embedding_segments_sum.hpp"
55 #include "ngraph/runtime/reference/erf.hpp"
56 #include "ngraph/runtime/reference/exp.hpp"
57 #include "ngraph/runtime/reference/extract_image_patches.hpp"
58 #include "ngraph/runtime/reference/floor.hpp"
59 #include "ngraph/runtime/reference/gather.hpp"
60 #include "ngraph/runtime/reference/gather_nd.hpp"
61 #include "ngraph/runtime/reference/gather_tree.hpp"
62 #include "ngraph/runtime/reference/gru_cell.hpp"
63 #include "ngraph/runtime/reference/log.hpp"
64 #include "ngraph/runtime/reference/log_softmax.hpp"
65 #include "ngraph/runtime/reference/lrn.hpp"
66 #include "ngraph/runtime/reference/lstm_cell.hpp"
67 #include "ngraph/runtime/reference/matmul.hpp"
68 #include "ngraph/runtime/reference/max.hpp"
69 #include "ngraph/runtime/reference/max_pool.hpp"
70 #include "ngraph/runtime/reference/min.hpp"
71 #include "ngraph/runtime/reference/negate.hpp"
72 #include "ngraph/runtime/reference/normalize_l2.hpp"
73 #include "ngraph/runtime/reference/not.hpp"
74 #include "ngraph/runtime/reference/one_hot.hpp"
75 #include "ngraph/runtime/reference/pad.hpp"
76 #include "ngraph/runtime/reference/prior_box.hpp"
77 #include "ngraph/runtime/reference/product.hpp"
78 #include "ngraph/runtime/reference/quantize.hpp"
79 #include "ngraph/runtime/reference/region_yolo.hpp"
80 #include "ngraph/runtime/reference/relu.hpp"
81 #include "ngraph/runtime/reference/reorg_yolo.hpp"
82 #include "ngraph/runtime/reference/replace_slice.hpp"
83 #include "ngraph/runtime/reference/reshape.hpp"
84 #include "ngraph/runtime/reference/result.hpp"
85 #include "ngraph/runtime/reference/reverse.hpp"
86 #include "ngraph/runtime/reference/reverse_sequence.hpp"
87 #include "ngraph/runtime/reference/rnn_cell.hpp"
88 #include "ngraph/runtime/reference/round.hpp"
89 #include "ngraph/runtime/reference/scatter_nd_update.hpp"
90 #include "ngraph/runtime/reference/select.hpp"
91 #include "ngraph/runtime/reference/sequences.hpp"
92 #include "ngraph/runtime/reference/sigmoid.hpp"
93 #include "ngraph/runtime/reference/sign.hpp"
94 #include "ngraph/runtime/reference/sin.hpp"
95 #include "ngraph/runtime/reference/sinh.hpp"
96 #include "ngraph/runtime/reference/softmax.hpp"
97 #include "ngraph/runtime/reference/sqrt.hpp"
98 #include "ngraph/runtime/reference/sum.hpp"
99 #include "ngraph/runtime/reference/tan.hpp"
100 #include "ngraph/runtime/reference/tanh.hpp"
101 #include "ngraph/runtime/reference/topk.hpp"
102 #include "ngraph/runtime/tensor.hpp"
103 #include "op/avg_pool.hpp"
104 #include "op/convolution.hpp"
105 #include "op/group_conv.hpp"
107 NGRAPH_SUPPRESS_DEPRECATED_START
113 namespace interpreter
118 // This expands the op list in op_tbl.hpp into a list of enumerations that look like
125 #define NGRAPH_OP(NAME, NAMESPACE) ID_SUFFIX(NAME),
126 #include "opset_int_tbl.hpp"
130 } // namespace interpreter
131 } // namespace runtime
132 } // namespace ngraph
134 class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : public Executable
136 friend class INTBackend;
139 INTExecutable(const std::shared_ptr<Function>& function,
140 bool enable_performance_collection = false);
142 bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
143 const std::vector<std::shared_ptr<Tensor>>& inputs) override;
145 void set_nan_check(bool enable);
147 std::vector<PerformanceCounter> get_performance_data() const override;
149 std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index) override;
151 std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index) override;
153 std::vector<std::shared_ptr<runtime::Tensor>>
154 create_input_tensor(size_t input_index, size_t pipeline_depth) override;
156 std::vector<std::shared_ptr<runtime::Tensor>>
157 create_output_tensor(size_t output_index, size_t pipeline_depth) override;
160 std::shared_ptr<ngraph::op::Parameter> get_parameter(size_t index) const;
161 std::shared_ptr<ngraph::op::Result> get_result(size_t index) const;
162 int get_alignment() const { return 64; }
163 bool m_is_compiled = false;
164 bool m_nan_check_enabled = false;
165 bool m_performance_counters_enabled = false;
166 std::shared_ptr<Function> m_function;
167 std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
168 std::vector<std::shared_ptr<Node>> m_nodes;
169 std::set<std::string> m_unsupported_op_name_list;
171 static OP_TYPEID get_typeid(const Node& node);
173 static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
174 const Node* op = nullptr);
176 virtual void generate_calls(const element::Type& type,
178 const std::vector<std::shared_ptr<HostTensor>>& outputs,
179 const std::vector<std::shared_ptr<HostTensor>>& inputs);
181 template <typename T>
182 void op_engine(const Node& node,
183 const std::vector<std::shared_ptr<HostTensor>>& out,
184 const std::vector<std::shared_ptr<HostTensor>>& args)
186 // We want to check that every OP_TYPEID enumeration is included in the list.
187 // These GCC flags enable compile-time checking so that if an enumeration
188 // is not in the list an error is generated.
189 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
190 #pragma GCC diagnostic push
191 #pragma GCC diagnostic error "-Wswitch"
192 #pragma GCC diagnostic error "-Wswitch-enum"
194 switch (get_typeid(node))
198 size_t element_count = shape_size(node.get_output_shape(0));
200 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
203 case OP_TYPEID::Acos:
205 size_t element_count = shape_size(node.get_output_shape(0));
207 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
210 case OP_TYPEID::Asin:
212 size_t element_count = shape_size(node.get_output_shape(0));
214 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
217 case OP_TYPEID::Atan:
219 size_t element_count = shape_size(node.get_output_shape(0));
221 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
226 const op::Elu* elu_node = static_cast<const op::Elu*>(&node);
228 size_t element_count = shape_size(node.get_output_shape(0));
229 reference::elu<T>(args[0]->get_data_ptr<const T>(),
230 out[0]->get_data_ptr<T>(),
232 elu_node->get_alpha());
235 case OP_TYPEID::AvgPool:
237 const op::v0::AvgPool* avg_pool = static_cast<const op::v0::AvgPool*>(&node);
239 reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
240 out[0]->get_data_ptr<T>(),
241 node.get_input_shape(0),
242 node.get_output_shape(0),
243 avg_pool->get_window_shape(),
244 avg_pool->get_window_movement_strides(),
245 avg_pool->get_padding_below(),
246 avg_pool->get_padding_above(),
247 avg_pool->get_include_padding_in_avg_computation());
250 case OP_TYPEID::BatchNormInference:
252 const ngraph::op::v0::BatchNormInference* bn =
253 static_cast<const ngraph::op::v0::BatchNormInference*>(&node);
254 reference::batch_norm_inference<T>(bn->get_eps_value(),
255 args[0]->get_data_ptr<const T>(),
256 args[1]->get_data_ptr<const T>(),
257 args[2]->get_data_ptr<const T>(),
258 args[3]->get_data_ptr<const T>(),
259 args[4]->get_data_ptr<const T>(),
260 out[0]->get_data_ptr<T>(),
261 node.get_input_shape(2));
264 case OP_TYPEID::BatchNormInference_v5:
266 const ngraph::op::v5::BatchNormInference* bn =
267 static_cast<const ngraph::op::v5::BatchNormInference*>(&node);
268 reference::batch_norm_inference<T>(bn->get_eps_value(),
269 args[1]->get_data_ptr<const T>(),
270 args[2]->get_data_ptr<const T>(),
271 args[0]->get_data_ptr<const T>(),
272 args[3]->get_data_ptr<const T>(),
273 args[4]->get_data_ptr<const T>(),
274 out[0]->get_data_ptr<T>(),
275 node.get_input_shape(0));
278 case OP_TYPEID::Ceiling:
280 size_t element_count = shape_size(node.get_output_shape(0));
281 reference::ceiling<T>(
282 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
285 case OP_TYPEID::Convert:
287 // const op::Convert* c = static_cast<const op::Convert*>(&node);
288 element::Type type = node.get_element_type();
289 std::stringstream ss;
290 size_t element_count = shape_size(node.get_output_shape(0));
293 case element::Type_t::boolean:
294 reference::convert_to_bool<T>(
295 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
297 case element::Type_t::f32:
298 reference::convert<T>(
299 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
301 case element::Type_t::f64:
302 reference::convert<T>(args[0]->get_data_ptr<const T>(),
303 out[0]->get_data_ptr<double>(),
306 case element::Type_t::i8:
307 reference::convert<T>(args[0]->get_data_ptr<const T>(),
308 out[0]->get_data_ptr<int8_t>(),
311 case element::Type_t::i16:
312 reference::convert<T>(args[0]->get_data_ptr<const T>(),
313 out[0]->get_data_ptr<int16_t>(),
316 case element::Type_t::i32:
317 reference::convert<T>(args[0]->get_data_ptr<const T>(),
318 out[0]->get_data_ptr<int32_t>(),
321 case element::Type_t::i64:
322 reference::convert<T>(args[0]->get_data_ptr<const T>(),
323 out[0]->get_data_ptr<int64_t>(),
326 case element::Type_t::u8:
327 reference::convert<T>(args[0]->get_data_ptr<const T>(),
328 out[0]->get_data_ptr<uint8_t>(),
331 case element::Type_t::u16:
332 reference::convert<T>(args[0]->get_data_ptr<const T>(),
333 out[0]->get_data_ptr<uint16_t>(),
336 case element::Type_t::u32:
337 reference::convert<T>(args[0]->get_data_ptr<const T>(),
338 out[0]->get_data_ptr<uint32_t>(),
341 case element::Type_t::u64:
342 reference::convert<T>(args[0]->get_data_ptr<const T>(),
343 out[0]->get_data_ptr<uint64_t>(),
346 case element::Type_t::undefined:
347 case element::Type_t::dynamic:
348 case element::Type_t::u1:
349 case element::Type_t::bf16:
350 case element::Type_t::f16:
351 ss << "unsupported element type " << type << " op Convert";
352 throw std::runtime_error(ss.str());
356 case OP_TYPEID::Convolution:
358 const op::v0::Convolution* c = static_cast<const op::v0::Convolution*>(&node);
359 reference::convolution<T>(args[0]->get_data_ptr<const T>(),
360 args[1]->get_data_ptr<const T>(),
361 out[0]->get_data_ptr<T>(),
362 node.get_input_shape(0),
363 node.get_input_shape(1),
364 node.get_output_shape(0),
365 c->get_window_movement_strides(),
366 c->get_window_dilation_strides(),
367 c->get_padding_below(),
368 c->get_padding_above(),
369 c->get_data_dilation_strides());
373 case OP_TYPEID::ConvolutionBackpropData:
375 // Note that args[1] and args[0] are switched here from the usual order.
376 const op::v0::ConvolutionBackpropData* c =
377 static_cast<const op::v0::ConvolutionBackpropData*>(&node);
378 reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
379 args[0]->get_data_ptr<const T>(),
380 out[0]->get_data_ptr<T>(),
381 c->get_input_shape(1),
382 c->get_input_shape(0),
383 c->get_data_batch_shape(),
384 c->get_data_dilation_strides_forward(),
385 c->get_window_dilation_strides_forward(),
386 c->compute_backward_delta_out_pad_below(),
387 c->compute_backward_delta_out_pad_above(),
388 c->get_window_movement_strides_forward());
393 size_t element_count = shape_size(node.get_output_shape(0));
395 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
398 case OP_TYPEID::Cosh:
400 size_t element_count = shape_size(node.get_output_shape(0));
402 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
405 case OP_TYPEID::CTCGreedyDecoder_v0:
407 const auto ctc_greedy_dec = static_cast<const op::v0::CTCGreedyDecoder*>(&node);
408 reference::ctc_greedy_decoder<T>(args[0]->get_data_ptr<const T>(),
409 args[1]->get_data_ptr<const T>(),
410 out[0]->get_data_ptr<T>(),
411 args[0]->get_shape(),
412 args[1]->get_shape(),
414 ctc_greedy_dec->get_ctc_merge_repeated());
417 case OP_TYPEID::CTCLoss_v4:
419 const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
420 auto t_int = node.get_input_element_type(1);
421 if (t_int == element::i32)
423 reference::CTCLoss<T, int32_t>(
424 args[0]->get_data_ptr<const T>(),
425 ctc_loss->get_input_shape(0),
426 args[1]->get_data_ptr<const int32_t>(),
427 args[2]->get_data_ptr<const int32_t>(),
428 args[3]->get_data_ptr<const int32_t>(),
429 args.size() > 4 ? args[4]->get_data_ptr<const int32_t>() : nullptr,
430 ctc_loss->get_preprocess_collapse_repeated(),
431 ctc_loss->get_ctc_merge_repeated(),
432 ctc_loss->get_unique(),
433 out[0]->get_data_ptr<T>());
435 else if (t_int == element::i64)
437 reference::CTCLoss<T, int64_t>(
438 args[0]->get_data_ptr<const T>(),
439 ctc_loss->get_input_shape(0),
440 args[1]->get_data_ptr<const int64_t>(),
441 args[2]->get_data_ptr<const int64_t>(),
442 args[3]->get_data_ptr<const int64_t>(),
443 args.size() > 4 ? args[4]->get_data_ptr<const int64_t>() : nullptr,
444 ctc_loss->get_preprocess_collapse_repeated(),
445 ctc_loss->get_ctc_merge_repeated(),
446 ctc_loss->get_unique(),
447 out[0]->get_data_ptr<T>());
451 case OP_TYPEID::CumSum:
453 const op::CumSum* cumsum = static_cast<const op::CumSum*>(&node);
454 auto axis_et = node.get_input_element_type(1);
455 if (axis_et == element::i32)
457 reference::cumsum<T, int32_t>(args[0]->get_data_ptr<const T>(),
458 args[1]->get_data_ptr<const int32_t>(),
459 out[0]->get_data_ptr<T>(),
460 node.get_input_shape(0),
461 cumsum->is_exclusive(),
462 cumsum->is_reverse());
464 else if (axis_et == element::i64)
466 reference::cumsum<T, int64_t>(args[0]->get_data_ptr<const T>(),
467 args[1]->get_data_ptr<const int64_t>(),
468 out[0]->get_data_ptr<T>(),
469 node.get_input_shape(0),
470 cumsum->is_exclusive(),
471 cumsum->is_reverse());
477 const op::Dot* dot = static_cast<const op::Dot*>(&node);
479 reference::dot(args[0]->get_data_ptr<const T>(),
480 args[1]->get_data_ptr<const T>(),
481 out[0]->get_data_ptr<T>(),
482 node.get_input_shape(0),
483 node.get_input_shape(1),
484 node.get_output_shape(0),
485 dot->get_reduction_axes_count());
488 case OP_TYPEID::EmbeddingBagOffsetsSum_v3:
490 const op::EmbeddingBagOffsetsSum* embed =
491 static_cast<const op::EmbeddingBagOffsetsSum*>(&node);
492 auto indicesType = embed->input(1).get_element_type();
493 size_t indices_num = shape_size(embed->get_input_shape(1));
495 if (indicesType == element::u64 || indicesType == element::i64)
497 reference::embeddingBagOffsetsSum<T, size_t>(
498 args[0]->get_data_ptr<const T>(),
499 args[1]->get_data_ptr<const size_t>(),
500 args[2]->get_data_ptr<const size_t>(),
501 args.size() > 3 ? args[3]->get_data_ptr<const size_t>() : nullptr,
502 args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
503 out[0]->get_data_ptr<T>(),
507 else if (indicesType == element::u32 || indicesType == element::i32)
509 reference::embeddingBagOffsetsSum<T, unsigned>(
510 args[0]->get_data_ptr<const T>(),
511 args[1]->get_data_ptr<const unsigned>(),
512 args[2]->get_data_ptr<const unsigned>(),
513 args.size() > 3 ? args[3]->get_data_ptr<const unsigned>() : nullptr,
514 args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
515 out[0]->get_data_ptr<T>(),
521 throw ngraph_error(std::string("Unsupported index type ") +
522 indicesType.c_type_string() +
523 std::string(" in EmbeddingBagOffsetsSum"));
527 case OP_TYPEID::EmbeddingBagPackedSum_v3:
529 const op::EmbeddingBagPackedSum* embed =
530 static_cast<const op::EmbeddingBagPackedSum*>(&node);
531 auto indicesType = embed->input(1).get_element_type();
533 if (indicesType == element::u64 || indicesType == element::i64)
535 reference::embeddingBagPackedSum<T, size_t>(
536 args[0]->get_data_ptr<const T>(),
537 args[1]->get_data_ptr<const size_t>(),
538 args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
539 out[0]->get_data_ptr<T>(),
540 embed->get_input_shape(1),
543 else if (indicesType == element::u32 || indicesType == element::i32)
545 reference::embeddingBagPackedSum<T, unsigned>(
546 args[0]->get_data_ptr<const T>(),
547 args[1]->get_data_ptr<const unsigned>(),
548 args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
549 out[0]->get_data_ptr<T>(),
550 embed->get_input_shape(1),
555 throw ngraph_error(std::string("Unsupported index type ") +
556 indicesType.c_type_string() +
557 std::string(" in EmbeddingBagPackedSum"));
561 case OP_TYPEID::EmbeddingSegmentsSum_v3:
563 const op::EmbeddingSegmentsSum* embed =
564 static_cast<const op::EmbeddingSegmentsSum*>(&node);
565 auto indicesType = embed->input(1).get_element_type();
566 size_t indices_num = shape_size(embed->get_input_shape(1));
568 if (indicesType == element::u64 || indicesType == element::i64)
570 reference::embeddingSegmentsSum<T, size_t>(
571 args[0]->get_data_ptr<const T>(),
572 args[1]->get_data_ptr<const size_t>(),
573 args[2]->get_data_ptr<const size_t>(),
574 args.size() > 4 ? args[4]->get_data_ptr<const size_t>() : nullptr,
575 args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
576 out[0]->get_data_ptr<T>(),
577 embed->get_input_shape(0),
578 embed->get_input_shape(1),
581 else if (indicesType == element::u32 || indicesType == element::i32)
583 reference::embeddingSegmentsSum<T, unsigned>(
584 args[0]->get_data_ptr<const T>(),
585 args[1]->get_data_ptr<const unsigned>(),
586 args[2]->get_data_ptr<const unsigned>(),
587 args.size() > 4 ? args[4]->get_data_ptr<const unsigned>() : nullptr,
588 args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
589 out[0]->get_data_ptr<T>(),
590 embed->get_input_shape(0),
591 embed->get_input_shape(1),
596 throw ngraph_error(std::string("Unsupported index type ") +
597 indicesType.c_type_string() +
598 std::string(" in EmbeddingSegmentsSum"));
604 size_t element_count = shape_size(node.get_output_shape(0));
606 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
609 case OP_TYPEID::ExtractImagePatches_v3:
611 const op::ExtractImagePatches* extImgPatches =
612 static_cast<const op::ExtractImagePatches*>(&node);
613 reference::extractImagePatches<T, size_t>(extImgPatches,
614 args[0]->get_data_ptr<const T>(),
615 out[0]->get_data_ptr<T>(),
616 extImgPatches->get_input_shape(0),
617 extImgPatches->get_shape());
622 size_t element_count = shape_size(node.get_output_shape(0));
624 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
627 #ifdef INTERPRETER_USE_HYBRID
628 case OP_TYPEID::FunctionCall:
630 auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
631 auto backend = f->get_backend();
632 auto executable = f->get_executable();
634 std::vector<std::shared_ptr<Tensor>> outputs;
635 std::vector<std::shared_ptr<Tensor>> inputs;
636 for (const std::shared_ptr<HostTensor>& t : out)
638 auto backend_tensor = backend->create_tensor(
639 t->get_element_type(), t->get_shape(), t->get_data_ptr());
640 outputs.push_back(backend_tensor);
642 for (const std::shared_ptr<HostTensor>& t : args)
644 auto backend_tensor = backend->create_tensor(
645 t->get_element_type(), t->get_shape(), t->get_data_ptr());
646 inputs.push_back(backend_tensor);
648 executable->call(outputs, inputs);
652 case OP_TYPEID::Floor:
654 size_t element_count = shape_size(node.get_output_shape(0));
656 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
659 case OP_TYPEID::GatherND:
661 if (node.get_input_element_type(1) == element::i64)
663 reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
664 args[1]->get_data_ptr<int64_t>(),
665 out[0]->get_data_ptr<T>(),
666 node.get_input_shape(0),
667 node.get_input_shape(1),
668 node.get_output_shape(0));
670 else if (node.get_input_element_type(1) == element::i32)
672 reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
673 args[1]->get_data_ptr<int32_t>(),
674 out[0]->get_data_ptr<T>(),
675 node.get_input_shape(0),
676 node.get_input_shape(1),
677 node.get_output_shape(0));
681 throw ngraph_error("Unexpected type");
685 case OP_TYPEID::GatherND_v5:
687 const op::v5::GatherND* gatherNDNode = static_cast<const op::v5::GatherND*>(&node);
688 if (node.get_input_element_type(1) == element::i64)
690 reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
691 args[1]->get_data_ptr<int64_t>(),
692 out[0]->get_data_ptr<T>(),
693 node.get_input_shape(0),
694 node.get_input_shape(1),
695 node.get_output_shape(0),
696 gatherNDNode->get_batch_dims());
698 else if (node.get_input_element_type(1) == element::i32)
700 reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
701 args[1]->get_data_ptr<int32_t>(),
702 out[0]->get_data_ptr<T>(),
703 node.get_input_shape(0),
704 node.get_input_shape(1),
705 node.get_output_shape(0),
706 gatherNDNode->get_batch_dims());
710 throw ngraph_error("Unexpected type");
714 case OP_TYPEID::GRUCell_v3:
716 const op::v3::GRUCell* gru_cell = static_cast<const op::v3::GRUCell*>(&node);
717 runtime::reference::gru_cell(args[0]->get_data_ptr<T>(),
718 args[0]->get_shape(),
719 args[1]->get_data_ptr<T>(),
720 args[1]->get_shape(),
721 args[2]->get_data_ptr<T>(),
722 args[2]->get_shape(),
723 args[3]->get_data_ptr<T>(),
724 args[3]->get_shape(),
725 args[4]->get_data_ptr<T>(),
726 args[4]->get_shape(),
727 out[0]->get_data_ptr<T>(),
728 gru_cell->get_activations()[0],
729 gru_cell->get_activations()[1],
730 gru_cell->get_clip(),
731 gru_cell->get_linear_before_reset());
734 case OP_TYPEID::LSTMCell_v4:
736 const op::v4::LSTMCell* lstm_cell = static_cast<const op::v4::LSTMCell*>(&node);
737 runtime::reference::lstm_cell(args[0]->get_data_ptr<T>(),
738 args[0]->get_shape(),
739 args[1]->get_data_ptr<T>(),
740 args[1]->get_shape(),
741 args[2]->get_data_ptr<T>(),
742 args[2]->get_shape(),
743 args[3]->get_data_ptr<T>(),
744 args[3]->get_shape(),
745 args[4]->get_data_ptr<T>(),
746 args[4]->get_shape(),
747 args[5]->get_data_ptr<T>(),
748 args[5]->get_shape(),
749 out[0]->get_data_ptr<T>(),
750 out[1]->get_data_ptr<T>(),
751 lstm_cell->get_activations()[0],
752 lstm_cell->get_activations()[1],
753 lstm_cell->get_activations()[2],
754 lstm_cell->get_clip());
757 case OP_TYPEID::RNNCell_v0:
759 const op::v0::RNNCell* rnn_cell = static_cast<const op::v0::RNNCell*>(&node);
760 runtime::reference::rnn_cell(args[0]->get_data_ptr<T>(),
761 args[0]->get_shape(),
762 args[1]->get_data_ptr<T>(),
763 args[1]->get_shape(),
764 args[2]->get_data_ptr<T>(),
765 args[2]->get_shape(),
766 args[3]->get_data_ptr<T>(),
767 args[3]->get_shape(),
768 args[4]->get_data_ptr<T>(),
769 args[4]->get_shape(),
770 out[0]->get_data_ptr<T>(),
771 rnn_cell->get_activations()[0],
772 rnn_cell->get_clip());
775 case OP_TYPEID::LSTMSequence:
776 case OP_TYPEID::LSTMSequence_v5:
778 auto lstm_seq = static_cast<const op::v5::LSTMSequence*>(&node);
779 runtime::reference::lstm_sequence<T>(args[0]->get_data_ptr<char>(),
780 args[0]->get_shape(),
781 args[1]->get_data_ptr<char>(),
782 args[1]->get_shape(),
783 args[2]->get_data_ptr<char>(),
784 args[2]->get_shape(),
785 args[3]->get_data_ptr<char>(),
786 args[3]->get_shape(),
787 args[4]->get_data_ptr<char>(),
788 args[4]->get_shape(),
789 args[5]->get_data_ptr<char>(),
790 args[5]->get_shape(),
791 args[6]->get_data_ptr<char>(),
792 args[6]->get_shape(),
793 out[0]->get_data_ptr<char>(),
794 out[1]->get_data_ptr<char>(),
795 out[2]->get_data_ptr<char>(),
796 lstm_seq->get_activations()[0],
797 lstm_seq->get_activations()[1],
798 lstm_seq->get_activations()[2],
799 lstm_seq->get_clip(),
800 lstm_seq->get_direction());
803 case OP_TYPEID::GRUSequence_v5:
805 auto gru_seq = static_cast<const op::v5::GRUSequence*>(&node);
806 runtime::reference::gru_sequence<T>(args[0]->get_data_ptr<char>(),
807 args[0]->get_shape(),
808 args[1]->get_data_ptr<char>(),
809 args[1]->get_shape(),
810 args[2]->get_data_ptr<char>(),
811 args[2]->get_shape(),
812 args[3]->get_data_ptr<char>(),
813 args[3]->get_shape(),
814 args[4]->get_data_ptr<char>(),
815 args[4]->get_shape(),
816 args[5]->get_data_ptr<char>(),
817 args[5]->get_shape(),
818 out[0]->get_data_ptr<char>(),
819 out[1]->get_data_ptr<char>(),
820 gru_seq->get_activations()[0],
821 gru_seq->get_activations()[1],
823 gru_seq->get_direction(),
824 gru_seq->get_linear_before_reset());
827 case OP_TYPEID::RNNSequence_v5:
829 auto rnn_seq = static_cast<const op::v5::RNNSequence*>(&node);
830 runtime::reference::rnn_sequence<T>(args[0]->get_data_ptr<char>(),
831 args[0]->get_shape(),
832 args[1]->get_data_ptr<char>(),
833 args[1]->get_shape(),
834 args[2]->get_data_ptr<char>(),
835 args[2]->get_shape(),
836 args[3]->get_data_ptr<char>(),
837 args[3]->get_shape(),
838 args[4]->get_data_ptr<char>(),
839 args[4]->get_shape(),
840 args[5]->get_data_ptr<char>(),
841 args[5]->get_shape(),
842 out[0]->get_data_ptr<char>(),
843 out[1]->get_data_ptr<char>(),
844 rnn_seq->get_activations()[0],
846 rnn_seq->get_direction());
851 size_t element_count = shape_size(node.get_output_shape(0));
853 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
856 case OP_TYPEID::LogSoftmax_v5:
858 const op::v5::LogSoftmax* log_softmax = static_cast<const op::v5::LogSoftmax*>(&node);
859 int64_t i_axis = log_softmax->get_axis();
862 i_axis += args[0]->get_partial_shape().rank().get_length();
864 reference::log_softmax<T>(args[0]->get_data_ptr<const T>(),
865 out[0]->get_data_ptr<T>(),
866 node.get_output_shape(0),
867 AxisSet{(size_t)i_axis});
872 const op::LRN* lrn = static_cast<const op::LRN*>(&node);
873 reference::lrn<T>(args[0]->get_data_ptr<const T>(),
874 lrn->get_reduction_axes(),
875 out[0]->get_data_ptr<T>(),
876 node.get_input_shape(0),
883 case OP_TYPEID::Negative:
885 size_t element_count = shape_size(node.get_output_shape(0));
886 reference::negate<T>(
887 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
890 case OP_TYPEID::LogicalNot_v1:
893 size_t element_count = shape_size(node.get_output_shape(0));
894 reference::logical_not(
895 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
898 case OP_TYPEID::OneHot:
900 const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
901 reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
902 out[0]->get_data_ptr<T>(),
903 node.get_input_shape(0),
904 node.get_output_shape(0),
905 oh->get_one_hot_axis());
908 case OP_TYPEID::Parameter: break;
909 case OP_TYPEID::PriorBox:
911 const op::PriorBox* pbox = static_cast<const op::PriorBox*>(&node);
912 runtime::reference::prior_box<T>(args[0]->get_data_ptr<T>(),
913 args[1]->get_data_ptr<T>(),
914 out[0]->get_data_ptr<float>(),
919 case OP_TYPEID::ReorgYolo_v0:
921 const op::v0::ReorgYolo* reorg_yolo = static_cast<const op::v0::ReorgYolo*>(&node);
922 runtime::reference::reorg_yolo(args[0]->get_data_ptr<char>(),
923 out[0]->get_data_ptr<char>(),
924 args[0]->get_shape(),
925 reorg_yolo->get_strides().at(0),
926 args[0]->get_element_type().size());
929 case OP_TYPEID::Quantize:
931 const op::Quantize* quantize = static_cast<const op::Quantize*>(&node);
932 auto type = quantize->get_element_type();
934 if (type == element::u8)
936 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
937 args[1]->get_data_ptr<const T>(),
938 args[2]->get_data_ptr<const uint8_t>(),
939 out[0]->get_data_ptr<uint8_t>(),
940 node.get_input_shape(0),
941 node.get_input_shape(1),
942 quantize->get_axes(),
943 quantize->get_round_mode());
945 else if (type == element::i8)
947 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
948 args[1]->get_data_ptr<const T>(),
949 args[2]->get_data_ptr<const int8_t>(),
950 out[0]->get_data_ptr<int8_t>(),
951 node.get_input_shape(0),
952 node.get_input_shape(1),
953 quantize->get_axes(),
954 quantize->get_round_mode());
956 else if (type == element::i32)
958 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
959 args[1]->get_data_ptr<const T>(),
960 args[2]->get_data_ptr<const int32_t>(),
961 out[0]->get_data_ptr<int32_t>(),
962 node.get_input_shape(0),
963 node.get_input_shape(1),
964 quantize->get_axes(),
965 quantize->get_round_mode());
969 std::stringstream ss;
970 ss << "unsupported element type " << type << " op Quantize";
971 throw std::runtime_error(ss.str());
977 case OP_TYPEID::QuantizedConvolution:
979 const op::QuantizedConvolution* qc =
980 static_cast<const op::QuantizedConvolution*>(&node);
982 auto input_element_type = qc->get_input_element_type(0);
983 auto filter_element_type = qc->get_input_element_type(1);
984 auto output_element_type = qc->get_output_element_type(0);
986 if (input_element_type == element::u8 && filter_element_type == element::i8 &&
987 output_element_type == element::i8)
989 reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
990 args[0]->get_data_ptr<const uint8_t>(),
991 args[1]->get_data_ptr<const int8_t>(),
992 out[0]->get_data_ptr<int8_t>(),
993 node.get_input_shape(0),
994 node.get_input_shape(1),
995 node.get_output_shape(0),
996 qc->get_window_movement_strides(),
997 qc->get_window_dilation_strides(),
998 qc->get_padding_below(),
999 qc->get_padding_above(),
1000 qc->get_data_dilation_strides(),
1001 args[2]->get_data_ptr<const float>(),
1002 args[3]->get_data_ptr<const uint8_t>(),
1003 args[4]->get_data_ptr<const float>(),
1004 args[5]->get_data_ptr<const int8_t>(),
1005 args[6]->get_data_ptr<const float>(),
1006 args[7]->get_data_ptr<const int8_t>());
1008 else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
1009 output_element_type == element::u8)
1011 reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
1012 args[0]->get_data_ptr<const uint8_t>(),
1013 args[1]->get_data_ptr<const uint8_t>(),
1014 out[0]->get_data_ptr<uint8_t>(),
1015 node.get_input_shape(0),
1016 node.get_input_shape(1),
1017 node.get_output_shape(0),
1018 qc->get_window_movement_strides(),
1019 qc->get_window_dilation_strides(),
1020 qc->get_padding_below(),
1021 qc->get_padding_above(),
1022 qc->get_data_dilation_strides(),
1023 args[2]->get_data_ptr<const float>(),
1024 args[3]->get_data_ptr<const uint8_t>(),
1025 args[4]->get_data_ptr<const float>(),
1026 args[5]->get_data_ptr<const uint8_t>(),
1027 args[6]->get_data_ptr<const float>(),
1028 args[7]->get_data_ptr<const uint8_t>());
1030 else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
1031 output_element_type == element::i32)
1033 reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
1034 args[0]->get_data_ptr<const uint8_t>(),
1035 args[1]->get_data_ptr<const int8_t>(),
1036 out[0]->get_data_ptr<int32_t>(),
1037 node.get_input_shape(0),
1038 node.get_input_shape(1),
1039 node.get_output_shape(0),
1040 qc->get_window_movement_strides(),
1041 qc->get_window_dilation_strides(),
1042 qc->get_padding_below(),
1043 qc->get_padding_above(),
1044 qc->get_data_dilation_strides(),
1045 args[2]->get_data_ptr<const float>(),
1046 args[3]->get_data_ptr<const uint8_t>(),
1047 args[4]->get_data_ptr<const float>(),
1048 args[5]->get_data_ptr<const int8_t>(),
1049 args[6]->get_data_ptr<const float>(),
1050 args[7]->get_data_ptr<const int32_t>());
1052 else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
1053 output_element_type == element::i32)
1055 reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
1056 args[0]->get_data_ptr<const uint8_t>(),
1057 args[1]->get_data_ptr<const uint8_t>(),
1058 out[0]->get_data_ptr<int32_t>(),
1059 node.get_input_shape(0),
1060 node.get_input_shape(1),
1061 node.get_output_shape(0),
1062 qc->get_window_movement_strides(),
1063 qc->get_window_dilation_strides(),
1064 qc->get_padding_below(),
1065 qc->get_padding_above(),
1066 qc->get_data_dilation_strides(),
1067 args[2]->get_data_ptr<const float>(),
1068 args[3]->get_data_ptr<const uint8_t>(),
1069 args[4]->get_data_ptr<const float>(),
1070 args[5]->get_data_ptr<const uint8_t>(),
1071 args[6]->get_data_ptr<const float>(),
1072 args[7]->get_data_ptr<const int32_t>());
1076 std::stringstream ss;
1077 ss << "unsupported element type";
1078 throw std::runtime_error(ss.str());
1084 case OP_TYPEID::QuantizedDot:
1086 const op::QuantizedDot* qd = static_cast<const op::QuantizedDot*>(&node);
1088 auto input0_element_type = qd->get_input_element_type(0);
1089 auto input1_element_type = qd->get_input_element_type(1);
1090 auto output_element_type = qd->get_output_element_type(0);
1092 if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
1093 output_element_type == element::i8)
1095 reference::dot<uint8_t, int8_t, int8_t, int32_t>(
1096 args[0]->get_data_ptr<const uint8_t>(),
1097 args[1]->get_data_ptr<const int8_t>(),
1098 out[0]->get_data_ptr<int8_t>(),
1099 node.get_input_shape(0),
1100 node.get_input_shape(1),
1101 node.get_output_shape(0),
1103 args[2]->get_data_ptr<const float>(),
1104 args[3]->get_data_ptr<const uint8_t>(),
1105 args[4]->get_data_ptr<const float>(),
1106 args[5]->get_data_ptr<const int8_t>(),
1107 args[6]->get_data_ptr<const float>(),
1108 args[7]->get_data_ptr<const int8_t>());
1110 else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
1111 output_element_type == element::u8)
1113 reference::dot<uint8_t, uint8_t, uint8_t, int32_t>(
1114 args[0]->get_data_ptr<const uint8_t>(),
1115 args[1]->get_data_ptr<const uint8_t>(),
1116 out[0]->get_data_ptr<uint8_t>(),
1117 node.get_input_shape(0),
1118 node.get_input_shape(1),
1119 node.get_output_shape(0),
1121 args[2]->get_data_ptr<const float>(),
1122 args[3]->get_data_ptr<const uint8_t>(),
1123 args[4]->get_data_ptr<const float>(),
1124 args[5]->get_data_ptr<const uint8_t>(),
1125 args[6]->get_data_ptr<const float>(),
1126 args[7]->get_data_ptr<const uint8_t>());
1128 else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
1129 output_element_type == element::i32)
1131 reference::dot<uint8_t, uint8_t, int32_t, int32_t>(
1132 args[0]->get_data_ptr<const uint8_t>(),
1133 args[1]->get_data_ptr<const uint8_t>(),
1134 out[0]->get_data_ptr<int32_t>(),
1135 node.get_input_shape(0),
1136 node.get_input_shape(1),
1137 node.get_output_shape(0),
1139 args[2]->get_data_ptr<const float>(),
1140 args[3]->get_data_ptr<const uint8_t>(),
1141 args[4]->get_data_ptr<const float>(),
1142 args[5]->get_data_ptr<const uint8_t>(),
1143 args[6]->get_data_ptr<const float>(),
1144 args[7]->get_data_ptr<const int32_t>());
1146 else if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
1147 output_element_type == element::i32)
1149 reference::dot<uint8_t, int8_t, int32_t, int32_t>(
1150 args[0]->get_data_ptr<const uint8_t>(),
1151 args[1]->get_data_ptr<const int8_t>(),
1152 out[0]->get_data_ptr<int32_t>(),
1153 node.get_input_shape(0),
1154 node.get_input_shape(1),
1155 node.get_output_shape(0),
1157 args[2]->get_data_ptr<const float>(),
1158 args[3]->get_data_ptr<const uint8_t>(),
1159 args[4]->get_data_ptr<const float>(),
1160 args[5]->get_data_ptr<const int8_t>(),
1161 args[6]->get_data_ptr<const float>(),
1162 args[7]->get_data_ptr<const int32_t>());
1166 std::stringstream ss;
1167 ss << "unsupported element type";
1168 throw std::runtime_error(ss.str());
1173 case OP_TYPEID::RegionYolo_v0:
1175 const op::RegionYolo* region_yolo = static_cast<const op::RegionYolo*>(&node);
1176 reference::region_yolo<T>(args[0]->get_data_ptr<const T>(),
1177 out[0]->get_data_ptr<T>(),
1178 args[0]->get_shape(),
1179 region_yolo->get_num_coords(),
1180 region_yolo->get_num_classes(),
1181 region_yolo->get_num_regions(),
1182 region_yolo->get_do_softmax(),
1183 region_yolo->get_mask());
1186 case OP_TYPEID::Relu:
1188 size_t element_count = shape_size(node.get_output_shape(0));
1190 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1193 case OP_TYPEID::ReplaceSlice:
1195 const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
1196 reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
1197 args[1]->get_data_ptr<const T>(),
1198 out[0]->get_data_ptr<T>(),
1199 node.get_input_shape(1),
1200 slice->get_lower_bounds(),
1201 slice->get_upper_bounds(),
1202 slice->get_strides(),
1203 node.get_output_shape(0));
1206 case OP_TYPEID::Reverse:
1208 const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
1209 reference::reverse(args[0]->get_data_ptr<const char>(),
1210 out[0]->get_data_ptr<char>(),
1211 node.get_input_shape(0),
1212 node.get_output_shape(0),
1213 reverse->get_reversed_axes(),
1214 args[0]->get_element_type().size());
1217 case OP_TYPEID::ReverseSequence:
1219 const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
1221 if (node.get_input_element_type(1) == element::i32)
1223 reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
1224 out[0]->get_data_ptr<T>(),
1225 node.get_input_shape(0),
1226 reverse->get_batch_axis(),
1227 reverse->get_sequence_axis(),
1228 args[1]->get_data_ptr<const int32_t>());
1232 throw ngraph_error("only int32 indices are supported");
1236 case OP_TYPEID::Round:
1238 size_t element_count = shape_size(node.get_output_shape(0));
1239 reference::round<T>(args[0]->get_data_ptr<const T>(),
1240 out[0]->get_data_ptr<T>(),
1242 op::v5::Round::RoundMode::HALF_TO_EVEN);
1245 case OP_TYPEID::Select:
1247 size_t element_count = shape_size(node.get_output_shape(0));
1248 reference::select<T>(args[0]->get_data_ptr<const char>(),
1249 args[1]->get_data_ptr<const T>(),
1250 args[2]->get_data_ptr<const T>(),
1251 out[0]->get_data_ptr<T>(),
1255 case OP_TYPEID::Sigmoid:
1257 size_t element_count = shape_size(node.get_output_shape(0));
1258 reference::sigmoid<T>(
1259 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1262 case OP_TYPEID::Sign:
1264 size_t element_count = shape_size(node.get_output_shape(0));
1266 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1269 case OP_TYPEID::Sin:
1271 size_t element_count = shape_size(node.get_output_shape(0));
1273 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1276 case OP_TYPEID::Sinh:
1278 size_t element_count = shape_size(node.get_output_shape(0));
1280 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1283 case OP_TYPEID::Sqrt:
1285 size_t element_count = shape_size(node.get_output_shape(0));
1287 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1290 case OP_TYPEID::Tan:
1292 size_t element_count = shape_size(node.get_output_shape(0));
1294 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1297 case OP_TYPEID::Tanh:
1299 size_t element_count = shape_size(node.get_output_shape(0));
1301 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1304 case OP_TYPEID::TopK:
1306 const op::TopK* topk = static_cast<const op::TopK*>(&node);
1307 if (node.get_output_element_type(0) == element::i64)
1309 reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
1310 out[0]->get_data_ptr<int64_t>(),
1311 out[1]->get_data_ptr<T>(),
1312 node.get_input_shape(0),
1313 node.get_output_shape(0),
1314 topk->get_top_k_axis(),
1316 topk->get_compute_max(),
1319 else if (node.get_output_element_type(0) == element::i32)
1321 reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
1322 out[0]->get_data_ptr<int32_t>(),
1323 out[1]->get_data_ptr<T>(),
1324 node.get_input_shape(0),
1325 node.get_output_shape(0),
1326 topk->get_top_k_axis(),
1328 topk->get_compute_max(),
1333 throw ngraph_error("Unexpected type");
1337 case OP_TYPEID::DetectionOutput_v0:
1339 const op::DetectionOutput* detOut = static_cast<const op::DetectionOutput*>(&node);
1340 reference::referenceDetectionOutput<T> refDetOut(
1341 detOut->get_attrs(), node.get_input_shape(0), node.get_input_shape(2));
1342 if (node.get_input_size() == 3)
1344 refDetOut.run(args[0]->get_data_ptr<const T>(),
1345 args[1]->get_data_ptr<const T>(),
1346 args[2]->get_data_ptr<const T>(),
1349 out[0]->get_data_ptr<T>());
1351 else if (node.get_input_size() == 5)
1353 refDetOut.run(args[0]->get_data_ptr<const T>(),
1354 args[1]->get_data_ptr<const T>(),
1355 args[2]->get_data_ptr<const T>(),
1356 args[3]->get_data_ptr<const T>(),
1357 args[4]->get_data_ptr<const T>(),
1358 out[0]->get_data_ptr<T>());
1362 throw ngraph_error("DetectionOutput layer supports only 3 or 5 inputs");
1367 case OP_TYPEID::ScatterNDUpdate_v3:
1369 const op::ScatterNDUpdate* scatterNDUpd =
1370 static_cast<const op::v3::ScatterNDUpdate*>(&node);
1371 auto idxType = scatterNDUpd->get_input_element_type(1);
1372 if (idxType == element::i32)
1374 reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
1375 args[1]->get_data_ptr<const int32_t>(),
1376 args[2]->get_data_ptr<const T>(),
1377 out[0]->get_data_ptr<T>(),
1378 node.get_input_shape(0),
1379 node.get_input_shape(1),
1380 node.get_input_shape(2));
1382 else if (idxType == element::i64)
1384 reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
1385 args[1]->get_data_ptr<const int64_t>(),
1386 args[2]->get_data_ptr<const T>(),
1387 out[0]->get_data_ptr<T>(),
1388 node.get_input_shape(0),
1389 node.get_input_shape(1),
1390 node.get_input_shape(2));
1395 "ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
1400 case OP_TYPEID::GatherTree_v1:
1402 reference::gather_tree(args[0]->get_data_ptr<const char>(),
1403 args[1]->get_data_ptr<const char>(),
1404 args[2]->get_data_ptr<const char>(),
1405 args[3]->get_data_ptr<const char>(),
1406 out[0]->get_data_ptr<char>(),
1407 node.get_input_shape(0),
1408 node.get_input_shape(1),
1409 node.get_input_shape(2),
1410 node.get_input_shape(3),
1411 args[1]->get_element_type());
1414 case OP_TYPEID::NormalizeL2:
1416 const op::NormalizeL2* norm = static_cast<const op::NormalizeL2*>(&node);
1417 reference::normalize_l2<T>(args[0]->get_data_ptr<const T>(),
1418 out[0]->get_data_ptr<T>(),
1419 node.get_input_shape(0),
1420 norm->get_reduction_axes(),
1422 norm->get_eps_mode());
1426 // Fused Ops are not supported in interpreter. They need to be decomposed before execution
1427 case OP_TYPEID::DepthToSpace:
1428 case OP_TYPEID::FakeQuantize:
1429 case OP_TYPEID::Gather:
1430 case OP_TYPEID::Gelu:
1431 case OP_TYPEID::GRN:
1432 case OP_TYPEID::GroupConvolution:
1433 case OP_TYPEID::GroupConvolutionBackpropData:
1434 case OP_TYPEID::HardSigmoid:
1435 case OP_TYPEID::Interpolate:
1436 case OP_TYPEID::MVN:
1437 case OP_TYPEID::PRelu:
1438 case OP_TYPEID::ScatterUpdate_v3:
1439 case OP_TYPEID::Selu:
1440 case OP_TYPEID::ShuffleChannels:
1441 case OP_TYPEID::SpaceToDepth:
1442 case OP_TYPEID::Split:
1443 case OP_TYPEID::SquaredDifference:
1444 case OP_TYPEID::StopGradient:
1445 case OP_TYPEID::TensorIterator:
1446 case OP_TYPEID::Tile:
1447 case OP_TYPEID::UnknownOp:
1448 throw unsupported_op("Unsupported op '" + node.description() + "'");
1449 case OP_TYPEID::Add:
1450 case OP_TYPEID::Broadcast:
1451 case OP_TYPEID::Clamp:
1452 case OP_TYPEID::Concat:
1453 case OP_TYPEID::Constant:
1454 case OP_TYPEID::Divide:
1455 case OP_TYPEID::Equal:
1456 case OP_TYPEID::Greater:
1457 case OP_TYPEID::GreaterEq:
1458 case OP_TYPEID::Less:
1459 case OP_TYPEID::LessEq:
1460 case OP_TYPEID::LessEqual_v1:
1461 case OP_TYPEID::LogicalAnd_v1:
1462 case OP_TYPEID::LogicalOr_v1:
1463 case OP_TYPEID::LogicalXor_v1:
1464 case OP_TYPEID::MatMul:
1465 case OP_TYPEID::Max:
1466 case OP_TYPEID::Maximum:
1467 case OP_TYPEID::Min:
1468 case OP_TYPEID::Minimum:
1469 case OP_TYPEID::Multiply:
1470 case OP_TYPEID::NonZero_v3:
1471 case OP_TYPEID::NotEqual:
1473 case OP_TYPEID::Power:
1474 case OP_TYPEID::Product:
1475 case OP_TYPEID::Range:
1476 case OP_TYPEID::Reshape:
1477 case OP_TYPEID::Result:
1478 case OP_TYPEID::Round_v5:
1479 case OP_TYPEID::ShapeOf_v3:
1480 case OP_TYPEID::ShapeOf:
1481 case OP_TYPEID::Softmax:
1482 case OP_TYPEID::Squeeze:
1483 case OP_TYPEID::Sum:
1484 case OP_TYPEID::Subtract:
1485 case OP_TYPEID::Unsqueeze:
1486 case OP_TYPEID::Xor:
1487 case OP_TYPEID::Slice:
1488 // These ops are handled by op evaluators so nothing to do
1490 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
1491 #pragma GCC diagnostic pop
1497 NGRAPH_SUPPRESS_DEPRECATED_END