Reference implementation for ScatterUpdate (#1678)
authorAdam Osewski <adam.osewski@intel.com>
Tue, 25 Aug 2020 03:12:39 +0000 (05:12 +0200)
committerGitHub <noreply@github.com>
Tue, 25 Aug 2020 03:12:39 +0000 (06:12 +0300)
* Reference implementation for ScatterUpdate and use of it in evaluate.

* Review comments. Clarify comments.

* Update file directory.

* Replace scatter_update reference implementation in ngraph/core/reference/

* Remove template code from ScatterUpdate reference implementation

* Apply review requests

Co-authored-by: mitruska <katarzyna.mitrus@intel.com>
ngraph/core/include/ngraph/op/scatter_update.hpp
ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp
ngraph/core/src/op/scatter_update.cpp
ngraph/test/eval.cpp
ngraph/test/runtime/interpreter/int_executable.hpp

index f42fb96..25a4b94 100644 (file)
@@ -18,6 +18,7 @@
 
 #include "ngraph/op/op.hpp"
 #include "ngraph/op/util/scatter_base.hpp"
+#include "ngraph/runtime/host_tensor.hpp"
 
 namespace ngraph
 {
@@ -49,6 +50,9 @@ namespace ngraph
 
                 virtual std::shared_ptr<Node>
                     clone_with_new_inputs(const OutputVector& inputs) const override;
+
+                bool evaluate(const HostTensorVector& outputs,
+                              const HostTensorVector& inputs) const override;
             };
         }
     }
index e3cae8c..f8d00b1 100644 (file)
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
 //
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
 
 #pragma once
 
-#include <string>
+#include "ngraph/check.hpp"
 #include "ngraph/coordinate_transform.hpp"
 #include "ngraph/shape.hpp"
 
-using namespace ngraph;
-
 namespace ngraph
 {
     namespace runtime
     {
         namespace reference
         {
-            template <typename dataType, typename indicesType, typename axisType>
-            void scatterUpdate(const dataType* inputData,
-                               const indicesType* indices,
-                               const dataType* updates,
-                               const axisType* _axis,
-                               dataType* outBuf,
-                               const Shape& dataShape,
-                               const Shape& indicesShape,
-                               const Shape& updatesShape)
+            void scatter_update(const char* input_data,
+                                const int64_t* indices,
+                                const char* updates,
+                                const int64_t axis,
+                                char* out_buf,
+                                const size_t elem_size,
+                                const Shape& data_shape,
+                                const Shape& indices_shape,
+                                const Shape& updates_shape)
             {
-                int rank = static_cast<int>(dataShape.size());
-                if (_axis[0] < -rank || _axis[0] > rank - 1)
-                {
-                    std::string error =
-                        std::string("ScatterUpdate layer has out of bounds axis value: ") +
-                        std::to_string(_axis[0]);
-                    throw ngraph_error(error);
-                }
-                size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
-                CoordinateTransform indicesTransform{indicesShape};
+                // Copy inputs to out
+                std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape));
+
+                // Algorithm overview
+                // data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...]
+                // where first ... in the data corresponds to first axis dimensions,
+                // last ... in the data corresponds to the rank(data) - (axis + 1) dimensions.
+
+                //
+                // for i_coord in indices[m, n, ..., p]:
+                //     # get linear index
+                //     i_idx = index(i_coord)
+                //     # simultaneously iterate over two slices of data with same elements count
+                //     for d_coord in slice data[..., i_idx, ...],
+                //         u_coord in slice updates[..., i_coord, ...]
+                //          data[index(d_coord)] = updates[index(u_coord)]
 
-                Shape dataShapeIter = dataShape;
-                dataShapeIter.erase(dataShapeIter.begin() + axis);
-                CoordinateTransform dataTransfIter{dataShapeIter};
+                CoordinateTransform indices_transform{indices_shape};
+                CoordinateTransform data_transform{data_shape};
 
-                CoordinateTransform updateTransform{updatesShape};
-                CoordinateTransform dataTransform{dataShape};
+                size_t indices_ndim = indices_shape.size();
+                size_t updates_ndim = updates_shape.size();
 
-                std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
+                // Create an outer CoordinateTransform for "update", which would allow to
+                // iterate only over "indices" dimensions:
+                // set to "1" all non-indices dimensions
+                // updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1]
+                Coordinate updates_indices_start_corner(updates_ndim, 0);
+                Coordinate updates_indices_end_corner(updates_ndim, 1);
+                for (size_t i = 0; i < indices_ndim; ++i)
+                {
+                    updates_indices_end_corner[axis + i] = updates_shape[axis + i];
+                }
+                CoordinateTransform updates_indices_transform(
+                    updates_shape, updates_indices_start_corner, updates_indices_end_corner);
+                // Is needed to simultaneously iterate over updates coordinates while
+                // iterating over indices.
+                auto updates_indices_coord_iter = updates_indices_transform.begin();
 
-                for (const Coordinate& indicesCoordIt : indicesTransform)
+                for (const Coordinate& indices_cord : indices_transform)
                 {
-                    const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
+                    const size_t indices_idx = indices_transform.index(indices_cord);
+                    int64_t slice_index = indices[indices_idx];
 
-                    if (indices[indicesIdx] < 0)
-                    {
-                        std::string error =
-                            std::string("ScatterUpdate layer has negative index value: ") +
-                            std::to_string(indices[indicesIdx]);
-                        throw ngraph_error(error);
-                    }
-                    const size_t idx = static_cast<size_t>(indices[indicesIdx]);
-                    if (dataShape[axis] <= idx)
+                    // Define the extent of coordinates which will be updated.
+                    Coordinate out_start_corner(data_shape.size(), 0);
+                    Coordinate out_end_corner(data_shape);
+                    out_start_corner[axis] = static_cast<size_t>(slice_index);
+                    out_end_corner[axis] = out_start_corner[axis] + 1;
+                    CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner);
+
+                    // Define the CoordinateTransform for updates coordinates.
+                    // All except indices-dimensions.
+                    Coordinate updates_update_start_corner = *updates_indices_coord_iter;
+                    Coordinate updates_update_end_corner(updates_shape);
+                    for (size_t i = 0; i < indices_ndim; ++i)
                     {
-                        std::string error =
-                            std::string("ScatterUpdate layer has out of bounds coordinate: ") +
-                            std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
-                            "th axis";
-                        throw ngraph_error(error);
+                        updates_update_end_corner[axis + i] =
+                            updates_update_start_corner[axis + i] + 1;
                     }
-
-                    for (const Coordinate& dataCoordIt : dataTransfIter)
+                    // The m, n, .., p symbols stand for values at those axes.
+                    // The m+1 means value at axis m plus 1.
+                    // udpates_shape (start): [ 0, ...,  0, m  , n  , ... p  ,  0, ...,  0]
+                    // updates_shape (end):   [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1]
+                    CoordinateTransform updates_update_transform(
+                        updates_shape, updates_update_start_corner, updates_update_end_corner);
+                    auto updates_update_coord_iter = updates_update_transform.begin();
+                    for (const Coordinate& out_cord : out_transform)
                     {
-                        Coordinate dataCoord = dataCoordIt;
-                        dataCoord.insert(dataCoord.begin() + axis, idx);
-                        const size_t startIndices = dataTransform.index(dataCoord);
-
-                        auto updCoord = dataCoordIt;
-                        updCoord.insert(
-                            updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
-                        const size_t startUpd = updateTransform.index(updCoord);
-                        outBuf[startIndices] = updates[startUpd];
+                        const auto src_idx =
+                            updates_update_transform.index(*updates_update_coord_iter) * elem_size;
+                        std::copy(updates + src_idx,
+                                  updates + (src_idx + elem_size),
+                                  out_buf + out_transform.index(out_cord) * elem_size);
+                        updates_update_coord_iter++;
                     }
+                    updates_indices_coord_iter++;
                 }
             }
-        } // namespace reference
-    }     // namespace runtime
-} // namespace ngraph
+        }
+    }
+}
index 1600c8f..1ebf07e 100644 (file)
 //*****************************************************************************
 
 #include "ngraph/op/scatter_update.hpp"
+#include "ngraph/runtime/reference/scatter_update.hpp"
 #include "ngraph/shape.hpp"
+#include "ngraph/type/element_type.hpp"
+#include "ngraph/type/element_type_traits.hpp"
+#include "ngraph/validation_util.hpp"
 
 using namespace std;
 using namespace ngraph;
@@ -36,3 +40,110 @@ shared_ptr<Node> op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector
     return make_shared<v3::ScatterUpdate>(
         new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
 }
+
+bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs,
+                                     const HostTensorVector& inputs) const
+{
+    const auto& data = inputs[0];
+    const auto& indices = inputs[1];
+    const auto& updates = inputs[2];
+    const auto& axis = inputs[3];
+    const auto& out = outputs[0];
+
+    const auto elem_size = data->get_element_type().size();
+    out->set_shape(data->get_shape());
+
+    int64_t axis_val = 0;
+    switch (axis->get_element_type())
+    {
+    case element::Type_t::i8: axis_val = axis->get_data_ptr<element::Type_t::i8>()[0]; break;
+    case element::Type_t::i16: axis_val = axis->get_data_ptr<element::Type_t::i16>()[0]; break;
+    case element::Type_t::i32: axis_val = axis->get_data_ptr<element::Type_t::i32>()[0]; break;
+    case element::Type_t::i64: axis_val = axis->get_data_ptr<element::Type_t::i64>()[0]; break;
+    case element::Type_t::u8: axis_val = axis->get_data_ptr<element::Type_t::u8>()[0]; break;
+    case element::Type_t::u16: axis_val = axis->get_data_ptr<element::Type_t::u16>()[0]; break;
+    case element::Type_t::u32: axis_val = axis->get_data_ptr<element::Type_t::u32>()[0]; break;
+    case element::Type_t::u64: axis_val = axis->get_data_ptr<element::Type_t::u64>()[0]; break;
+    default: throw ngraph_error("axis element type is not integral data type");
+    }
+
+    if (axis_val < 0)
+    {
+        axis_val =
+            ngraph::normalize_axis(this, axis_val, static_cast<int64_t>(data->get_shape().size()));
+    }
+
+    std::vector<int64_t> indices_casted_vector;
+    switch (indices->get_element_type())
+    {
+    case element::Type_t::i8:
+    {
+        auto indices_ptr = indices->get_data_ptr<element::Type_t::i8>();
+        indices_casted_vector =
+            std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+        break;
+    }
+    case element::Type_t::i16:
+    {
+        auto indices_ptr = indices->get_data_ptr<element::Type_t::i16>();
+        indices_casted_vector =
+            std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+        break;
+    }
+    case element::Type_t::i32:
+    {
+        auto indices_ptr = indices->get_data_ptr<element::Type_t::i32>();
+        indices_casted_vector =
+            std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+        break;
+    }
+    case element::Type_t::i64:
+    {
+        auto indices_ptr = indices->get_data_ptr<element::Type_t::i64>();
+        indices_casted_vector =
+            std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+        break;
+    }
+    case element::Type_t::u8:
+    {
+        auto indices_ptr = indices->get_data_ptr<element::Type_t::u8>();
+        indices_casted_vector =
+            std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+        break;
+    }
+    case element::Type_t::u16:
+    {
+        auto indices_ptr = indices->get_data_ptr<element::Type_t::u16>();
+        indices_casted_vector =
+            std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+        break;
+    }
+    case element::Type_t::u32:
+    {
+        auto indices_ptr = indices->get_data_ptr<element::Type_t::u32>();
+        indices_casted_vector =
+            std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+        break;
+    }
+    case element::Type_t::u64:
+    {
+        auto indices_ptr = indices->get_data_ptr<element::Type_t::u64>();
+        indices_casted_vector =
+            std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+        break;
+    }
+    default: throw ngraph_error("indices element type is not integral data type");
+    }
+
+    runtime::reference::scatter_update(data->get_data_ptr<char>(),
+                                       indices_casted_vector.data(),
+                                       updates->get_data_ptr<char>(),
+                                       axis_val,
+                                       out->get_data_ptr<char>(),
+                                       elem_size,
+                                       data->get_shape(),
+                                       indices->get_shape(),
+                                       updates->get_shape());
+
+    return true;
+}
index e32fdcb..792ee38 100644 (file)
@@ -54,6 +54,7 @@
 #include "ngraph/op/reshape.hpp"
 #include "ngraph/op/round.hpp"
 #include "ngraph/op/scatter_elements_update.hpp"
+#include "ngraph/op/scatter_update.hpp"
 #include "ngraph/op/shape_of.hpp"
 #include "ngraph/op/sigmoid.hpp"
 #include "ngraph/op/sign.hpp"
@@ -1937,3 +1938,180 @@ TEST(eval, reduce_logical_and__neg_axis)
                       }),
         ngraph::ngraph_error);
 }
