* Reference implementation for ScatterUpdate and use of it in evaluate.
* Review comments. Clarify comments.
* Update file directory.
* Replace scatter_update reference implementation in ngraph/core/reference/
* Remove template code from ScatterUpdate reference implementation
* Apply review requests
Co-authored-by: mitruska <katarzyna.mitrus@intel.com>
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/scatter_base.hpp"
+#include "ngraph/runtime/host_tensor.hpp"
namespace ngraph
{
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& inputs) const override;
+
+ bool evaluate(const HostTensorVector& outputs,
+ const HostTensorVector& inputs) const override;
};
}
}
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
#pragma once
-#include <string>
+#include "ngraph/check.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape.hpp"
-using namespace ngraph;
-
namespace ngraph
{
namespace runtime
{
namespace reference
{
- template <typename dataType, typename indicesType, typename axisType>
- void scatterUpdate(const dataType* inputData,
- const indicesType* indices,
- const dataType* updates,
- const axisType* _axis,
- dataType* outBuf,
- const Shape& dataShape,
- const Shape& indicesShape,
- const Shape& updatesShape)
+ void scatter_update(const char* input_data,
+ const int64_t* indices,
+ const char* updates,
+ const int64_t axis,
+ char* out_buf,
+ const size_t elem_size,
+ const Shape& data_shape,
+ const Shape& indices_shape,
+ const Shape& updates_shape)
{
- int rank = static_cast<int>(dataShape.size());
- if (_axis[0] < -rank || _axis[0] > rank - 1)
- {
- std::string error =
- std::string("ScatterUpdate layer has out of bounds axis value: ") +
- std::to_string(_axis[0]);
- throw ngraph_error(error);
- }
- size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
- CoordinateTransform indicesTransform{indicesShape};
+ // Copy inputs to out
+ std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape));
+
+ // Algorithm overview
+ // data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...]
+ // where first ... in the data corresponds to first axis dimensions,
+ // last ... in the data corresponds to the rank(data) - (axis + 1) dimensions.
+
+ //
+ // for i_coord in indices[m, n, ..., p]:
+ // # get linear index
+ // i_idx = index(i_coord)
+ // # simultaneously iterate over two slices of data with same elements count
+ // for d_coord in slice data[..., i_idx, ...],
+ // u_coord in slice updates[..., i_coord, ...]
+ // data[index(d_coord)] = updates[index(u_coord)]
- Shape dataShapeIter = dataShape;
- dataShapeIter.erase(dataShapeIter.begin() + axis);
- CoordinateTransform dataTransfIter{dataShapeIter};
+ CoordinateTransform indices_transform{indices_shape};
+ CoordinateTransform data_transform{data_shape};
- CoordinateTransform updateTransform{updatesShape};
- CoordinateTransform dataTransform{dataShape};
+ size_t indices_ndim = indices_shape.size();
+ size_t updates_ndim = updates_shape.size();
- std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
+ // Create an outer CoordinateTransform for "update", which would allow to
+ // iterate only over "indices" dimensions:
+ // set to "1" all non-indices dimensions
+ // updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1]
+ Coordinate updates_indices_start_corner(updates_ndim, 0);
+ Coordinate updates_indices_end_corner(updates_ndim, 1);
+ for (size_t i = 0; i < indices_ndim; ++i)
+ {
+ updates_indices_end_corner[axis + i] = updates_shape[axis + i];
+ }
+ CoordinateTransform updates_indices_transform(
+ updates_shape, updates_indices_start_corner, updates_indices_end_corner);
+ // Is needed to simultaneously iterate over updates coordinates while
+ // iterating over indices.
+ auto updates_indices_coord_iter = updates_indices_transform.begin();
- for (const Coordinate& indicesCoordIt : indicesTransform)
+ for (const Coordinate& indices_cord : indices_transform)
{
- const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
+ const size_t indices_idx = indices_transform.index(indices_cord);
+ int64_t slice_index = indices[indices_idx];
- if (indices[indicesIdx] < 0)
- {
- std::string error =
- std::string("ScatterUpdate layer has negative index value: ") +
- std::to_string(indices[indicesIdx]);
- throw ngraph_error(error);
- }
- const size_t idx = static_cast<size_t>(indices[indicesIdx]);
- if (dataShape[axis] <= idx)
+ // Define the extent of coordinates which will be updated.
+ Coordinate out_start_corner(data_shape.size(), 0);
+ Coordinate out_end_corner(data_shape);
+ out_start_corner[axis] = static_cast<size_t>(slice_index);
+ out_end_corner[axis] = out_start_corner[axis] + 1;
+ CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner);
+
+ // Define the CoordinateTransform for updates coordinates.
+ // All except indices-dimensions.
+ Coordinate updates_update_start_corner = *updates_indices_coord_iter;
+ Coordinate updates_update_end_corner(updates_shape);
+ for (size_t i = 0; i < indices_ndim; ++i)
{
- std::string error =
- std::string("ScatterUpdate layer has out of bounds coordinate: ") +
- std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
- "th axis";
- throw ngraph_error(error);
+ updates_update_end_corner[axis + i] =
+ updates_update_start_corner[axis + i] + 1;
}
-
- for (const Coordinate& dataCoordIt : dataTransfIter)
+ // The m, n, .., p symbols stand for values at those axes.
+ // The m+1 means value at axis m plus 1.
+ // udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0]
+ // updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1]
+ CoordinateTransform updates_update_transform(
+ updates_shape, updates_update_start_corner, updates_update_end_corner);
+ auto updates_update_coord_iter = updates_update_transform.begin();
+ for (const Coordinate& out_cord : out_transform)
{
- Coordinate dataCoord = dataCoordIt;
- dataCoord.insert(dataCoord.begin() + axis, idx);
- const size_t startIndices = dataTransform.index(dataCoord);
-
- auto updCoord = dataCoordIt;
- updCoord.insert(
- updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
- const size_t startUpd = updateTransform.index(updCoord);
- outBuf[startIndices] = updates[startUpd];
+ const auto src_idx =
+ updates_update_transform.index(*updates_update_coord_iter) * elem_size;
+ std::copy(updates + src_idx,
+ updates + (src_idx + elem_size),
+ out_buf + out_transform.index(out_cord) * elem_size);
+ updates_update_coord_iter++;
}
+ updates_indices_coord_iter++;
}
}
- } // namespace reference
- } // namespace runtime
-} // namespace ngraph
+ }
+ }
+}
//*****************************************************************************
#include "ngraph/op/scatter_update.hpp"
+#include "ngraph/runtime/reference/scatter_update.hpp"
#include "ngraph/shape.hpp"
+#include "ngraph/type/element_type.hpp"
+#include "ngraph/type/element_type_traits.hpp"
+#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
return make_shared<v3::ScatterUpdate>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}
+
+bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs,
+ const HostTensorVector& inputs) const
+{
+ const auto& data = inputs[0];
+ const auto& indices = inputs[1];
+ const auto& updates = inputs[2];
+ const auto& axis = inputs[3];
+ const auto& out = outputs[0];
+
+ const auto elem_size = data->get_element_type().size();
+ out->set_shape(data->get_shape());
+
+ int64_t axis_val = 0;
+ switch (axis->get_element_type())
+ {
+ case element::Type_t::i8: axis_val = axis->get_data_ptr<element::Type_t::i8>()[0]; break;
+ case element::Type_t::i16: axis_val = axis->get_data_ptr<element::Type_t::i16>()[0]; break;
+ case element::Type_t::i32: axis_val = axis->get_data_ptr<element::Type_t::i32>()[0]; break;
+ case element::Type_t::i64: axis_val = axis->get_data_ptr<element::Type_t::i64>()[0]; break;
+ case element::Type_t::u8: axis_val = axis->get_data_ptr<element::Type_t::u8>()[0]; break;
+ case element::Type_t::u16: axis_val = axis->get_data_ptr<element::Type_t::u16>()[0]; break;
+ case element::Type_t::u32: axis_val = axis->get_data_ptr<element::Type_t::u32>()[0]; break;
+ case element::Type_t::u64: axis_val = axis->get_data_ptr<element::Type_t::u64>()[0]; break;
+ default: throw ngraph_error("axis element type is not integral data type");
+ }
+
+ if (axis_val < 0)
+ {
+ axis_val =
+ ngraph::normalize_axis(this, axis_val, static_cast<int64_t>(data->get_shape().size()));
+ }
+
+ std::vector<int64_t> indices_casted_vector;
+ switch (indices->get_element_type())
+ {
+ case element::Type_t::i8:
+ {
+ auto indices_ptr = indices->get_data_ptr<element::Type_t::i8>();
+ indices_casted_vector =
+ std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+ break;
+ }
+ case element::Type_t::i16:
+ {
+ auto indices_ptr = indices->get_data_ptr<element::Type_t::i16>();
+ indices_casted_vector =
+ std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+ break;
+ }
+ case element::Type_t::i32:
+ {
+ auto indices_ptr = indices->get_data_ptr<element::Type_t::i32>();
+ indices_casted_vector =
+ std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+ break;
+ }
+ case element::Type_t::i64:
+ {
+ auto indices_ptr = indices->get_data_ptr<element::Type_t::i64>();
+ indices_casted_vector =
+ std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+ break;
+ }
+ case element::Type_t::u8:
+ {
+ auto indices_ptr = indices->get_data_ptr<element::Type_t::u8>();
+ indices_casted_vector =
+ std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+ break;
+ }
+ case element::Type_t::u16:
+ {
+ auto indices_ptr = indices->get_data_ptr<element::Type_t::u16>();
+ indices_casted_vector =
+ std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+ break;
+ }
+ case element::Type_t::u32:
+ {
+ auto indices_ptr = indices->get_data_ptr<element::Type_t::u32>();
+ indices_casted_vector =
+ std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+ break;
+ }
+ case element::Type_t::u64:
+ {
+ auto indices_ptr = indices->get_data_ptr<element::Type_t::u64>();
+ indices_casted_vector =
+ std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
+ break;
+ }
+ default: throw ngraph_error("indices element type is not integral data type");
+ }
+
+ runtime::reference::scatter_update(data->get_data_ptr<char>(),
+ indices_casted_vector.data(),
+ updates->get_data_ptr<char>(),
+ axis_val,
+ out->get_data_ptr<char>(),
+ elem_size,
+ data->get_shape(),
+ indices->get_shape(),
+ updates->get_shape());
+
+ return true;
+}
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/round.hpp"
#include "ngraph/op/scatter_elements_update.hpp"
+#include "ngraph/op/scatter_update.hpp"
#include "ngraph/op/shape_of.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
}),
ngraph::ngraph_error);
}
+
+TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i32)
+{
+ const Shape data_shape{3, 3};
+ const Shape indices_shape{1, 2};
+ const Shape updates_shape{1, 2, 3};
+
+ auto arg1 = make_shared<op::Parameter>(element::f32, data_shape);
+ auto arg2 = make_shared<op::Parameter>(element::i32, indices_shape);
+ auto arg3 = make_shared<op::Parameter>(element::f32, updates_shape);
+ auto arg4 = make_shared<op::Parameter>(element::i32, Shape{});
+ auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+ auto fun = make_shared<Function>(OutputVector{scatter_update},
+ ParameterVector{arg1, arg2, arg3, arg4});
+ auto result_tensor = make_shared<HostTensor>();
+ ASSERT_TRUE(fun->evaluate({result_tensor},
+ {make_host_tensor<element::Type_t::f32>(
+ data_shape, std::vector<float>(shape_size(data_shape))),
+ make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
+ make_host_tensor<element::Type_t::f32>(
+ updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+ make_host_tensor<element::Type_t::i32>({}, {0})}));
+ EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+ EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3}));
+ auto cval = read_vector<float>(result_tensor);
+ vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
+ ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i64)
+{
+ const Shape data_shape{3, 3};
+ const Shape indices_shape{1, 2};
+ const Shape updates_shape{1, 2, 3};
+
+ auto arg1 = make_shared<op::Parameter>(element::f32, data_shape);
+ auto arg2 = make_shared<op::Parameter>(element::i64, indices_shape);
+ auto arg3 = make_shared<op::Parameter>(element::f32, updates_shape);
+ auto arg4 = make_shared<op::Parameter>(element::i64, Shape{});
+ auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+ auto fun = make_shared<Function>(OutputVector{scatter_update},
+ ParameterVector{arg1, arg2, arg3, arg4});
+ auto result_tensor = make_shared<HostTensor>();
+ ASSERT_TRUE(fun->evaluate({result_tensor},
+ {make_host_tensor<element::Type_t::f32>(
+ data_shape, std::vector<float>(shape_size(data_shape))),
+ make_host_tensor<element::Type_t::i64>(indices_shape, {1, 2}),
+ make_host_tensor<element::Type_t::f32>(
+ updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+ make_host_tensor<element::Type_t::i64>({}, {0})}));
+ EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+ EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3}));
+ auto cval = read_vector<float>(result_tensor);
+ vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
+ ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_dynamic_scatter_update_basic)
+{
+ const Shape data_shape{3, 3};
+ const Shape indices_shape{1, 2};
+ const Shape updates_shape{1, 2, 3};
+
+ auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+ auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+ auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+ auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
+
+ auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+ auto fun = make_shared<Function>(OutputVector{scatter_update},
+ ParameterVector{arg1, arg2, arg3, arg4});
+ auto result_tensor = make_shared<HostTensor>();
+ ASSERT_TRUE(fun->evaluate({result_tensor},
+ {make_host_tensor<element::Type_t::f32>(
+ data_shape, std::vector<float>(shape_size(data_shape))),
+ make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
+ make_host_tensor<element::Type_t::f32>(
+ updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+ make_host_tensor<element::Type_t::i64>({}, {0})}));
+
+ EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+ EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
+ auto cval = read_vector<float>(result_tensor);
+ vector<float> out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f};
+ ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_dynamic_scatter_update_negative_axis)
+{
+ const Shape data_shape{3, 3};
+ const Shape indices_shape{1, 2};
+ const Shape updates_shape{3, 1, 2};
+ const Shape axis_shape{};
+
+ auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+ auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+ auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+ auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
+
+ auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+ auto fun = make_shared<Function>(OutputVector{scatter_update},
+ ParameterVector{arg1, arg2, arg3, arg4});
+ auto result_tensor = make_shared<HostTensor>();
+ ASSERT_TRUE(fun->evaluate({result_tensor},
+ {make_host_tensor<element::Type_t::f32>(
+ data_shape, std::vector<float>(shape_size(data_shape))),
+ make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
+ make_host_tensor<element::Type_t::f32>(
+ updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+ make_host_tensor<element::Type_t::i64>(axis_shape, {-1})}));
+
+ EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+ EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
+ auto cval = read_vector<float>(result_tensor);
+ vector<float> out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f};
+ ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_dynamic_scatter_update_1d_axis)
+{
+ const Shape data_shape{3, 3};
+ const Shape indices_shape{1, 2};
+ const Shape updates_shape{3, 1, 2};
+
+ auto arg1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+ auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+ auto arg3 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
+ auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
+
+ auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+ auto fun = make_shared<Function>(OutputVector{scatter_update},
+ ParameterVector{arg1, arg2, arg3, arg4});
+ auto result_tensor = make_shared<HostTensor>();
+ ASSERT_TRUE(fun->evaluate({result_tensor},
+ {make_host_tensor<element::Type_t::f32>(
+ data_shape, std::vector<float>(shape_size(data_shape))),
+ make_host_tensor<element::Type_t::i32>(indices_shape, {1, 2}),
+ make_host_tensor<element::Type_t::f32>(
+ updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}),
+ make_host_tensor<element::Type_t::i64>({1}, {1})}));
+
+ EXPECT_EQ(result_tensor->get_element_type(), element::f32);
+ EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3}));
+ auto cval = read_vector<float>(result_tensor);
+ vector<float> out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f};
+ ASSERT_EQ(cval, out);
+}
+
+TEST(eval, evaluate_dynamic_scatter_update_one_elem_i32)
+{
+ const Shape data_shape{3, 3, 2};
+ const Shape indices_shape{1, 1};
+ const Shape updates_shape{1, 1, 3, 2};
+
+ auto arg1 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+ auto arg2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+ auto arg3 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
+ auto arg4 = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
+
+ auto scatter_update = make_shared<op::v3::ScatterUpdate>(arg1, arg2, arg3, arg4);
+ auto fun = make_shared<Function>(OutputVector{scatter_update},
+ ParameterVector{arg1, arg2, arg3, arg4});
+ auto result_tensor = make_shared<HostTensor>();
+ ASSERT_TRUE(
+ fun->evaluate({result_tensor},
+ {make_host_tensor<element::Type_t::i32>(
+ data_shape, std::vector<int32_t>(shape_size(data_shape))),
+ make_host_tensor<element::Type_t::i32>(indices_shape, {1}),
+ make_host_tensor<element::Type_t::i32>(updates_shape, {1, 2, 3, 4, 5, 6}),
+ make_host_tensor<element::Type_t::i64>({}, {0})}));
+
+ EXPECT_EQ(result_tensor->get_element_type(), element::i32);
+ EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3, 2}));
+ auto cval = read_vector<int32_t>(result_tensor);
+ vector<int32_t> out{0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0};
+ ASSERT_EQ(cval, out);
+}
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/round.hpp"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
-#include "ngraph/runtime/reference/scatter_update.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
break;
}
- case OP_TYPEID::ScatterUpdate_v3:
- {
- const op::v3::ScatterUpdate* scatterUpd =
- static_cast<const op::v3::ScatterUpdate*>(&node);
-
- if (scatterUpd->get_input_element_type(3) != element::i64)
- throw ngraph_error(
- "ScatterNDUpdate layer support only i64 'axis' input precision!");
-
- auto idxType = scatterUpd->get_input_element_type(1);
- if (idxType == element::i32)
- {
- reference::scatterUpdate<T, int32_t, int64_t>(
- args[0]->get_data_ptr<const T>(),
- args[1]->get_data_ptr<const int32_t>(),
- args[2]->get_data_ptr<const T>(),
- args[3]->get_data_ptr<const int64_t>(),
- out[0]->get_data_ptr<T>(),
- node.get_input_shape(0),
- node.get_input_shape(1),
- node.get_input_shape(2));
- }
- else if (idxType == element::i64)
- {
- reference::scatterUpdate<T, int64_t, int64_t>(
- args[0]->get_data_ptr<const T>(),
- args[1]->get_data_ptr<const int64_t>(),
- args[2]->get_data_ptr<const T>(),
- args[3]->get_data_ptr<const int64_t>(),
- out[0]->get_data_ptr<T>(),
- node.get_input_shape(0),
- node.get_input_shape(1),
- node.get_input_shape(2));
- }
- else
- {
- throw ngraph_error(
- "ScatterUpdate layer support only i32 and i64 'indices' input precision!");
- }
-
- break;
- }
// Fused Ops are not supported in interpreter. They need to be decomposed before execution
case OP_TYPEID::DepthToSpace:
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::PRelu:
case OP_TYPEID::RNNCell:
+ case OP_TYPEID::ScatterUpdate_v3:
case OP_TYPEID::Selu:
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SpaceToDepth: