cebb82f5427f57b072b49e30232772c437b85926
[platform/upstream/dldt.git] / ngraph / test / runtime / interpreter / int_executable.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include <initializer_list>
20 #include <iostream>
21 #include <memory>
22 #include <sstream>
23 #include <string>
24 #include <vector>
25
26 #include <ngraph/runtime/host_tensor.hpp>
27 #include "backend.hpp"
28 #include "int_backend_visibility.hpp"
29 #include "ngraph/ops.hpp"
30 #include "ngraph/runtime/aligned_buffer.hpp"
31 #include "ngraph/runtime/reference/abs.hpp"
32 #include "ngraph/runtime/reference/acos.hpp"
33 #include "ngraph/runtime/reference/asin.hpp"
34 #include "ngraph/runtime/reference/atan.hpp"
35 #include "ngraph/runtime/reference/atan2.hpp"
36 #include "ngraph/runtime/reference/avg_pool.hpp"
37 #include "ngraph/runtime/reference/batch_norm.hpp"
38 #include "ngraph/runtime/reference/broadcast.hpp"
39 #include "ngraph/runtime/reference/ceiling.hpp"
40 #include "ngraph/runtime/reference/concat.hpp"
41 #include "ngraph/runtime/reference/constant.hpp"
42 #include "ngraph/runtime/reference/convert.hpp"
43 #include "ngraph/runtime/reference/convolution.hpp"
44 #include "ngraph/runtime/reference/cos.hpp"
45 #include "ngraph/runtime/reference/cosh.hpp"
46 #include "ngraph/runtime/reference/ctc_greedy_decoder.hpp"
47 #include "ngraph/runtime/reference/ctc_loss.hpp"
48 #include "ngraph/runtime/reference/cum_sum.hpp"
49 #include "ngraph/runtime/reference/detection_output.hpp"
50 #include "ngraph/runtime/reference/dot.hpp"
51 #include "ngraph/runtime/reference/elu.hpp"
52 #include "ngraph/runtime/reference/embedding_bag_offsets_sum.hpp"
53 #include "ngraph/runtime/reference/embedding_bag_packed_sum.hpp"
54 #include "ngraph/runtime/reference/embedding_segments_sum.hpp"
55 #include "ngraph/runtime/reference/erf.hpp"
56 #include "ngraph/runtime/reference/exp.hpp"
57 #include "ngraph/runtime/reference/extract_image_patches.hpp"
58 #include "ngraph/runtime/reference/floor.hpp"
59 #include "ngraph/runtime/reference/gather.hpp"
60 #include "ngraph/runtime/reference/gather_nd.hpp"
61 #include "ngraph/runtime/reference/gather_tree.hpp"
62 #include "ngraph/runtime/reference/gru_cell.hpp"
63 #include "ngraph/runtime/reference/log.hpp"
64 #include "ngraph/runtime/reference/log_softmax.hpp"
65 #include "ngraph/runtime/reference/lrn.hpp"
66 #include "ngraph/runtime/reference/lstm_cell.hpp"
67 #include "ngraph/runtime/reference/matmul.hpp"
68 #include "ngraph/runtime/reference/max.hpp"
69 #include "ngraph/runtime/reference/max_pool.hpp"
70 #include "ngraph/runtime/reference/min.hpp"
71 #include "ngraph/runtime/reference/negate.hpp"
72 #include "ngraph/runtime/reference/normalize_l2.hpp"
73 #include "ngraph/runtime/reference/not.hpp"
74 #include "ngraph/runtime/reference/one_hot.hpp"
75 #include "ngraph/runtime/reference/pad.hpp"
76 #include "ngraph/runtime/reference/prior_box.hpp"
77 #include "ngraph/runtime/reference/product.hpp"
78 #include "ngraph/runtime/reference/quantize.hpp"
79 #include "ngraph/runtime/reference/region_yolo.hpp"
80 #include "ngraph/runtime/reference/relu.hpp"
81 #include "ngraph/runtime/reference/reorg_yolo.hpp"
82 #include "ngraph/runtime/reference/replace_slice.hpp"
83 #include "ngraph/runtime/reference/reshape.hpp"
84 #include "ngraph/runtime/reference/result.hpp"
85 #include "ngraph/runtime/reference/reverse.hpp"
86 #include "ngraph/runtime/reference/reverse_sequence.hpp"
87 #include "ngraph/runtime/reference/rnn_cell.hpp"
88 #include "ngraph/runtime/reference/round.hpp"
89 #include "ngraph/runtime/reference/scatter_nd_update.hpp"
90 #include "ngraph/runtime/reference/select.hpp"
91 #include "ngraph/runtime/reference/sequences.hpp"
92 #include "ngraph/runtime/reference/sigmoid.hpp"
93 #include "ngraph/runtime/reference/sign.hpp"
94 #include "ngraph/runtime/reference/sin.hpp"
95 #include "ngraph/runtime/reference/sinh.hpp"
96 #include "ngraph/runtime/reference/softmax.hpp"
97 #include "ngraph/runtime/reference/sqrt.hpp"
98 #include "ngraph/runtime/reference/sum.hpp"
99 #include "ngraph/runtime/reference/tan.hpp"
100 #include "ngraph/runtime/reference/tanh.hpp"
101 #include "ngraph/runtime/reference/topk.hpp"
102 #include "ngraph/runtime/tensor.hpp"
103 #include "op/avg_pool.hpp"
104 #include "op/convolution.hpp"
105 #include "op/group_conv.hpp"
106
107 NGRAPH_SUPPRESS_DEPRECATED_START
108
109 namespace ngraph
110 {
111     namespace runtime
112     {
113         namespace interpreter
114         {
115             class INTBackend;
116             class INTExecutable;
117
118             // This expands the op list in op_tbl.hpp into a list of enumerations that look like
119             // this:
120             // Abs,
121             // Acos,
122             // ...
123             enum class OP_TYPEID
124             {
125 #define NGRAPH_OP(NAME, NAMESPACE) ID_SUFFIX(NAME),
126 #include "opset_int_tbl.hpp"
127 #undef NGRAPH_OP
128                 UnknownOp
129             };
130         } // namespace interpreter
131     }     // namespace runtime
132 } // namespace ngraph
133
134 class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : public Executable
135 {
136     friend class INTBackend;
137
138 public:
139     INTExecutable(const std::shared_ptr<Function>& function,
140                   bool enable_performance_collection = false);
141
142     bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
143               const std::vector<std::shared_ptr<Tensor>>& inputs) override;
144
145     void set_nan_check(bool enable);
146
147     std::vector<PerformanceCounter> get_performance_data() const override;
148
149     std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index) override;
150
151     std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index) override;
152
153     std::vector<std::shared_ptr<runtime::Tensor>>
154         create_input_tensor(size_t input_index, size_t pipeline_depth) override;
155
156     std::vector<std::shared_ptr<runtime::Tensor>>
157         create_output_tensor(size_t output_index, size_t pipeline_depth) override;
158
159 protected:
160     std::shared_ptr<ngraph::op::Parameter> get_parameter(size_t index) const;
161     std::shared_ptr<ngraph::op::Result> get_result(size_t index) const;
162     int get_alignment() const { return 64; }
163     bool m_is_compiled = false;
164     bool m_nan_check_enabled = false;
165     bool m_performance_counters_enabled = false;
166     std::shared_ptr<Function> m_function;
167     std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
168     std::vector<std::shared_ptr<Node>> m_nodes;
169     std::set<std::string> m_unsupported_op_name_list;
170
171     static OP_TYPEID get_typeid(const Node& node);
172
173     static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
174                                   const Node* op = nullptr);
175
176     virtual void generate_calls(const element::Type& type,
177                                 const Node& op,
178                                 const std::vector<std::shared_ptr<HostTensor>>& outputs,
179                                 const std::vector<std::shared_ptr<HostTensor>>& inputs);
180
181     template <typename T>
182     void op_engine(const Node& node,
183                    const std::vector<std::shared_ptr<HostTensor>>& out,
184                    const std::vector<std::shared_ptr<HostTensor>>& args)
185     {
186 // We want to check that every OP_TYPEID enumeration is included in the list.
187 // These GCC flags enable compile-time checking so that if an enumeration
188 // is not in the list an error is generated.
189 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
190 #pragma GCC diagnostic push
191 #pragma GCC diagnostic error "-Wswitch"
192 #pragma GCC diagnostic error "-Wswitch-enum"
193 #endif
194         switch (get_typeid(node))
195         {
196         case OP_TYPEID::Abs:
197         {
198             size_t element_count = shape_size(node.get_output_shape(0));
199             reference::abs<T>(
200                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
201             break;
202         }
203         case OP_TYPEID::Acos:
204         {
205             size_t element_count = shape_size(node.get_output_shape(0));
206             reference::acos<T>(
207                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
208             break;
209         }
210         case OP_TYPEID::Asin:
211         {
212             size_t element_count = shape_size(node.get_output_shape(0));
213             reference::asin<T>(
214                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
215             break;
216         }
217         case OP_TYPEID::Atan:
218         {
219             size_t element_count = shape_size(node.get_output_shape(0));
220             reference::atan<T>(
221                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
222             break;
223         }
224         case OP_TYPEID::Elu:
225         {
226             const op::Elu* elu_node = static_cast<const op::Elu*>(&node);
227
228             size_t element_count = shape_size(node.get_output_shape(0));
229             reference::elu<T>(args[0]->get_data_ptr<const T>(),
230                               out[0]->get_data_ptr<T>(),
231                               element_count,
232                               elu_node->get_alpha());
233             break;
234         }
235         case OP_TYPEID::AvgPool:
236         {
237             const op::v0::AvgPool* avg_pool = static_cast<const op::v0::AvgPool*>(&node);
238
239             reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
240                                    out[0]->get_data_ptr<T>(),
241                                    node.get_input_shape(0),
242                                    node.get_output_shape(0),
243                                    avg_pool->get_window_shape(),
244                                    avg_pool->get_window_movement_strides(),
245                                    avg_pool->get_padding_below(),
246                                    avg_pool->get_padding_above(),
247                                    avg_pool->get_include_padding_in_avg_computation());
248             break;
249         }
250         case OP_TYPEID::BatchNormInference:
251         {
252             const ngraph::op::v0::BatchNormInference* bn =
253                 static_cast<const ngraph::op::v0::BatchNormInference*>(&node);
254             reference::batch_norm_inference<T>(bn->get_eps_value(),
255                                                args[0]->get_data_ptr<const T>(),
256                                                args[1]->get_data_ptr<const T>(),
257                                                args[2]->get_data_ptr<const T>(),
258                                                args[3]->get_data_ptr<const T>(),
259                                                args[4]->get_data_ptr<const T>(),
260                                                out[0]->get_data_ptr<T>(),
261                                                node.get_input_shape(2));
262             break;
263         }
264         case OP_TYPEID::BatchNormInference_v5:
265         {
266             const ngraph::op::v5::BatchNormInference* bn =
267                 static_cast<const ngraph::op::v5::BatchNormInference*>(&node);
268             reference::batch_norm_inference<T>(bn->get_eps_value(),
269                                                args[1]->get_data_ptr<const T>(),
270                                                args[2]->get_data_ptr<const T>(),
271                                                args[0]->get_data_ptr<const T>(),
272                                                args[3]->get_data_ptr<const T>(),
273                                                args[4]->get_data_ptr<const T>(),
274                                                out[0]->get_data_ptr<T>(),
275                                                node.get_input_shape(0));
276             break;
277         }
278         case OP_TYPEID::Ceiling:
279         {
280             size_t element_count = shape_size(node.get_output_shape(0));
281             reference::ceiling<T>(
282                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
283             break;
284         }
285         case OP_TYPEID::Convert:
286         {
287             // const op::Convert* c = static_cast<const op::Convert*>(&node);
288             element::Type type = node.get_element_type();
289             std::stringstream ss;
290             size_t element_count = shape_size(node.get_output_shape(0));
291             switch (type)
292             {
293             case element::Type_t::boolean:
294                 reference::convert_to_bool<T>(
295                     args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
296                 break;
297             case element::Type_t::f32:
298                 reference::convert<T>(
299                     args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
300                 break;
301             case element::Type_t::f64:
302                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
303                                       out[0]->get_data_ptr<double>(),
304                                       element_count);
305                 break;
306             case element::Type_t::i8:
307                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
308                                       out[0]->get_data_ptr<int8_t>(),
309                                       element_count);
310                 break;
311             case element::Type_t::i16:
312                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
313                                       out[0]->get_data_ptr<int16_t>(),
314                                       element_count);
315                 break;
316             case element::Type_t::i32:
317                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
318                                       out[0]->get_data_ptr<int32_t>(),
319                                       element_count);
320                 break;
321             case element::Type_t::i64:
322                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
323                                       out[0]->get_data_ptr<int64_t>(),
324                                       element_count);
325                 break;
326             case element::Type_t::u8:
327                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
328                                       out[0]->get_data_ptr<uint8_t>(),
329                                       element_count);
330                 break;
331             case element::Type_t::u16:
332                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
333                                       out[0]->get_data_ptr<uint16_t>(),
334                                       element_count);
335                 break;
336             case element::Type_t::u32:
337                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
338                                       out[0]->get_data_ptr<uint32_t>(),
339                                       element_count);
340                 break;
341             case element::Type_t::u64:
342                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
343                                       out[0]->get_data_ptr<uint64_t>(),
344                                       element_count);
345                 break;
346             case element::Type_t::undefined:
347             case element::Type_t::dynamic:
348             case element::Type_t::u1:
349             case element::Type_t::bf16:
350             case element::Type_t::f16:
351                 ss << "unsupported element type " << type << " op Convert";
352                 throw std::runtime_error(ss.str());
353             }
354             break;
355         }
356         case OP_TYPEID::Convolution:
357         {
358             const op::v0::Convolution* c = static_cast<const op::v0::Convolution*>(&node);
359             reference::convolution<T>(args[0]->get_data_ptr<const T>(),
360                                       args[1]->get_data_ptr<const T>(),
361                                       out[0]->get_data_ptr<T>(),
362                                       node.get_input_shape(0),
363                                       node.get_input_shape(1),
364                                       node.get_output_shape(0),
365                                       c->get_window_movement_strides(),
366                                       c->get_window_dilation_strides(),
367                                       c->get_padding_below(),
368                                       c->get_padding_above(),
369                                       c->get_data_dilation_strides());
370
371             break;
372         }
373         case OP_TYPEID::ConvolutionBackpropData:
374         {
375             // Note that args[1] and args[0] are switched here from the usual order.
376             const op::v0::ConvolutionBackpropData* c =
377                 static_cast<const op::v0::ConvolutionBackpropData*>(&node);
378             reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
379                                                   args[0]->get_data_ptr<const T>(),
380                                                   out[0]->get_data_ptr<T>(),
381                                                   c->get_input_shape(1),
382                                                   c->get_input_shape(0),
383                                                   c->get_data_batch_shape(),
384                                                   c->get_data_dilation_strides_forward(),
385                                                   c->get_window_dilation_strides_forward(),
386                                                   c->compute_backward_delta_out_pad_below(),
387                                                   c->compute_backward_delta_out_pad_above(),
388                                                   c->get_window_movement_strides_forward());
389             break;
390         }
391         case OP_TYPEID::Cos:
392         {
393             size_t element_count = shape_size(node.get_output_shape(0));
394             reference::cos<T>(
395                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
396             break;
397         }
398         case OP_TYPEID::Cosh:
399         {
400             size_t element_count = shape_size(node.get_output_shape(0));
401             reference::cosh<T>(
402                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
403             break;
404         }
405         case OP_TYPEID::CTCGreedyDecoder_v0:
406         {
407             const auto ctc_greedy_dec = static_cast<const op::v0::CTCGreedyDecoder*>(&node);
408             reference::ctc_greedy_decoder<T>(args[0]->get_data_ptr<const T>(),
409                                              args[1]->get_data_ptr<const T>(),
410                                              out[0]->get_data_ptr<T>(),
411                                              args[0]->get_shape(),
412                                              args[1]->get_shape(),
413                                              out[0]->get_shape(),
414                                              ctc_greedy_dec->get_ctc_merge_repeated());
415             break;
416         }
417         case OP_TYPEID::CTCLoss_v4:
418         {
419             const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
420             auto t_int = node.get_input_element_type(1);
421             if (t_int == element::i32)
422             {
423                 reference::CTCLoss<T, int32_t>(
424                     args[0]->get_data_ptr<const T>(),
425                     ctc_loss->get_input_shape(0),
426                     args[1]->get_data_ptr<const int32_t>(),
427                     args[2]->get_data_ptr<const int32_t>(),
428                     args[3]->get_data_ptr<const int32_t>(),
429                     args.size() > 4 ? args[4]->get_data_ptr<const int32_t>() : nullptr,
430                     ctc_loss->get_preprocess_collapse_repeated(),
431                     ctc_loss->get_ctc_merge_repeated(),
432                     ctc_loss->get_unique(),
433                     out[0]->get_data_ptr<T>());
434             }
435             else if (t_int == element::i64)
436             {
437                 reference::CTCLoss<T, int64_t>(
438                     args[0]->get_data_ptr<const T>(),
439                     ctc_loss->get_input_shape(0),
440                     args[1]->get_data_ptr<const int64_t>(),
441                     args[2]->get_data_ptr<const int64_t>(),
442                     args[3]->get_data_ptr<const int64_t>(),
443                     args.size() > 4 ? args[4]->get_data_ptr<const int64_t>() : nullptr,
444                     ctc_loss->get_preprocess_collapse_repeated(),
445                     ctc_loss->get_ctc_merge_repeated(),
446                     ctc_loss->get_unique(),
447                     out[0]->get_data_ptr<T>());
448             }
449             break;
450         }
451         case OP_TYPEID::CumSum:
452         {
453             const op::CumSum* cumsum = static_cast<const op::CumSum*>(&node);
454             auto axis_et = node.get_input_element_type(1);
455             if (axis_et == element::i32)
456             {
457                 reference::cumsum<T, int32_t>(args[0]->get_data_ptr<const T>(),
458                                               args[1]->get_data_ptr<const int32_t>(),
459                                               out[0]->get_data_ptr<T>(),
460                                               node.get_input_shape(0),
461                                               cumsum->is_exclusive(),
462                                               cumsum->is_reverse());
463             }
464             else if (axis_et == element::i64)
465             {
466                 reference::cumsum<T, int64_t>(args[0]->get_data_ptr<const T>(),
467                                               args[1]->get_data_ptr<const int64_t>(),
468                                               out[0]->get_data_ptr<T>(),
469                                               node.get_input_shape(0),
470                                               cumsum->is_exclusive(),
471                                               cumsum->is_reverse());
472             }
473             break;
474         }
475         case OP_TYPEID::Dot:
476         {
477             const op::Dot* dot = static_cast<const op::Dot*>(&node);
478
479             reference::dot(args[0]->get_data_ptr<const T>(),
480                            args[1]->get_data_ptr<const T>(),
481                            out[0]->get_data_ptr<T>(),
482                            node.get_input_shape(0),
483                            node.get_input_shape(1),
484                            node.get_output_shape(0),
485                            dot->get_reduction_axes_count());
486             break;
487         }
488         case OP_TYPEID::EmbeddingBagOffsetsSum_v3:
489         {
490             const op::EmbeddingBagOffsetsSum* embed =
491                 static_cast<const op::EmbeddingBagOffsetsSum*>(&node);
492             auto indicesType = embed->input(1).get_element_type();
493             size_t indices_num = shape_size(embed->get_input_shape(1));
494
495             if (indicesType == element::u64 || indicesType == element::i64)
496             {
497                 reference::embeddingBagOffsetsSum<T, size_t>(
498                     args[0]->get_data_ptr<const T>(),
499                     args[1]->get_data_ptr<const size_t>(),
500                     args[2]->get_data_ptr<const size_t>(),
501                     args.size() > 3 ? args[3]->get_data_ptr<const size_t>() : nullptr,
502                     args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
503                     out[0]->get_data_ptr<T>(),
504                     indices_num,
505                     embed->get_shape());
506             }
507             else if (indicesType == element::u32 || indicesType == element::i32)
508             {
509                 reference::embeddingBagOffsetsSum<T, unsigned>(
510                     args[0]->get_data_ptr<const T>(),
511                     args[1]->get_data_ptr<const unsigned>(),
512                     args[2]->get_data_ptr<const unsigned>(),
513                     args.size() > 3 ? args[3]->get_data_ptr<const unsigned>() : nullptr,
514                     args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
515                     out[0]->get_data_ptr<T>(),
516                     indices_num,
517                     embed->get_shape());
518             }
519             else
520             {
521                 throw ngraph_error(std::string("Unsupported index type ") +
522                                    indicesType.c_type_string() +
523                                    std::string(" in EmbeddingBagOffsetsSum"));
524             }
525             break;
526         }
527         case OP_TYPEID::EmbeddingBagPackedSum_v3:
528         {
529             const op::EmbeddingBagPackedSum* embed =
530                 static_cast<const op::EmbeddingBagPackedSum*>(&node);
531             auto indicesType = embed->input(1).get_element_type();
532
533             if (indicesType == element::u64 || indicesType == element::i64)
534             {
535                 reference::embeddingBagPackedSum<T, size_t>(
536                     args[0]->get_data_ptr<const T>(),
537                     args[1]->get_data_ptr<const size_t>(),
538                     args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
539                     out[0]->get_data_ptr<T>(),
540                     embed->get_input_shape(1),
541                     embed->get_shape());
542             }
543             else if (indicesType == element::u32 || indicesType == element::i32)
544             {
545                 reference::embeddingBagPackedSum<T, unsigned>(
546                     args[0]->get_data_ptr<const T>(),
547                     args[1]->get_data_ptr<const unsigned>(),
548                     args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
549                     out[0]->get_data_ptr<T>(),
550                     embed->get_input_shape(1),
551                     embed->get_shape());
552             }
553             else
554             {
555                 throw ngraph_error(std::string("Unsupported index type ") +
556                                    indicesType.c_type_string() +
557                                    std::string(" in EmbeddingBagPackedSum"));
558             }
559             break;
560         }
561         case OP_TYPEID::EmbeddingSegmentsSum_v3:
562         {
563             const op::EmbeddingSegmentsSum* embed =
564                 static_cast<const op::EmbeddingSegmentsSum*>(&node);
565             auto indicesType = embed->input(1).get_element_type();
566             size_t indices_num = shape_size(embed->get_input_shape(1));
567
568             if (indicesType == element::u64 || indicesType == element::i64)
569             {
570                 reference::embeddingSegmentsSum<T, size_t>(
571                     args[0]->get_data_ptr<const T>(),
572                     args[1]->get_data_ptr<const size_t>(),
573                     args[2]->get_data_ptr<const size_t>(),
574                     args.size() > 4 ? args[4]->get_data_ptr<const size_t>() : nullptr,
575                     args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
576                     out[0]->get_data_ptr<T>(),
577                     embed->get_input_shape(0),
578                     embed->get_input_shape(1),
579                     embed->get_shape());
580             }
581             else if (indicesType == element::u32 || indicesType == element::i32)
582             {
583                 reference::embeddingSegmentsSum<T, unsigned>(
584                     args[0]->get_data_ptr<const T>(),
585                     args[1]->get_data_ptr<const unsigned>(),
586                     args[2]->get_data_ptr<const unsigned>(),
587                     args.size() > 4 ? args[4]->get_data_ptr<const unsigned>() : nullptr,
588                     args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
589                     out[0]->get_data_ptr<T>(),
590                     embed->get_input_shape(0),
591                     embed->get_input_shape(1),
592                     embed->get_shape());
593             }
594             else
595             {
596                 throw ngraph_error(std::string("Unsupported index type ") +
597                                    indicesType.c_type_string() +
598                                    std::string(" in EmbeddingSegmentsSum"));
599             }
600             break;
601         }
602         case OP_TYPEID::Erf:
603         {
604             size_t element_count = shape_size(node.get_output_shape(0));
605             reference::erf<T>(
606                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
607             break;
608         }
609         case OP_TYPEID::ExtractImagePatches_v3:
610         {
611             const op::ExtractImagePatches* extImgPatches =
612                 static_cast<const op::ExtractImagePatches*>(&node);
613             reference::extractImagePatches<T, size_t>(extImgPatches,
614                                                       args[0]->get_data_ptr<const T>(),
615                                                       out[0]->get_data_ptr<T>(),
616                                                       extImgPatches->get_input_shape(0),
617                                                       extImgPatches->get_shape());
618             break;
619         }
620         case OP_TYPEID::Exp:
621         {
622             size_t element_count = shape_size(node.get_output_shape(0));
623             reference::exp<T>(
624                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
625             break;
626         }
627 #ifdef INTERPRETER_USE_HYBRID
628         case OP_TYPEID::FunctionCall:
629         {
630             auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
631             auto backend = f->get_backend();
632             auto executable = f->get_executable();
633
634             std::vector<std::shared_ptr<Tensor>> outputs;
635             std::vector<std::shared_ptr<Tensor>> inputs;
636             for (const std::shared_ptr<HostTensor>& t : out)
637             {
638                 auto backend_tensor = backend->create_tensor(
639                     t->get_element_type(), t->get_shape(), t->get_data_ptr());
640                 outputs.push_back(backend_tensor);
641             }
642             for (const std::shared_ptr<HostTensor>& t : args)
643             {
644                 auto backend_tensor = backend->create_tensor(
645                     t->get_element_type(), t->get_shape(), t->get_data_ptr());
646                 inputs.push_back(backend_tensor);
647             }
648             executable->call(outputs, inputs);
649             break;
650         }
651 #endif
652         case OP_TYPEID::Floor:
653         {
654             size_t element_count = shape_size(node.get_output_shape(0));
655             reference::floor<T>(
656                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
657             break;
658         }
659         case OP_TYPEID::GatherND_v5:
660         {
661             const op::v5::GatherND* gatherNDNode = static_cast<const op::v5::GatherND*>(&node);
662             if (node.get_input_element_type(1) == element::i64)
663             {
664                 reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
665                                                  args[1]->get_data_ptr<int64_t>(),
666                                                  out[0]->get_data_ptr<T>(),
667                                                  node.get_input_shape(0),
668                                                  node.get_input_shape(1),
669                                                  node.get_output_shape(0),
670                                                  gatherNDNode->get_batch_dims());
671             }
672             else if (node.get_input_element_type(1) == element::i32)
673             {
674                 reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
675                                                  args[1]->get_data_ptr<int32_t>(),
676                                                  out[0]->get_data_ptr<T>(),
677                                                  node.get_input_shape(0),
678                                                  node.get_input_shape(1),
679                                                  node.get_output_shape(0),
680                                                  gatherNDNode->get_batch_dims());
681             }
682             else
683             {
684                 throw ngraph_error("Unexpected type");
685             }
686             break;
687         }
688         case OP_TYPEID::GRUCell_v3:
689         {
690             const op::v3::GRUCell* gru_cell = static_cast<const op::v3::GRUCell*>(&node);
691             runtime::reference::gru_cell(args[0]->get_data_ptr<T>(),
692                                          args[0]->get_shape(),
693                                          args[1]->get_data_ptr<T>(),
694                                          args[1]->get_shape(),
695                                          args[2]->get_data_ptr<T>(),
696                                          args[2]->get_shape(),
697                                          args[3]->get_data_ptr<T>(),
698                                          args[3]->get_shape(),
699                                          args[4]->get_data_ptr<T>(),
700                                          args[4]->get_shape(),
701                                          out[0]->get_data_ptr<T>(),
702                                          gru_cell->get_activations()[0],
703                                          gru_cell->get_activations()[1],
704                                          gru_cell->get_clip(),
705                                          gru_cell->get_linear_before_reset());
706             break;
707         }
708         case OP_TYPEID::LSTMCell_v4:
709         {
710             const op::v4::LSTMCell* lstm_cell = static_cast<const op::v4::LSTMCell*>(&node);
711             runtime::reference::lstm_cell(args[0]->get_data_ptr<T>(),
712                                           args[0]->get_shape(),
713                                           args[1]->get_data_ptr<T>(),
714                                           args[1]->get_shape(),
715                                           args[2]->get_data_ptr<T>(),
716                                           args[2]->get_shape(),
717                                           args[3]->get_data_ptr<T>(),
718                                           args[3]->get_shape(),
719                                           args[4]->get_data_ptr<T>(),
720                                           args[4]->get_shape(),
721                                           args[5]->get_data_ptr<T>(),
722                                           args[5]->get_shape(),
723                                           out[0]->get_data_ptr<T>(),
724                                           out[1]->get_data_ptr<T>(),
725                                           lstm_cell->get_activations()[0],
726                                           lstm_cell->get_activations()[1],
727                                           lstm_cell->get_activations()[2],
728                                           lstm_cell->get_clip());
729             break;
730         }
731         case OP_TYPEID::RNNCell_v0:
732         {
733             const op::v0::RNNCell* rnn_cell = static_cast<const op::v0::RNNCell*>(&node);
734             runtime::reference::rnn_cell(args[0]->get_data_ptr<T>(),
735                                          args[0]->get_shape(),
736                                          args[1]->get_data_ptr<T>(),
737                                          args[1]->get_shape(),
738                                          args[2]->get_data_ptr<T>(),
739                                          args[2]->get_shape(),
740                                          args[3]->get_data_ptr<T>(),
741                                          args[3]->get_shape(),
742                                          args[4]->get_data_ptr<T>(),
743                                          args[4]->get_shape(),
744                                          out[0]->get_data_ptr<T>(),
745                                          rnn_cell->get_activations()[0],
746                                          rnn_cell->get_clip());
747             break;
748         }
749         case OP_TYPEID::LSTMSequence:
750         case OP_TYPEID::LSTMSequence_v5:
751         {
752             auto lstm_seq = static_cast<const op::v5::LSTMSequence*>(&node);
753             runtime::reference::lstm_sequence<T>(args[0]->get_data_ptr<char>(),
754                                                  args[0]->get_shape(),
755                                                  args[1]->get_data_ptr<char>(),
756                                                  args[1]->get_shape(),
757                                                  args[2]->get_data_ptr<char>(),
758                                                  args[2]->get_shape(),
759                                                  args[3]->get_data_ptr<char>(),
760                                                  args[3]->get_shape(),
761                                                  args[4]->get_data_ptr<char>(),
762                                                  args[4]->get_shape(),
763                                                  args[5]->get_data_ptr<char>(),
764                                                  args[5]->get_shape(),
765                                                  args[6]->get_data_ptr<char>(),
766                                                  args[6]->get_shape(),
767                                                  out[0]->get_data_ptr<char>(),
768                                                  out[1]->get_data_ptr<char>(),
769                                                  out[2]->get_data_ptr<char>(),
770                                                  lstm_seq->get_activations()[0],
771                                                  lstm_seq->get_activations()[1],
772                                                  lstm_seq->get_activations()[2],
773                                                  lstm_seq->get_clip(),
774                                                  lstm_seq->get_direction());
775             break;
776         }
777         case OP_TYPEID::GRUSequence_v5:
778         {
779             auto gru_seq = static_cast<const op::v5::GRUSequence*>(&node);
780             runtime::reference::gru_sequence<T>(args[0]->get_data_ptr<char>(),
781                                                 args[0]->get_shape(),
782                                                 args[1]->get_data_ptr<char>(),
783                                                 args[1]->get_shape(),
784                                                 args[2]->get_data_ptr<char>(),
785                                                 args[2]->get_shape(),
786                                                 args[3]->get_data_ptr<char>(),
787                                                 args[3]->get_shape(),
788                                                 args[4]->get_data_ptr<char>(),
789                                                 args[4]->get_shape(),
790                                                 args[5]->get_data_ptr<char>(),
791                                                 args[5]->get_shape(),
792                                                 out[0]->get_data_ptr<char>(),
793                                                 out[1]->get_data_ptr<char>(),
794                                                 gru_seq->get_activations()[0],
795                                                 gru_seq->get_activations()[1],
796                                                 gru_seq->get_clip(),
797                                                 gru_seq->get_direction(),
798                                                 gru_seq->get_linear_before_reset());
799             break;
800         }
801         case OP_TYPEID::RNNSequence_v5:
802         {
803             auto rnn_seq = static_cast<const op::v5::RNNSequence*>(&node);
804             runtime::reference::rnn_sequence<T>(args[0]->get_data_ptr<char>(),
805                                                 args[0]->get_shape(),
806                                                 args[1]->get_data_ptr<char>(),
807                                                 args[1]->get_shape(),
808                                                 args[2]->get_data_ptr<char>(),
809                                                 args[2]->get_shape(),
810                                                 args[3]->get_data_ptr<char>(),
811                                                 args[3]->get_shape(),
812                                                 args[4]->get_data_ptr<char>(),
813                                                 args[4]->get_shape(),
814                                                 args[5]->get_data_ptr<char>(),
815                                                 args[5]->get_shape(),
816                                                 out[0]->get_data_ptr<char>(),
817                                                 out[1]->get_data_ptr<char>(),
818                                                 rnn_seq->get_activations()[0],
819                                                 rnn_seq->get_clip(),
820                                                 rnn_seq->get_direction());
821             break;
822         }
823         case OP_TYPEID::Log:
824         {
825             size_t element_count = shape_size(node.get_output_shape(0));
826             reference::log<T>(
827                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
828             break;
829         }
830         case OP_TYPEID::LogSoftmax_v5:
831         {
832             const op::v5::LogSoftmax* log_softmax = static_cast<const op::v5::LogSoftmax*>(&node);
833             int64_t i_axis = log_softmax->get_axis();
834             if (i_axis < 0)
835             {
836                 i_axis += args[0]->get_partial_shape().rank().get_length();
837             }
838             reference::log_softmax<T>(args[0]->get_data_ptr<const T>(),
839                                       out[0]->get_data_ptr<T>(),
840                                       node.get_output_shape(0),
841                                       AxisSet{(size_t)i_axis});
842             break;
843         }
844         case OP_TYPEID::LRN:
845         {
846             const op::LRN* lrn = static_cast<const op::LRN*>(&node);
847             reference::lrn<T>(args[0]->get_data_ptr<const T>(),
848                               lrn->get_reduction_axes(),
849                               out[0]->get_data_ptr<T>(),
850                               node.get_input_shape(0),
851                               lrn->get_alpha(),
852                               lrn->get_beta(),
853                               lrn->get_bias(),
854                               lrn->get_nsize());
855             break;
856         }
857         case OP_TYPEID::Negative:
858         {
859             size_t element_count = shape_size(node.get_output_shape(0));
860             reference::negate<T>(
861                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
862             break;
863         }
864         case OP_TYPEID::LogicalNot_v1:
865         case OP_TYPEID::Not:
866         {
867             size_t element_count = shape_size(node.get_output_shape(0));
868             reference::logical_not(
869                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
870             break;
871         }
872         case OP_TYPEID::OneHot_v1:
873         {
874             const op::v1::OneHot* oh = static_cast<const op::v1::OneHot*>(&node);
875             T on_value = args[2]->get_data_ptr<T>()[0];
876             T off_value = args[3]->get_data_ptr<T>()[0];
877
878             switch (args[0]->get_element_type())
879             {
880             case element::Type_t::i8:
881                 reference::one_hot(args[0]->get_data_ptr<const int8_t>(),
882                                    out[0]->get_data_ptr<T>(),
883                                    node.get_input_shape(0),
884                                    node.get_output_shape(0),
885                                    oh->get_axis(),
886                                    on_value,
887                                    off_value);
888                 break;
889             case element::Type_t::i16:
890                 reference::one_hot(args[0]->get_data_ptr<const int16_t>(),
891                                    out[0]->get_data_ptr<T>(),
892                                    node.get_input_shape(0),
893                                    node.get_output_shape(0),
894                                    oh->get_axis(),
895                                    on_value,
896                                    off_value);
897                 break;
898             case element::Type_t::i32:
899                 reference::one_hot(args[0]->get_data_ptr<const int32_t>(),
900                                    out[0]->get_data_ptr<T>(),
901                                    node.get_input_shape(0),
902                                    node.get_output_shape(0),
903                                    oh->get_axis(),
904                                    on_value,
905                                    off_value);
906                 break;
907             case element::Type_t::i64:
908                 reference::one_hot(args[0]->get_data_ptr<const int64_t>(),
909                                    out[0]->get_data_ptr<T>(),
910                                    node.get_input_shape(0),
911                                    node.get_output_shape(0),
912                                    oh->get_axis(),
913                                    on_value,
914                                    off_value);
915                 break;
916             case element::Type_t::u8:
917                 reference::one_hot(args[0]->get_data_ptr<const uint8_t>(),
918                                    out[0]->get_data_ptr<T>(),
919                                    node.get_input_shape(0),
920                                    node.get_output_shape(0),
921                                    oh->get_axis(),
922                                    on_value,
923                                    off_value);
924                 break;
925             case element::Type_t::u16:
926                 reference::one_hot(args[0]->get_data_ptr<const uint16_t>(),
927                                    out[0]->get_data_ptr<T>(),
928                                    node.get_input_shape(0),
929                                    node.get_output_shape(0),
930                                    oh->get_axis(),
931                                    on_value,
932                                    off_value);
933                 break;
934             case element::Type_t::u32:
935                 reference::one_hot(args[0]->get_data_ptr<const uint32_t>(),
936                                    out[0]->get_data_ptr<T>(),
937                                    node.get_input_shape(0),
938                                    node.get_output_shape(0),
939                                    oh->get_axis(),
940                                    on_value,
941                                    off_value);
942                 break;
943             case element::Type_t::u64:
944                 reference::one_hot(args[0]->get_data_ptr<const uint64_t>(),
945                                    out[0]->get_data_ptr<T>(),
946                                    node.get_input_shape(0),
947                                    node.get_output_shape(0),
948                                    oh->get_axis(),
949                                    on_value,
950                                    off_value);
951                 break;
952             case element::Type_t::undefined:
953             case element::Type_t::dynamic:
954             case element::Type_t::u1:
955             case element::Type_t::boolean:
956             case element::Type_t::bf16:
957             case element::Type_t::f16:
958             case element::Type_t::f32:
959             case element::Type_t::f64:
960             default: NGRAPH_CHECK(false, "Indices input element type must be integer");
961             }
962
963             break;
964         }
965         case OP_TYPEID::Parameter: break;
966         case OP_TYPEID::PriorBox:
967         {
968             const op::PriorBox* pbox = static_cast<const op::PriorBox*>(&node);
969             runtime::reference::prior_box<T>(args[0]->get_data_ptr<T>(),
970                                              args[1]->get_data_ptr<T>(),
971                                              out[0]->get_data_ptr<float>(),
972                                              out[0]->get_shape(),
973                                              pbox->get_attrs());
974             break;
975         }
976         case OP_TYPEID::ReorgYolo_v0:
977         {
978             const op::v0::ReorgYolo* reorg_yolo = static_cast<const op::v0::ReorgYolo*>(&node);
979             runtime::reference::reorg_yolo(args[0]->get_data_ptr<char>(),
980                                            out[0]->get_data_ptr<char>(),
981                                            args[0]->get_shape(),
982                                            reorg_yolo->get_strides().at(0),
983                                            args[0]->get_element_type().size());
984             break;
985         }
986         case OP_TYPEID::Quantize:
987         {
988             const op::Quantize* quantize = static_cast<const op::Quantize*>(&node);
989             auto type = quantize->get_element_type();
990
991             if (type == element::u8)
992             {
993                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
994                                        args[1]->get_data_ptr<const T>(),
995                                        args[2]->get_data_ptr<const uint8_t>(),
996                                        out[0]->get_data_ptr<uint8_t>(),
997                                        node.get_input_shape(0),
998                                        node.get_input_shape(1),
999                                        quantize->get_axes(),
1000                                        quantize->get_round_mode());
1001             }
1002             else if (type == element::i8)
1003             {
1004                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
1005                                        args[1]->get_data_ptr<const T>(),
1006                                        args[2]->get_data_ptr<const int8_t>(),
1007                                        out[0]->get_data_ptr<int8_t>(),
1008                                        node.get_input_shape(0),
1009                                        node.get_input_shape(1),
1010                                        quantize->get_axes(),
1011                                        quantize->get_round_mode());
1012             }
1013             else if (type == element::i32)
1014             {
1015                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
1016                                        args[1]->get_data_ptr<const T>(),
1017                                        args[2]->get_data_ptr<const int32_t>(),
1018                                        out[0]->get_data_ptr<int32_t>(),
1019                                        node.get_input_shape(0),
1020                                        node.get_input_shape(1),
1021                                        quantize->get_axes(),
1022                                        quantize->get_round_mode());
1023             }
1024             else
1025             {
1026                 std::stringstream ss;
1027                 ss << "unsupported element type " << type << " op Quantize";
1028                 throw std::runtime_error(ss.str());
1029             }
1030
1031             break;
1032         }
1033
1034         case OP_TYPEID::QuantizedConvolution:
1035         {
1036             const op::QuantizedConvolution* qc =
1037                 static_cast<const op::QuantizedConvolution*>(&node);
1038
1039             auto input_element_type = qc->get_input_element_type(0);
1040             auto filter_element_type = qc->get_input_element_type(1);
1041             auto output_element_type = qc->get_output_element_type(0);
1042
1043             if (input_element_type == element::u8 && filter_element_type == element::i8 &&
1044                 output_element_type == element::i8)
1045             {
1046                 reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
1047                     args[0]->get_data_ptr<const uint8_t>(),
1048                     args[1]->get_data_ptr<const int8_t>(),
1049                     out[0]->get_data_ptr<int8_t>(),
1050                     node.get_input_shape(0),
1051                     node.get_input_shape(1),
1052                     node.get_output_shape(0),
1053                     qc->get_window_movement_strides(),
1054                     qc->get_window_dilation_strides(),
1055                     qc->get_padding_below(),
1056                     qc->get_padding_above(),
1057                     qc->get_data_dilation_strides(),
1058                     args[2]->get_data_ptr<const float>(),
1059                     args[3]->get_data_ptr<const uint8_t>(),
1060                     args[4]->get_data_ptr<const float>(),
1061                     args[5]->get_data_ptr<const int8_t>(),
1062                     args[6]->get_data_ptr<const float>(),
1063                     args[7]->get_data_ptr<const int8_t>());
1064             }
1065             else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
1066                      output_element_type == element::u8)
1067             {
1068                 reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
1069                     args[0]->get_data_ptr<const uint8_t>(),
1070                     args[1]->get_data_ptr<const uint8_t>(),
1071                     out[0]->get_data_ptr<uint8_t>(),
1072                     node.get_input_shape(0),
1073                     node.get_input_shape(1),
1074                     node.get_output_shape(0),
1075                     qc->get_window_movement_strides(),
1076                     qc->get_window_dilation_strides(),
1077                     qc->get_padding_below(),
1078                     qc->get_padding_above(),
1079                     qc->get_data_dilation_strides(),
1080                     args[2]->get_data_ptr<const float>(),
1081                     args[3]->get_data_ptr<const uint8_t>(),
1082                     args[4]->get_data_ptr<const float>(),
1083                     args[5]->get_data_ptr<const uint8_t>(),
1084                     args[6]->get_data_ptr<const float>(),
1085                     args[7]->get_data_ptr<const uint8_t>());
1086             }
1087             else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
1088                      output_element_type == element::i32)
1089             {
1090                 reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
1091                     args[0]->get_data_ptr<const uint8_t>(),
1092                     args[1]->get_data_ptr<const int8_t>(),
1093                     out[0]->get_data_ptr<int32_t>(),
1094                     node.get_input_shape(0),
1095                     node.get_input_shape(1),
1096                     node.get_output_shape(0),
1097                     qc->get_window_movement_strides(),
1098                     qc->get_window_dilation_strides(),
1099                     qc->get_padding_below(),
1100                     qc->get_padding_above(),
1101                     qc->get_data_dilation_strides(),
1102                     args[2]->get_data_ptr<const float>(),
1103                     args[3]->get_data_ptr<const uint8_t>(),
1104                     args[4]->get_data_ptr<const float>(),
1105                     args[5]->get_data_ptr<const int8_t>(),
1106                     args[6]->get_data_ptr<const float>(),
1107                     args[7]->get_data_ptr<const int32_t>());
1108             }
1109             else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
1110                      output_element_type == element::i32)
1111             {
1112                 reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
1113                     args[0]->get_data_ptr<const uint8_t>(),
1114                     args[1]->get_data_ptr<const uint8_t>(),
1115                     out[0]->get_data_ptr<int32_t>(),
1116                     node.get_input_shape(0),
1117                     node.get_input_shape(1),
1118                     node.get_output_shape(0),
1119                     qc->get_window_movement_strides(),
1120                     qc->get_window_dilation_strides(),
1121                     qc->get_padding_below(),
1122                     qc->get_padding_above(),
1123                     qc->get_data_dilation_strides(),
1124                     args[2]->get_data_ptr<const float>(),
1125                     args[3]->get_data_ptr<const uint8_t>(),
1126                     args[4]->get_data_ptr<const float>(),
1127                     args[5]->get_data_ptr<const uint8_t>(),
1128                     args[6]->get_data_ptr<const float>(),
1129                     args[7]->get_data_ptr<const int32_t>());
1130             }
1131             else
1132             {
1133                 std::stringstream ss;
1134                 ss << "unsupported element type";
1135                 throw std::runtime_error(ss.str());
1136             }
1137
1138             break;
1139         }
1140
1141         case OP_TYPEID::QuantizedDot:
1142         {
1143             const op::QuantizedDot* qd = static_cast<const op::QuantizedDot*>(&node);
1144
1145             auto input0_element_type = qd->get_input_element_type(0);
1146             auto input1_element_type = qd->get_input_element_type(1);
1147             auto output_element_type = qd->get_output_element_type(0);
1148
1149             if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
1150                 output_element_type == element::i8)
1151             {
1152                 reference::dot<uint8_t, int8_t, int8_t, int32_t>(
1153                     args[0]->get_data_ptr<const uint8_t>(),
1154                     args[1]->get_data_ptr<const int8_t>(),
1155                     out[0]->get_data_ptr<int8_t>(),
1156                     node.get_input_shape(0),
1157                     node.get_input_shape(1),
1158                     node.get_output_shape(0),
1159                     1,
1160                     args[2]->get_data_ptr<const float>(),
1161                     args[3]->get_data_ptr<const uint8_t>(),
1162                     args[4]->get_data_ptr<const float>(),
1163                     args[5]->get_data_ptr<const int8_t>(),
1164                     args[6]->get_data_ptr<const float>(),
1165                     args[7]->get_data_ptr<const int8_t>());
1166             }
1167             else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
1168                      output_element_type == element::u8)
1169             {
1170                 reference::dot<uint8_t, uint8_t, uint8_t, int32_t>(
1171                     args[0]->get_data_ptr<const uint8_t>(),
1172                     args[1]->get_data_ptr<const uint8_t>(),
1173                     out[0]->get_data_ptr<uint8_t>(),
1174                     node.get_input_shape(0),
1175                     node.get_input_shape(1),
1176                     node.get_output_shape(0),
1177                     1,
1178                     args[2]->get_data_ptr<const float>(),
1179                     args[3]->get_data_ptr<const uint8_t>(),
1180                     args[4]->get_data_ptr<const float>(),
1181                     args[5]->get_data_ptr<const uint8_t>(),
1182                     args[6]->get_data_ptr<const float>(),
1183                     args[7]->get_data_ptr<const uint8_t>());
1184             }
1185             else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
1186                      output_element_type == element::i32)
1187             {
1188                 reference::dot<uint8_t, uint8_t, int32_t, int32_t>(
1189                     args[0]->get_data_ptr<const uint8_t>(),
1190                     args[1]->get_data_ptr<const uint8_t>(),
1191                     out[0]->get_data_ptr<int32_t>(),
1192                     node.get_input_shape(0),
1193                     node.get_input_shape(1),
1194                     node.get_output_shape(0),
1195                     1,
1196                     args[2]->get_data_ptr<const float>(),
1197                     args[3]->get_data_ptr<const uint8_t>(),
1198                     args[4]->get_data_ptr<const float>(),
1199                     args[5]->get_data_ptr<const uint8_t>(),
1200                     args[6]->get_data_ptr<const float>(),
1201                     args[7]->get_data_ptr<const int32_t>());
1202             }
1203             else if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
1204                      output_element_type == element::i32)
1205             {
1206                 reference::dot<uint8_t, int8_t, int32_t, int32_t>(
1207                     args[0]->get_data_ptr<const uint8_t>(),
1208                     args[1]->get_data_ptr<const int8_t>(),
1209                     out[0]->get_data_ptr<int32_t>(),
1210                     node.get_input_shape(0),
1211                     node.get_input_shape(1),
1212                     node.get_output_shape(0),
1213                     1,
1214                     args[2]->get_data_ptr<const float>(),
1215                     args[3]->get_data_ptr<const uint8_t>(),
1216                     args[4]->get_data_ptr<const float>(),
1217                     args[5]->get_data_ptr<const int8_t>(),
1218                     args[6]->get_data_ptr<const float>(),
1219                     args[7]->get_data_ptr<const int32_t>());
1220             }
1221             else
1222             {
1223                 std::stringstream ss;
1224                 ss << "unsupported element type";
1225                 throw std::runtime_error(ss.str());
1226             }
1227
1228             break;
1229         }
1230         case OP_TYPEID::RegionYolo_v0:
1231         {
1232             const op::RegionYolo* region_yolo = static_cast<const op::RegionYolo*>(&node);
1233             reference::region_yolo<T>(args[0]->get_data_ptr<const T>(),
1234                                       out[0]->get_data_ptr<T>(),
1235                                       args[0]->get_shape(),
1236                                       region_yolo->get_num_coords(),
1237                                       region_yolo->get_num_classes(),
1238                                       region_yolo->get_num_regions(),
1239                                       region_yolo->get_do_softmax(),
1240                                       region_yolo->get_mask());
1241             break;
1242         }
1243         case OP_TYPEID::Relu:
1244         {
1245             size_t element_count = shape_size(node.get_output_shape(0));
1246             reference::relu<T>(
1247                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1248             break;
1249         }
1250         case OP_TYPEID::ReplaceSlice:
1251         {
1252             const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
1253             reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
1254                                         args[1]->get_data_ptr<const T>(),
1255                                         out[0]->get_data_ptr<T>(),
1256                                         node.get_input_shape(1),
1257                                         slice->get_lower_bounds(),
1258                                         slice->get_upper_bounds(),
1259                                         slice->get_strides(),
1260                                         node.get_output_shape(0));
1261             break;
1262         }
1263         case OP_TYPEID::Reverse:
1264         {
1265             const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
1266             reference::reverse(args[0]->get_data_ptr<const char>(),
1267                                out[0]->get_data_ptr<char>(),
1268                                node.get_input_shape(0),
1269                                node.get_output_shape(0),
1270                                reverse->get_reversed_axes(),
1271                                args[0]->get_element_type().size());
1272             break;
1273         }
1274         case OP_TYPEID::ReverseSequence:
1275         {
1276             const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
1277
1278             if (node.get_input_element_type(1) == element::i32)
1279             {
1280                 reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
1281                                                         out[0]->get_data_ptr<T>(),
1282                                                         node.get_input_shape(0),
1283                                                         reverse->get_batch_axis(),
1284                                                         reverse->get_sequence_axis(),
1285                                                         args[1]->get_data_ptr<const int32_t>());
1286             }
1287             else
1288             {
1289                 throw ngraph_error("only int32 indices are supported");
1290             }
1291             break;
1292         }
1293         case OP_TYPEID::Round:
1294         {
1295             size_t element_count = shape_size(node.get_output_shape(0));
1296             reference::round<T>(args[0]->get_data_ptr<const T>(),
1297                                 out[0]->get_data_ptr<T>(),
1298                                 element_count,
1299                                 op::v5::Round::RoundMode::HALF_TO_EVEN);
1300             break;
1301         }
1302         case OP_TYPEID::Select:
1303         {
1304             size_t element_count = shape_size(node.get_output_shape(0));
1305             reference::select<T>(args[0]->get_data_ptr<const char>(),
1306                                  args[1]->get_data_ptr<const T>(),
1307                                  args[2]->get_data_ptr<const T>(),
1308                                  out[0]->get_data_ptr<T>(),
1309                                  element_count);
1310             break;
1311         }
1312         case OP_TYPEID::Sigmoid:
1313         {
1314             size_t element_count = shape_size(node.get_output_shape(0));
1315             reference::sigmoid<T>(
1316                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1317             break;
1318         }
1319         case OP_TYPEID::Sign:
1320         {
1321             size_t element_count = shape_size(node.get_output_shape(0));
1322             reference::sign<T>(
1323                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1324             break;
1325         }
1326         case OP_TYPEID::Sin:
1327         {
1328             size_t element_count = shape_size(node.get_output_shape(0));
1329             reference::sin<T>(
1330                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1331             break;
1332         }
1333         case OP_TYPEID::Sinh:
1334         {
1335             size_t element_count = shape_size(node.get_output_shape(0));
1336             reference::sinh<T>(
1337                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1338             break;
1339         }
1340         case OP_TYPEID::Sqrt:
1341         {
1342             size_t element_count = shape_size(node.get_output_shape(0));
1343             reference::sqrt<T>(
1344                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1345             break;
1346         }
1347         case OP_TYPEID::Tan:
1348         {
1349             size_t element_count = shape_size(node.get_output_shape(0));
1350             reference::tan<T>(
1351                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1352             break;
1353         }
1354         case OP_TYPEID::Tanh:
1355         {
1356             size_t element_count = shape_size(node.get_output_shape(0));
1357             reference::tanh<T>(
1358                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1359             break;
1360         }
1361         case OP_TYPEID::TopK:
1362         {
1363             const op::TopK* topk = static_cast<const op::TopK*>(&node);
1364             if (node.get_output_element_type(0) == element::i64)
1365             {
1366                 reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
1367                                             out[0]->get_data_ptr<int64_t>(),
1368                                             out[1]->get_data_ptr<T>(),
1369                                             node.get_input_shape(0),
1370                                             node.get_output_shape(0),
1371                                             topk->get_top_k_axis(),
1372                                             topk->get_k(),
1373                                             topk->get_compute_max(),
1374                                             topk->get_sort());
1375             }
1376             else if (node.get_output_element_type(0) == element::i32)
1377             {
1378                 reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
1379                                             out[0]->get_data_ptr<int32_t>(),
1380                                             out[1]->get_data_ptr<T>(),
1381                                             node.get_input_shape(0),
1382                                             node.get_output_shape(0),
1383                                             topk->get_top_k_axis(),
1384                                             topk->get_k(),
1385                                             topk->get_compute_max(),
1386                                             topk->get_sort());
1387             }
1388             else
1389             {
1390                 throw ngraph_error("Unexpected type");
1391             }
1392             break;
1393         }
1394         case OP_TYPEID::DetectionOutput_v0:
1395         {
1396             const op::DetectionOutput* detOut = static_cast<const op::DetectionOutput*>(&node);
1397             reference::referenceDetectionOutput<T> refDetOut(
1398                 detOut->get_attrs(), node.get_input_shape(0), node.get_input_shape(2));
1399             if (node.get_input_size() == 3)
1400             {
1401                 refDetOut.run(args[0]->get_data_ptr<const T>(),
1402                               args[1]->get_data_ptr<const T>(),
1403                               args[2]->get_data_ptr<const T>(),
1404                               nullptr,
1405                               nullptr,
1406                               out[0]->get_data_ptr<T>());
1407             }
1408             else if (node.get_input_size() == 5)
1409             {
1410                 refDetOut.run(args[0]->get_data_ptr<const T>(),
1411                               args[1]->get_data_ptr<const T>(),
1412                               args[2]->get_data_ptr<const T>(),
1413                               args[3]->get_data_ptr<const T>(),
1414                               args[4]->get_data_ptr<const T>(),
1415                               out[0]->get_data_ptr<T>());
1416             }
1417             else
1418             {
1419                 throw ngraph_error("DetectionOutput layer supports only 3 or 5 inputs");
1420             }
1421
1422             break;
1423         }
1424         case OP_TYPEID::ScatterNDUpdate_v3:
1425         {
1426             const op::ScatterNDUpdate* scatterNDUpd =
1427                 static_cast<const op::v3::ScatterNDUpdate*>(&node);
1428             auto idxType = scatterNDUpd->get_input_element_type(1);
1429             if (idxType == element::i32)
1430             {
1431                 reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
1432                                                        args[1]->get_data_ptr<const int32_t>(),
1433                                                        args[2]->get_data_ptr<const T>(),
1434                                                        out[0]->get_data_ptr<T>(),
1435                                                        node.get_input_shape(0),
1436                                                        node.get_input_shape(1),
1437                                                        node.get_input_shape(2));
1438             }
1439             else if (idxType == element::i64)
1440             {
1441                 reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
1442                                                        args[1]->get_data_ptr<const int64_t>(),
1443                                                        args[2]->get_data_ptr<const T>(),
1444                                                        out[0]->get_data_ptr<T>(),
1445                                                        node.get_input_shape(0),
1446                                                        node.get_input_shape(1),
1447                                                        node.get_input_shape(2));
1448             }
1449             else
1450             {
1451                 throw ngraph_error(
1452                     "ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
1453             }
1454
1455             break;
1456         }
1457         case OP_TYPEID::GatherTree_v1:
1458         {
1459             reference::gather_tree(args[0]->get_data_ptr<const char>(),
1460                                    args[1]->get_data_ptr<const char>(),
1461                                    args[2]->get_data_ptr<const char>(),
1462                                    args[3]->get_data_ptr<const char>(),
1463                                    out[0]->get_data_ptr<char>(),
1464                                    node.get_input_shape(0),
1465                                    node.get_input_shape(1),
1466                                    node.get_input_shape(2),
1467                                    node.get_input_shape(3),
1468                                    args[1]->get_element_type());
1469             break;
1470         }
1471         case OP_TYPEID::NormalizeL2:
1472         {
1473             const op::NormalizeL2* norm = static_cast<const op::NormalizeL2*>(&node);
1474             reference::normalize_l2<T>(args[0]->get_data_ptr<const T>(),
1475                                        out[0]->get_data_ptr<T>(),
1476                                        node.get_input_shape(0),
1477                                        norm->get_reduction_axes(),
1478                                        norm->get_eps(),
1479                                        norm->get_eps_mode());
1480             break;
1481         }
1482
1483         // Fused Ops are not supported in interpreter. They need to be decomposed before execution
1484         case OP_TYPEID::DepthToSpace:
1485         case OP_TYPEID::FakeQuantize:
1486         case OP_TYPEID::Gather:
1487         case OP_TYPEID::Gelu:
1488         case OP_TYPEID::GRN:
1489         case OP_TYPEID::GroupConvolution:
1490         case OP_TYPEID::GroupConvolutionBackpropData:
1491         case OP_TYPEID::HardSigmoid:
1492         case OP_TYPEID::Interpolate:
1493         case OP_TYPEID::MVN:
1494         case OP_TYPEID::PRelu:
1495         case OP_TYPEID::ScatterUpdate_v3:
1496         case OP_TYPEID::Selu:
1497         case OP_TYPEID::ShuffleChannels:
1498         case OP_TYPEID::SpaceToDepth:
1499         case OP_TYPEID::Split:
1500         case OP_TYPEID::SquaredDifference:
1501         case OP_TYPEID::StopGradient:
1502         case OP_TYPEID::TensorIterator:
1503         case OP_TYPEID::Tile:
1504         case OP_TYPEID::UnknownOp:
1505             throw unsupported_op("Unsupported op '" + node.description() + "'");
1506         case OP_TYPEID::Add:
1507         case OP_TYPEID::Broadcast:
1508         case OP_TYPEID::Clamp:
1509         case OP_TYPEID::Concat:
1510         case OP_TYPEID::Constant:
1511         case OP_TYPEID::Divide:
1512         case OP_TYPEID::Equal:
1513         case OP_TYPEID::Greater:
1514         case OP_TYPEID::GreaterEq:
1515         case OP_TYPEID::Less:
1516         case OP_TYPEID::LessEq:
1517         case OP_TYPEID::LessEqual_v1:
1518         case OP_TYPEID::LogicalAnd_v1:
1519         case OP_TYPEID::LogicalOr_v1:
1520         case OP_TYPEID::LogicalXor_v1:
1521         case OP_TYPEID::MatMul:
1522         case OP_TYPEID::Max:
1523         case OP_TYPEID::Maximum:
1524         case OP_TYPEID::Min:
1525         case OP_TYPEID::Minimum:
1526         case OP_TYPEID::Multiply:
1527         case OP_TYPEID::NonZero_v3:
1528         case OP_TYPEID::NotEqual:
1529         case OP_TYPEID::Or:
1530         case OP_TYPEID::Power:
1531         case OP_TYPEID::Range:
1532         case OP_TYPEID::Reshape:
1533         case OP_TYPEID::Result:
1534         case OP_TYPEID::Round_v5:
1535         case OP_TYPEID::ShapeOf_v3:
1536         case OP_TYPEID::ShapeOf:
1537         case OP_TYPEID::Softmax:
1538         case OP_TYPEID::Squeeze:
1539         case OP_TYPEID::Sum:
1540         case OP_TYPEID::Subtract:
1541         case OP_TYPEID::Unsqueeze:
1542         case OP_TYPEID::Xor:
1543         case OP_TYPEID::Slice:
1544             // These ops are handled by op evaluators so nothing to do
1545             break;
1546 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
1547 #pragma GCC diagnostic pop
1548 #endif
1549         }
1550     }
1551 };
1552
1553 NGRAPH_SUPPRESS_DEPRECATED_END