+
+TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i32)
+{
+    const Shape data_shape{3, 3};
+    const Shape indices_shape{1, 2};
+    const Shape updates_shape{1, 2, 3};
+
+    auto arg1 = make_shared<op::Parameter>(element::f32, data_shape);
+    auto arg2 = make_shared<op::Parameter>(element::i32, indices_shape);
+    auto arg3 = make_shared<op::Parameter>(element::f32, updates_shape);
+    auto arg4 = make_shared<op::Parameter>(element::i32, Shape{});
+    auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+    auto fun = make_shared<Function>(OutputVector{scatter_update},
+                                     ParameterVector{arg1, arg2, arg3, arg4});
+    auto result_tensor = make_shared<HostTensor>();
+    ASSERT_TRUE(fun->evaluate({result_tensor},
+                              {make_host_tensor<element::Type_t::f32>(
+                                   data_shape, std::vector<float>(shape_size(data_shape))),
+                               make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
+                               make_host_tensor<element::Type_t::f32>(
+                                   updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+                               make_host_tensor<element::Type_t::i32>({}, {0})}));
+    EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+    EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3}));
+    auto cval = read_vector<float>(result_tensor);
+    vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
+    ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i64)
+{
+    const Shape data_shape{3, 3};
+    const Shape indices_shape{1, 2};
+    const Shape updates_shape{1, 2, 3};
+
+    auto arg1 = make_shared<op::Parameter>(element::f32, data_shape);
+    auto arg2 = make_shared<op::Parameter>(element::i64, indices_shape);
+    auto arg3 = make_shared<op::Parameter>(element::f32, updates_shape);
+    auto arg4 = make_shared<op::Parameter>(element::i64, Shape{});
+    auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+    auto fun = make_shared<Function>(OutputVector{scatter_update},
+                                     ParameterVector{arg1, arg2, arg3, arg4});
+    auto result_tensor = make_shared<HostTensor>();
+    ASSERT_TRUE(fun->evaluate({result_tensor},
+                              {make_host_tensor<element::Type_t::f32>(
+                                   data_shape, std::vector<float>(shape_size(data_shape))),
+                               make_host_tensor<element::Type_t::i64>(indices_shape, {1, 2}),
+                               make_host_tensor<element::Type_t::f32>(
+                                   updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+                               make_host_tensor<element::Type_t::i64>({}, {0})}));
+    EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+    EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3}));
+    auto cval = read_vector<float>(result_tensor);
+    vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
+    ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_dynamic_scatter_update_basic)
+{
+    const Shape data_shape{3, 3};
+    const Shape indices_shape{1, 2};
+    const Shape updates_shape{1, 2, 3};
+
+    auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+    auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+    auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+    auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
+
+    auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+    auto fun = make_shared<Function>(OutputVector{scatter_update},
+                                     ParameterVector{arg1, arg2, arg3, arg4});
+    auto result_tensor = make_shared<HostTensor>();
+    ASSERT_TRUE(fun->evaluate({result_tensor},
+                              {make_host_tensor<element::Type_t::f32>(
+                                   data_shape, std::vector<float>(shape_size(data_shape))),
+                               make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
+                               make_host_tensor<element::Type_t::f32>(
+                                   updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+                               make_host_tensor<element::Type_t::i64>({}, {0})}));
+
+    EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+    EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
+    auto cval = read_vector<float>(result_tensor);
+    vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
+    ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_dynamic_scatter_update_negative_axis)
+{
+    const Shape data_shape{3, 3};
+    const Shape indices_shape{1, 2};
+    const Shape updates_shape{3, 1, 2};
+    const Shape axis_shape{};
+
+    auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+    auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+    auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+    auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
+
+    auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+    auto fun = make_shared<Function>(OutputVector{scatter_update},
+                                     ParameterVector{arg1, arg2, arg3, arg4});
+    auto result_tensor = make_shared<HostTensor>();
+    ASSERT_TRUE(fun->evaluate({result_tensor},
+                              {make_host_tensor<element::Type_t::f32>(
+                                   data_shape, std::vector<float>(shape_size(data_shape))),
+                               make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
+                               make_host_tensor<element::Type_t::f32>(
+                                   updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+                               make_host_tensor<element::Type_t::i64>(axis_shape, {-1})}));
+
+    EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+    EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
+    auto cval = read_vector<float>(result_tensor);
+    vector<float> out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f};
+    ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_dynamic_scatter_update_1d_axis)
+{
+    const Shape data_shape{3, 3};
+    const Shape indices_shape{1, 2};
+    const Shape updates_shape{3, 1, 2};
+
+    auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+    auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+    auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+    auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
+
+    auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+    auto fun = make_shared<Function>(OutputVector{scatter_update},
+                                     ParameterVector{arg1, arg2, arg3, arg4});
+    auto result_tensor = make_shared<HostTensor>();
+    ASSERT_TRUE(fun->evaluate({result_tensor},
+                              {make_host_tensor<element::Type_t::f32>(
+                                   data_shape, std::vector<float>(shape_size(data_shape))),
+                               make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
+                               make_host_tensor<element::Type_t::f32>(
+                                   updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+                               make_host_tensor<element::Type_t::i64>({1}, {1})}));
+
+    EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+    EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
+    auto cval = read_vector<float>(result_tensor);
+    vector<float> out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f};
+    ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_dynamic_scatter_update_one_elem_i32)
+{
+    const Shape data_shape{3, 3, 2};
+    const Shape indices_shape{1, 1};
+    const Shape updates_shape{1, 1, 3, 2};
+
+    auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+    auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+    auto arg3 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+    auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
+
+    auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+    auto fun = make_shared<Function>(OutputVector{scatter_update},
+                                     ParameterVector{arg1, arg2, arg3, arg4});
+    auto result_tensor = make_shared<HostTensor>();
+    ASSERT_TRUE(
+        fun->evaluate({result_tensor},
+                      {make_host_tensor<element::Type_t::i32>(
+                           data_shape, std::vector<int32_t>(shape_size(data_shape))),
+                       make_host_tensor<element::Type_t::i32>(indices_shape, {1}),
+                       make_host_tensor<element::Type_t::i32>(updates_shape, {1, 2, 3, 4, 5, 6}),
+                       make_host_tensor<element::Type_t::i64>({}, {0})}));
+
+    EXPECT_EQ(result_tensor->get_element_type(), element::i32);
+    EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3, 2}));
+    auto cval = read_vector<int32_t>(result_tensor);
+    vector<int32_t> out{0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0};
+    ASSERT_EQ(cval, out);
+}
index e1230c8..20152a8 100644 (file)
@@ -79,7 +79,6 @@
 #include "ngraph/runtime/reference/reverse_sequence.hpp"
 #include "ngraph/runtime/reference/round.hpp"
 #include "ngraph/runtime/reference/scatter_nd_update.hpp"
-#include "ngraph/runtime/reference/scatter_update.hpp"
 #include "ngraph/runtime/reference/select.hpp"
 #include "ngraph/runtime/reference/sigmoid.hpp"
 #include "ngraph/runtime/reference/sign.hpp"
@@ -1195,48 +1194,6 @@ protected:
 
             break;
         }
-        case OP_TYPEID::ScatterUpdate_v3:
-        {
-            const op::v3::ScatterUpdate* scatterUpd =
-                static_cast<const op::v3::ScatterUpdate*>(&node);
-
-            if (scatterUpd->get_input_element_type(3) != element::i64)
-                throw ngraph_error(
-                    "ScatterNDUpdate layer support only i64 'axis' input precision!");
-
-            auto idxType = scatterUpd->get_input_element_type(1);
-            if (idxType == element::i32)
-            {
-                reference::scatterUpdate<T, int32_t, int64_t>(
-                    args[0]->get_data_ptr<const T>(),
-                    args[1]->get_data_ptr<const int32_t>(),
-                    args[2]->get_data_ptr<const T>(),
-                    args[3]->get_data_ptr<const int64_t>(),
-                    out[0]->get_data_ptr<T>(),
-                    node.get_input_shape(0),
-                    node.get_input_shape(1),
-                    node.get_input_shape(2));
-            }
-            else if (idxType == element::i64)
-            {
-                reference::scatterUpdate<T, int64_t, int64_t>(
-                    args[0]->get_data_ptr<const T>(),
-                    args[1]->get_data_ptr<const int64_t>(),
-                    args[2]->get_data_ptr<const T>(),
-                    args[3]->get_data_ptr<const int64_t>(),
-                    out[0]->get_data_ptr<T>(),
-                    node.get_input_shape(0),
-                    node.get_input_shape(1),
-                    node.get_input_shape(2));
-            }
-            else
-            {
-                throw ngraph_error(
-                    "ScatterUpdate layer support only i32 and i64 'indices' input precision!");
-            }
-
-            break;
-        }
 
         // Fused Ops are not supported in interpreter. They need to be decomposed before execution
         case OP_TYPEID::DepthToSpace:
@@ -1255,6 +1212,7 @@ protected:
         case OP_TYPEID::NormalizeL2:
         case OP_TYPEID::PRelu:
         case OP_TYPEID::RNNCell:
+        case OP_TYPEID::ScatterUpdate_v3:
         case OP_TYPEID::Selu:
         case OP_TYPEID::ShuffleChannels:
         case OP_TYPEID::SpaceToDepth: