Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / runtime / interpreter / int_executable.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include <initializer_list>
20 #include <iostream>
21 #include <memory>
22 #include <sstream>
23 #include <string>
24 #include <vector>
25
26 #include <ngraph/runtime/host_tensor.hpp>
27 #include "backend.hpp"
28 #include "int_backend_visibility.hpp"
29 #include "ngraph/ops.hpp"
30 #include "ngraph/runtime/aligned_buffer.hpp"
31 #include "ngraph/runtime/reference/abs.hpp"
32 #include "ngraph/runtime/reference/acos.hpp"
33 #include "ngraph/runtime/reference/any.hpp"
34 #include "ngraph/runtime/reference/asin.hpp"
35 #include "ngraph/runtime/reference/atan.hpp"
36 #include "ngraph/runtime/reference/atan2.hpp"
37 #include "ngraph/runtime/reference/avg_pool.hpp"
38 #include "ngraph/runtime/reference/batch_norm.hpp"
39 #include "ngraph/runtime/reference/broadcast.hpp"
40 #include "ngraph/runtime/reference/ceiling.hpp"
41 #include "ngraph/runtime/reference/concat.hpp"
42 #include "ngraph/runtime/reference/constant.hpp"
43 #include "ngraph/runtime/reference/convert.hpp"
44 #include "ngraph/runtime/reference/convolution.hpp"
45 #include "ngraph/runtime/reference/cos.hpp"
46 #include "ngraph/runtime/reference/cosh.hpp"
47 #include "ngraph/runtime/reference/ctc_loss.hpp"
48 #include "ngraph/runtime/reference/cum_sum.hpp"
49 #include "ngraph/runtime/reference/dequantize.hpp"
50 #include "ngraph/runtime/reference/detection_output.hpp"
51 #include "ngraph/runtime/reference/dot.hpp"
52 #include "ngraph/runtime/reference/elu.hpp"
53 #include "ngraph/runtime/reference/embedding_bag_offsets_sum.hpp"
54 #include "ngraph/runtime/reference/embedding_bag_packed_sum.hpp"
55 #include "ngraph/runtime/reference/embedding_segments_sum.hpp"
56 #include "ngraph/runtime/reference/erf.hpp"
57 #include "ngraph/runtime/reference/exp.hpp"
58 #include "ngraph/runtime/reference/extract_image_patches.hpp"
59 #include "ngraph/runtime/reference/floor.hpp"
60 #include "ngraph/runtime/reference/gather.hpp"
61 #include "ngraph/runtime/reference/gather_nd.hpp"
62 #include "ngraph/runtime/reference/log.hpp"
63 #include "ngraph/runtime/reference/lrn.hpp"
64 #include "ngraph/runtime/reference/matmul.hpp"
65 #include "ngraph/runtime/reference/max.hpp"
66 #include "ngraph/runtime/reference/max_pool.hpp"
67 #include "ngraph/runtime/reference/min.hpp"
68 #include "ngraph/runtime/reference/negate.hpp"
69 #include "ngraph/runtime/reference/not.hpp"
70 #include "ngraph/runtime/reference/one_hot.hpp"
71 #include "ngraph/runtime/reference/pad.hpp"
72 #include "ngraph/runtime/reference/product.hpp"
73 #include "ngraph/runtime/reference/quantize.hpp"
74 #include "ngraph/runtime/reference/relu.hpp"
75 #include "ngraph/runtime/reference/replace_slice.hpp"
76 #include "ngraph/runtime/reference/reshape.hpp"
77 #include "ngraph/runtime/reference/result.hpp"
78 #include "ngraph/runtime/reference/reverse.hpp"
79 #include "ngraph/runtime/reference/reverse_sequence.hpp"
80 #include "ngraph/runtime/reference/round.hpp"
81 #include "ngraph/runtime/reference/scatter_nd_update.hpp"
82 #include "ngraph/runtime/reference/scatter_update.hpp"
83 #include "ngraph/runtime/reference/select.hpp"
84 #include "ngraph/runtime/reference/sigmoid.hpp"
85 #include "ngraph/runtime/reference/sign.hpp"
86 #include "ngraph/runtime/reference/sin.hpp"
87 #include "ngraph/runtime/reference/sinh.hpp"
88 #include "ngraph/runtime/reference/softmax.hpp"
89 #include "ngraph/runtime/reference/sqrt.hpp"
90 #include "ngraph/runtime/reference/sum.hpp"
91 #include "ngraph/runtime/reference/tan.hpp"
92 #include "ngraph/runtime/reference/tanh.hpp"
93 #include "ngraph/runtime/reference/topk.hpp"
94 #include "ngraph/runtime/tensor.hpp"
95 #include "op/avg_pool.hpp"
96 #include "op/convolution.hpp"
97 #include "op/group_conv.hpp"
98
99 NGRAPH_SUPPRESS_DEPRECATED_START
100
101 namespace ngraph
102 {
103     namespace runtime
104     {
105         namespace interpreter
106         {
107             class INTBackend;
108             class INTExecutable;
109
110             // This expands the op list in op_tbl.hpp into a list of enumerations that look like
111             // this:
112             // Abs,
113             // Acos,
114             // ...
115             enum class OP_TYPEID
116             {
117 #define NGRAPH_OP(NAME, NAMESPACE) ID_SUFFIX(NAME),
118 #include "opset_int_tbl.hpp"
119 #undef NGRAPH_OP
120                 UnknownOp
121             };
122         } // namespace interpreter
123     }     // namespace runtime
124 } // namespace ngraph
125
126 class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : public Executable
127 {
128     friend class INTBackend;
129
130 public:
131     INTExecutable(const std::shared_ptr<Function>& function,
132                   bool enable_performance_collection = false);
133
134     bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
135               const std::vector<std::shared_ptr<Tensor>>& inputs) override;
136
137     void set_nan_check(bool enable);
138
139     std::vector<PerformanceCounter> get_performance_data() const override;
140
141     std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index) override;
142
143     std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index) override;
144
145     std::vector<std::shared_ptr<runtime::Tensor>>
146         create_input_tensor(size_t input_index, size_t pipeline_depth) override;
147
148     std::vector<std::shared_ptr<runtime::Tensor>>
149         create_output_tensor(size_t output_index, size_t pipeline_depth) override;
150
151 protected:
152     std::shared_ptr<ngraph::op::Parameter> get_parameter(size_t index) const;
153     std::shared_ptr<ngraph::op::Result> get_result(size_t index) const;
154     int get_alignment() const { return 64; }
155     bool m_is_compiled = false;
156     bool m_nan_check_enabled = false;
157     bool m_performance_counters_enabled = false;
158     std::shared_ptr<Function> m_function;
159     std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
160     std::vector<std::shared_ptr<Node>> m_nodes;
161     std::set<std::string> m_unsupported_op_name_list;
162
163     static OP_TYPEID get_typeid(const Node& node);
164
165     static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
166                                   const Node* op = nullptr);
167
168     virtual void generate_calls(const element::Type& type,
169                                 const Node& op,
170                                 const std::vector<std::shared_ptr<HostTensor>>& outputs,
171                                 const std::vector<std::shared_ptr<HostTensor>>& inputs);
172
173     template <typename T>
174     void op_engine(const Node& node,
175                    const std::vector<std::shared_ptr<HostTensor>>& out,
176                    const std::vector<std::shared_ptr<HostTensor>>& args)
177     {
178 // We want to check that every OP_TYPEID enumeration is included in the list.
179 // These GCC flags enable compile-time checking so that if an enumeration
180 // is not in the list an error is generated.
181 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
182 #pragma GCC diagnostic push
183 #pragma GCC diagnostic error "-Wswitch"
184 #pragma GCC diagnostic error "-Wswitch-enum"
185 #endif
186         switch (get_typeid(node))
187         {
188         case OP_TYPEID::Abs:
189         {
190             size_t element_count = shape_size(node.get_output_shape(0));
191             reference::abs<T>(
192                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
193             break;
194         }
195         case OP_TYPEID::Acos:
196         {
197             size_t element_count = shape_size(node.get_output_shape(0));
198             reference::acos<T>(
199                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
200             break;
201         }
202         case OP_TYPEID::Any:
203         {
204             const op::Any* any = static_cast<const op::Any*>(&node);
205             reference::any(args[0]->get_data_ptr<const char>(),
206                            out[0]->get_data_ptr<char>(),
207                            node.get_input_shape(0),
208                            any->get_reduction_axes(),
209                            false);
210             break;
211         }
212         case OP_TYPEID::Asin:
213         {
214             size_t element_count = shape_size(node.get_output_shape(0));
215             reference::asin<T>(
216                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
217             break;
218         }
219         case OP_TYPEID::Atan:
220         {
221             size_t element_count = shape_size(node.get_output_shape(0));
222             reference::atan<T>(
223                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
224             break;
225         }
226         case OP_TYPEID::Elu:
227         {
228             const op::Elu* elu_node = static_cast<const op::Elu*>(&node);
229
230             size_t element_count = shape_size(node.get_output_shape(0));
231             reference::elu<T>(args[0]->get_data_ptr<const T>(),
232                               out[0]->get_data_ptr<T>(),
233                               element_count,
234                               elu_node->get_alpha());
235             break;
236         }
237         case OP_TYPEID::AvgPool:
238         {
239             const op::v0::AvgPool* avg_pool = static_cast<const op::v0::AvgPool*>(&node);
240
241             reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
242                                    out[0]->get_data_ptr<T>(),
243                                    node.get_input_shape(0),
244                                    node.get_output_shape(0),
245                                    avg_pool->get_window_shape(),
246                                    avg_pool->get_window_movement_strides(),
247                                    avg_pool->get_padding_below(),
248                                    avg_pool->get_padding_above(),
249                                    avg_pool->get_include_padding_in_avg_computation());
250             break;
251         }
252         case OP_TYPEID::BatchNormInference:
253         {
254             const ngraph::op::BatchNormInference* bn =
255                 static_cast<const ngraph::op::BatchNormInference*>(&node);
256             reference::batch_norm_inference<T>(bn->get_eps_value(),
257                                                args[0]->get_data_ptr<const T>(),
258                                                args[1]->get_data_ptr<const T>(),
259                                                args[2]->get_data_ptr<const T>(),
260                                                args[3]->get_data_ptr<const T>(),
261                                                args[4]->get_data_ptr<const T>(),
262                                                out[0]->get_data_ptr<T>(),
263                                                node.get_input_shape(2));
264             break;
265         }
266         case OP_TYPEID::BroadcastLike: break;
267         case OP_TYPEID::Ceiling:
268         {
269             size_t element_count = shape_size(node.get_output_shape(0));
270             reference::ceiling<T>(
271                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
272             break;
273         }
274         case OP_TYPEID::Convert:
275         {
276             // const op::Convert* c = static_cast<const op::Convert*>(&node);
277             element::Type type = node.get_element_type();
278             std::stringstream ss;
279             size_t element_count = shape_size(node.get_output_shape(0));
280             switch (type)
281             {
282             case element::Type_t::boolean:
283                 reference::convert_to_bool<T>(
284                     args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
285                 break;
286             case element::Type_t::f32:
287                 reference::convert<T>(
288                     args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
289                 break;
290             case element::Type_t::f64:
291                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
292                                       out[0]->get_data_ptr<double>(),
293                                       element_count);
294                 break;
295             case element::Type_t::i8:
296                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
297                                       out[0]->get_data_ptr<int8_t>(),
298                                       element_count);
299                 break;
300             case element::Type_t::i16:
301                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
302                                       out[0]->get_data_ptr<int16_t>(),
303                                       element_count);
304                 break;
305             case element::Type_t::i32:
306                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
307                                       out[0]->get_data_ptr<int32_t>(),
308                                       element_count);
309                 break;
310             case element::Type_t::i64:
311                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
312                                       out[0]->get_data_ptr<int64_t>(),
313                                       element_count);
314                 break;
315             case element::Type_t::u8:
316                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
317                                       out[0]->get_data_ptr<uint8_t>(),
318                                       element_count);
319                 break;
320             case element::Type_t::u16:
321                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
322                                       out[0]->get_data_ptr<uint16_t>(),
323                                       element_count);
324                 break;
325             case element::Type_t::u32:
326                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
327                                       out[0]->get_data_ptr<uint32_t>(),
328                                       element_count);
329                 break;
330             case element::Type_t::u64:
331                 reference::convert<T>(args[0]->get_data_ptr<const T>(),
332                                       out[0]->get_data_ptr<uint64_t>(),
333                                       element_count);
334                 break;
335             case element::Type_t::undefined:
336             case element::Type_t::dynamic:
337             case element::Type_t::u1:
338             case element::Type_t::bf16:
339             case element::Type_t::f16:
340                 ss << "unsupported element type " << type << " op Convert";
341                 throw std::runtime_error(ss.str());
342             }
343             break;
344         }
345         case OP_TYPEID::Convolution:
346         {
347             const op::v0::Convolution* c = static_cast<const op::v0::Convolution*>(&node);
348             reference::convolution<T>(args[0]->get_data_ptr<const T>(),
349                                       args[1]->get_data_ptr<const T>(),
350                                       out[0]->get_data_ptr<T>(),
351                                       node.get_input_shape(0),
352                                       node.get_input_shape(1),
353                                       node.get_output_shape(0),
354                                       c->get_window_movement_strides(),
355                                       c->get_window_dilation_strides(),
356                                       c->get_padding_below(),
357                                       c->get_padding_above(),
358                                       c->get_data_dilation_strides());
359
360             break;
361         }
362         case OP_TYPEID::ConvolutionBackpropData:
363         {
364             // Note that args[1] and args[0] are switched here from the usual order.
365             const op::v0::ConvolutionBackpropData* c =
366                 static_cast<const op::v0::ConvolutionBackpropData*>(&node);
367             reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
368                                                   args[0]->get_data_ptr<const T>(),
369                                                   out[0]->get_data_ptr<T>(),
370                                                   c->get_input_shape(1),
371                                                   c->get_input_shape(0),
372                                                   c->get_data_batch_shape(),
373                                                   c->get_data_dilation_strides_forward(),
374                                                   c->get_window_dilation_strides_forward(),
375                                                   c->compute_backward_delta_out_pad_below(),
376                                                   c->compute_backward_delta_out_pad_above(),
377                                                   c->get_window_movement_strides_forward());
378             break;
379         }
380         case OP_TYPEID::Cos:
381         {
382             size_t element_count = shape_size(node.get_output_shape(0));
383             reference::cos<T>(
384                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
385             break;
386         }
387         case OP_TYPEID::Cosh:
388         {
389             size_t element_count = shape_size(node.get_output_shape(0));
390             reference::cosh<T>(
391                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
392             break;
393         }
394         case OP_TYPEID::CTCLoss_v4:
395         {
396             const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
397             auto t_int = node.get_input_element_type(1);
398             if (t_int == element::i32)
399             {
400                 reference::CTCLoss<T, int32_t>(
401                     args[0]->get_data_ptr<const T>(),
402                     ctc_loss->get_input_shape(0),
403                     args[1]->get_data_ptr<const int32_t>(),
404                     args[2]->get_data_ptr<const int32_t>(),
405                     args[3]->get_data_ptr<const int32_t>(),
406                     args.size() > 4 ? args[4]->get_data_ptr<const int32_t>() : nullptr,
407                     ctc_loss->get_preprocess_collapse_repeated(),
408                     ctc_loss->get_ctc_merge_repeated(),
409                     ctc_loss->get_unique(),
410                     out[0]->get_data_ptr<T>());
411             }
412             else if (t_int == element::i64)
413             {
414                 reference::CTCLoss<T, int64_t>(
415                     args[0]->get_data_ptr<const T>(),
416                     ctc_loss->get_input_shape(0),
417                     args[1]->get_data_ptr<const int64_t>(),
418                     args[2]->get_data_ptr<const int64_t>(),
419                     args[3]->get_data_ptr<const int64_t>(),
420                     args.size() > 4 ? args[4]->get_data_ptr<const int64_t>() : nullptr,
421                     ctc_loss->get_preprocess_collapse_repeated(),
422                     ctc_loss->get_ctc_merge_repeated(),
423                     ctc_loss->get_unique(),
424                     out[0]->get_data_ptr<T>());
425             }
426             break;
427         }
428         case OP_TYPEID::CumSum:
429         {
430             const op::CumSum* cumsum = static_cast<const op::CumSum*>(&node);
431             auto axis_et = node.get_input_element_type(1);
432             if (axis_et == element::i32)
433             {
434                 reference::cumsum<T, int32_t>(args[0]->get_data_ptr<const T>(),
435                                               args[1]->get_data_ptr<const int32_t>(),
436                                               out[0]->get_data_ptr<T>(),
437                                               node.get_input_shape(0),
438                                               cumsum->is_exclusive(),
439                                               cumsum->is_reverse());
440             }
441             else if (axis_et == element::i64)
442             {
443                 reference::cumsum<T, int64_t>(args[0]->get_data_ptr<const T>(),
444                                               args[1]->get_data_ptr<const int64_t>(),
445                                               out[0]->get_data_ptr<T>(),
446                                               node.get_input_shape(0),
447                                               cumsum->is_exclusive(),
448                                               cumsum->is_reverse());
449             }
450             break;
451         }
452         case OP_TYPEID::Dequantize:
453         {
454             const op::Dequantize* dequantize = static_cast<const op::Dequantize*>(&node);
455             auto type = dequantize->get_element_type();
456
457             if (type == element::f32)
458             {
459                 reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
460                                          args[1]->get_data_ptr<const float>(),
461                                          args[2]->get_data_ptr<const T>(),
462                                          out[0]->get_data_ptr<float>(),
463                                          node.get_input_shape(0),
464                                          node.get_input_shape(1),
465                                          dequantize->get_axes());
466             }
467             else if (type == element::f64)
468             {
469                 reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
470                                          args[1]->get_data_ptr<const double>(),
471                                          args[2]->get_data_ptr<const T>(),
472                                          out[0]->get_data_ptr<double>(),
473                                          node.get_input_shape(0),
474                                          node.get_input_shape(1),
475                                          dequantize->get_axes());
476             }
477             else
478             {
479                 std::stringstream ss;
480                 ss << "unsupported element type " << type << " op Dequantize";
481                 throw std::runtime_error(ss.str());
482             }
483
484             break;
485         }
486         case OP_TYPEID::Dot:
487         {
488             const op::Dot* dot = static_cast<const op::Dot*>(&node);
489
490             reference::dot(args[0]->get_data_ptr<const T>(),
491                            args[1]->get_data_ptr<const T>(),
492                            out[0]->get_data_ptr<T>(),
493                            node.get_input_shape(0),
494                            node.get_input_shape(1),
495                            node.get_output_shape(0),
496                            dot->get_reduction_axes_count());
497             break;
498         }
499         case OP_TYPEID::EmbeddingBagOffsetsSum_v3:
500         {
501             const op::EmbeddingBagOffsetsSum* embed =
502                 static_cast<const op::EmbeddingBagOffsetsSum*>(&node);
503             auto indicesType = embed->input(1).get_element_type();
504             size_t indices_num = shape_size(embed->get_input_shape(1));
505
506             if (indicesType == element::u64 || indicesType == element::i64)
507             {
508                 reference::embeddingBagOffsetsSum<T, size_t>(
509                     args[0]->get_data_ptr<const T>(),
510                     args[1]->get_data_ptr<const size_t>(),
511                     args[2]->get_data_ptr<const size_t>(),
512                     args.size() > 3 ? args[3]->get_data_ptr<const size_t>() : nullptr,
513                     args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
514                     out[0]->get_data_ptr<T>(),
515                     indices_num,
516                     embed->get_shape());
517             }
518             else if (indicesType == element::u32 || indicesType == element::i32)
519             {
520                 reference::embeddingBagOffsetsSum<T, unsigned>(
521                     args[0]->get_data_ptr<const T>(),
522                     args[1]->get_data_ptr<const unsigned>(),
523                     args[2]->get_data_ptr<const unsigned>(),
524                     args.size() > 3 ? args[3]->get_data_ptr<const unsigned>() : nullptr,
525                     args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
526                     out[0]->get_data_ptr<T>(),
527                     indices_num,
528                     embed->get_shape());
529             }
530             else
531             {
532                 throw ngraph_error(std::string("Unsupported index type ") +
533                                    indicesType.c_type_string() +
534                                    std::string(" in EmbeddingBagOffsetsSum"));
535             }
536             break;
537         }
538         case OP_TYPEID::EmbeddingBagPackedSum_v3:
539         {
540             const op::EmbeddingBagPackedSum* embed =
541                 static_cast<const op::EmbeddingBagPackedSum*>(&node);
542             auto indicesType = embed->input(1).get_element_type();
543
544             if (indicesType == element::u64 || indicesType == element::i64)
545             {
546                 reference::embeddingBagPackedSum<T, size_t>(
547                     args[0]->get_data_ptr<const T>(),
548                     args[1]->get_data_ptr<const size_t>(),
549                     args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
550                     out[0]->get_data_ptr<T>(),
551                     embed->get_input_shape(1),
552                     embed->get_shape());
553             }
554             else if (indicesType == element::u32 || indicesType == element::i32)
555             {
556                 reference::embeddingBagPackedSum<T, unsigned>(
557                     args[0]->get_data_ptr<const T>(),
558                     args[1]->get_data_ptr<const unsigned>(),
559                     args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
560                     out[0]->get_data_ptr<T>(),
561                     embed->get_input_shape(1),
562                     embed->get_shape());
563             }
564             else
565             {
566                 throw ngraph_error(std::string("Unsupported index type ") +
567                                    indicesType.c_type_string() +
568                                    std::string(" in EmbeddingBagPackedSum"));
569             }
570             break;
571         }
572         case OP_TYPEID::EmbeddingSegmentsSum_v3:
573         {
574             const op::EmbeddingSegmentsSum* embed =
575                 static_cast<const op::EmbeddingSegmentsSum*>(&node);
576             auto indicesType = embed->input(1).get_element_type();
577             size_t indices_num = shape_size(embed->get_input_shape(1));
578
579             if (indicesType == element::u64 || indicesType == element::i64)
580             {
581                 reference::embeddingSegmentsSum<T, size_t>(
582                     args[0]->get_data_ptr<const T>(),
583                     args[1]->get_data_ptr<const size_t>(),
584                     args[2]->get_data_ptr<const size_t>(),
585                     args.size() > 4 ? args[4]->get_data_ptr<const size_t>() : nullptr,
586                     args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
587                     out[0]->get_data_ptr<T>(),
588                     embed->get_input_shape(0),
589                     embed->get_input_shape(1),
590                     embed->get_shape());
591             }
592             else if (indicesType == element::u32 || indicesType == element::i32)
593             {
594                 reference::embeddingSegmentsSum<T, unsigned>(
595                     args[0]->get_data_ptr<const T>(),
596                     args[1]->get_data_ptr<const unsigned>(),
597                     args[2]->get_data_ptr<const unsigned>(),
598                     args.size() > 4 ? args[4]->get_data_ptr<const unsigned>() : nullptr,
599                     args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
600                     out[0]->get_data_ptr<T>(),
601                     embed->get_input_shape(0),
602                     embed->get_input_shape(1),
603                     embed->get_shape());
604             }
605             else
606             {
607                 throw ngraph_error(std::string("Unsupported index type ") +
608                                    indicesType.c_type_string() +
609                                    std::string(" in EmbeddingSegmentsSum"));
610             }
611             break;
612         }
613         case OP_TYPEID::Erf:
614         {
615             size_t element_count = shape_size(node.get_output_shape(0));
616             reference::erf<T>(
617                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
618             break;
619         }
620         case OP_TYPEID::ExtractImagePatches_v3:
621         {
622             const op::ExtractImagePatches* extImgPatches =
623                 static_cast<const op::ExtractImagePatches*>(&node);
624             reference::extractImagePatches<T, size_t>(extImgPatches,
625                                                       args[0]->get_data_ptr<const T>(),
626                                                       out[0]->get_data_ptr<T>(),
627                                                       extImgPatches->get_input_shape(0),
628                                                       extImgPatches->get_shape());
629             break;
630         }
631         case OP_TYPEID::Exp:
632         {
633             size_t element_count = shape_size(node.get_output_shape(0));
634             reference::exp<T>(
635                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
636             break;
637         }
638 #ifdef INTERPRETER_USE_HYBRID
639         case OP_TYPEID::FunctionCall:
640         {
641             auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
642             auto backend = f->get_backend();
643             auto executable = f->get_executable();
644
645             std::vector<std::shared_ptr<Tensor>> outputs;
646             std::vector<std::shared_ptr<Tensor>> inputs;
647             for (const std::shared_ptr<HostTensor>& t : out)
648             {
649                 auto backend_tensor = backend->create_tensor(
650                     t->get_element_type(), t->get_shape(), t->get_data_ptr());
651                 outputs.push_back(backend_tensor);
652             }
653             for (const std::shared_ptr<HostTensor>& t : args)
654             {
655                 auto backend_tensor = backend->create_tensor(
656                     t->get_element_type(), t->get_shape(), t->get_data_ptr());
657                 inputs.push_back(backend_tensor);
658             }
659             executable->call(outputs, inputs);
660             break;
661         }
662 #endif
663         case OP_TYPEID::Floor:
664         {
665             size_t element_count = shape_size(node.get_output_shape(0));
666             reference::floor<T>(
667                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
668             break;
669         }
670         case OP_TYPEID::GatherND:
671         {
672             if (node.get_input_element_type(1) == element::i64)
673             {
674                 reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
675                                                  args[1]->get_data_ptr<int64_t>(),
676                                                  out[0]->get_data_ptr<T>(),
677                                                  node.get_input_shape(0),
678                                                  node.get_input_shape(1),
679                                                  node.get_output_shape(0));
680             }
681             else if (node.get_input_element_type(1) == element::i32)
682             {
683                 reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
684                                                  args[1]->get_data_ptr<int32_t>(),
685                                                  out[0]->get_data_ptr<T>(),
686                                                  node.get_input_shape(0),
687                                                  node.get_input_shape(1),
688                                                  node.get_output_shape(0));
689             }
690             else
691             {
692                 throw ngraph_error("Unexpected type");
693             }
694             break;
695         }
696         case OP_TYPEID::Log:
697         {
698             size_t element_count = shape_size(node.get_output_shape(0));
699             reference::log<T>(
700                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
701             break;
702         }
703         case OP_TYPEID::LRN:
704         {
705             const op::LRN* lrn = static_cast<const op::LRN*>(&node);
706             reference::lrn<T>(args[0]->get_data_ptr<const T>(),
707                               lrn->get_reduction_axes(),
708                               out[0]->get_data_ptr<T>(),
709                               node.get_input_shape(0),
710                               lrn->get_alpha(),
711                               lrn->get_beta(),
712                               lrn->get_bias(),
713                               lrn->get_nsize());
714             break;
715         }
716         case OP_TYPEID::Negative:
717         {
718             size_t element_count = shape_size(node.get_output_shape(0));
719             reference::negate<T>(
720                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
721             break;
722         }
723         case OP_TYPEID::LogicalNot_v1:
724         case OP_TYPEID::Not:
725         {
726             size_t element_count = shape_size(node.get_output_shape(0));
727             reference::logical_not(
728                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
729             break;
730         }
731         case OP_TYPEID::OneHot:
732         {
733             const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
734             reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
735                                   out[0]->get_data_ptr<T>(),
736                                   node.get_input_shape(0),
737                                   node.get_output_shape(0),
738                                   oh->get_one_hot_axis());
739             break;
740         }
741         case OP_TYPEID::Parameter: break;
742         case OP_TYPEID::Quantize:
743         {
744             const op::Quantize* quantize = static_cast<const op::Quantize*>(&node);
745             auto type = quantize->get_element_type();
746
747             if (type == element::u8)
748             {
749                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
750                                        args[1]->get_data_ptr<const T>(),
751                                        args[2]->get_data_ptr<const uint8_t>(),
752                                        out[0]->get_data_ptr<uint8_t>(),
753                                        node.get_input_shape(0),
754                                        node.get_input_shape(1),
755                                        quantize->get_axes(),
756                                        quantize->get_round_mode());
757             }
758             else if (type == element::i8)
759             {
760                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
761                                        args[1]->get_data_ptr<const T>(),
762                                        args[2]->get_data_ptr<const int8_t>(),
763                                        out[0]->get_data_ptr<int8_t>(),
764                                        node.get_input_shape(0),
765                                        node.get_input_shape(1),
766                                        quantize->get_axes(),
767                                        quantize->get_round_mode());
768             }
769             else if (type == element::i32)
770             {
771                 reference::quantize<T>(args[0]->get_data_ptr<const T>(),
772                                        args[1]->get_data_ptr<const T>(),
773                                        args[2]->get_data_ptr<const int32_t>(),
774                                        out[0]->get_data_ptr<int32_t>(),
775                                        node.get_input_shape(0),
776                                        node.get_input_shape(1),
777                                        quantize->get_axes(),
778                                        quantize->get_round_mode());
779             }
780             else
781             {
782                 std::stringstream ss;
783                 ss << "unsupported element type " << type << " op Quantize";
784                 throw std::runtime_error(ss.str());
785             }
786
787             break;
788         }
789
790         case OP_TYPEID::QuantizedConvolution:
791         {
792             const op::QuantizedConvolution* qc =
793                 static_cast<const op::QuantizedConvolution*>(&node);
794
795             auto input_element_type = qc->get_input_element_type(0);
796             auto filter_element_type = qc->get_input_element_type(1);
797             auto output_element_type = qc->get_output_element_type(0);
798
799             if (input_element_type == element::u8 && filter_element_type == element::i8 &&
800                 output_element_type == element::i8)
801             {
802                 reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
803                     args[0]->get_data_ptr<const uint8_t>(),
804                     args[1]->get_data_ptr<const int8_t>(),
805                     out[0]->get_data_ptr<int8_t>(),
806                     node.get_input_shape(0),
807                     node.get_input_shape(1),
808                     node.get_output_shape(0),
809                     qc->get_window_movement_strides(),
810                     qc->get_window_dilation_strides(),
811                     qc->get_padding_below(),
812                     qc->get_padding_above(),
813                     qc->get_data_dilation_strides(),
814                     args[2]->get_data_ptr<const float>(),
815                     args[3]->get_data_ptr<const uint8_t>(),
816                     args[4]->get_data_ptr<const float>(),
817                     args[5]->get_data_ptr<const int8_t>(),
818                     args[6]->get_data_ptr<const float>(),
819                     args[7]->get_data_ptr<const int8_t>());
820             }
821             else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
822                      output_element_type == element::u8)
823             {
824                 reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
825                     args[0]->get_data_ptr<const uint8_t>(),
826                     args[1]->get_data_ptr<const uint8_t>(),
827                     out[0]->get_data_ptr<uint8_t>(),
828                     node.get_input_shape(0),
829                     node.get_input_shape(1),
830                     node.get_output_shape(0),
831                     qc->get_window_movement_strides(),
832                     qc->get_window_dilation_strides(),
833                     qc->get_padding_below(),
834                     qc->get_padding_above(),
835                     qc->get_data_dilation_strides(),
836                     args[2]->get_data_ptr<const float>(),
837                     args[3]->get_data_ptr<const uint8_t>(),
838                     args[4]->get_data_ptr<const float>(),
839                     args[5]->get_data_ptr<const uint8_t>(),
840                     args[6]->get_data_ptr<const float>(),
841                     args[7]->get_data_ptr<const uint8_t>());
842             }
843             else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
844                      output_element_type == element::i32)
845             {
846                 reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
847                     args[0]->get_data_ptr<const uint8_t>(),
848                     args[1]->get_data_ptr<const int8_t>(),
849                     out[0]->get_data_ptr<int32_t>(),
850                     node.get_input_shape(0),
851                     node.get_input_shape(1),
852                     node.get_output_shape(0),
853                     qc->get_window_movement_strides(),
854                     qc->get_window_dilation_strides(),
855                     qc->get_padding_below(),
856                     qc->get_padding_above(),
857                     qc->get_data_dilation_strides(),
858                     args[2]->get_data_ptr<const float>(),
859                     args[3]->get_data_ptr<const uint8_t>(),
860                     args[4]->get_data_ptr<const float>(),
861                     args[5]->get_data_ptr<const int8_t>(),
862                     args[6]->get_data_ptr<const float>(),
863                     args[7]->get_data_ptr<const int32_t>());
864             }
865             else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
866                      output_element_type == element::i32)
867             {
868                 reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
869                     args[0]->get_data_ptr<const uint8_t>(),
870                     args[1]->get_data_ptr<const uint8_t>(),
871                     out[0]->get_data_ptr<int32_t>(),
872                     node.get_input_shape(0),
873                     node.get_input_shape(1),
874                     node.get_output_shape(0),
875                     qc->get_window_movement_strides(),
876                     qc->get_window_dilation_strides(),
877                     qc->get_padding_below(),
878                     qc->get_padding_above(),
879                     qc->get_data_dilation_strides(),
880                     args[2]->get_data_ptr<const float>(),
881                     args[3]->get_data_ptr<const uint8_t>(),
882                     args[4]->get_data_ptr<const float>(),
883                     args[5]->get_data_ptr<const uint8_t>(),
884                     args[6]->get_data_ptr<const float>(),
885                     args[7]->get_data_ptr<const int32_t>());
886             }
887             else
888             {
889                 std::stringstream ss;
890                 ss << "unsupported element type";
891                 throw std::runtime_error(ss.str());
892             }
893
894             break;
895         }
896
897         case OP_TYPEID::QuantizedDot:
898         {
899             const op::QuantizedDot* qd = static_cast<const op::QuantizedDot*>(&node);
900
901             auto input0_element_type = qd->get_input_element_type(0);
902             auto input1_element_type = qd->get_input_element_type(1);
903             auto output_element_type = qd->get_output_element_type(0);
904
905             if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
906                 output_element_type == element::i8)
907             {
908                 reference::dot<uint8_t, int8_t, int8_t, int32_t>(
909                     args[0]->get_data_ptr<const uint8_t>(),
910                     args[1]->get_data_ptr<const int8_t>(),
911                     out[0]->get_data_ptr<int8_t>(),
912                     node.get_input_shape(0),
913                     node.get_input_shape(1),
914                     node.get_output_shape(0),
915                     1,
916                     args[2]->get_data_ptr<const float>(),
917                     args[3]->get_data_ptr<const uint8_t>(),
918                     args[4]->get_data_ptr<const float>(),
919                     args[5]->get_data_ptr<const int8_t>(),
920                     args[6]->get_data_ptr<const float>(),
921                     args[7]->get_data_ptr<const int8_t>());
922             }
923             else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
924                      output_element_type == element::u8)
925             {
926                 reference::dot<uint8_t, uint8_t, uint8_t, int32_t>(
927                     args[0]->get_data_ptr<const uint8_t>(),
928                     args[1]->get_data_ptr<const uint8_t>(),
929                     out[0]->get_data_ptr<uint8_t>(),
930                     node.get_input_shape(0),
931                     node.get_input_shape(1),
932                     node.get_output_shape(0),
933                     1,
934                     args[2]->get_data_ptr<const float>(),
935                     args[3]->get_data_ptr<const uint8_t>(),
936                     args[4]->get_data_ptr<const float>(),
937                     args[5]->get_data_ptr<const uint8_t>(),
938                     args[6]->get_data_ptr<const float>(),
939                     args[7]->get_data_ptr<const uint8_t>());
940             }
941             else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
942                      output_element_type == element::i32)
943             {
944                 reference::dot<uint8_t, uint8_t, int32_t, int32_t>(
945                     args[0]->get_data_ptr<const uint8_t>(),
946                     args[1]->get_data_ptr<const uint8_t>(),
947                     out[0]->get_data_ptr<int32_t>(),
948                     node.get_input_shape(0),
949                     node.get_input_shape(1),
950                     node.get_output_shape(0),
951                     1,
952                     args[2]->get_data_ptr<const float>(),
953                     args[3]->get_data_ptr<const uint8_t>(),
954                     args[4]->get_data_ptr<const float>(),
955                     args[5]->get_data_ptr<const uint8_t>(),
956                     args[6]->get_data_ptr<const float>(),
957                     args[7]->get_data_ptr<const int32_t>());
958             }
959             else if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
960                      output_element_type == element::i32)
961             {
962                 reference::dot<uint8_t, int8_t, int32_t, int32_t>(
963                     args[0]->get_data_ptr<const uint8_t>(),
964                     args[1]->get_data_ptr<const int8_t>(),
965                     out[0]->get_data_ptr<int32_t>(),
966                     node.get_input_shape(0),
967                     node.get_input_shape(1),
968                     node.get_output_shape(0),
969                     1,
970                     args[2]->get_data_ptr<const float>(),
971                     args[3]->get_data_ptr<const uint8_t>(),
972                     args[4]->get_data_ptr<const float>(),
973                     args[5]->get_data_ptr<const int8_t>(),
974                     args[6]->get_data_ptr<const float>(),
975                     args[7]->get_data_ptr<const int32_t>());
976             }
977             else
978             {
979                 std::stringstream ss;
980                 ss << "unsupported element type";
981                 throw std::runtime_error(ss.str());
982             }
983
984             break;
985         }
986         case OP_TYPEID::Relu:
987         {
988             size_t element_count = shape_size(node.get_output_shape(0));
989             reference::relu<T>(
990                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
991             break;
992         }
993         case OP_TYPEID::ReplaceSlice:
994         {
995             const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
996             reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
997                                         args[1]->get_data_ptr<const T>(),
998                                         out[0]->get_data_ptr<T>(),
999                                         node.get_input_shape(1),
1000                                         slice->get_lower_bounds(),
1001                                         slice->get_upper_bounds(),
1002                                         slice->get_strides(),
1003                                         node.get_output_shape(0));
1004             break;
1005         }
1006         case OP_TYPEID::Reverse:
1007         {
1008             const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
1009             reference::reverse(args[0]->get_data_ptr<const char>(),
1010                                out[0]->get_data_ptr<char>(),
1011                                node.get_input_shape(0),
1012                                node.get_output_shape(0),
1013                                reverse->get_reversed_axes(),
1014                                args[0]->get_element_type().size());
1015             break;
1016         }
1017         case OP_TYPEID::ReverseSequence:
1018         {
1019             const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
1020
1021             if (node.get_input_element_type(1) == element::i32)
1022             {
1023                 reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
1024                                                         out[0]->get_data_ptr<T>(),
1025                                                         node.get_input_shape(0),
1026                                                         reverse->get_batch_axis(),
1027                                                         reverse->get_sequence_axis(),
1028                                                         args[1]->get_data_ptr<const int32_t>());
1029             }
1030             else
1031             {
1032                 throw ngraph_error("only int32 indices are supported");
1033             }
1034             break;
1035         }
1036         case OP_TYPEID::Round:
1037         {
1038             size_t element_count = shape_size(node.get_output_shape(0));
1039             reference::round<T>(
1040                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1041             break;
1042         }
1043         case OP_TYPEID::Select:
1044         {
1045             size_t element_count = shape_size(node.get_output_shape(0));
1046             reference::select<T>(args[0]->get_data_ptr<const char>(),
1047                                  args[1]->get_data_ptr<const T>(),
1048                                  args[2]->get_data_ptr<const T>(),
1049                                  out[0]->get_data_ptr<T>(),
1050                                  element_count);
1051             break;
1052         }
1053         case OP_TYPEID::Sigmoid:
1054         {
1055             size_t element_count = shape_size(node.get_output_shape(0));
1056             reference::sigmoid<T>(
1057                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1058             break;
1059         }
1060         case OP_TYPEID::Sign:
1061         {
1062             size_t element_count = shape_size(node.get_output_shape(0));
1063             reference::sign<T>(
1064                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1065             break;
1066         }
1067         case OP_TYPEID::Sin:
1068         {
1069             size_t element_count = shape_size(node.get_output_shape(0));
1070             reference::sin<T>(
1071                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1072             break;
1073         }
1074         case OP_TYPEID::Sinh:
1075         {
1076             size_t element_count = shape_size(node.get_output_shape(0));
1077             reference::sinh<T>(
1078                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1079             break;
1080         }
1081         case OP_TYPEID::Sqrt:
1082         {
1083             size_t element_count = shape_size(node.get_output_shape(0));
1084             reference::sqrt<T>(
1085                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1086             break;
1087         }
1088         case OP_TYPEID::Tan:
1089         {
1090             size_t element_count = shape_size(node.get_output_shape(0));
1091             reference::tan<T>(
1092                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1093             break;
1094         }
1095         case OP_TYPEID::Tanh:
1096         {
1097             size_t element_count = shape_size(node.get_output_shape(0));
1098             reference::tanh<T>(
1099                 args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
1100             break;
1101         }
1102         case OP_TYPEID::TopK:
1103         {
1104             const op::TopK* topk = static_cast<const op::TopK*>(&node);
1105             if (node.get_output_element_type(0) == element::i64)
1106             {
1107                 reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
1108                                             out[0]->get_data_ptr<int64_t>(),
1109                                             out[1]->get_data_ptr<T>(),
1110                                             node.get_input_shape(0),
1111                                             node.get_output_shape(0),
1112                                             topk->get_top_k_axis(),
1113                                             topk->get_k(),
1114                                             topk->get_compute_max(),
1115                                             topk->get_sort());
1116             }
1117             else if (node.get_output_element_type(0) == element::i32)
1118             {
1119                 reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
1120                                             out[0]->get_data_ptr<int32_t>(),
1121                                             out[1]->get_data_ptr<T>(),
1122                                             node.get_input_shape(0),
1123                                             node.get_output_shape(0),
1124                                             topk->get_top_k_axis(),
1125                                             topk->get_k(),
1126                                             topk->get_compute_max(),
1127                                             topk->get_sort());
1128             }
1129             else
1130             {
1131                 throw ngraph_error("Unexpected type");
1132             }
1133             break;
1134         }
1135         case OP_TYPEID::DetectionOutput_v0:
1136         {
1137             const op::DetectionOutput* detOut = static_cast<const op::DetectionOutput*>(&node);
1138             reference::referenceDetectionOutput<T> refDetOut(
1139                 detOut->get_attrs(), node.get_input_shape(0), node.get_input_shape(2));
1140             if (node.get_input_size() == 3)
1141             {
1142                 refDetOut.run(args[0]->get_data_ptr<const T>(),
1143                               args[1]->get_data_ptr<const T>(),
1144                               args[2]->get_data_ptr<const T>(),
1145                               nullptr,
1146                               nullptr,
1147                               out[0]->get_data_ptr<T>());
1148             }
1149             else if (node.get_input_size() == 5)
1150             {
1151                 refDetOut.run(args[0]->get_data_ptr<const T>(),
1152                               args[1]->get_data_ptr<const T>(),
1153                               args[2]->get_data_ptr<const T>(),
1154                               args[3]->get_data_ptr<const T>(),
1155                               args[4]->get_data_ptr<const T>(),
1156                               out[0]->get_data_ptr<T>());
1157             }
1158             else
1159             {
1160                 throw ngraph_error("DetectionOutput layer supports only 3 or 5 inputs");
1161             }
1162
1163             break;
1164         }
1165         case OP_TYPEID::ScatterNDUpdate_v3:
1166         {
1167             const op::ScatterNDUpdate* scatterNDUpd =
1168                 static_cast<const op::v3::ScatterNDUpdate*>(&node);
1169             auto idxType = scatterNDUpd->get_input_element_type(1);
1170             if (idxType == element::i32)
1171             {
1172                 reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
1173                                                        args[1]->get_data_ptr<const int32_t>(),
1174                                                        args[2]->get_data_ptr<const T>(),
1175                                                        out[0]->get_data_ptr<T>(),
1176                                                        node.get_input_shape(0),
1177                                                        node.get_input_shape(1),
1178                                                        node.get_input_shape(2));
1179             }
1180             else if (idxType == element::i64)
1181             {
1182                 reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
1183                                                        args[1]->get_data_ptr<const int64_t>(),
1184                                                        args[2]->get_data_ptr<const T>(),
1185                                                        out[0]->get_data_ptr<T>(),
1186                                                        node.get_input_shape(0),
1187                                                        node.get_input_shape(1),
1188                                                        node.get_input_shape(2));
1189             }
1190             else
1191             {
1192                 throw ngraph_error(
1193                     "ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
1194             }
1195
1196             break;
1197         }
1198         case OP_TYPEID::ScatterUpdate_v3:
1199         {
1200             const op::v3::ScatterUpdate* scatterUpd =
1201                 static_cast<const op::v3::ScatterUpdate*>(&node);
1202
1203             if (scatterUpd->get_input_element_type(3) != element::i64)
1204                 throw ngraph_error(
1205                     "ScatterNDUpdate layer support only i64 'axis' input precision!");
1206
1207             auto idxType = scatterUpd->get_input_element_type(1);
1208             if (idxType == element::i32)
1209             {
1210                 reference::scatterUpdate<T, int32_t, int64_t>(
1211                     args[0]->get_data_ptr<const T>(),
1212                     args[1]->get_data_ptr<const int32_t>(),
1213                     args[2]->get_data_ptr<const T>(),
1214                     args[3]->get_data_ptr<const int64_t>(),
1215                     out[0]->get_data_ptr<T>(),
1216                     node.get_input_shape(0),
1217                     node.get_input_shape(1),
1218                     node.get_input_shape(2));
1219             }
1220             else if (idxType == element::i64)
1221             {
1222                 reference::scatterUpdate<T, int64_t, int64_t>(
1223                     args[0]->get_data_ptr<const T>(),
1224                     args[1]->get_data_ptr<const int64_t>(),
1225                     args[2]->get_data_ptr<const T>(),
1226                     args[3]->get_data_ptr<const int64_t>(),
1227                     out[0]->get_data_ptr<T>(),
1228                     node.get_input_shape(0),
1229                     node.get_input_shape(1),
1230                     node.get_input_shape(2));
1231             }
1232             else
1233             {
1234                 throw ngraph_error(
1235                     "ScatterUpdate layer support only i32 and i64 'indices' input precision!");
1236             }
1237
1238             break;
1239         }
1240
1241         // Fused Ops are not supported in interpreter. They need to be decomposed before execution
1242         case OP_TYPEID::DepthToSpace:
1243         case OP_TYPEID::FakeQuantize:
1244         case OP_TYPEID::Gather:
1245         case OP_TYPEID::Gelu:
1246         case OP_TYPEID::GRN:
1247         case OP_TYPEID::GroupConvolution:
1248         case OP_TYPEID::GroupConvolutionBackpropData:
1249         case OP_TYPEID::GRUCell:
1250         case OP_TYPEID::HardSigmoid:
1251         case OP_TYPEID::Interpolate:
1252         case OP_TYPEID::LSTMCell:
1253         case OP_TYPEID::LSTMSequence:
1254         case OP_TYPEID::MVN:
1255         case OP_TYPEID::NormalizeL2:
1256         case OP_TYPEID::PRelu:
1257         case OP_TYPEID::RNNCell:
1258         case OP_TYPEID::Selu:
1259         case OP_TYPEID::ShuffleChannels:
1260         case OP_TYPEID::SpaceToDepth:
1261         case OP_TYPEID::Split:
1262         case OP_TYPEID::SquaredDifference:
1263         case OP_TYPEID::StopGradient:
1264         case OP_TYPEID::TensorIterator:
1265         case OP_TYPEID::Tile:
1266         case OP_TYPEID::UnknownOp:
1267             throw unsupported_op("Unsupported op '" + node.description() + "'");
1268         case OP_TYPEID::Add:
1269         case OP_TYPEID::Broadcast:
1270         case OP_TYPEID::Clamp:
1271         case OP_TYPEID::Concat:
1272         case OP_TYPEID::Constant:
1273         case OP_TYPEID::Divide:
1274         case OP_TYPEID::Equal:
1275         case OP_TYPEID::Greater:
1276         case OP_TYPEID::GreaterEq:
1277         case OP_TYPEID::Less:
1278         case OP_TYPEID::LessEq:
1279         case OP_TYPEID::LessEqual_v1:
1280         case OP_TYPEID::LogicalAnd_v1:
1281         case OP_TYPEID::LogicalOr_v1:
1282         case OP_TYPEID::LogicalXor_v1:
1283         case OP_TYPEID::MatMul:
1284         case OP_TYPEID::Max:
1285         case OP_TYPEID::Maximum:
1286         case OP_TYPEID::Min:
1287         case OP_TYPEID::Minimum:
1288         case OP_TYPEID::Multiply:
1289         case OP_TYPEID::NonZero_v3:
1290         case OP_TYPEID::NotEqual:
1291         case OP_TYPEID::Or:
1292         case OP_TYPEID::Pad:
1293         case OP_TYPEID::Power:
1294         case OP_TYPEID::Product:
1295         case OP_TYPEID::Range:
1296         case OP_TYPEID::Reshape:
1297         case OP_TYPEID::Result:
1298         case OP_TYPEID::ShapeOf_v3:
1299         case OP_TYPEID::ShapeOf:
1300         case OP_TYPEID::Softmax:
1301         case OP_TYPEID::Squeeze:
1302         case OP_TYPEID::Sum:
1303         case OP_TYPEID::Subtract:
1304         case OP_TYPEID::Unsqueeze:
1305         case OP_TYPEID::Xor:
1306         case OP_TYPEID::Slice:
1307             // These ops are handled by op evaluators so nothing to do
1308             break;
1309 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
1310 #pragma GCC diagnostic pop
1311 #endif
1312         }
1313     }
1314 };
1315
1316 NGRAPH_SUPPRESS_DEPRECATED_END