Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / ngraph / test / runtime / interpreter / int_executable.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include <initializer_list>
20 #include <iostream>
21 #include <memory>
22 #include <sstream>
23 #include <string>
24 #include <vector>
25
26 #include <ngraph/runtime/host_tensor.hpp>
27 #include "backend.hpp"
28 #include "int_backend_visibility.hpp"
29 #include "ngraph/ops.hpp"
30 #include "ngraph/runtime/aligned_buffer.hpp"
31 #include "ngraph/runtime/reference/abs.hpp"
32 #include "ngraph/runtime/reference/acos.hpp"
33 #include "ngraph/runtime/reference/asin.hpp"
34 #include "ngraph/runtime/reference/atan.hpp"
35 #include "ngraph/runtime/reference/atan2.hpp"
36 #include "ngraph/runtime/reference/avg_pool.hpp"
37 #include "ngraph/runtime/reference/batch_norm.hpp"
38 #include "ngraph/runtime/reference/broadcast.hpp"
39 #include "ngraph/runtime/reference/ceiling.hpp"
40 #include "ngraph/runtime/reference/concat.hpp"
41 #include "ngraph/runtime/reference/constant.hpp"
42 #include "ngraph/runtime/reference/convert.hpp"
43 #include "ngraph/runtime/reference/convolution.hpp"
44 #include "ngraph/runtime/reference/cos.hpp"
45 #include "ngraph/runtime/reference/cosh.hpp"
46 #include "ngraph/runtime/reference/ctc_greedy_decoder.hpp"
47 #include "ngraph/runtime/reference/ctc_loss.hpp"
48 #include "ngraph/runtime/reference/cum_sum.hpp"
49 #include "ngraph/runtime/reference/detection_output.hpp"
50 #include "ngraph/runtime/reference/dot.hpp"
51 #include "ngraph/runtime/reference/elu.hpp"
52 #include "ngraph/runtime/reference/embedding_bag_offsets_sum.hpp"
53 #include "ngraph/runtime/reference/embedding_bag_packed_sum.hpp"
54 #include "ngraph/runtime/reference/embedding_segments_sum.hpp"
55 #include "ngraph/runtime/reference/erf.hpp"
56 #include "ngraph/runtime/reference/exp.hpp"
57 #include "ngraph/runtime/reference/extract_image_patches.hpp"
58 #include "ngraph/runtime/reference/floor.hpp"
59 #include "ngraph/runtime/reference/gather.hpp"
60 #include "ngraph/runtime/reference/gather_nd.hpp"
61 #include "ngraph/runtime/reference/gather_tree.hpp"
62 #include "ngraph/runtime/reference/gru_cell.hpp"
63 #include "ngraph/runtime/reference/log.hpp"
64 #include "ngraph/runtime/reference/log_softmax.hpp"
65 #include "ngraph/runtime/reference/lrn.hpp"
66 #include "ngraph/runtime/reference/lstm_cell.hpp"
67 #include "ngraph/runtime/reference/matmul.hpp"
68 #include "ngraph/runtime/reference/max.hpp"
69 #include "ngraph/runtime/reference/max_pool.hpp"
70 #include "ngraph/runtime/reference/min.hpp"
71 #include "ngraph/runtime/reference/negate.hpp"
72 #include "ngraph/runtime/reference/normalize_l2.hpp"
73 #include "ngraph/runtime/reference/not.hpp"
74 #include "ngraph/runtime/reference/one_hot.hpp"
75 #include "ngraph/runtime/reference/pad.hpp"
76 #include "ngraph/runtime/reference/prior_box.hpp"
77 #include "ngraph/runtime/reference/product.hpp"
78 #include "ngraph/runtime/reference/quantize.hpp"
79 #include "ngraph/runtime/reference/region_yolo.hpp"
80 #include "ngraph/runtime/reference/relu.hpp"
81 #include "ngraph/runtime/reference/reorg_yolo.hpp"
82 #include "ngraph/runtime/reference/replace_slice.hpp"
83 #include "ngraph/runtime/reference/reshape.hpp"
84 #include "ngraph/runtime/reference/result.hpp"
85 #include "ngraph/runtime/reference/reverse.hpp"
86 #include "ngraph/runtime/reference/reverse_sequence.hpp"
87 #include "ngraph/runtime/reference/rnn_cell.hpp"
88 #include "ngraph/runtime/reference/round.hpp"
89 #include "ngraph/runtime/reference/scatter_nd_update.hpp"
90 #include "ngraph/runtime/reference/select.hpp"
91 #include "ngraph/runtime/reference/sequences.hpp"
92 #include "ngraph/runtime/reference/sigmoid.hpp"
93 #include "ngraph/runtime/reference/sign.hpp"
94 #include "ngraph/runtime/reference/sin.hpp"
95 #include "ngraph/runtime/reference/sinh.hpp"
96 #include "ngraph/runtime/reference/softmax.hpp"
97 #include "ngraph/runtime/reference/sqrt.hpp"
98 #include "ngraph/runtime/reference/sum.hpp"
99 #include "ngraph/runtime/reference/tan.hpp"
100 #include "ngraph/runtime/reference/tanh.hpp"
101 #include "ngraph/runtime/reference/topk.hpp"
102 #include "ngraph/runtime/tensor.hpp"
103 #include "op/avg_pool.hpp"
104 #include "op/convolution.hpp"
105 #include "op/group_conv.hpp"
106
107 NGRAPH_SUPPRESS_DEPRECATED_START
108
109 namespace ngraph
110 {
111     namespace runtime
112     {
113         namespace interpreter
114         {
115             class INTBackend;
116             class INTExecutable;
117
118             // This expands the op list in op_tbl.hpp into a list of enumerations that look like
119             // this:
120             // Abs,
121             // Acos,
122             // ...
123             enum class OP_TYPEID
124             {
125 #define NGRAPH_OP(NAME, NAMESPACE) ID_SUFFIX(NAME),
126 #include "opset_int_tbl.hpp"
127 #undef NGRAPH_OP
128                 UnknownOp
129             };
130         } // namespace interpreter
131     }     // namespace runtime
132 } // namespace ngraph
133
134 class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : public Executable
135 {
136     friend class INTBackend;
137
138 public:
139     INTExecutable(const std::shared_ptr<Function>& function,
140                   bool enable_performance_collection = false);
141
142     bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
143               const std::vector<std::shared_ptr<Tensor>>& inputs) override;
144
145     void set_nan_check(bool enable);
146
147     std::vector<PerformanceCounter> get_performance_data() const override;
148
149     std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index) override;
150
151     std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index) override;
152
153     std::vector<std::shared_ptr<runtime::Tensor>>
154         create_input_tensor(size_t input_index, size_t pipeline_depth) override;
155
156     std::vector<std::shared_ptr<runtime::Tensor>>
157         create_output_tensor(size_t output_index, size_t pipeline_depth) override;
158
159 protected:
160     std::shared_ptr<ngraph::op::Parameter> get_parameter(size_t index) const;
161     std::shared_ptr<ngraph::op::Result> get_result(size_t index) const;
162     int get_alignment() const { return 64; }
163     bool m_is_compiled = false;
164     bool m_nan_check_enabled = false;
165     bool m_performance_counters_enabled = false;
166     std::shared_ptr<Function> m_function;
167     std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
168     std::vector<std::shared_ptr<Node>> m_nodes;
169     std::set<std::string> m_unsupported_op_name_list;
170
171     static OP_TYPEID get_typeid(const Node& node);
172
173     static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
174                                   const Node* op = nullptr);
175
176     virtual void generate_calls(const element::Type& type,
177                                 const Node& op,
178                                 const std::vector<std::shared_ptr<HostTensor>>& outputs,
179                                 const std::vector<std::shared_ptr<HostTensor>>& inputs);
180
181     template <typename T>
182     void op_engine(const Node& node,
183                    const std::vector<std::shared_ptr<HostTensor>>& out,
184                    const std::vector<std::shared_ptr<HostTensor>>& args)
185     {
186 // We want to check that every OP_TYPEID enumeration is included in the list.
187 // These GCC flags enable compile-time checking so that if an enumeration
188 // is not in the list an error is generated.
189 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
190 #pragma GCC diagnostic push
191 #pragma GCC diagnostic error "-Wswitch"
192 #pragma GCC diagnostic error "-Wswitch-enum"
193 #endif
194         switch (get_typeid(node))
195         {
196         case OP_TYPEID::Abs:
197         {
198             size_t element_count = shape_size(node.get_output_shape(0));
199             reference::abs<T>(
200                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
201             break;
202         }
203         case OP_TYPEID::Acos:
204         {
205             size_t element_count = shape_size(node.get_output_shape(0));
206             reference::acos<T>(
207                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
208             break;
209         }
210         case OP_TYPEID::Asin:
211         {
212             size_t element_count = shape_size(node.get_output_shape(0));
213             reference::asin<T>(
214                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
215             break;
216         }
217         case OP_TYPEID::Atan:
218         {
219             size_t element_count = shape_size(node.get_output_shape(0));
220             reference::atan<T>(
221                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
222             break;
223         }
224         case OP_TYPEID::Elu:
225         {
226             const op::Elu* elu_node = static_cast<const op::Elu*>(&node);
227
228             size_t element_count = shape_size(node.get_output_shape(0));
229             reference::elu<T>(args[0]->get_data_ptr<const T>(),
230                               out[0]->get_data_ptr<T>(),
231                               element_count,
232                               elu_node->get_alpha());
233             break;
234         }
235         case OP_TYPEID::AvgPool:
236         {
237             const op::v0::AvgPool* avg_pool = static_cast<const op::v0::AvgPool*>(&node);
238
239             reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
240                                    out[0]->get_data_ptr<T>(),
241                                    node.get_input_shape(0),
242                                    node.get_output_shape(0),
243                                    avg_pool->get_window_shape(),
244                                    avg_pool->get_window_movement_strides(),
245                                    avg_pool->get_padding_below(),
246                                    avg_pool->get_padding_above(),
247                                    avg_pool->get_include_padding_in_avg_computation());
248             break;
249         }
250         case OP_TYPEID::BatchNormInference:
251         {
252             const ngraph::op::v0::BatchNormInference* bn =
253                 static_cast<const ngraph::op::v0::BatchNormInference*>(&node);
254             reference::batch_norm_inference<T>(bn->get_eps_value(),
255                                                args[0]->get_data_ptr<const T>(),
256                                                args[1]->get_data_ptr<const T>(),
257                                                args[2]->get_data_ptr<const T>(),
258                                                args[3]->get_data_ptr<const T>(),
259                                                args[4]->get_data_ptr<const T>(),
260                                                out[0]->get_data_ptr<T>(),
261                                                node.get_input_shape(2));
262             break;
263         }
264         case OP_TYPEID::BatchNormInference_v5:
265         {
266             const ngraph::op::v5::BatchNormInference* bn =
267                 static_cast<const ngraph::op::v5::BatchNormInference*>(&node);
268             reference::batch_norm_inference<T>(bn->get_eps_value(),
269                                                args[1]->get_data_ptr<const T>(),
270                                                args[2]->get_data_ptr<const T>(),
271                                                args[0]->get_data_ptr<const T>(),
272                                                args[3]->get_data_ptr<const T>(),
273                                                args[4]->get_data_ptr<const T>(),
274                                                out[0]->get_data_ptr<T>(),
275                                                node.get_input_shape(0));
276             break;
277         }
278         case OP_TYPEID::Ceiling:
279         {
280             size_t element_count = shape_size(node.get_output_shape(0));
281             reference::ceiling<T>(
282                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
283             break;
284         }
285         case OP_TYPEID::Convert:
286         {
287             // const op::Convert* c = static_cast<const op::Convert*>(&node);
288             element::Type type = node.get_element_type();
289             std::stringstream ss;
290             size_t element_count = shape_size(node.get_output_shape(0));
291             switch (type)
292             {
293             case element::Type_t::boolean:
294                 reference::convert_to_bool<T>(
295                     args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
296                 break;
297             case element::Type_t::f32:
298                 reference::convert<T>(
299                     args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
300                 break;
301             case element::Type_t::f64:
302                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
303                                       out[0]->get_data_ptr<double>(),
304                                       element_count);
305                 break;
306             case element::Type_t::i8:
307                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
308                                       out[0]->get_data_ptr<int8_t>(),
309                                       element_count);
310                 break;
311             case element::Type_t::i16:
312                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
313                                       out[0]->get_data_ptr<int16_t>(),
314                                       element_count);
315                 break;
316             case element::Type_t::i32:
317                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
318                                       out[0]->get_data_ptr<int32_t>(),
319                                       element_count);
320                 break;
321             case element::Type_t::i64:
322                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
323                                       out[0]->get_data_ptr<int64_t>(),
324                                       element_count);
325                 break;
326             case element::Type_t::u8:
327                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
328                                       out[0]->get_data_ptr<uint8_t>(),
329                                       element_count);
330                 break;
331             case element::Type_t::u16:
332                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
333                                       out[0]->get_data_ptr<uint16_t>(),
334                                       element_count);
335                 break;
336             case element::Type_t::u32:
337                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
338                                       out[0]->get_data_ptr<uint32_t>(),
339                                       element_count);
340                 break;
341             case element::Type_t::u64:
342                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
343                                       out[0]->get_data_ptr<uint64_t>(),
344                                       element_count);
345                 break;
346             case element::Type_t::undefined:
347             case element::Type_t::dynamic:
348             case element::Type_t::u1:
349             case element::Type_t::bf16:
350             case element::Type_t::f16:
351                 ss << "unsupported element type " << type << " op Convert";
352                 throw std::runtime_error(ss.str());
353             }
354             break;
355         }
356         case OP_TYPEID::Convolution:
357         {
358             const op::v0::Convolution* c = static_cast<const op::v0::Convolution*>(&node);
359             reference::convolution<T>(args[0]->get_data_ptr<const T>(),
360                                       args[1]->get_data_ptr<const T>(),
361                                       out[0]->get_data_ptr<T>(),
362                                       node.get_input_shape(0),
363                                       node.get_input_shape(1),
364                                       node.get_output_shape(0),
365                                       c->get_window_movement_strides(),
366                                       c->get_window_dilation_strides(),
367                                       c->get_padding_below(),
368                                       c->get_padding_above(),
369                                       c->get_data_dilation_strides());
370
371             break;
372         }
373         case OP_TYPEID::ConvolutionBackpropData:
374         {
375             // Note that args[1] and args[0] are switched here from the usual order.
376             const op::v0::ConvolutionBackpropData* c =
377                 static_cast<const op::v0::ConvolutionBackpropData*>(&node);
378             reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
379                                                   args[0]->get_data_ptr<const T>(),
380                                                   out[0]->get_data_ptr<T>(),
381                                                   c->get_input_shape(1),
382                                                   c->get_input_shape(0),
383                                                   c->get_data_batch_shape(),
384                                                   c->get_data_dilation_strides_forward(),
385                                                   c->get_window_dilation_strides_forward(),
386                                                   c->compute_backward_delta_out_pad_below(),
387                                                   c->compute_backward_delta_out_pad_above(),
388                                                   c->get_window_movement_strides_forward());
389             break;
390         }
391         case OP_TYPEID::Cos:
392         {
393             size_t element_count = shape_size(node.get_output_shape(0));
394             reference::cos<T>(
395                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
396             break;
397         }
398         case OP_TYPEID::Cosh:
399         {
400             size_t element_count = shape_size(node.get_output_shape(0));
401             reference::cosh<T>(
402                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
403             break;
404         }
405         case OP_TYPEID::CTCGreedyDecoder_v0:
406         {
407             const auto ctc_greedy_dec = static_cast<const op::v0::CTCGreedyDecoder*>(&node);
408             reference::ctc_greedy_decoder<T>(args[0]->get_data_ptr<const T>(),
409                                              args[1]->get_data_ptr<const T>(),
410                                              out[0]->get_data_ptr<T>(),
411                                              args[0]->get_shape(),
412                                              args[1]->get_shape(),
413                                              out[0]->get_shape(),
414                                              ctc_greedy_dec->get_ctc_merge_repeated());
415             break;
416         }
417         case OP_TYPEID::CTCLoss_v4:
418         {
419             const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
420             auto t_int = node.get_input_element_type(1);
421             if (t_int == element::i32)
422             {
423                 reference::CTCLoss<T, int32_t>(
424                     args[0]->get_data_ptr<const T>(),
425                     ctc_loss->get_input_shape(0),
426                     args[1]->get_data_ptr<const int32_t>(),
427                     args[2]->get_data_ptr<const int32_t>(),
428                     args[3]->get_data_ptr<const int32_t>(),
429                     args.size() > 4 ? args[4]->get_data_ptr<const int32_t>() : nullptr,
430                     ctc_loss->get_preprocess_collapse_repeated(),
431                     ctc_loss->get_ctc_merge_repeated(),
432                     ctc_loss->get_unique(),
433                     out[0]->get_data_ptr<T>());
434             }
435             else if (t_int == element::i64)
436             {
437                 reference::CTCLoss<T, int64_t>(
438                     args[0]->get_data_ptr<const T>(),
439                     ctc_loss->get_input_shape(0),
440                     args[1]->get_data_ptr<const int64_t>(),
441                     args[2]->get_data_ptr<const int64_t>(),
442                     args[3]->get_data_ptr<const int64_t>(),
443                     args.size() > 4 ? args[4]->get_data_ptr<const int64_t>() : nullptr,
444                     ctc_loss->get_preprocess_collapse_repeated(),
445                     ctc_loss->get_ctc_merge_repeated(),
446                     ctc_loss->get_unique(),
447                     out[0]->get_data_ptr<T>());
448             }
449             break;
450         }
451         case OP_TYPEID::CumSum:
452         {
453             const op::CumSum* cumsum = static_cast<const op::CumSum*>(&node);
454             auto axis_et = node.get_input_element_type(1);
455             if (axis_et == element::i32)
456             {
457                 reference::cumsum<T, int32_t>(args[0]->get_data_ptr<const T>(),
458                                               args[1]->get_data_ptr<const int32_t>(),
459                                               out[0]->get_data_ptr<T>(),
460                                               node.get_input_shape(0),
461                                               cumsum->is_exclusive(),
462                                               cumsum->is_reverse());
463             }
464             else if (axis_et == element::i64)
465             {
466                 reference::cumsum<T, int64_t>(args[0]->get_data_ptr<const T>(),
467                                               args[1]->get_data_ptr<const int64_t>(),
468                                               out[0]->get_data_ptr<T>(),
469                                               node.get_input_shape(0),
470                                               cumsum->is_exclusive(),
471                                               cumsum->is_reverse());
472             }
473             break;
474         }
475         case OP_TYPEID::Dot:
476         {
477             const op::Dot* dot = static_cast<const op::Dot*>(&node);
478
479             reference::dot(args[0]->get_data_ptr<const T>(),
480                            args[1]->get_data_ptr<const T>(),
481                            out[0]->get_data_ptr<T>(),
482                            node.get_input_shape(0),
483                            node.get_input_shape(1),
484                            node.get_output_shape(0),
485                            dot->get_reduction_axes_count());
486             break;
487         }
488         case OP_TYPEID::EmbeddingBagOffsetsSum_v3:
489         {
490             const op::EmbeddingBagOffsetsSum* embed =
491                 static_cast<const op::EmbeddingBagOffsetsSum*>(&node);
492             auto indicesType = embed->input(1).get_element_type();
493             size_t indices_num = shape_size(embed->get_input_shape(1));
494
495             if (indicesType == element::u64 || indicesType == element::i64)
496             {
497                 reference::embeddingBagOffsetsSum<T, size_t>(
498                     args[0]->get_data_ptr<const T>(),
499                     args[1]->get_data_ptr<const size_t>(),
500                     args[2]->get_data_ptr<const size_t>(),
501                     args.size() > 3 ? args[3]->get_data_ptr<const size_t>() : nullptr,
502                     args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
503                     out[0]->get_data_ptr<T>(),
504                     indices_num,
505                     embed->get_shape());
506             }
507             else if (indicesType == element::u32 || indicesType == element::i32)
508             {
509                 reference::embeddingBagOffsetsSum<T, unsigned>(
510                     args[0]->get_data_ptr<const T>(),
511                     args[1]->get_data_ptr<const unsigned>(),
512                     args[2]->get_data_ptr<const unsigned>(),
513                     args.size() > 3 ? args[3]->get_data_ptr<const unsigned>() : nullptr,
514                     args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
515                     out[0]->get_data_ptr<T>(),
516                     indices_num,
517                     embed->get_shape());
518             }
519             else
520             {
521                 throw ngraph_error(std::string("Unsupported index type ") +
522                                    indicesType.c_type_string() +
523                                    std::string(" in EmbeddingBagOffsetsSum"));
524             }
525             break;
526         }
527         case OP_TYPEID::EmbeddingBagPackedSum_v3:
528         {
529             const op::EmbeddingBagPackedSum* embed =
530                 static_cast<const op::EmbeddingBagPackedSum*>(&node);
531             auto indicesType = embed->input(1).get_element_type();
532
533             if (indicesType == element::u64 || indicesType == element::i64)
534             {
535                 reference::embeddingBagPackedSum<T, size_t>(
536                     args[0]->get_data_ptr<const T>(),
537                     args[1]->get_data_ptr<const size_t>(),
538                     args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
539                     out[0]->get_data_ptr<T>(),
540                     embed->get_input_shape(1),
541                     embed->get_shape());
542             }
543             else if (indicesType == element::u32 || indicesType == element::i32)
544             {
545                 reference::embeddingBagPackedSum<T, unsigned>(
546                     args[0]->get_data_ptr<const T>(),
547                     args[1]->get_data_ptr<const unsigned>(),
548                     args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
549                     out[0]->get_data_ptr<T>(),
550                     embed->get_input_shape(1),
551                     embed->get_shape());
552             }
553             else
554             {
555                 throw ngraph_error(std::string("Unsupported index type ") +
556                                    indicesType.c_type_string() +
557                                    std::string(" in EmbeddingBagPackedSum"));
558             }
559             break;
560         }
561         case OP_TYPEID::EmbeddingSegmentsSum_v3:
562         {
563             const op::EmbeddingSegmentsSum* embed =
564                 static_cast<const op::EmbeddingSegmentsSum*>(&node);
565             auto indicesType = embed->input(1).get_element_type();
566             size_t indices_num = shape_size(embed->get_input_shape(1));
567
568             if (indicesType == element::u64 || indicesType == element::i64)
569             {
570                 reference::embeddingSegmentsSum<T, size_t>(
571                     args[0]->get_data_ptr<const T>(),
572                     args[1]->get_data_ptr<const size_t>(),
573                     args[2]->get_data_ptr<const size_t>(),
574                     args.size() > 4 ? args[4]->get_data_ptr<const size_t>() : nullptr,
575                     args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
576                     out[0]->get_data_ptr<T>(),
577                     embed->get_input_shape(0),
578                     embed->get_input_shape(1),
579                     embed->get_shape());
580             }
581             else if (indicesType == element::u32 || indicesType == element::i32)
582             {
583                 reference::embeddingSegmentsSum<T, unsigned>(
584                     args[0]->get_data_ptr<const T>(),
585                     args[1]->get_data_ptr<const unsigned>(),
586                     args[2]->get_data_ptr<const unsigned>(),
587                     args.size() > 4 ? args[4]->get_data_ptr<const unsigned>() : nullptr,
588                     args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
589                     out[0]->get_data_ptr<T>(),
590                     embed->get_input_shape(0),
591                     embed->get_input_shape(1),
592                     embed->get_shape());
593             }
594             else
595             {
596                 throw ngraph_error(std::string("Unsupported index type ") +
597                                    indicesType.c_type_string() +
598                                    std::string(" in EmbeddingSegmentsSum"));
599             }
600             break;
601         }
602         case OP_TYPEID::Erf:
603         {
604             size_t element_count = shape_size(node.get_output_shape(0));
605             reference::erf<T>(
606                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
607             break;
608         }
609         case OP_TYPEID::ExtractImagePatches_v3:
610         {
611             const op::ExtractImagePatches* extImgPatches =
612                 static_cast<const op::ExtractImagePatches*>(&node);
613             reference::extractImagePatches<T, size_t>(extImgPatches,
614                                                       args[0]->get_data_ptr<const T>(),
615                                                       out[0]->get_data_ptr<T>(),
616                                                       extImgPatches->get_input_shape(0),
617                                                       extImgPatches->get_shape());
618             break;
619         }
620         case OP_TYPEID::Exp:
621         {
622             size_t element_count = shape_size(node.get_output_shape(0));
623             reference::exp<T>(
624                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
625             break;
626         }
627 #ifdef INTERPRETER_USE_HYBRID
628         case OP_TYPEID::FunctionCall:
629         {
630             auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
631             auto backend = f->get_backend();
632             auto executable = f->get_executable();
633
634             std::vector<std::shared_ptr<Tensor>> outputs;
635             std::vector<std::shared_ptr<Tensor>> inputs;
636             for (const std::shared_ptr<HostTensor>& t : out)
637             {
638                 auto backend_tensor = backend->create_tensor(
639                     t->get_element_type(), t->get_shape(), t->get_data_ptr());
640                 outputs.push_back(backend_tensor);
641             }
642             for (const std::shared_ptr<HostTensor>& t : args)
643             {
644                 auto backend_tensor = backend->create_tensor(
645                     t->get_element_type(), t->get_shape(), t->get_data_ptr());
646                 inputs.push_back(backend_tensor);
647             }
648             executable->call(outputs, inputs);
649             break;
650         }
651 #endif
652         case OP_TYPEID::Floor:
653         {
654             size_t element_count = shape_size(node.get_output_shape(0));
655             reference::floor<T>(
656                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
657             break;
658         }
659         case OP_TYPEID::GatherND:
660         {
661             if (node.get_input_element_type(1) == element::i64)
662             {
663                 reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
664                                                  args[1]->get_data_ptr<int64_t>(),
665                                                  out[0]->get_data_ptr<T>(),
666                                                  node.get_input_shape(0),
667                                                  node.get_input_shape(1),
668                                                  node.get_output_shape(0));
669             }
670             else if (node.get_input_element_type(1) == element::i32)
671             {
672                 reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
673                                                  args[1]->get_data_ptr<int32_t>(),
674                                                  out[0]->get_data_ptr<T>(),
675                                                  node.get_input_shape(0),
676                                                  node.get_input_shape(1),
677                                                  node.get_output_shape(0));
678             }
679             else
680             {
681                 throw ngraph_error("Unexpected type");
682             }
683             break;
684         }
685         case OP_TYPEID::GatherND_v5:
686         {
687             const op::v5::GatherND* gatherNDNode = static_cast<const op::v5::GatherND*>(&node);
688             if (node.get_input_element_type(1) == element::i64)
689             {
690                 reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
691                                                  args[1]->get_data_ptr<int64_t>(),
692                                                  out[0]->get_data_ptr<T>(),
693                                                  node.get_input_shape(0),
694                                                  node.get_input_shape(1),
695                                                  node.get_output_shape(0),
696                                                  gatherNDNode->get_batch_dims());
697             }
698             else if (node.get_input_element_type(1) == element::i32)
699             {
700                 reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
701                                                  args[1]->get_data_ptr<int32_t>(),
702                                                  out[0]->get_data_ptr<T>(),
703                                                  node.get_input_shape(0),
704                                                  node.get_input_shape(1),
705                                                  node.get_output_shape(0),
706                                                  gatherNDNode->get_batch_dims());
707             }
708             else
709             {
710                 throw ngraph_error("Unexpected type");
711             }
712             break;
713         }
714         case OP_TYPEID::GRUCell_v3:
715         {
716             const op::v3::GRUCell* gru_cell = static_cast<const op::v3::GRUCell*>(&node);
717             runtime::reference::gru_cell(args[0]->get_data_ptr<T>(),
718                                          args[0]->get_shape(),
719                                          args[1]->get_data_ptr<T>(),
720                                          args[1]->get_shape(),
721                                          args[2]->get_data_ptr<T>(),
722                                          args[2]->get_shape(),
723                                          args[3]->get_data_ptr<T>(),
724                                          args[3]->get_shape(),
725                                          args[4]->get_data_ptr<T>(),
726                                          args[4]->get_shape(),
727                                          out[0]->get_data_ptr<T>(),
728                                          gru_cell->get_activations()[0],
729                                          gru_cell->get_activations()[1],
730                                          gru_cell->get_clip(),
731                                          gru_cell->get_linear_before_reset());
732             break;
733         }
734         case OP_TYPEID::LSTMCell_v4:
735         {
736             const op::v4::LSTMCell* lstm_cell = static_cast<const op::v4::LSTMCell*>(&node);
737             runtime::reference::lstm_cell(args[0]->get_data_ptr<T>(),
738                                           args[0]->get_shape(),
739                                           args[1]->get_data_ptr<T>(),
740                                           args[1]->get_shape(),
741                                           args[2]->get_data_ptr<T>(),
742                                           args[2]->get_shape(),
743                                           args[3]->get_data_ptr<T>(),
744                                           args[3]->get_shape(),
745                                           args[4]->get_data_ptr<T>(),
746                                           args[4]->get_shape(),
747                                           args[5]->get_data_ptr<T>(),
748                                           args[5]->get_shape(),
749                                           out[0]->get_data_ptr<T>(),
750                                           out[1]->get_data_ptr<T>(),
751                                           lstm_cell->get_activations()[0],
752                                           lstm_cell->get_activations()[1],
753                                           lstm_cell->get_activations()[2],
754                                           lstm_cell->get_clip());
755             break;
756         }
757         case OP_TYPEID::RNNCell_v0:
758         {
759             const op::v0::RNNCell* rnn_cell = static_cast<const op::v0::RNNCell*>(&node);
760             runtime::reference::rnn_cell(args[0]->get_data_ptr<T>(),
761                                          args[0]->get_shape(),
762                                          args[1]->get_data_ptr<T>(),
763                                          args[1]->get_shape(),
764                                          args[2]->get_data_ptr<T>(),
765                                          args[2]->get_shape(),
766                                          args[3]->get_data_ptr<T>(),
767                                          args[3]->get_shape(),
768                                          args[4]->get_data_ptr<T>(),
769                                          args[4]->get_shape(),
770                                          out[0]->get_data_ptr<T>(),
771                                          rnn_cell->get_activations()[0],
772                                          rnn_cell->get_clip());
773             break;
774         }
775         case OP_TYPEID::LSTMSequence:
776         case OP_TYPEID::LSTMSequence_v5:
777         {
778             auto lstm_seq = static_cast<const op::v5::LSTMSequence*>(&node);
779             runtime::reference::lstm_sequence<T>(args[0]->get_data_ptr<char>(),
780                                                  args[0]->get_shape(),
781                                                  args[1]->get_data_ptr<char>(),
782                                                  args[1]->get_shape(),
783                                                  args[2]->get_data_ptr<char>(),
784                                                  args[2]->get_shape(),
785                                                  args[3]->get_data_ptr<char>(),
786                                                  args[3]->get_shape(),
787                                                  args[4]->get_data_ptr<char>(),
788                                                  args[4]->get_shape(),
789                                                  args[5]->get_data_ptr<char>(),
790                                                  args[5]->get_shape(),
791                                                  args[6]->get_data_ptr<char>(),
792                                                  args[6]->get_shape(),
793                                                  out[0]->get_data_ptr<char>(),
794                                                  out[1]->get_data_ptr<char>(),
795                                                  out[2]->get_data_ptr<char>(),
796                                                  lstm_seq->get_activations()[0],
797                                                  lstm_seq->get_activations()[1],
798                                                  lstm_seq->get_activations()[2],
799                                                  lstm_seq->get_clip(),
800                                                  lstm_seq->get_direction());
801             break;
802         }
803         case OP_TYPEID::GRUSequence_v5:
804         {
805             auto gru_seq = static_cast<const op::v5::GRUSequence*>(&node);
806             runtime::reference::gru_sequence<T>(args[0]->get_data_ptr<char>(),
807                                                 args[0]->get_shape(),
808                                                 args[1]->get_data_ptr<char>(),
809                                                 args[1]->get_shape(),
810                                                 args[2]->get_data_ptr<char>(),
811                                                 args[2]->get_shape(),
812                                                 args[3]->get_data_ptr<char>(),
813                                                 args[3]->get_shape(),
814                                                 args[4]->get_data_ptr<char>(),
815                                                 args[4]->get_shape(),
816                                                 args[5]->get_data_ptr<char>(),
817                                                 args[5]->get_shape(),
818                                                 out[0]->get_data_ptr<char>(),
819                                                 out[1]->get_data_ptr<char>(),
820                                                 gru_seq->get_activations()[0],
821                                                 gru_seq->get_activations()[1],
822                                                 gru_seq->get_clip(),
823                                                 gru_seq->get_direction(),
824                                                 gru_seq->get_linear_before_reset());
825             break;
826         }
827         case OP_TYPEID::RNNSequence_v5:
828         {
829             auto rnn_seq = static_cast<const op::v5::RNNSequence*>(&node);
830             runtime::reference::rnn_sequence<T>(args[0]->get_data_ptr<char>(),
831                                                 args[0]->get_shape(),
832                                                 args[1]->get_data_ptr<char>(),
833                                                 args[1]->get_shape(),
834                                                 args[2]->get_data_ptr<char>(),
835                                                 args[2]->get_shape(),
836                                                 args[3]->get_data_ptr<char>(),
837                                                 args[3]->get_shape(),
838                                                 args[4]->get_data_ptr<char>(),
839                                                 args[4]->get_shape(),
840                                                 args[5]->get_data_ptr<char>(),
841                                                 args[5]->get_shape(),
842                                                 out[0]->get_data_ptr<char>(),
843                                                 out[1]->get_data_ptr<char>(),
844                                                 rnn_seq->get_activations()[0],
845                                                 rnn_seq->get_clip(),
846                                                 rnn_seq->get_direction());
847             break;
848         }
849         case OP_TYPEID::Log:
850         {
851             size_t element_count = shape_size(node.get_output_shape(0));
852             reference::log<T>(
853                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
854             break;
855         }
856         case OP_TYPEID::LogSoftmax_v5:
857         {
858             const op::v5::LogSoftmax* log_softmax = static_cast<const op::v5::LogSoftmax*>(&node);
859             int64_t i_axis = log_softmax->get_axis();
860             if (i_axis < 0)
861             {
862                 i_axis += args[0]->get_partial_shape().rank().get_length();
863             }
864             reference::log_softmax<T>(args[0]->get_data_ptr<const T>(),
865                                       out[0]->get_data_ptr<T>(),
866                                       node.get_output_shape(0),
867                                       AxisSet{(size_t)i_axis});
868             break;
869         }
870         case OP_TYPEID::LRN:
871         {
872             const op::LRN* lrn = static_cast<const op::LRN*>(&node);
873             reference::lrn<T>(args[0]->get_data_ptr<const T>(),
874                               lrn->get_reduction_axes(),
875                               out[0]->get_data_ptr<T>(),
876                               node.get_input_shape(0),
877                               lrn->get_alpha(),
878                               lrn->get_beta(),
879                               lrn->get_bias(),
880                               lrn->get_nsize());
881             break;
882         }
883         case OP_TYPEID::Negative:
884         {
885             size_t element_count = shape_size(node.get_output_shape(0));
886             reference::negate<T>(
887                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
888             break;
889         }
890         case OP_TYPEID::LogicalNot_v1:
891         case OP_TYPEID::Not:
892         {
893             size_t element_count = shape_size(node.get_output_shape(0));
894             reference::logical_not(
895                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
896             break;
897         }
898         case OP_TYPEID::OneHot:
899         {
900             const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
901             reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
902                                   out[0]->get_data_ptr<T>(),
903                                   node.get_input_shape(0),
904                                   node.get_output_shape(0),
905                                   oh->get_one_hot_axis());
906             break;
907         }
908         case OP_TYPEID::Parameter: break;
909         case OP_TYPEID::PriorBox:
910         {
911             const op::PriorBox* pbox = static_cast<const op::PriorBox*>(&node);
912             runtime::reference::prior_box<T>(args[0]->get_data_ptr<T>(),
913                                              args[1]->get_data_ptr<T>(),
914                                              out[0]->get_data_ptr<float>(),
915                                              out[0]->get_shape(),
916                                              pbox->get_attrs());
917             break;
918         }
919         case OP_TYPEID::ReorgYolo_v0:
920         {
921             const op::v0::ReorgYolo* reorg_yolo = static_cast<const op::v0::ReorgYolo*>(&node);
922             runtime::reference::reorg_yolo(args[0]->get_data_ptr<char>(),
923                                            out[0]->get_data_ptr<char>(),
924                                            args[0]->get_shape(),
925                                            reorg_yolo->get_strides().at(0),
926                                            args[0]->get_element_type().size());
927             break;
928         }
929         case OP_TYPEID::Quantize:
930         {
931             const op::Quantize* quantize = static_cast<const op::Quantize*>(&node);
932             auto type = quantize->get_element_type();
933
934             if (type == element::u8)
935             {
936                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
937                                        args[1]->get_data_ptr<const T>(),
938                                        args[2]->get_data_ptr<const uint8_t>(),
939                                        out[0]->get_data_ptr<uint8_t>(),
940                                        node.get_input_shape(0),
941                                        node.get_input_shape(1),
942                                        quantize->get_axes(),
943                                        quantize->get_round_mode());
944             }
945             else if (type == element::i8)
946             {
947                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
948                                        args[1]->get_data_ptr<const T>(),
949                                        args[2]->get_data_ptr<const int8_t>(),
950                                        out[0]->get_data_ptr<int8_t>(),
951                                        node.get_input_shape(0),
952                                        node.get_input_shape(1),
953                                        quantize->get_axes(),
954                                        quantize->get_round_mode());
955             }
956             else if (type == element::i32)
957             {
958                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
959                                        args[1]->get_data_ptr<const T>(),
960                                        args[2]->get_data_ptr<const int32_t>(),
961                                        out[0]->get_data_ptr<int32_t>(),
962                                        node.get_input_shape(0),
963                                        node.get_input_shape(1),
964                                        quantize->get_axes(),
965                                        quantize->get_round_mode());
966             }
967             else
968             {
969                 std::stringstream ss;
970                 ss << "unsupported element type " << type << " op Quantize";
971                 throw std::runtime_error(ss.str());
972             }
973
974             break;
975         }
976
977         case OP_TYPEID::QuantizedConvolution:
978         {
979             const op::QuantizedConvolution* qc =
980                 static_cast<const op::QuantizedConvolution*>(&node);
981
982             auto input_element_type = qc->get_input_element_type(0);
983             auto filter_element_type = qc->get_input_element_type(1);
984             auto output_element_type = qc->get_output_element_type(0);
985
986             if (input_element_type == element::u8 && filter_element_type == element::i8 &&
987                 output_element_type == element::i8)
988             {
989                 reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
990                     args[0]->get_data_ptr<const uint8_t>(),
991                     args[1]->get_data_ptr<const int8_t>(),
992                     out[0]->get_data_ptr<int8_t>(),
993                     node.get_input_shape(0),
994                     node.get_input_shape(1),
995                     node.get_output_shape(0),
996                     qc->get_window_movement_strides(),
997                     qc->get_window_dilation_strides(),
998                     qc->get_padding_below(),
999                     qc->get_padding_above(),
1000                     qc->get_data_dilation_strides(),
1001                     args[2]->get_data_ptr<const float>(),
1002                     args[3]->get_data_ptr<const uint8_t>(),
1003                     args[4]->get_data_ptr<const float>(),
1004                     args[5]->get_data_ptr<const int8_t>(),
1005                     args[6]->get_data_ptr<const float>(),
1006                     args[7]->get_data_ptr<const int8_t>());
1007             }
1008             else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
1009                      output_element_type == element::u8)
1010             {
1011                 reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
1012                     args[0]->get_data_ptr<const uint8_t>(),
1013                     args[1]->get_data_ptr<const uint8_t>(),
1014                     out[0]->get_data_ptr<uint8_t>(),
1015                     node.get_input_shape(0),
1016                     node.get_input_shape(1),
1017                     node.get_output_shape(0),
1018                     qc->get_window_movement_strides(),
1019                     qc->get_window_dilation_strides(),
1020                     qc->get_padding_below(),
1021                     qc->get_padding_above(),
1022                     qc->get_data_dilation_strides(),
1023                     args[2]->get_data_ptr<const float>(),
1024                     args[3]->get_data_ptr<const uint8_t>(),
1025                     args[4]->get_data_ptr<const float>(),
1026                     args[5]->get_data_ptr<const uint8_t>(),
1027                     args[6]->get_data_ptr<const float>(),
1028                     args[7]->get_data_ptr<const uint8_t>());
1029             }
1030             else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
1031                      output_element_type == element::i32)
1032             {
1033                 reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
1034                     args[0]->get_data_ptr<const uint8_t>(),
1035                     args[1]->get_data_ptr<const int8_t>(),
1036                     out[0]->get_data_ptr<int32_t>(),
1037                     node.get_input_shape(0),
1038                     node.get_input_shape(1),
1039                     node.get_output_shape(0),
1040                     qc->get_window_movement_strides(),
1041                     qc->get_window_dilation_strides(),
1042                     qc->get_padding_below(),
1043                     qc->get_padding_above(),
1044                     qc->get_data_dilation_strides(),
1045                     args[2]->get_data_ptr<const float>(),
1046                     args[3]->get_data_ptr<const uint8_t>(),
1047                     args[4]->get_data_ptr<const float>(),
1048                     args[5]->get_data_ptr<const int8_t>(),
1049                     args[6]->get_data_ptr<const float>(),
1050                     args[7]->get_data_ptr<const int32_t>());
1051             }
1052             else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
1053                      output_element_type == element::i32)
1054             {
1055                 reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
1056                     args[0]->get_data_ptr<const uint8_t>(),
1057                     args[1]->get_data_ptr<const uint8_t>(),
1058                     out[0]->get_data_ptr<int32_t>(),
1059                     node.get_input_shape(0),
1060                     node.get_input_shape(1),
1061                     node.get_output_shape(0),
1062                     qc->get_window_movement_strides(),
1063                     qc->get_window_dilation_strides(),
1064                     qc->get_padding_below(),
1065                     qc->get_padding_above(),
1066                     qc->get_data_dilation_strides(),
1067                     args[2]->get_data_ptr<const float>(),
1068                     args[3]->get_data_ptr<const uint8_t>(),
1069                     args[4]->get_data_ptr<const float>(),
1070                     args[5]->get_data_ptr<const uint8_t>(),
1071                     args[6]->get_data_ptr<const float>(),
1072                     args[7]->get_data_ptr<const int32_t>());
1073             }
1074             else
1075             {
1076                 std::stringstream ss;
1077                 ss << "unsupported element type";
1078                 throw std::runtime_error(ss.str());
1079             }
1080
1081             break;
1082         }
1083
1084         case OP_TYPEID::QuantizedDot:
1085         {
1086             const op::QuantizedDot* qd = static_cast<const op::QuantizedDot*>(&node);
1087
1088             auto input0_element_type = qd->get_input_element_type(0);
1089             auto input1_element_type = qd->get_input_element_type(1);
1090             auto output_element_type = qd->get_output_element_type(0);
1091
1092             if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
1093                 output_element_type == element::i8)
1094             {
1095                 reference::dot<uint8_t, int8_t, int8_t, int32_t>(
1096                     args[0]->get_data_ptr<const uint8_t>(),
1097                     args[1]->get_data_ptr<const int8_t>(),
1098                     out[0]->get_data_ptr<int8_t>(),
1099                     node.get_input_shape(0),
1100                     node.get_input_shape(1),
1101                     node.get_output_shape(0),
1102                     1,
1103                     args[2]->get_data_ptr<const float>(),
1104                     args[3]->get_data_ptr<const uint8_t>(),
1105                     args[4]->get_data_ptr<const float>(),
1106                     args[5]->get_data_ptr<const int8_t>(),
1107                     args[6]->get_data_ptr<const float>(),
1108                     args[7]->get_data_ptr<const int8_t>());
1109             }
1110             else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
1111                      output_element_type == element::u8)
1112             {
1113                 reference::dot<uint8_t, uint8_t, uint8_t, int32_t>(
1114                     args[0]->get_data_ptr<const uint8_t>(),
1115                     args[1]->get_data_ptr<const uint8_t>(),
1116                     out[0]->get_data_ptr<uint8_t>(),
1117                     node.get_input_shape(0),
1118                     node.get_input_shape(1),
1119                     node.get_output_shape(0),
1120                     1,
1121                     args[2]->get_data_ptr<const float>(),
1122                     args[3]->get_data_ptr<const uint8_t>(),
1123                     args[4]->get_data_ptr<const float>(),
1124                     args[5]->get_data_ptr<const uint8_t>(),
1125                     args[6]->get_data_ptr<const float>(),
1126                     args[7]->get_data_ptr<const uint8_t>());
1127             }
1128             else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
1129                      output_element_type == element::i32)
1130             {
1131                 reference::dot<uint8_t, uint8_t, int32_t, int32_t>(
1132                     args[0]->get_data_ptr<const uint8_t>(),
1133                     args[1]->get_data_ptr<const uint8_t>(),
1134                     out[0]->get_data_ptr<int32_t>(),
1135                     node.get_input_shape(0),
1136                     node.get_input_shape(1),
1137                     node.get_output_shape(0),
1138                     1,
1139                     args[2]->get_data_ptr<const float>(),
1140                     args[3]->get_data_ptr<const uint8_t>(),
1141                     args[4]->get_data_ptr<const float>(),
1142                     args[5]->get_data_ptr<const uint8_t>(),
1143                     args[6]->get_data_ptr<const float>(),
1144                     args[7]->get_data_ptr<const int32_t>());
1145             }
1146             else if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
1147                      output_element_type == element::i32)
1148             {
1149                 reference::dot<uint8_t, int8_t, int32_t, int32_t>(
1150                     args[0]->get_data_ptr<const uint8_t>(),
1151                     args[1]->get_data_ptr<const int8_t>(),
1152                     out[0]->get_data_ptr<int32_t>(),
1153                     node.get_input_shape(0),
1154                     node.get_input_shape(1),
1155                     node.get_output_shape(0),
1156                     1,
1157                     args[2]->get_data_ptr<const float>(),
1158                     args[3]->get_data_ptr<const uint8_t>(),
1159                     args[4]->get_data_ptr<const float>(),
1160                     args[5]->get_data_ptr<const int8_t>(),
1161                     args[6]->get_data_ptr<const float>(),
1162                     args[7]->get_data_ptr<const int32_t>());
1163             }
1164             else
1165             {
1166                 std::stringstream ss;
1167                 ss << "unsupported element type";
1168                 throw std::runtime_error(ss.str());
1169             }
1170
1171             break;
1172         }
1173         case OP_TYPEID::RegionYolo_v0:
1174         {
1175             const op::RegionYolo* region_yolo = static_cast<const op::RegionYolo*>(&node);
1176             reference::region_yolo<T>(args[0]->get_data_ptr<const T>(),
1177                                       out[0]->get_data_ptr<T>(),
1178                                       args[0]->get_shape(),
1179                                       region_yolo->get_num_coords(),
1180                                       region_yolo->get_num_classes(),
1181                                       region_yolo->get_num_regions(),
1182                                       region_yolo->get_do_softmax(),
1183                                       region_yolo->get_mask());
1184             break;
1185         }
1186         case OP_TYPEID::Relu:
1187         {
1188             size_t element_count = shape_size(node.get_output_shape(0));
1189             reference::relu<T>(
1190                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1191             break;
1192         }
1193         case OP_TYPEID::ReplaceSlice:
1194         {
1195             const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
1196             reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
1197                                         args[1]->get_data_ptr<const T>(),
1198                                         out[0]->get_data_ptr<T>(),
1199                                         node.get_input_shape(1),
1200                                         slice->get_lower_bounds(),
1201                                         slice->get_upper_bounds(),
1202                                         slice->get_strides(),
1203                                         node.get_output_shape(0));
1204             break;
1205         }
1206         case OP_TYPEID::Reverse:
1207         {
1208             const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
1209             reference::reverse(args[0]->get_data_ptr<const char>(),
1210                                out[0]->get_data_ptr<char>(),
1211                                node.get_input_shape(0),
1212                                node.get_output_shape(0),
1213                                reverse->get_reversed_axes(),
1214                                args[0]->get_element_type().size());
1215             break;
1216         }
1217         case OP_TYPEID::ReverseSequence:
1218         {
1219             const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
1220
1221             if (node.get_input_element_type(1) == element::i32)
1222             {
1223                 reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
1224                                                         out[0]->get_data_ptr<T>(),
1225                                                         node.get_input_shape(0),
1226                                                         reverse->get_batch_axis(),
1227                                                         reverse->get_sequence_axis(),
1228                                                         args[1]->get_data_ptr<const int32_t>());
1229             }
1230             else
1231             {
1232                 throw ngraph_error("only int32 indices are supported");
1233             }
1234             break;
1235         }
1236         case OP_TYPEID::Round:
1237         {
1238             size_t element_count = shape_size(node.get_output_shape(0));
1239             reference::round<T>(args[0]->get_data_ptr<const T>(),
1240                                 out[0]->get_data_ptr<T>(),
1241                                 element_count,
1242                                 op::v5::Round::RoundMode::HALF_TO_EVEN);
1243             break;
1244         }
1245         case OP_TYPEID::Select:
1246         {
1247             size_t element_count = shape_size(node.get_output_shape(0));
1248             reference::select<T>(args[0]->get_data_ptr<const char>(),
1249                                  args[1]->get_data_ptr<const T>(),
1250                                  args[2]->get_data_ptr<const T>(),
1251                                  out[0]->get_data_ptr<T>(),
1252                                  element_count);
1253             break;
1254         }
1255         case OP_TYPEID::Sigmoid:
1256         {
1257             size_t element_count = shape_size(node.get_output_shape(0));
1258             reference::sigmoid<T>(
1259                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1260             break;
1261         }
1262         case OP_TYPEID::Sign:
1263         {
1264             size_t element_count = shape_size(node.get_output_shape(0));
1265             reference::sign<T>(
1266                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1267             break;
1268         }
1269         case OP_TYPEID::Sin:
1270         {
1271             size_t element_count = shape_size(node.get_output_shape(0));
1272             reference::sin<T>(
1273                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1274             break;
1275         }
1276         case OP_TYPEID::Sinh:
1277         {
1278             size_t element_count = shape_size(node.get_output_shape(0));
1279             reference::sinh<T>(
1280                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1281             break;
1282         }
1283         case OP_TYPEID::Sqrt:
1284         {
1285             size_t element_count = shape_size(node.get_output_shape(0));
1286             reference::sqrt<T>(
1287                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1288             break;
1289         }
1290         case OP_TYPEID::Tan:
1291         {
1292             size_t element_count = shape_size(node.get_output_shape(0));
1293             reference::tan<T>(
1294                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1295             break;
1296         }
1297         case OP_TYPEID::Tanh:
1298         {
1299             size_t element_count = shape_size(node.get_output_shape(0));
1300             reference::tanh<T>(
1301                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1302             break;
1303         }
1304         case OP_TYPEID::TopK:
1305         {
1306             const op::TopK* topk = static_cast<const op::TopK*>(&node);
1307             if (node.get_output_element_type(0) == element::i64)
1308             {
1309                 reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
1310                                             out[0]->get_data_ptr<int64_t>(),
1311                                             out[1]->get_data_ptr<T>(),
1312                                             node.get_input_shape(0),
1313                                             node.get_output_shape(0),
1314                                             topk->get_top_k_axis(),
1315                                             topk->get_k(),
1316                                             topk->get_compute_max(),
1317                                             topk->get_sort());
1318             }
1319             else if (node.get_output_element_type(0) == element::i32)
1320             {
1321                 reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
1322                                             out[0]->get_data_ptr<int32_t>(),
1323                                             out[1]->get_data_ptr<T>(),
1324                                             node.get_input_shape(0),
1325                                             node.get_output_shape(0),
1326                                             topk->get_top_k_axis(),
1327                                             topk->get_k(),
1328                                             topk->get_compute_max(),
1329                                             topk->get_sort());
1330             }
1331             else
1332             {
1333                 throw ngraph_error("Unexpected type");
1334             }
1335             break;
1336         }
1337         case OP_TYPEID::DetectionOutput_v0:
1338         {
1339             const op::DetectionOutput* detOut = static_cast<const op::DetectionOutput*>(&node);
1340             reference::referenceDetectionOutput<T> refDetOut(
1341                 detOut->get_attrs(), node.get_input_shape(0), node.get_input_shape(2));
1342             if (node.get_input_size() == 3)
1343             {
1344                 refDetOut.run(args[0]->get_data_ptr<const T>(),
1345                               args[1]->get_data_ptr<const T>(),
1346                               args[2]->get_data_ptr<const T>(),
1347                               nullptr,
1348                               nullptr,
1349                               out[0]->get_data_ptr<T>());
1350             }
1351             else if (node.get_input_size() == 5)
1352             {
1353                 refDetOut.run(args[0]->get_data_ptr<const T>(),
1354                               args[1]->get_data_ptr<const T>(),
1355                               args[2]->get_data_ptr<const T>(),
1356                               args[3]->get_data_ptr<const T>(),
1357                               args[4]->get_data_ptr<const T>(),
1358                               out[0]->get_data_ptr<T>());
1359             }
1360             else
1361             {
1362                 throw ngraph_error("DetectionOutput layer supports only 3 or 5 inputs");
1363             }
1364
1365             break;
1366         }
1367         case OP_TYPEID::ScatterNDUpdate_v3:
1368         {
1369             const op::ScatterNDUpdate* scatterNDUpd =
1370                 static_cast<const op::v3::ScatterNDUpdate*>(&node);
1371             auto idxType = scatterNDUpd->get_input_element_type(1);
1372             if (idxType == element::i32)
1373             {
1374                 reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
1375                                                        args[1]->get_data_ptr<const int32_t>(),
1376                                                        args[2]->get_data_ptr<const T>(),
1377                                                        out[0]->get_data_ptr<T>(),
1378                                                        node.get_input_shape(0),
1379                                                        node.get_input_shape(1),
1380                                                        node.get_input_shape(2));
1381             }
1382             else if (idxType == element::i64)
1383             {
1384                 reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
1385                                                        args[1]->get_data_ptr<const int64_t>(),
1386                                                        args[2]->get_data_ptr<const T>(),
1387                                                        out[0]->get_data_ptr<T>(),
1388                                                        node.get_input_shape(0),
1389                                                        node.get_input_shape(1),
1390                                                        node.get_input_shape(2));
1391             }
1392             else
1393             {
1394                 throw ngraph_error(
1395                     "ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
1396             }
1397
1398             break;
1399         }
1400         case OP_TYPEID::GatherTree_v1:
1401         {
1402             reference::gather_tree(args[0]->get_data_ptr<const char>(),
1403                                    args[1]->get_data_ptr<const char>(),
1404                                    args[2]->get_data_ptr<const char>(),
1405                                    args[3]->get_data_ptr<const char>(),
1406                                    out[0]->get_data_ptr<char>(),
1407                                    node.get_input_shape(0),
1408                                    node.get_input_shape(1),
1409                                    node.get_input_shape(2),
1410                                    node.get_input_shape(3),
1411                                    args[1]->get_element_type());
1412             break;
1413         }
1414         case OP_TYPEID::NormalizeL2:
1415         {
1416             const op::NormalizeL2* norm = static_cast<const op::NormalizeL2*>(&node);
1417             reference::normalize_l2<T>(args[0]->get_data_ptr<const T>(),
1418                                        out[0]->get_data_ptr<T>(),
1419                                        node.get_input_shape(0),
1420                                        norm->get_reduction_axes(),
1421                                        norm->get_eps(),
1422                                        norm->get_eps_mode());
1423             break;
1424         }
1425
1426         // Fused Ops are not supported in interpreter. They need to be decomposed before execution
1427         case OP_TYPEID::DepthToSpace:
1428         case OP_TYPEID::FakeQuantize:
1429         case OP_TYPEID::Gather:
1430         case OP_TYPEID::Gelu:
1431         case OP_TYPEID::GRN:
1432         case OP_TYPEID::GroupConvolution:
1433         case OP_TYPEID::GroupConvolutionBackpropData:
1434         case OP_TYPEID::HardSigmoid:
1435         case OP_TYPEID::Interpolate:
1436         case OP_TYPEID::MVN:
1437         case OP_TYPEID::PRelu:
1438         case OP_TYPEID::ScatterUpdate_v3:
1439         case OP_TYPEID::Selu:
1440         case OP_TYPEID::ShuffleChannels:
1441         case OP_TYPEID::SpaceToDepth:
1442         case OP_TYPEID::Split:
1443         case OP_TYPEID::SquaredDifference:
1444         case OP_TYPEID::StopGradient:
1445         case OP_TYPEID::TensorIterator:
1446         case OP_TYPEID::Tile:
1447         case OP_TYPEID::UnknownOp:
1448             throw unsupported_op("Unsupported op '" + node.description() + "'");
1449         case OP_TYPEID::Add:
1450         case OP_TYPEID::Broadcast:
1451         case OP_TYPEID::Clamp:
1452         case OP_TYPEID::Concat:
1453         case OP_TYPEID::Constant:
1454         case OP_TYPEID::Divide:
1455         case OP_TYPEID::Equal:
1456         case OP_TYPEID::Greater:
1457         case OP_TYPEID::GreaterEq:
1458         case OP_TYPEID::Less:
1459         case OP_TYPEID::LessEq:
1460         case OP_TYPEID::LessEqual_v1:
1461         case OP_TYPEID::LogicalAnd_v1:
1462         case OP_TYPEID::LogicalOr_v1:
1463         case OP_TYPEID::LogicalXor_v1:
1464         case OP_TYPEID::MatMul:
1465         case OP_TYPEID::Max:
1466         case OP_TYPEID::Maximum:
1467         case OP_TYPEID::Min:
1468         case OP_TYPEID::Minimum:
1469         case OP_TYPEID::Multiply:
1470         case OP_TYPEID::NonZero_v3:
1471         case OP_TYPEID::NotEqual:
1472         case OP_TYPEID::Or:
1473         case OP_TYPEID::Power:
1474         case OP_TYPEID::Product:
1475         case OP_TYPEID::Range:
1476         case OP_TYPEID::Reshape:
1477         case OP_TYPEID::Result:
1478         case OP_TYPEID::Round_v5:
1479         case OP_TYPEID::ShapeOf_v3:
1480         case OP_TYPEID::ShapeOf:
1481         case OP_TYPEID::Softmax:
1482         case OP_TYPEID::Squeeze:
1483         case OP_TYPEID::Sum:
1484         case OP_TYPEID::Subtract:
1485         case OP_TYPEID::Unsqueeze:
1486         case OP_TYPEID::Xor:
1487         case OP_TYPEID::Slice:
1488             // These ops are handled by op evaluators so nothing to do
1489             break;
1490 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
1491 #pragma GCC diagnostic pop
1492 #endif
1493         }
1494     }
1495 };
1496
1497 NGRAPH_SUPPRESS_DEPRECATED_END