1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include <initializer_list>
26 #include <ngraph/runtime/host_tensor.hpp>
27 #include "backend.hpp"
28 #include "int_backend_visibility.hpp"
29 #include "ngraph/ops.hpp"
30 #include "ngraph/runtime/aligned_buffer.hpp"
31 #include "ngraph/runtime/reference/abs.hpp"
32 #include "ngraph/runtime/reference/acos.hpp"
33 #include "ngraph/runtime/reference/asin.hpp"
34 #include "ngraph/runtime/reference/atan.hpp"
35 #include "ngraph/runtime/reference/atan2.hpp"
36 #include "ngraph/runtime/reference/avg_pool.hpp"
37 #include "ngraph/runtime/reference/batch_norm.hpp"
38 #include "ngraph/runtime/reference/broadcast.hpp"
39 #include "ngraph/runtime/reference/ceiling.hpp"
40 #include "ngraph/runtime/reference/concat.hpp"
41 #include "ngraph/runtime/reference/constant.hpp"
42 #include "ngraph/runtime/reference/convert.hpp"
43 #include "ngraph/runtime/reference/convolution.hpp"
44 #include "ngraph/runtime/reference/cos.hpp"
45 #include "ngraph/runtime/reference/cosh.hpp"
46 #include "ngraph/runtime/reference/ctc_greedy_decoder.hpp"
47 #include "ngraph/runtime/reference/ctc_loss.hpp"
48 #include "ngraph/runtime/reference/cum_sum.hpp"
49 #include "ngraph/runtime/reference/detection_output.hpp"
50 #include "ngraph/runtime/reference/dot.hpp"
51 #include "ngraph/runtime/reference/elu.hpp"
52 #include "ngraph/runtime/reference/embedding_bag_offsets_sum.hpp"
53 #include "ngraph/runtime/reference/embedding_bag_packed_sum.hpp"
54 #include "ngraph/runtime/reference/embedding_segments_sum.hpp"
55 #include "ngraph/runtime/reference/erf.hpp"
56 #include "ngraph/runtime/reference/exp.hpp"
57 #include "ngraph/runtime/reference/extract_image_patches.hpp"
58 #include "ngraph/runtime/reference/floor.hpp"
59 #include "ngraph/runtime/reference/gather.hpp"
60 #include "ngraph/runtime/reference/gather_nd.hpp"
61 #include "ngraph/runtime/reference/gather_tree.hpp"
62 #include "ngraph/runtime/reference/gru_cell.hpp"
63 #include "ngraph/runtime/reference/log.hpp"
64 #include "ngraph/runtime/reference/log_softmax.hpp"
65 #include "ngraph/runtime/reference/lrn.hpp"
66 #include "ngraph/runtime/reference/lstm_cell.hpp"
67 #include "ngraph/runtime/reference/matmul.hpp"
68 #include "ngraph/runtime/reference/max.hpp"
69 #include "ngraph/runtime/reference/max_pool.hpp"
70 #include "ngraph/runtime/reference/min.hpp"
71 #include "ngraph/runtime/reference/negate.hpp"
72 #include "ngraph/runtime/reference/normalize_l2.hpp"
73 #include "ngraph/runtime/reference/not.hpp"
74 #include "ngraph/runtime/reference/one_hot.hpp"
75 #include "ngraph/runtime/reference/pad.hpp"
76 #include "ngraph/runtime/reference/prior_box.hpp"
77 #include "ngraph/runtime/reference/product.hpp"
78 #include "ngraph/runtime/reference/quantize.hpp"
79 #include "ngraph/runtime/reference/region_yolo.hpp"
80 #include "ngraph/runtime/reference/relu.hpp"
81 #include "ngraph/runtime/reference/reorg_yolo.hpp"
82 #include "ngraph/runtime/reference/replace_slice.hpp"
83 #include "ngraph/runtime/reference/reshape.hpp"
84 #include "ngraph/runtime/reference/result.hpp"
85 #include "ngraph/runtime/reference/reverse.hpp"
86 #include "ngraph/runtime/reference/reverse_sequence.hpp"
87 #include "ngraph/runtime/reference/rnn_cell.hpp"
88 #include "ngraph/runtime/reference/round.hpp"
89 #include "ngraph/runtime/reference/scatter_nd_update.hpp"
90 #include "ngraph/runtime/reference/select.hpp"
91 #include "ngraph/runtime/reference/sequences.hpp"
92 #include "ngraph/runtime/reference/sigmoid.hpp"
93 #include "ngraph/runtime/reference/sign.hpp"
94 #include "ngraph/runtime/reference/sin.hpp"
95 #include "ngraph/runtime/reference/sinh.hpp"
96 #include "ngraph/runtime/reference/softmax.hpp"
97 #include "ngraph/runtime/reference/sqrt.hpp"
98 #include "ngraph/runtime/reference/sum.hpp"
99 #include "ngraph/runtime/reference/tan.hpp"
100 #include "ngraph/runtime/reference/tanh.hpp"
101 #include "ngraph/runtime/reference/topk.hpp"
102 #include "ngraph/runtime/tensor.hpp"
103 #include "op/avg_pool.hpp"
104 #include "op/convolution.hpp"
105 #include "op/group_conv.hpp"
107 NGRAPH_SUPPRESS_DEPRECATED_START
113 namespace interpreter
118 // This expands the op list in op_tbl.hpp into a list of enumerations that look like
125 #define NGRAPH_OP(NAME, NAMESPACE) ID_SUFFIX(NAME),
126 #include "opset_int_tbl.hpp"
130 } // namespace interpreter
131 } // namespace runtime
132 } // namespace ngraph
134 class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : public Executable
136 friend class INTBackend;
139 INTExecutable(const std::shared_ptr<Function>& function,
140 bool enable_performance_collection = false);
142 bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
143 const std::vector<std::shared_ptr<Tensor>>& inputs) override;
145 void set_nan_check(bool enable);
147 std::vector<PerformanceCounter> get_performance_data() const override;
149 std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index) override;
151 std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index) override;
153 std::vector<std::shared_ptr<runtime::Tensor>>
154 create_input_tensor(size_t input_index, size_t pipeline_depth) override;
156 std::vector<std::shared_ptr<runtime::Tensor>>
157 create_output_tensor(size_t output_index, size_t pipeline_depth) override;
160 std::shared_ptr<ngraph::op::Parameter> get_parameter(size_t index) const;
161 std::shared_ptr<ngraph::op::Result> get_result(size_t index) const;
162 int get_alignment() const { return 64; }
163 bool m_is_compiled = false;
164 bool m_nan_check_enabled = false;
165 bool m_performance_counters_enabled = false;
166 std::shared_ptr<Function> m_function;
167 std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
168 std::vector<std::shared_ptr<Node>> m_nodes;
169 std::set<std::string> m_unsupported_op_name_list;
171 static OP_TYPEID get_typeid(const Node& node);
173 static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
174 const Node* op = nullptr);
176 virtual void generate_calls(const element::Type& type,
178 const std::vector<std::shared_ptr<HostTensor>>& outputs,
179 const std::vector<std::shared_ptr<HostTensor>>& inputs);
181 template <typename T>
182 void op_engine(const Node& node,
183 const std::vector<std::shared_ptr<HostTensor>>& out,
184 const std::vector<std::shared_ptr<HostTensor>>& args)
186 // We want to check that every OP_TYPEID enumeration is included in the list.
187 // These GCC flags enable compile-time checking so that if an enumeration
188 // is not in the list an error is generated.
189 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
190 #pragma GCC diagnostic push
191 #pragma GCC diagnostic error "-Wswitch"
192 #pragma GCC diagnostic error "-Wswitch-enum"
194 switch (get_typeid(node))
198 size_t element_count = shape_size(node.get_output_shape(0));
200 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
203 case OP_TYPEID::Acos:
205 size_t element_count = shape_size(node.get_output_shape(0));
207 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
210 case OP_TYPEID::Asin:
212 size_t element_count = shape_size(node.get_output_shape(0));
214 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
217 case OP_TYPEID::Atan:
219 size_t element_count = shape_size(node.get_output_shape(0));
221 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
226 const op::Elu* elu_node = static_cast<const op::Elu*>(&node);
228 size_t element_count = shape_size(node.get_output_shape(0));
229 reference::elu<T>(args[0]->get_data_ptr<const T>(),
230 out[0]->get_data_ptr<T>(),
232 elu_node->get_alpha());
235 case OP_TYPEID::AvgPool:
237 const op::v0::AvgPool* avg_pool = static_cast<const op::v0::AvgPool*>(&node);
239 reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
240 out[0]->get_data_ptr<T>(),
241 node.get_input_shape(0),
242 node.get_output_shape(0),
243 avg_pool->get_window_shape(),
244 avg_pool->get_window_movement_strides(),
245 avg_pool->get_padding_below(),
246 avg_pool->get_padding_above(),
247 avg_pool->get_include_padding_in_avg_computation());
250 case OP_TYPEID::BatchNormInference:
252 const ngraph::op::v0::BatchNormInference* bn =
253 static_cast<const ngraph::op::v0::BatchNormInference*>(&node);
254 reference::batch_norm_inference<T>(bn->get_eps_value(),
255 args[0]->get_data_ptr<const T>(),
256 args[1]->get_data_ptr<const T>(),
257 args[2]->get_data_ptr<const T>(),
258 args[3]->get_data_ptr<const T>(),
259 args[4]->get_data_ptr<const T>(),
260 out[0]->get_data_ptr<T>(),
261 node.get_input_shape(2));
264 case OP_TYPEID::BatchNormInference_v5:
266 const ngraph::op::v5::BatchNormInference* bn =
267 static_cast<const ngraph::op::v5::BatchNormInference*>(&node);
268 reference::batch_norm_inference<T>(bn->get_eps_value(),
269 args[1]->get_data_ptr<const T>(),
270 args[2]->get_data_ptr<const T>(),
271 args[0]->get_data_ptr<const T>(),
272 args[3]->get_data_ptr<const T>(),
273 args[4]->get_data_ptr<const T>(),
274 out[0]->get_data_ptr<T>(),
275 node.get_input_shape(0));
278 case OP_TYPEID::Ceiling:
280 size_t element_count = shape_size(node.get_output_shape(0));
281 reference::ceiling<T>(
282 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
285 case OP_TYPEID::Convert:
287 // const op::Convert* c = static_cast<const op::Convert*>(&node);
288 element::Type type = node.get_element_type();
289 std::stringstream ss;
290 size_t element_count = shape_size(node.get_output_shape(0));
293 case element::Type_t::boolean:
294 reference::convert_to_bool<T>(
295 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
297 case element::Type_t::f32:
298 reference::convert<T>(
299 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
301 case element::Type_t::f64:
302 reference::convert<T>(args[0]->get_data_ptr<const T>(),
303 out[0]->get_data_ptr<double>(),
306 case element::Type_t::i8:
307 reference::convert<T>(args[0]->get_data_ptr<const T>(),
308 out[0]->get_data_ptr<int8_t>(),
311 case element::Type_t::i16:
312 reference::convert<T>(args[0]->get_data_ptr<const T>(),
313 out[0]->get_data_ptr<int16_t>(),
316 case element::Type_t::i32:
317 reference::convert<T>(args[0]->get_data_ptr<const T>(),
318 out[0]->get_data_ptr<int32_t>(),
321 case element::Type_t::i64:
322 reference::convert<T>(args[0]->get_data_ptr<const T>(),
323 out[0]->get_data_ptr<int64_t>(),
326 case element::Type_t::u8:
327 reference::convert<T>(args[0]->get_data_ptr<const T>(),
328 out[0]->get_data_ptr<uint8_t>(),
331 case element::Type_t::u16:
332 reference::convert<T>(args[0]->get_data_ptr<const T>(),
333 out[0]->get_data_ptr<uint16_t>(),
336 case element::Type_t::u32:
337 reference::convert<T>(args[0]->get_data_ptr<const T>(),
338 out[0]->get_data_ptr<uint32_t>(),
341 case element::Type_t::u64:
342 reference::convert<T>(args[0]->get_data_ptr<const T>(),
343 out[0]->get_data_ptr<uint64_t>(),
346 case element::Type_t::undefined:
347 case element::Type_t::dynamic:
348 case element::Type_t::u1:
349 case element::Type_t::bf16:
350 case element::Type_t::f16:
351 ss << "unsupported element type " << type << " op Convert";
352 throw std::runtime_error(ss.str());
356 case OP_TYPEID::Convolution:
358 const op::v0::Convolution* c = static_cast<const op::v0::Convolution*>(&node);
359 reference::convolution<T>(args[0]->get_data_ptr<const T>(),
360 args[1]->get_data_ptr<const T>(),
361 out[0]->get_data_ptr<T>(),
362 node.get_input_shape(0),
363 node.get_input_shape(1),
364 node.get_output_shape(0),
365 c->get_window_movement_strides(),
366 c->get_window_dilation_strides(),
367 c->get_padding_below(),
368 c->get_padding_above(),
369 c->get_data_dilation_strides());
373 case OP_TYPEID::ConvolutionBackpropData:
375 // Note that args[1] and args[0] are switched here from the usual order.
376 const op::v0::ConvolutionBackpropData* c =
377 static_cast<const op::v0::ConvolutionBackpropData*>(&node);
378 reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
379 args[0]->get_data_ptr<const T>(),
380 out[0]->get_data_ptr<T>(),
381 c->get_input_shape(1),
382 c->get_input_shape(0),
383 c->get_data_batch_shape(),
384 c->get_data_dilation_strides_forward(),
385 c->get_window_dilation_strides_forward(),
386 c->compute_backward_delta_out_pad_below(),
387 c->compute_backward_delta_out_pad_above(),
388 c->get_window_movement_strides_forward());
393 size_t element_count = shape_size(node.get_output_shape(0));
395 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
398 case OP_TYPEID::Cosh:
400 size_t element_count = shape_size(node.get_output_shape(0));
402 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
405 case OP_TYPEID::CTCGreedyDecoder_v0:
407 const auto ctc_greedy_dec = static_cast<const op::v0::CTCGreedyDecoder*>(&node);
408 reference::ctc_greedy_decoder<T>(args[0]->get_data_ptr<const T>(),
409 args[1]->get_data_ptr<const T>(),
410 out[0]->get_data_ptr<T>(),
411 args[0]->get_shape(),
412 args[1]->get_shape(),
414 ctc_greedy_dec->get_ctc_merge_repeated());
417 case OP_TYPEID::CTCLoss_v4:
419 const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
420 auto t_int = node.get_input_element_type(1);
421 if (t_int == element::i32)
423 reference::CTCLoss<T, int32_t>(
424 args[0]->get_data_ptr<const T>(),
425 ctc_loss->get_input_shape(0),
426 args[1]->get_data_ptr<const int32_t>(),
427 args[2]->get_data_ptr<const int32_t>(),
428 args[3]->get_data_ptr<const int32_t>(),
429 args.size() > 4 ? args[4]->get_data_ptr<const int32_t>() : nullptr,
430 ctc_loss->get_preprocess_collapse_repeated(),
431 ctc_loss->get_ctc_merge_repeated(),
432 ctc_loss->get_unique(),
433 out[0]->get_data_ptr<T>());
435 else if (t_int == element::i64)
437 reference::CTCLoss<T, int64_t>(
438 args[0]->get_data_ptr<const T>(),
439 ctc_loss->get_input_shape(0),
440 args[1]->get_data_ptr<const int64_t>(),
441 args[2]->get_data_ptr<const int64_t>(),
442 args[3]->get_data_ptr<const int64_t>(),
443 args.size() > 4 ? args[4]->get_data_ptr<const int64_t>() : nullptr,
444 ctc_loss->get_preprocess_collapse_repeated(),
445 ctc_loss->get_ctc_merge_repeated(),
446 ctc_loss->get_unique(),
447 out[0]->get_data_ptr<T>());
451 case OP_TYPEID::CumSum:
453 const op::CumSum* cumsum = static_cast<const op::CumSum*>(&node);
454 auto axis_et = node.get_input_element_type(1);
455 if (axis_et == element::i32)
457 reference::cumsum<T, int32_t>(args[0]->get_data_ptr<const T>(),
458 args[1]->get_data_ptr<const int32_t>(),
459 out[0]->get_data_ptr<T>(),
460 node.get_input_shape(0),
461 cumsum->is_exclusive(),
462 cumsum->is_reverse());
464 else if (axis_et == element::i64)
466 reference::cumsum<T, int64_t>(args[0]->get_data_ptr<const T>(),
467 args[1]->get_data_ptr<const int64_t>(),
468 out[0]->get_data_ptr<T>(),
469 node.get_input_shape(0),
470 cumsum->is_exclusive(),
471 cumsum->is_reverse());
477 const op::Dot* dot = static_cast<const op::Dot*>(&node);
479 reference::dot(args[0]->get_data_ptr<const T>(),
480 args[1]->get_data_ptr<const T>(),
481 out[0]->get_data_ptr<T>(),
482 node.get_input_shape(0),
483 node.get_input_shape(1),
484 node.get_output_shape(0),
485 dot->get_reduction_axes_count());
488 case OP_TYPEID::EmbeddingBagOffsetsSum_v3:
490 const op::EmbeddingBagOffsetsSum* embed =
491 static_cast<const op::EmbeddingBagOffsetsSum*>(&node);
492 auto indicesType = embed->input(1).get_element_type();
493 size_t indices_num = shape_size(embed->get_input_shape(1));
495 if (indicesType == element::u64 || indicesType == element::i64)
497 reference::embeddingBagOffsetsSum<T, size_t>(
498 args[0]->get_data_ptr<const T>(),
499 args[1]->get_data_ptr<const size_t>(),
500 args[2]->get_data_ptr<const size_t>(),
501 args.size() > 3 ? args[3]->get_data_ptr<const size_t>() : nullptr,
502 args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
503 out[0]->get_data_ptr<T>(),
507 else if (indicesType == element::u32 || indicesType == element::i32)
509 reference::embeddingBagOffsetsSum<T, unsigned>(
510 args[0]->get_data_ptr<const T>(),
511 args[1]->get_data_ptr<const unsigned>(),
512 args[2]->get_data_ptr<const unsigned>(),
513 args.size() > 3 ? args[3]->get_data_ptr<const unsigned>() : nullptr,
514 args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
515 out[0]->get_data_ptr<T>(),
521 throw ngraph_error(std::string("Unsupported index type ") +
522 indicesType.c_type_string() +
523 std::string(" in EmbeddingBagOffsetsSum"));
527 case OP_TYPEID::EmbeddingBagPackedSum_v3:
529 const op::EmbeddingBagPackedSum* embed =
530 static_cast<const op::EmbeddingBagPackedSum*>(&node);
531 auto indicesType = embed->input(1).get_element_type();
533 if (indicesType == element::u64 || indicesType == element::i64)
535 reference::embeddingBagPackedSum<T, size_t>(
536 args[0]->get_data_ptr<const T>(),
537 args[1]->get_data_ptr<const size_t>(),
538 args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
539 out[0]->get_data_ptr<T>(),
540 embed->get_input_shape(1),
543 else if (indicesType == element::u32 || indicesType == element::i32)
545 reference::embeddingBagPackedSum<T, unsigned>(
546 args[0]->get_data_ptr<const T>(),
547 args[1]->get_data_ptr<const unsigned>(),
548 args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
549 out[0]->get_data_ptr<T>(),
550 embed->get_input_shape(1),
555 throw ngraph_error(std::string("Unsupported index type ") +
556 indicesType.c_type_string() +
557 std::string(" in EmbeddingBagPackedSum"));
561 case OP_TYPEID::EmbeddingSegmentsSum_v3:
563 const op::EmbeddingSegmentsSum* embed =
564 static_cast<const op::EmbeddingSegmentsSum*>(&node);
565 auto indicesType = embed->input(1).get_element_type();
566 size_t indices_num = shape_size(embed->get_input_shape(1));
568 if (indicesType == element::u64 || indicesType == element::i64)
570 reference::embeddingSegmentsSum<T, size_t>(
571 args[0]->get_data_ptr<const T>(),
572 args[1]->get_data_ptr<const size_t>(),
573 args[2]->get_data_ptr<const size_t>(),
574 args.size() > 4 ? args[4]->get_data_ptr<const size_t>() : nullptr,
575 args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
576 out[0]->get_data_ptr<T>(),
577 embed->get_input_shape(0),
578 embed->get_input_shape(1),
581 else if (indicesType == element::u32 || indicesType == element::i32)
583 reference::embeddingSegmentsSum<T, unsigned>(
584 args[0]->get_data_ptr<const T>(),
585 args[1]->get_data_ptr<const unsigned>(),
586 args[2]->get_data_ptr<const unsigned>(),
587 args.size() > 4 ? args[4]->get_data_ptr<const unsigned>() : nullptr,
588 args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
589 out[0]->get_data_ptr<T>(),
590 embed->get_input_shape(0),
591 embed->get_input_shape(1),
596 throw ngraph_error(std::string("Unsupported index type ") +
597 indicesType.c_type_string() +
598 std::string(" in EmbeddingSegmentsSum"));
604 size_t element_count = shape_size(node.get_output_shape(0));
606 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
609 case OP_TYPEID::ExtractImagePatches_v3:
611 const op::ExtractImagePatches* extImgPatches =
612 static_cast<const op::ExtractImagePatches*>(&node);
613 reference::extractImagePatches<T, size_t>(extImgPatches,
614 args[0]->get_data_ptr<const T>(),
615 out[0]->get_data_ptr<T>(),
616 extImgPatches->get_input_shape(0),
617 extImgPatches->get_shape());
622 size_t element_count = shape_size(node.get_output_shape(0));
624 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
627 #ifdef INTERPRETER_USE_HYBRID
628 case OP_TYPEID::FunctionCall:
630 auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
631 auto backend = f->get_backend();
632 auto executable = f->get_executable();
634 std::vector<std::shared_ptr<Tensor>> outputs;
635 std::vector<std::shared_ptr<Tensor>> inputs;
636 for (const std::shared_ptr<HostTensor>& t : out)
638 auto backend_tensor = backend->create_tensor(
639 t->get_element_type(), t->get_shape(), t->get_data_ptr());
640 outputs.push_back(backend_tensor);
642 for (const std::shared_ptr<HostTensor>& t : args)
644 auto backend_tensor = backend->create_tensor(
645 t->get_element_type(), t->get_shape(), t->get_data_ptr());
646 inputs.push_back(backend_tensor);
648 executable->call(outputs, inputs);
652 case OP_TYPEID::Floor:
654 size_t element_count = shape_size(node.get_output_shape(0));
656 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
659 case OP_TYPEID::GatherND_v5:
661 const op::v5::GatherND* gatherNDNode = static_cast<const op::v5::GatherND*>(&node);
662 if (node.get_input_element_type(1) == element::i64)
664 reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
665 args[1]->get_data_ptr<int64_t>(),
666 out[0]->get_data_ptr<T>(),
667 node.get_input_shape(0),
668 node.get_input_shape(1),
669 node.get_output_shape(0),
670 gatherNDNode->get_batch_dims());
672 else if (node.get_input_element_type(1) == element::i32)
674 reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
675 args[1]->get_data_ptr<int32_t>(),
676 out[0]->get_data_ptr<T>(),
677 node.get_input_shape(0),
678 node.get_input_shape(1),
679 node.get_output_shape(0),
680 gatherNDNode->get_batch_dims());
684 throw ngraph_error("Unexpected type");
688 case OP_TYPEID::GRUCell_v3:
690 const op::v3::GRUCell* gru_cell = static_cast<const op::v3::GRUCell*>(&node);
691 runtime::reference::gru_cell(args[0]->get_data_ptr<T>(),
692 args[0]->get_shape(),
693 args[1]->get_data_ptr<T>(),
694 args[1]->get_shape(),
695 args[2]->get_data_ptr<T>(),
696 args[2]->get_shape(),
697 args[3]->get_data_ptr<T>(),
698 args[3]->get_shape(),
699 args[4]->get_data_ptr<T>(),
700 args[4]->get_shape(),
701 out[0]->get_data_ptr<T>(),
702 gru_cell->get_activations()[0],
703 gru_cell->get_activations()[1],
704 gru_cell->get_clip(),
705 gru_cell->get_linear_before_reset());
708 case OP_TYPEID::LSTMCell_v4:
710 const op::v4::LSTMCell* lstm_cell = static_cast<const op::v4::LSTMCell*>(&node);
711 runtime::reference::lstm_cell(args[0]->get_data_ptr<T>(),
712 args[0]->get_shape(),
713 args[1]->get_data_ptr<T>(),
714 args[1]->get_shape(),
715 args[2]->get_data_ptr<T>(),
716 args[2]->get_shape(),
717 args[3]->get_data_ptr<T>(),
718 args[3]->get_shape(),
719 args[4]->get_data_ptr<T>(),
720 args[4]->get_shape(),
721 args[5]->get_data_ptr<T>(),
722 args[5]->get_shape(),
723 out[0]->get_data_ptr<T>(),
724 out[1]->get_data_ptr<T>(),
725 lstm_cell->get_activations()[0],
726 lstm_cell->get_activations()[1],
727 lstm_cell->get_activations()[2],
728 lstm_cell->get_clip());
731 case OP_TYPEID::RNNCell_v0:
733 const op::v0::RNNCell* rnn_cell = static_cast<const op::v0::RNNCell*>(&node);
734 runtime::reference::rnn_cell(args[0]->get_data_ptr<T>(),
735 args[0]->get_shape(),
736 args[1]->get_data_ptr<T>(),
737 args[1]->get_shape(),
738 args[2]->get_data_ptr<T>(),
739 args[2]->get_shape(),
740 args[3]->get_data_ptr<T>(),
741 args[3]->get_shape(),
742 args[4]->get_data_ptr<T>(),
743 args[4]->get_shape(),
744 out[0]->get_data_ptr<T>(),
745 rnn_cell->get_activations()[0],
746 rnn_cell->get_clip());
749 case OP_TYPEID::LSTMSequence:
750 case OP_TYPEID::LSTMSequence_v5:
752 auto lstm_seq = static_cast<const op::v5::LSTMSequence*>(&node);
753 runtime::reference::lstm_sequence<T>(args[0]->get_data_ptr<char>(),
754 args[0]->get_shape(),
755 args[1]->get_data_ptr<char>(),
756 args[1]->get_shape(),
757 args[2]->get_data_ptr<char>(),
758 args[2]->get_shape(),
759 args[3]->get_data_ptr<char>(),
760 args[3]->get_shape(),
761 args[4]->get_data_ptr<char>(),
762 args[4]->get_shape(),
763 args[5]->get_data_ptr<char>(),
764 args[5]->get_shape(),
765 args[6]->get_data_ptr<char>(),
766 args[6]->get_shape(),
767 out[0]->get_data_ptr<char>(),
768 out[1]->get_data_ptr<char>(),
769 out[2]->get_data_ptr<char>(),
770 lstm_seq->get_activations()[0],
771 lstm_seq->get_activations()[1],
772 lstm_seq->get_activations()[2],
773 lstm_seq->get_clip(),
774 lstm_seq->get_direction());
777 case OP_TYPEID::GRUSequence_v5:
779 auto gru_seq = static_cast<const op::v5::GRUSequence*>(&node);
780 runtime::reference::gru_sequence<T>(args[0]->get_data_ptr<char>(),
781 args[0]->get_shape(),
782 args[1]->get_data_ptr<char>(),
783 args[1]->get_shape(),
784 args[2]->get_data_ptr<char>(),
785 args[2]->get_shape(),
786 args[3]->get_data_ptr<char>(),
787 args[3]->get_shape(),
788 args[4]->get_data_ptr<char>(),
789 args[4]->get_shape(),
790 args[5]->get_data_ptr<char>(),
791 args[5]->get_shape(),
792 out[0]->get_data_ptr<char>(),
793 out[1]->get_data_ptr<char>(),
794 gru_seq->get_activations()[0],
795 gru_seq->get_activations()[1],
797 gru_seq->get_direction(),
798 gru_seq->get_linear_before_reset());
801 case OP_TYPEID::RNNSequence_v5:
803 auto rnn_seq = static_cast<const op::v5::RNNSequence*>(&node);
804 runtime::reference::rnn_sequence<T>(args[0]->get_data_ptr<char>(),
805 args[0]->get_shape(),
806 args[1]->get_data_ptr<char>(),
807 args[1]->get_shape(),
808 args[2]->get_data_ptr<char>(),
809 args[2]->get_shape(),
810 args[3]->get_data_ptr<char>(),
811 args[3]->get_shape(),
812 args[4]->get_data_ptr<char>(),
813 args[4]->get_shape(),
814 args[5]->get_data_ptr<char>(),
815 args[5]->get_shape(),
816 out[0]->get_data_ptr<char>(),
817 out[1]->get_data_ptr<char>(),
818 rnn_seq->get_activations()[0],
820 rnn_seq->get_direction());
825 size_t element_count = shape_size(node.get_output_shape(0));
827 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
830 case OP_TYPEID::LogSoftmax_v5:
832 const op::v5::LogSoftmax* log_softmax = static_cast<const op::v5::LogSoftmax*>(&node);
833 int64_t i_axis = log_softmax->get_axis();
836 i_axis += args[0]->get_partial_shape().rank().get_length();
838 reference::log_softmax<T>(args[0]->get_data_ptr<const T>(),
839 out[0]->get_data_ptr<T>(),
840 node.get_output_shape(0),
841 AxisSet{(size_t)i_axis});
846 const op::LRN* lrn = static_cast<const op::LRN*>(&node);
847 reference::lrn<T>(args[0]->get_data_ptr<const T>(),
848 lrn->get_reduction_axes(),
849 out[0]->get_data_ptr<T>(),
850 node.get_input_shape(0),
857 case OP_TYPEID::Negative:
859 size_t element_count = shape_size(node.get_output_shape(0));
860 reference::negate<T>(
861 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
864 case OP_TYPEID::LogicalNot_v1:
867 size_t element_count = shape_size(node.get_output_shape(0));
868 reference::logical_not(
869 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
872 case OP_TYPEID::OneHot_v1:
874 const op::v1::OneHot* oh = static_cast<const op::v1::OneHot*>(&node);
875 T on_value = args[2]->get_data_ptr<T>()[0];
876 T off_value = args[3]->get_data_ptr<T>()[0];
878 switch (args[0]->get_element_type())
880 case element::Type_t::i8:
881 reference::one_hot(args[0]->get_data_ptr<const int8_t>(),
882 out[0]->get_data_ptr<T>(),
883 node.get_input_shape(0),
884 node.get_output_shape(0),
889 case element::Type_t::i16:
890 reference::one_hot(args[0]->get_data_ptr<const int16_t>(),
891 out[0]->get_data_ptr<T>(),
892 node.get_input_shape(0),
893 node.get_output_shape(0),
898 case element::Type_t::i32:
899 reference::one_hot(args[0]->get_data_ptr<const int32_t>(),
900 out[0]->get_data_ptr<T>(),
901 node.get_input_shape(0),
902 node.get_output_shape(0),
907 case element::Type_t::i64:
908 reference::one_hot(args[0]->get_data_ptr<const int64_t>(),
909 out[0]->get_data_ptr<T>(),
910 node.get_input_shape(0),
911 node.get_output_shape(0),
916 case element::Type_t::u8:
917 reference::one_hot(args[0]->get_data_ptr<const uint8_t>(),
918 out[0]->get_data_ptr<T>(),
919 node.get_input_shape(0),
920 node.get_output_shape(0),
925 case element::Type_t::u16:
926 reference::one_hot(args[0]->get_data_ptr<const uint16_t>(),
927 out[0]->get_data_ptr<T>(),
928 node.get_input_shape(0),
929 node.get_output_shape(0),
934 case element::Type_t::u32:
935 reference::one_hot(args[0]->get_data_ptr<const uint32_t>(),
936 out[0]->get_data_ptr<T>(),
937 node.get_input_shape(0),
938 node.get_output_shape(0),
943 case element::Type_t::u64:
944 reference::one_hot(args[0]->get_data_ptr<const uint64_t>(),
945 out[0]->get_data_ptr<T>(),
946 node.get_input_shape(0),
947 node.get_output_shape(0),
952 case element::Type_t::undefined:
953 case element::Type_t::dynamic:
954 case element::Type_t::u1:
955 case element::Type_t::boolean:
956 case element::Type_t::bf16:
957 case element::Type_t::f16:
958 case element::Type_t::f32:
959 case element::Type_t::f64:
960 default: NGRAPH_CHECK(false, "Indices input element type must be integer");
965 case OP_TYPEID::Parameter: break;
966 case OP_TYPEID::PriorBox:
968 const op::PriorBox* pbox = static_cast<const op::PriorBox*>(&node);
969 runtime::reference::prior_box<T>(args[0]->get_data_ptr<T>(),
970 args[1]->get_data_ptr<T>(),
971 out[0]->get_data_ptr<float>(),
976 case OP_TYPEID::ReorgYolo_v0:
978 const op::v0::ReorgYolo* reorg_yolo = static_cast<const op::v0::ReorgYolo*>(&node);
979 runtime::reference::reorg_yolo(args[0]->get_data_ptr<char>(),
980 out[0]->get_data_ptr<char>(),
981 args[0]->get_shape(),
982 reorg_yolo->get_strides().at(0),
983 args[0]->get_element_type().size());
986 case OP_TYPEID::Quantize:
988 const op::Quantize* quantize = static_cast<const op::Quantize*>(&node);
989 auto type = quantize->get_element_type();
991 if (type == element::u8)
993 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
994 args[1]->get_data_ptr<const T>(),
995 args[2]->get_data_ptr<const uint8_t>(),
996 out[0]->get_data_ptr<uint8_t>(),
997 node.get_input_shape(0),
998 node.get_input_shape(1),
999 quantize->get_axes(),
1000 quantize->get_round_mode());
1002 else if (type == element::i8)
1004 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
1005 args[1]->get_data_ptr<const T>(),
1006 args[2]->get_data_ptr<const int8_t>(),
1007 out[0]->get_data_ptr<int8_t>(),
1008 node.get_input_shape(0),
1009 node.get_input_shape(1),
1010 quantize->get_axes(),
1011 quantize->get_round_mode());
1013 else if (type == element::i32)
1015 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
1016 args[1]->get_data_ptr<const T>(),
1017 args[2]->get_data_ptr<const int32_t>(),
1018 out[0]->get_data_ptr<int32_t>(),
1019 node.get_input_shape(0),
1020 node.get_input_shape(1),
1021 quantize->get_axes(),
1022 quantize->get_round_mode());
1026 std::stringstream ss;
1027 ss << "unsupported element type " << type << " op Quantize";
1028 throw std::runtime_error(ss.str());
1034 case OP_TYPEID::QuantizedConvolution:
1036 const op::QuantizedConvolution* qc =
1037 static_cast<const op::QuantizedConvolution*>(&node);
1039 auto input_element_type = qc->get_input_element_type(0);
1040 auto filter_element_type = qc->get_input_element_type(1);
1041 auto output_element_type = qc->get_output_element_type(0);
1043 if (input_element_type == element::u8 && filter_element_type == element::i8 &&
1044 output_element_type == element::i8)
1046 reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
1047 args[0]->get_data_ptr<const uint8_t>(),
1048 args[1]->get_data_ptr<const int8_t>(),
1049 out[0]->get_data_ptr<int8_t>(),
1050 node.get_input_shape(0),
1051 node.get_input_shape(1),
1052 node.get_output_shape(0),
1053 qc->get_window_movement_strides(),
1054 qc->get_window_dilation_strides(),
1055 qc->get_padding_below(),
1056 qc->get_padding_above(),
1057 qc->get_data_dilation_strides(),
1058 args[2]->get_data_ptr<const float>(),
1059 args[3]->get_data_ptr<const uint8_t>(),
1060 args[4]->get_data_ptr<const float>(),
1061 args[5]->get_data_ptr<const int8_t>(),
1062 args[6]->get_data_ptr<const float>(),
1063 args[7]->get_data_ptr<const int8_t>());
1065 else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
1066 output_element_type == element::u8)
1068 reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
1069 args[0]->get_data_ptr<const uint8_t>(),
1070 args[1]->get_data_ptr<const uint8_t>(),
1071 out[0]->get_data_ptr<uint8_t>(),
1072 node.get_input_shape(0),
1073 node.get_input_shape(1),
1074 node.get_output_shape(0),
1075 qc->get_window_movement_strides(),
1076 qc->get_window_dilation_strides(),
1077 qc->get_padding_below(),
1078 qc->get_padding_above(),
1079 qc->get_data_dilation_strides(),
1080 args[2]->get_data_ptr<const float>(),
1081 args[3]->get_data_ptr<const uint8_t>(),
1082 args[4]->get_data_ptr<const float>(),
1083 args[5]->get_data_ptr<const uint8_t>(),
1084 args[6]->get_data_ptr<const float>(),
1085 args[7]->get_data_ptr<const uint8_t>());
1087 else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
1088 output_element_type == element::i32)
1090 reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
1091 args[0]->get_data_ptr<const uint8_t>(),
1092 args[1]->get_data_ptr<const int8_t>(),
1093 out[0]->get_data_ptr<int32_t>(),
1094 node.get_input_shape(0),
1095 node.get_input_shape(1),
1096 node.get_output_shape(0),
1097 qc->get_window_movement_strides(),
1098 qc->get_window_dilation_strides(),
1099 qc->get_padding_below(),
1100 qc->get_padding_above(),
1101 qc->get_data_dilation_strides(),
1102 args[2]->get_data_ptr<const float>(),
1103 args[3]->get_data_ptr<const uint8_t>(),
1104 args[4]->get_data_ptr<const float>(),
1105 args[5]->get_data_ptr<const int8_t>(),
1106 args[6]->get_data_ptr<const float>(),
1107 args[7]->get_data_ptr<const int32_t>());
1109 else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
1110 output_element_type == element::i32)
1112 reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
1113 args[0]->get_data_ptr<const uint8_t>(),
1114 args[1]->get_data_ptr<const uint8_t>(),
1115 out[0]->get_data_ptr<int32_t>(),
1116 node.get_input_shape(0),
1117 node.get_input_shape(1),
1118 node.get_output_shape(0),
1119 qc->get_window_movement_strides(),
1120 qc->get_window_dilation_strides(),
1121 qc->get_padding_below(),
1122 qc->get_padding_above(),
1123 qc->get_data_dilation_strides(),
1124 args[2]->get_data_ptr<const float>(),
1125 args[3]->get_data_ptr<const uint8_t>(),
1126 args[4]->get_data_ptr<const float>(),
1127 args[5]->get_data_ptr<const uint8_t>(),
1128 args[6]->get_data_ptr<const float>(),
1129 args[7]->get_data_ptr<const int32_t>());
1133 std::stringstream ss;
1134 ss << "unsupported element type";
1135 throw std::runtime_error(ss.str());
1141 case OP_TYPEID::QuantizedDot:
1143 const op::QuantizedDot* qd = static_cast<const op::QuantizedDot*>(&node);
1145 auto input0_element_type = qd->get_input_element_type(0);
1146 auto input1_element_type = qd->get_input_element_type(1);
1147 auto output_element_type = qd->get_output_element_type(0);
1149 if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
1150 output_element_type == element::i8)
1152 reference::dot<uint8_t, int8_t, int8_t, int32_t>(
1153 args[0]->get_data_ptr<const uint8_t>(),
1154 args[1]->get_data_ptr<const int8_t>(),
1155 out[0]->get_data_ptr<int8_t>(),
1156 node.get_input_shape(0),
1157 node.get_input_shape(1),
1158 node.get_output_shape(0),
1160 args[2]->get_data_ptr<const float>(),
1161 args[3]->get_data_ptr<const uint8_t>(),
1162 args[4]->get_data_ptr<const float>(),
1163 args[5]->get_data_ptr<const int8_t>(),
1164 args[6]->get_data_ptr<const float>(),
1165 args[7]->get_data_ptr<const int8_t>());
1167 else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
1168 output_element_type == element::u8)
1170 reference::dot<uint8_t, uint8_t, uint8_t, int32_t>(
1171 args[0]->get_data_ptr<const uint8_t>(),
1172 args[1]->get_data_ptr<const uint8_t>(),
1173 out[0]->get_data_ptr<uint8_t>(),
1174 node.get_input_shape(0),
1175 node.get_input_shape(1),
1176 node.get_output_shape(0),
1178 args[2]->get_data_ptr<const float>(),
1179 args[3]->get_data_ptr<const uint8_t>(),
1180 args[4]->get_data_ptr<const float>(),
1181 args[5]->get_data_ptr<const uint8_t>(),
1182 args[6]->get_data_ptr<const float>(),
1183 args[7]->get_data_ptr<const uint8_t>());
1185 else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
1186 output_element_type == element::i32)
1188 reference::dot<uint8_t, uint8_t, int32_t, int32_t>(
1189 args[0]->get_data_ptr<const uint8_t>(),
1190 args[1]->get_data_ptr<const uint8_t>(),
1191 out[0]->get_data_ptr<int32_t>(),
1192 node.get_input_shape(0),
1193 node.get_input_shape(1),
1194 node.get_output_shape(0),
1196 args[2]->get_data_ptr<const float>(),
1197 args[3]->get_data_ptr<const uint8_t>(),
1198 args[4]->get_data_ptr<const float>(),
1199 args[5]->get_data_ptr<const uint8_t>(),
1200 args[6]->get_data_ptr<const float>(),
1201 args[7]->get_data_ptr<const int32_t>());
1203 else if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
1204 output_element_type == element::i32)
1206 reference::dot<uint8_t, int8_t, int32_t, int32_t>(
1207 args[0]->get_data_ptr<const uint8_t>(),
1208 args[1]->get_data_ptr<const int8_t>(),
1209 out[0]->get_data_ptr<int32_t>(),
1210 node.get_input_shape(0),
1211 node.get_input_shape(1),
1212 node.get_output_shape(0),
1214 args[2]->get_data_ptr<const float>(),
1215 args[3]->get_data_ptr<const uint8_t>(),
1216 args[4]->get_data_ptr<const float>(),
1217 args[5]->get_data_ptr<const int8_t>(),
1218 args[6]->get_data_ptr<const float>(),
1219 args[7]->get_data_ptr<const int32_t>());
1223 std::stringstream ss;
1224 ss << "unsupported element type";
1225 throw std::runtime_error(ss.str());
1230 case OP_TYPEID::RegionYolo_v0:
1232 const op::RegionYolo* region_yolo = static_cast<const op::RegionYolo*>(&node);
1233 reference::region_yolo<T>(args[0]->get_data_ptr<const T>(),
1234 out[0]->get_data_ptr<T>(),
1235 args[0]->get_shape(),
1236 region_yolo->get_num_coords(),
1237 region_yolo->get_num_classes(),
1238 region_yolo->get_num_regions(),
1239 region_yolo->get_do_softmax(),
1240 region_yolo->get_mask());
1243 case OP_TYPEID::Relu:
1245 size_t element_count = shape_size(node.get_output_shape(0));
1247 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1250 case OP_TYPEID::ReplaceSlice:
1252 const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
1253 reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
1254 args[1]->get_data_ptr<const T>(),
1255 out[0]->get_data_ptr<T>(),
1256 node.get_input_shape(1),
1257 slice->get_lower_bounds(),
1258 slice->get_upper_bounds(),
1259 slice->get_strides(),
1260 node.get_output_shape(0));
1263 case OP_TYPEID::Reverse:
1265 const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
1266 reference::reverse(args[0]->get_data_ptr<const char>(),
1267 out[0]->get_data_ptr<char>(),
1268 node.get_input_shape(0),
1269 node.get_output_shape(0),
1270 reverse->get_reversed_axes(),
1271 args[0]->get_element_type().size());
1274 case OP_TYPEID::ReverseSequence:
1276 const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
1278 if (node.get_input_element_type(1) == element::i32)
1280 reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
1281 out[0]->get_data_ptr<T>(),
1282 node.get_input_shape(0),
1283 reverse->get_batch_axis(),
1284 reverse->get_sequence_axis(),
1285 args[1]->get_data_ptr<const int32_t>());
1289 throw ngraph_error("only int32 indices are supported");
1293 case OP_TYPEID::Round:
1295 size_t element_count = shape_size(node.get_output_shape(0));
1296 reference::round<T>(args[0]->get_data_ptr<const T>(),
1297 out[0]->get_data_ptr<T>(),
1299 op::v5::Round::RoundMode::HALF_TO_EVEN);
1302 case OP_TYPEID::Select:
1304 size_t element_count = shape_size(node.get_output_shape(0));
1305 reference::select<T>(args[0]->get_data_ptr<const char>(),
1306 args[1]->get_data_ptr<const T>(),
1307 args[2]->get_data_ptr<const T>(),
1308 out[0]->get_data_ptr<T>(),
1312 case OP_TYPEID::Sigmoid:
1314 size_t element_count = shape_size(node.get_output_shape(0));
1315 reference::sigmoid<T>(
1316 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1319 case OP_TYPEID::Sign:
1321 size_t element_count = shape_size(node.get_output_shape(0));
1323 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1326 case OP_TYPEID::Sin:
1328 size_t element_count = shape_size(node.get_output_shape(0));
1330 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1333 case OP_TYPEID::Sinh:
1335 size_t element_count = shape_size(node.get_output_shape(0));
1337 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1340 case OP_TYPEID::Sqrt:
1342 size_t element_count = shape_size(node.get_output_shape(0));
1344 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1347 case OP_TYPEID::Tan:
1349 size_t element_count = shape_size(node.get_output_shape(0));
1351 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1354 case OP_TYPEID::Tanh:
1356 size_t element_count = shape_size(node.get_output_shape(0));
1358 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1361 case OP_TYPEID::TopK:
1363 const op::TopK* topk = static_cast<const op::TopK*>(&node);
1364 if (node.get_output_element_type(0) == element::i64)
1366 reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
1367 out[0]->get_data_ptr<int64_t>(),
1368 out[1]->get_data_ptr<T>(),
1369 node.get_input_shape(0),
1370 node.get_output_shape(0),
1371 topk->get_top_k_axis(),
1373 topk->get_compute_max(),
1376 else if (node.get_output_element_type(0) == element::i32)
1378 reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
1379 out[0]->get_data_ptr<int32_t>(),
1380 out[1]->get_data_ptr<T>(),
1381 node.get_input_shape(0),
1382 node.get_output_shape(0),
1383 topk->get_top_k_axis(),
1385 topk->get_compute_max(),
1390 throw ngraph_error("Unexpected type");
1394 case OP_TYPEID::DetectionOutput_v0:
1396 const op::DetectionOutput* detOut = static_cast<const op::DetectionOutput*>(&node);
1397 reference::referenceDetectionOutput<T> refDetOut(
1398 detOut->get_attrs(), node.get_input_shape(0), node.get_input_shape(2));
1399 if (node.get_input_size() == 3)
1401 refDetOut.run(args[0]->get_data_ptr<const T>(),
1402 args[1]->get_data_ptr<const T>(),
1403 args[2]->get_data_ptr<const T>(),
1406 out[0]->get_data_ptr<T>());
1408 else if (node.get_input_size() == 5)
1410 refDetOut.run(args[0]->get_data_ptr<const T>(),
1411 args[1]->get_data_ptr<const T>(),
1412 args[2]->get_data_ptr<const T>(),
1413 args[3]->get_data_ptr<const T>(),
1414 args[4]->get_data_ptr<const T>(),
1415 out[0]->get_data_ptr<T>());
1419 throw ngraph_error("DetectionOutput layer supports only 3 or 5 inputs");
1424 case OP_TYPEID::ScatterNDUpdate_v3:
1426 const op::ScatterNDUpdate* scatterNDUpd =
1427 static_cast<const op::v3::ScatterNDUpdate*>(&node);
1428 auto idxType = scatterNDUpd->get_input_element_type(1);
1429 if (idxType == element::i32)
1431 reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
1432 args[1]->get_data_ptr<const int32_t>(),
1433 args[2]->get_data_ptr<const T>(),
1434 out[0]->get_data_ptr<T>(),
1435 node.get_input_shape(0),
1436 node.get_input_shape(1),
1437 node.get_input_shape(2));
1439 else if (idxType == element::i64)
1441 reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
1442 args[1]->get_data_ptr<const int64_t>(),
1443 args[2]->get_data_ptr<const T>(),
1444 out[0]->get_data_ptr<T>(),
1445 node.get_input_shape(0),
1446 node.get_input_shape(1),
1447 node.get_input_shape(2));
1452 "ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
1457 case OP_TYPEID::GatherTree_v1:
1459 reference::gather_tree(args[0]->get_data_ptr<const char>(),
1460 args[1]->get_data_ptr<const char>(),
1461 args[2]->get_data_ptr<const char>(),
1462 args[3]->get_data_ptr<const char>(),
1463 out[0]->get_data_ptr<char>(),
1464 node.get_input_shape(0),
1465 node.get_input_shape(1),
1466 node.get_input_shape(2),
1467 node.get_input_shape(3),
1468 args[1]->get_element_type());
1471 case OP_TYPEID::NormalizeL2:
1473 const op::NormalizeL2* norm = static_cast<const op::NormalizeL2*>(&node);
1474 reference::normalize_l2<T>(args[0]->get_data_ptr<const T>(),
1475 out[0]->get_data_ptr<T>(),
1476 node.get_input_shape(0),
1477 norm->get_reduction_axes(),
1479 norm->get_eps_mode());
1483 // Fused Ops are not supported in interpreter. They need to be decomposed before execution
1484 case OP_TYPEID::DepthToSpace:
1485 case OP_TYPEID::FakeQuantize:
1486 case OP_TYPEID::Gather:
1487 case OP_TYPEID::Gelu:
1488 case OP_TYPEID::GRN:
1489 case OP_TYPEID::GroupConvolution:
1490 case OP_TYPEID::GroupConvolutionBackpropData:
1491 case OP_TYPEID::HardSigmoid:
1492 case OP_TYPEID::Interpolate:
1493 case OP_TYPEID::MVN:
1494 case OP_TYPEID::PRelu:
1495 case OP_TYPEID::ScatterUpdate_v3:
1496 case OP_TYPEID::Selu:
1497 case OP_TYPEID::ShuffleChannels:
1498 case OP_TYPEID::SpaceToDepth:
1499 case OP_TYPEID::Split:
1500 case OP_TYPEID::SquaredDifference:
1501 case OP_TYPEID::StopGradient:
1502 case OP_TYPEID::TensorIterator:
1503 case OP_TYPEID::Tile:
1504 case OP_TYPEID::UnknownOp:
1505 throw unsupported_op("Unsupported op '" + node.description() + "'");
1506 case OP_TYPEID::Add:
1507 case OP_TYPEID::Broadcast:
1508 case OP_TYPEID::Clamp:
1509 case OP_TYPEID::Concat:
1510 case OP_TYPEID::Constant:
1511 case OP_TYPEID::Divide:
1512 case OP_TYPEID::Equal:
1513 case OP_TYPEID::Greater:
1514 case OP_TYPEID::GreaterEq:
1515 case OP_TYPEID::Less:
1516 case OP_TYPEID::LessEq:
1517 case OP_TYPEID::LessEqual_v1:
1518 case OP_TYPEID::LogicalAnd_v1:
1519 case OP_TYPEID::LogicalOr_v1:
1520 case OP_TYPEID::LogicalXor_v1:
1521 case OP_TYPEID::MatMul:
1522 case OP_TYPEID::Max:
1523 case OP_TYPEID::Maximum:
1524 case OP_TYPEID::Min:
1525 case OP_TYPEID::Minimum:
1526 case OP_TYPEID::Multiply:
1527 case OP_TYPEID::NonZero_v3:
1528 case OP_TYPEID::NotEqual:
1530 case OP_TYPEID::Power:
1531 case OP_TYPEID::Range:
1532 case OP_TYPEID::Reshape:
1533 case OP_TYPEID::Result:
1534 case OP_TYPEID::Round_v5:
1535 case OP_TYPEID::ShapeOf_v3:
1536 case OP_TYPEID::ShapeOf:
1537 case OP_TYPEID::Softmax:
1538 case OP_TYPEID::Squeeze:
1539 case OP_TYPEID::Sum:
1540 case OP_TYPEID::Subtract:
1541 case OP_TYPEID::Unsqueeze:
1542 case OP_TYPEID::Xor:
1543 case OP_TYPEID::Slice:
1544 // These ops are handled by op evaluators so nothing to do
1546 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
1547 #pragma GCC diagnostic pop
1553 NGRAPH_SUPPRESS_DEPRECATED_END