[IE CLDNN] ScatterUpdate layer added (#1839)
authorRomanZm <67912289+RomanZm@users.noreply.github.com>
Tue, 1 Sep 2020 13:07:12 +0000 (16:07 +0300)
committerGitHub <noreply@github.com>
Tue, 1 Sep 2020 13:07:12 +0000 (16:07 +0300)
19 files changed:
inference-engine/thirdparty/clDNN/api/scatter_update.hpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_ref.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_ref.h [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_selector.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_selector.h [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/scatter_update_ref.cl [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.h
inference-engine/thirdparty/clDNN/src/gpu/register_gpu.cpp
inference-engine/thirdparty/clDNN/src/gpu/register_gpu.hpp
inference-engine/thirdparty/clDNN/src/gpu/scatter_update_gpu.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp
inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h
inference-engine/thirdparty/clDNN/src/include/scatter_update_inst.h [new file with mode: 0644]
inference-engine/thirdparty/clDNN/src/scatter_update.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/scatter_update_gpu_test.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/tests/test_utils/test_utils.h

diff --git a/inference-engine/thirdparty/clDNN/api/scatter_update.hpp b/inference-engine/thirdparty/clDNN/api/scatter_update.hpp
new file mode 100644 (file)
index 0000000..d5f3684
--- /dev/null
@@ -0,0 +1,63 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+#pragma once
+#include "primitive.hpp"
+
+namespace cldnn {
+/// @addtogroup cpp_api C++ API
+/// @{
+/// @addtogroup cpp_topology Network Topology
+/// @{
+/// @addtogroup cpp_primitives Primitives
+/// @{
+
+/// @brief
+/// @details
+struct scatter_update : public primitive_base<scatter_update> {
+    CLDNN_DECLARE_PRIMITIVE(scatter_update)
+
+    enum scatter_update_axis {
+        along_b,
+        along_f,
+        along_x,
+        along_y,
+        along_z,
+        along_w
+    };
+
+    /// @brief Constructs scatter_update primitive.
+    /// @param id This primitive id.
+    /// @param dict Input dictionary primitive id.
+    /// @param idx Input indexes primitive id.
+    /// @param idupd Input updates primitive id.
+    /// @param axis Gathering axis.
+    scatter_update(const primitive_id& id,
+                   const primitive_id& dict,
+                   const primitive_id& idx,
+                   const primitive_id& idupd,
+                   const scatter_update_axis axis,
+                   const padding& output_padding = padding())
+        : primitive_base(id, {dict, idx, idupd}, output_padding), axis(axis) {}
+
+    /// @brief ScatterUpdate axis
+    scatter_update_axis axis;
+};
+/// @}
+/// @}
+/// @}
+}  // namespace cldnn
index 6797532..5308e41 100644 (file)
@@ -58,6 +58,7 @@ enum class KernelType {
     ONE_HOT,
     DETECTION_OUTPUT,
     GATHER,
+    SCATTER_UPDATE,
     DEPTH_TO_SPACE,
     BATCH_TO_SPACE,
     SHUFFLE_CHANNELS,
@@ -488,6 +489,18 @@ enum class GatherAxis {
 };
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+// ScatterUpdateAxis
+////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+enum class ScatterUpdateAxis {
+    X,
+    Y,
+    Z,
+    W,
+    FEATURE,
+    BATCH,
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
 // ReduceMode
 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
 enum class ReduceMode {
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_ref.cpp
new file mode 100644 (file)
index 0000000..352db1e
--- /dev/null
@@ -0,0 +1,301 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+#include "scatter_update_kernel_ref.h"
+#include "kernel_selector_utils.h"
+#include <string>
+#include <vector>
+
+namespace kernel_selector {
+static size_t GetScatterUpdateChannelIndex(const scatter_update_params& params) {
+    Tensor::DataChannelName name = Tensor::DataChannelName::X;
+
+    const size_t dict_size = params.inputs[0].GetDims().size();
+    switch (params.axis) {
+        case ScatterUpdateAxis::X:
+            return dict_size - 1;
+        case ScatterUpdateAxis::Y:
+            return dict_size - 2;
+        case ScatterUpdateAxis::Z:
+            return dict_size - 3;
+        case ScatterUpdateAxis::W:
+            return 2;
+        case ScatterUpdateAxis::FEATURE:
+            return 1;
+        case ScatterUpdateAxis::BATCH:
+            return 0;
+        default:
+            break;
+    }
+
+    return DataTensor::Channelndex(params.output.GetLayout(), name);
+}
+
+ParamsKey ScatterUpdateKernelRef::GetSupportedKey() const {
+    ParamsKey k;
+    k.EnableInputDataType(Datatype::F16);
+    k.EnableInputDataType(Datatype::F32);
+    k.EnableInputDataType(Datatype::INT32);
+    k.EnableOutputDataType(Datatype::F16);
+    k.EnableOutputDataType(Datatype::F32);
+    k.EnableOutputDataType(Datatype::INT32);
+    k.EnableOutputDataType(Datatype::INT8);
+    k.EnableOutputDataType(Datatype::UINT8);
+    k.EnableInputLayout(DataLayout::bfyx);
+    k.EnableOutputLayout(DataLayout::bfyx);
+    k.EnableInputLayout(DataLayout::bfzyx);
+    k.EnableOutputLayout(DataLayout::bfzyx);
+    k.EnableInputLayout(DataLayout::bfwzyx);
+    k.EnableOutputLayout(DataLayout::bfwzyx);
+    k.EnableTensorOffset();
+    k.EnableTensorPitches();
+    k.EnableBatching();
+    k.EnableDifferentTypes();
+    return k;
+}
+
+static size_t GetNonEmptyDimsNumber(const DataTensor& data_tensor) {
+    if (data_tensor.LogicalSize() != 1) {
+        // Count the number of "one size" dimensions starting with X to Batch
+        size_t one_size_dims = 0;
+        for (auto& i : data_tensor.GetDims()) {
+            if (i.v == 1)
+                one_size_dims++;
+            else
+                break;
+        }
+        return data_tensor.Dimentions() - one_size_dims;
+    } else {
+        return 1;
+    }
+}
+
+static inline std::string GetOrderString(std::vector<std::string>& order) {
+    std::string order_str = order[0];
+    for (size_t i = 1; i < order.size(); i++)
+        order_str += ", " + order[i];
+    
+    return order_str;
+}
+
+static inline std::vector<std::string> GetDefaultOrder(size_t size) {
+    std::vector<std::string> default_order;
+    if (size <= 4) {
+        default_order = {"b", "f", "y", "x"};
+    } else if (size == 5) {
+        default_order = {"b", "f", "z", "y", "x"};
+    } else if (size == 6) {
+        default_order = {"b", "f", "w", "z", "y", "x"};
+    }
+
+    return default_order;
+}
+
+static std::string GetUpdatesIndexOrder(const scatter_update_params& params, size_t axis) {
+    std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
+
+    for (unsigned int i = 0; i < params.inputs[2].GetDims().size() - params.output.GetDims().size(); i++)
+        default_order.push_back("0");
+
+    size_t indices_non_empty_dims = GetNonEmptyDimsNumber(params.inputs[1]);
+    std::string FYX_indices_size = "(INPUT1_FEATURE_NUM * INPUT1_SIZE_Y * INPUT1_SIZE_X)";
+    std::string YX_indices_size = "(INPUT1_SIZE_Y * INPUT1_SIZE_X)";
+    std::string X_indices_size = "(INPUT1_SIZE_X)";
+    
+    // Shift indices of ScatterUpdate updates input related to Indices dims
+    for (size_t i = default_order.size() - 1; i > (axis + indices_non_empty_dims - 1); i--)
+        default_order[i] = default_order[i - indices_non_empty_dims + 1];
+
+    // Insert Indices indexes in axis dimention in the Update index order
+    for (size_t i = axis; i < (axis + indices_non_empty_dims) && i < default_order.size(); i++) {
+        switch(i - axis) {
+            case 0:
+                default_order[i] = "(OUTPUT_INDEX_ON_AXIS /" + FYX_indices_size + ")";
+                break;
+            case 1:
+                default_order[i] = "((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")/" + YX_indices_size + ")";
+                break;
+            case 2:
+                default_order[i] = "(((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")%" + YX_indices_size + ")/" + X_indices_size + ")";
+                break;
+            case 3:
+                default_order[i] = "(((OUTPUT_INDEX_ON_AXIS %" + FYX_indices_size + ")%" + YX_indices_size + ")%" + X_indices_size + ")";
+                break;
+        }
+    }
+
+    return GetOrderString(default_order);
+}
+
+CommonDispatchData ScatterUpdateKernelRef::SetDefault(const scatter_update_params& params, const optional_params&, bool is_second) const {
+    CommonDispatchData runInfo;
+    const auto& output = params.output;
+
+    std::vector<size_t> global(3);
+    const size_t indices_size = params.inputs[1].LogicalSize();
+
+    switch (params.inputs[0].GetLayout()) {
+    case DataLayout::bfyx:
+        global = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
+        if (is_second) {
+            if (params.axis == ScatterUpdateAxis::BATCH)
+                global[2] = indices_size * output.Feature().v;
+            else if (params.axis == ScatterUpdateAxis::FEATURE)
+                global[2] = indices_size * output.Batch().v;
+            else if (params.axis == ScatterUpdateAxis::Y)
+                global[1] = indices_size;
+            else
+                global[0] = indices_size;
+        }
+        break;
+
+    case DataLayout::bfzyx:
+        global = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v};
+        if (is_second) {
+            if (params.axis == ScatterUpdateAxis::BATCH)
+                global[2] = indices_size * output.Feature().v;
+            else if (params.axis == ScatterUpdateAxis::FEATURE)
+                global[2] = indices_size * output.Batch().v;
+            else if (params.axis == ScatterUpdateAxis::Z)
+                global[1] = indices_size;
+            else if (params.axis == ScatterUpdateAxis::Y)
+                global[0] = indices_size * output.X().v;
+            else
+                global[0] = indices_size * output.Y().v;
+        }
+        break;
+
+    case DataLayout::bfwzyx:
+        global = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
+        if (is_second) {
+            if (params.axis == ScatterUpdateAxis::BATCH)
+                global[2] = indices_size * output.Feature().v;
+            else if (params.axis == ScatterUpdateAxis::FEATURE)
+                global[2] = indices_size * output.Batch().v;
+            else if (params.axis == ScatterUpdateAxis::Z)
+                global[1] = indices_size * output.W().v;
+            else if (params.axis == ScatterUpdateAxis::W)
+                global[1] = indices_size * output.Z().v;
+            else if (params.axis == ScatterUpdateAxis::Y)
+                global[0] = indices_size * output.X().v;
+            else
+                global[0] = indices_size * output.Y().v;
+        }
+        break;
+    default: break;
+    }
+    
+    std::vector<size_t> local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
+
+    runInfo.gws0 = global[0];
+    runInfo.gws1 = global[1];
+    runInfo.gws2 = global[2];
+    
+    runInfo.lws0 = local[0];
+    runInfo.lws1 = local[1];
+    runInfo.lws2 = local[2];
+
+    runInfo.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
+
+    return runInfo;
+}
+
+static std::string GetOutputIndexOnAxis(const scatter_update_params& params, size_t axis) {
+    std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
+    return default_order[axis];
+}
+
+static std::vector<std::string> GetVectorSecondOutputIndexOrder(const scatter_update_params& params, size_t axis) {
+    std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
+    default_order[axis] = "convert_int(indices[OUTPUT_INDEX_ON_AXIS])";
+    return default_order;
+}
+
+static std::string GetSecondIterOutputIndexOrder(const scatter_update_params& params, size_t axis) {
+    std::vector<std::string> default_order = GetDefaultOrder(params.output.GetDims().size());
+    default_order[axis] = "convert_int(indices[OUTPUT_INDEX_ON_AXIS])";
+    return GetOrderString(default_order);
+}
+
+JitConstants ScatterUpdateKernelRef::GetJitConstants(const scatter_update_params& params) const {
+    JitConstants jit = MakeBaseParamsJitConstants(params);
+
+    jit.AddConstant(MakeJitConstant("UPDATES_INDEX_ORDER", GetUpdatesIndexOrder(params, GetScatterUpdateChannelIndex(params))));
+    jit.AddConstant(MakeJitConstant("SECOND_ITER_OUTPUT_INDEX_ORDER", GetSecondIterOutputIndexOrder(params, GetScatterUpdateChannelIndex(params))));
+    jit.AddConstant(MakeJitConstant("OUTPUT_INDEX_ON_AXIS", GetOutputIndexOnAxis(params, GetScatterUpdateChannelIndex(params))));
+    jit.AddConstant(MakeJitConstant("AXIS_VALUE", GetScatterUpdateChannelIndex(params)));
+
+    if (!params.fused_ops.empty()) {
+        FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
+        FusedOpsConfiguration conf2 = { "_SECOND_KERNEL", GetVectorSecondOutputIndexOrder(params, GetScatterUpdateChannelIndex(params)), "val", params.inputs[0].GetDType() };
+        jit.Merge(MakeFusedOpsJitConstants(params, {conf1, conf2}));
+    }
+
+    return jit;
+}
+
+bool ScatterUpdateKernelRef::Validate(const Params& p, const optional_params& o) const {
+    if (p.GetType() != KernelType:: SCATTER_UPDATE || o.GetType() != KernelType::SCATTER_UPDATE) {
+        return false;
+    }
+
+    const scatter_update_params& params = static_cast<const scatter_update_params&>(p);
+
+    for (auto& fused_op : params.fused_ops) {
+        if (!IsFusedPrimitiveSupported(fused_op))
+            return false;
+    }
+
+    return true;
+}
+
+KernelsData ScatterUpdateKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
+    if (!Validate(params, options)) {
+        return {};
+    }
+
+    const scatter_update_params& orgParams = static_cast<const scatter_update_params&>(params);
+    const size_t indices_size = orgParams.inputs[1].LogicalSize();
+    int start_with_iteration = 0;
+    
+    // if dim of output along axis is equal to logical size of indices, we miss copying kernel
+    if (orgParams.inputs[0].Extract(orgParams.inputs[0].GetLayout(), Tensor::DataChannelName(orgParams.axis), orgParams.inputs[0].GetDims()).v == indices_size) {
+        start_with_iteration = 1;
+    }
+
+    KernelData kd = KernelData::Default<scatter_update_params>(params, (2 - start_with_iteration));
+    scatter_update_params& newParams = *static_cast<scatter_update_params*>(kd.params.get());
+    auto cldnn_jit = GetJitConstants(newParams);
+
+    for (int i = start_with_iteration; i < 2; i++) {
+        auto runInfo = SetDefault(newParams, options, (i == 1));
+        auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
+
+        if (i == 1){
+            cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true"));
+        }
+        std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
+
+        clKernelData& kernel = kd.kernels[i - start_with_iteration];
+
+        FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params));
+    }
+
+    kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;
+    
+    return {kd};
+}
+}  // namespace kernel_selector
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_ref.h
new file mode 100644 (file)
index 0000000..8db8199
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+#pragma once
+
+#include "common_kernel_base.h"
+
+namespace kernel_selector {
+////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+// scatter_update_params
+////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+struct scatter_update_params : public base_params {
+    scatter_update_params() : base_params(KernelType::SCATTER_UPDATE), axis(ScatterUpdateAxis::BATCH) {}
+
+    ScatterUpdateAxis axis;
+
+    virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+// scatter_update_optional_params
+////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+struct scatter_update_optional_params : optional_params {
+    scatter_update_optional_params() : optional_params(KernelType::SCATTER_UPDATE) {}
+};
+
+class ScatterUpdateKernelRef : public common_kernel_base {
+public:
+    ScatterUpdateKernelRef() : common_kernel_base("scatter_update_ref") {}
+    virtual ~ScatterUpdateKernelRef() {}
+    virtual JitConstants GetJitConstants(const scatter_update_params& params) const;
+    virtual CommonDispatchData SetDefault(const scatter_update_params& params, const optional_params&, bool is_second) const;
+    KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
+    ParamsKey GetSupportedKey() const override;
+    std::vector<FusedOpType> GetSupportedFusedOps() const override {
+        return { FusedOpType::QUANTIZE,
+                 FusedOpType::SCALE,
+                 FusedOpType::ACTIVATION };
+    }
+
+protected:
+    bool Validate(const Params& p, const optional_params& o) const override;
+};
+}  // namespace kernel_selector
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_selector.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_selector.cpp
new file mode 100644 (file)
index 0000000..e5b565f
--- /dev/null
@@ -0,0 +1,27 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+#include "scatter_update_kernel_selector.h"
+#include "scatter_update_kernel_ref.h"
+
+namespace kernel_selector {
+
+scatter_update_kernel_selector::scatter_update_kernel_selector() { Attach<ScatterUpdateKernelRef>(); }
+
+KernelsData scatter_update_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
+    return GetNaiveBestKernel(params, options, KernelType::SCATTER_UPDATE);
+}
+}  // namespace kernel_selector
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_selector.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_update_kernel_selector.h
new file mode 100644 (file)
index 0000000..366bc87
--- /dev/null
@@ -0,0 +1,35 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+#pragma once
+
+#include "kernel_selector.h"
+
+namespace kernel_selector {
+class scatter_update_kernel_selector : public kernel_selector_base {
+public:
+    static scatter_update_kernel_selector& Instance() {
+        static scatter_update_kernel_selector instance_;
+        return instance_;
+    }
+
+    scatter_update_kernel_selector();
+
+    virtual ~scatter_update_kernel_selector() {}
+
+    KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
+};
+}  // namespace kernel_selector
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/scatter_update_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/scatter_update_ref.cl
new file mode 100644 (file)
index 0000000..298b703
--- /dev/null
@@ -0,0 +1,146 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include "include/include_all.cl"
+
+#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
+#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
+#if OUTPUT_DIMS == 4
+    #define ORDER b,f,y,x
+#elif OUTPUT_DIMS == 5
+    #define ORDER b,f,z,y,x
+#elif OUTPUT_DIMS == 6
+    #define ORDER b,f,w,z,y,x
+#endif
+
+KERNEL(scatter_update_ref)(const __global INPUT0_TYPE* dictionary,
+                   const __global INPUT1_TYPE* indices,
+                   const __global INPUT2_TYPE* updates, 
+                   __global OUTPUT_TYPE* output
+#if HAS_FUSED_OPS_DECLS
+                   , FUSED_OPS_DECLS
+#endif
+)
+{
+    const uint dim0 = get_global_id(0);
+    const uint dim1 = get_global_id(1);
+    const uint dim2 = get_global_id(2);
+
+#ifndef IS_SECOND_ITER // First kernel
+    #if OUTPUT_DIMS == 4
+        const uint x = dim0;
+        const uint y = dim1;
+        const uint f = dim2 % OUTPUT_FEATURE_NUM;
+        const uint b = dim2 / OUTPUT_FEATURE_NUM;
+    #elif OUTPUT_DIMS == 5
+        const uint x = dim0 % OUTPUT_SIZE_X;
+        const uint y = dim0 / OUTPUT_SIZE_X;
+        const uint z = dim1;
+        const uint f = dim2 % OUTPUT_FEATURE_NUM;
+        const uint b = dim2 / OUTPUT_FEATURE_NUM;
+    #elif OUTPUT_DIMS == 6
+        const uint x = dim0 % OUTPUT_SIZE_X;
+        const uint y = dim0 / OUTPUT_SIZE_X;
+        const uint z = dim1 % OUTPUT_SIZE_Z;
+        const uint w = dim1 / OUTPUT_SIZE_Z;
+        const uint f = dim2 % OUTPUT_FEATURE_NUM;
+        const uint b = dim2 / OUTPUT_FEATURE_NUM;
+    #endif
+    
+    const uint output_idx = GET_OUTPUT_INDEX(ORDER);
+    INPUT0_TYPE val = dictionary[output_idx];
+    #if HAS_FUSED_OPS
+        FUSED_OPS_FIRST_KERNEL;
+        output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_FIRST_KERNEL);
+    #else
+        output[output_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
+    #endif
+
+#else // Second kernel
+    #if OUTPUT_DIMS == 4
+        const uint x = dim0;
+        const uint y = dim1;
+        #if AXIS_VALUE == 0
+            const uint f = dim2 % OUTPUT_FEATURE_NUM;
+            const uint b = dim2 / OUTPUT_FEATURE_NUM;
+        #else
+            const uint f = dim2 / OUTPUT_BATCH_NUM;
+            const uint b = dim2 % OUTPUT_BATCH_NUM;
+        #endif
+    #elif OUTPUT_DIMS == 5
+        const uint z = dim1;
+        #if AXIS_VALUE == 1
+            const uint f = dim2 / OUTPUT_BATCH_NUM;
+            const uint b = dim2 % OUTPUT_BATCH_NUM;
+            const uint x = dim0 % OUTPUT_SIZE_X;
+            const uint y = dim0 / OUTPUT_SIZE_X;
+        #elif AXIS_VALUE == 4
+            const uint f = dim2 % OUTPUT_FEATURE_NUM;
+            const uint b = dim2 / OUTPUT_FEATURE_NUM;
+            const uint x = dim0 / OUTPUT_SIZE_Y;
+            const uint y = dim0 % OUTPUT_SIZE_Y;
+        #else
+            const uint f = dim2 % OUTPUT_FEATURE_NUM;
+            const uint b = dim2 / OUTPUT_FEATURE_NUM;
+            const uint x = dim0 % OUTPUT_SIZE_X;
+            const uint y = dim0 / OUTPUT_SIZE_X;
+        #endif
+    #elif OUTPUT_DIMS == 6
+        #if AXIS_VALUE == 1
+            const uint f = dim2 / OUTPUT_BATCH_NUM;
+            const uint b = dim2 % OUTPUT_BATCH_NUM;
+            const uint x = dim0 % OUTPUT_SIZE_X;
+            const uint y = dim0 / OUTPUT_SIZE_X;
+            const uint z = dim1 % OUTPUT_SIZE_Z;
+            const uint w = dim1 / OUTPUT_SIZE_Z;
+        #elif AXIS_VALUE == 3
+            const uint f = dim2 % OUTPUT_FEATURE_NUM;
+            const uint b = dim2 / OUTPUT_FEATURE_NUM;
+            const uint x = dim0 % OUTPUT_SIZE_X;
+            const uint y = dim0 / OUTPUT_SIZE_X;
+            const uint z = dim1 / OUTPUT_SIZE_W;
+            const uint w = dim1 % OUTPUT_SIZE_W;
+        #elif AXIS_VALUE == 5
+            const uint f = dim2 % OUTPUT_FEATURE_NUM;
+            const uint b = dim2 / OUTPUT_FEATURE_NUM;
+            const uint x = dim0 / OUTPUT_SIZE_Y;
+            const uint y = dim0 % OUTPUT_SIZE_Y;
+            const uint z = dim1 % OUTPUT_SIZE_Z;
+            const uint w = dim1 / OUTPUT_SIZE_Z;
+        #else
+            const uint f = dim2 % OUTPUT_FEATURE_NUM;
+            const uint b = dim2 / OUTPUT_FEATURE_NUM;
+            const uint x = dim0 % OUTPUT_SIZE_X;
+            const uint y = dim0 / OUTPUT_SIZE_X;
+            const uint z = dim1 % OUTPUT_SIZE_Z;
+            const uint w = dim1 / OUTPUT_SIZE_Z;
+        #endif
+    #endif
+
+    const uint output_idx = GET_OUTPUT_INDEX(SECOND_ITER_OUTPUT_INDEX_ORDER);
+    const uint updates_idx = GET_UPDATES_INDEX(INPUT2, UPDATES_INDEX_ORDER);
+
+    INPUT2_TYPE val = updates[updates_idx];
+    #if HAS_FUSED_OPS
+        FUSED_OPS_SECOND_KERNEL;
+        output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
+    #else
+        output[output_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
+    #endif
+#endif
+}
+
+#undef GET_UPDATES_INDEX
+#undef GET_OUTPUT_INDEX
index d8aa3c5..1376433 100644 (file)
@@ -411,6 +411,18 @@ std::string toString(GatherAxis a) {
     }
 }
 
+std::string toString(ScatterUpdateAxis a) {
+    switch (a) {
+        case ScatterUpdateAxis::X:       return "X";
+        case ScatterUpdateAxis::Y:       return "Y";
+        case ScatterUpdateAxis::Z:       return "Z";
+        case ScatterUpdateAxis::W:       return "W";
+        case ScatterUpdateAxis::FEATURE: return "FEATURE";
+        case ScatterUpdateAxis::BATCH:   return "BATCH";
+        default: return "";
+    }
+}
+
 std::string toString(ResampleType type) {
     switch (type) {
         case ResampleType::NEAREST_NEIGHBOR:  return "SAMPLE_TYPE_NEAREST";
index 815cbdf..8c97e0e 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2016-2019 Intel Corporation
+// Copyright (c) 2016-2020 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -249,6 +249,7 @@ std::string toString(WeightsLayout layout);
 std::string toString(ConcatAxis a);
 std::string toString(TileAxis a);
 std::string toString(GatherAxis a);
+std::string toString(ScatterUpdateAxis a);
 std::string toString(ResampleType type);
 std::string toString(const BorderType type);
 std::string toString(const Tensor::Dim& dim);
index 1efa95d..1876757 100644 (file)
@@ -68,6 +68,7 @@ void register_implementations_gpu() {
     REGISTER_GPU(reverse_sequence);
     REGISTER_GPU(roi_pooling);
     REGISTER_GPU(scale);
+    REGISTER_GPU(scatter_update);
     REGISTER_GPU(select);
     REGISTER_GPU(shuffle_channels);
     REGISTER_GPU(softmax);
index 7c756f8..2b609e8 100644 (file)
@@ -60,6 +60,7 @@
 #include "api/reverse_sequence.hpp"
 #include "api/roi_pooling.hpp"
 #include "api/scale.hpp"
+#include "api/scatter_update.hpp"
 #include "api/select.hpp"
 #include "api/shuffle_channels.hpp"
 #include "api/softmax.hpp"
@@ -134,6 +135,7 @@ REGISTER_GPU(reshape);
 REGISTER_GPU(reverse_sequence);
 REGISTER_GPU(roi_pooling);
 REGISTER_GPU(scale);
+REGISTER_GPU(scatter_update);
 REGISTER_GPU(select);
 REGISTER_GPU(shuffle_channels);
 REGISTER_GPU(softmax);
diff --git a/inference-engine/thirdparty/clDNN/src/gpu/scatter_update_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/scatter_update_gpu.cpp
new file mode 100644 (file)
index 0000000..416eda9
--- /dev/null
@@ -0,0 +1,97 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+#include "scatter_update_inst.h"
+#include "primitive_gpu_base.h"
+#include "implementation_map.h"
+#include "kernel_selector_helper.h"
+#include "scatter_update/scatter_update_kernel_selector.h"
+#include "scatter_update/scatter_update_kernel_ref.h"
+#include "error_handler.h"
+
+using namespace cldnn;
+
+namespace cldnn {
+namespace gpu {
+kernel_selector::scatter_update_axis convert_axis(scatter_update::scatter_update_axis axis, const scatter_update_node& arg) {
+    switch (axis) {
+        case scatter_update::along_x:
+            return kernel_selector::scatter_update_axis::X;
+        case scatter_update::along_y:
+            return kernel_selector::scatter_update_axis::Y;
+        case scatter_update::along_z:
+            return kernel_selector::scatter_update_axis::Z;
+        case scatter_update::along_w:
+            return kernel_selector::scatter_update_axis::W;
+        case scatter_update::along_f:
+            return kernel_selector::scatter_update_axis::FEATURE;
+        case scatter_update::along_b:
+            return kernel_selector::scatter_update_axis::BATCH;
+        default:
+            CLDNN_ERROR_MESSAGE(arg.id(), "Unsupported Axis");
+    }
+    return kernel_selector::scatter_update_axis::X;
+}
+
+struct scatter_update_gpu : typed_primitive_gpu_impl<scatter_update> {
+    using parent = typed_primitive_gpu_impl<scatter_update>;
+    using parent::parent;
+
+public:
+    static primitive_impl* create(const scatter_update_node& arg) {
+        auto scatter_update_params = get_default_params<kernel_selector::scatter_update_params>(arg);
+        auto scatter_update_optional_params =
+            get_default_optional_params<kernel_selector::scatter_update_optional_params>(arg.get_program());
+
+        scatter_update_params.axis = convert_axis(arg.get_primitive()->axis, arg);
+
+        scatter_update_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
+        scatter_update_params.inputs.push_back(convert_data_tensor(arg.input(2).get_output_layout()));
+
+        auto& kernel_selector = kernel_selector::scatter_update_kernel_selector::Instance();
+        auto best_kernels = kernel_selector.GetBestKernels(scatter_update_params, scatter_update_optional_params);
+
+        CLDNN_ERROR_BOOL(arg.id(),
+                         "Best_kernel.empty()",
+                         best_kernels.empty(),
+                         "Cannot find a proper kernel with this arguments");
+
+        auto scatter_update = new scatter_update_gpu(arg, best_kernels[0]);
+
+        return scatter_update;
+    }
+};
+
+namespace detail {
+
+attach_scatter_update_gpu::attach_scatter_update_gpu() {
+    auto val_fw = scatter_update_gpu::create;
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
+
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
+
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
+    implementation_map<scatter_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
+}
+
+}  // namespace detail
+}  // namespace gpu
+}  // namespace cldnn
index 6b2e6bd..5555d6a 100644 (file)
@@ -44,6 +44,7 @@
 #include "depth_to_space_inst.h"
 #include "space_to_depth_inst.h"
 #include "gather_inst.h"
+#include "scatter_update_inst.h"
 #include "reverse_sequence_inst.h"
 #include "shuffle_channels_inst.h"
 #include "space_to_batch_inst.h"
@@ -200,7 +201,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) {
                  !input.is_type<reshape>() && !input.is_type<roi_pooling>() && !input.is_type<scale>() &&
                  !input.is_type<softmax>() && !input.is_type<resample>() && !input.is_type<mvn>() &&
                  !input.is_type<depth_to_space>() && !input.is_type<batch_to_space>() &&
-                 !input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<shuffle_channels>() &&
+                 !input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
                  !input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
                  !input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
                  !input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
@@ -400,6 +401,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
 
             should_fuse |= input_data.is_type<gather>();
 
+            should_fuse |= input_data.is_type<scatter_update>();
+
             should_fuse |= input_data.is_type<depth_to_space>();
 
             should_fuse |= input_data.is_type<space_to_depth>();
@@ -451,6 +454,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
 
             should_fuse |= input_data.is_type<gather>();
 
+            should_fuse |= input_data.is_type<scatter_update>();
+
             should_fuse |= input_data.is_type<depth_to_space>();
 
             should_fuse |= input_data.is_type<space_to_depth>();
@@ -531,6 +536,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
 
             should_fuse |= input_data.is_type<gather>() && quantize_node.get_scale_shift_opt();
 
+            should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
+
             should_fuse |= input_data.is_type<permute>() && quantize_node.get_scale_shift_opt();
 
             should_fuse |= input_data.is_type<depth_to_space>() && quantize_node.get_scale_shift_opt();
index c3d6af7..d4fa358 100644 (file)
@@ -78,6 +78,7 @@ using tuning_mode = kernel_selector::TuningMode;
 using sample_type = kernel_selector::ResampleType;
 using border_type = kernel_selector::BorderType;
 using gather_axis = kernel_selector::GatherAxis;
+using scatter_update_axis = kernel_selector::ScatterUpdateAxis;
 using reduce_mode = kernel_selector::ReduceMode;
 using cum_sum_axis = kernel_selector::CumSumAxis;
 using depth_to_space_mode = kernel_selector::DepthToSpaceMode;
diff --git a/inference-engine/thirdparty/clDNN/src/include/scatter_update_inst.h b/inference-engine/thirdparty/clDNN/src/include/scatter_update_inst.h
new file mode 100644 (file)
index 0000000..f6b5379
--- /dev/null
@@ -0,0 +1,49 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+#pragma once
+#include "api/scatter_update.hpp"
+#include "primitive_inst.h"
+#include <string>
+
+namespace cldnn {
+template <>
+struct typed_program_node<scatter_update> : public typed_program_node_base<scatter_update> {
+    using parent = typed_program_node_base<scatter_update>;
+
+public:
+    using parent::parent;
+
+    program_node& input(size_t index = 0) const { return get_dependency(index); }
+};
+
+using scatter_update_node = typed_program_node<scatter_update>;
+
+template <>
+class typed_primitive_inst<scatter_update> : public typed_primitive_inst_base<scatter_update> {
+    using parent = typed_primitive_inst_base<scatter_update>;
+
+public:
+    static layout calc_output_layout(scatter_update_node const& node);
+    static std::string to_string(scatter_update_node const& node);
+
+public:
+    typed_primitive_inst(network_impl& network, scatter_update_node const& desc);
+};
+
+using scatter_update_inst = typed_primitive_inst<scatter_update>;
+}  // namespace cldnn
diff --git a/inference-engine/thirdparty/clDNN/src/scatter_update.cpp b/inference-engine/thirdparty/clDNN/src/scatter_update.cpp
new file mode 100644 (file)
index 0000000..2c4dca9
--- /dev/null
@@ -0,0 +1,108 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+#include "scatter_update_inst.h"
+
+#include "primitive_type_base.h"
+#include "error_handler.h"
+#include "json_object.h"
+#include <string>
+
+namespace cldnn {
+primitive_type_id scatter_update::type_id() {
+    static primitive_type_base<scatter_update> instance;
+    return &instance;
+}
+
+static size_t GetNonEmptyDimsNumber(const layout& layout) {
+    if (layout.size.count() != 1) {
+        // Count the number of "one size" dimensions starting with X to Batch
+        size_t one_size_dims = 0;
+        std::vector<int32_t> dims;
+        if (layout.format == cldnn::format::bfwzyx)
+            dims = layout.size.sizes(format::bfwzyx);
+        else if (layout.format == cldnn::format::bfzyx)
+            dims = layout.size.sizes(format::bfzyx);
+        else
+            dims = layout.size.sizes(format::bfyx);
+        for (size_t i = 0; i < dims.size(); i++) {
+            if (dims[dims.size() - 1 - i] == 1)
+                one_size_dims++;
+            else
+                break;
+        }
+        return dims.size() - one_size_dims;
+    } else {
+        return 1;
+    }
+}
+
+layout scatter_update_inst::calc_output_layout(scatter_update_node const& node) {
+    auto desc = node.get_primitive();
+
+    const int32_t axis = desc->axis;
+    const size_t indices_size = node.input(1).get_output_layout().size.count();
+    const size_t input_number_of_dims = node.input(0).get_output_layout().size.sizes().size();
+    const size_t updates_number_of_dims = node.input(2).get_output_layout().size.sizes().size();
+    const size_t nonempty_indices_dims = GetNonEmptyDimsNumber(node.input(1).get_output_layout());
+
+    auto input_layout = node.input(0).get_output_layout();
+    
+    auto output_shape = input_layout.size;
+    auto input_format = input_layout.format;
+    auto output_type = input_layout.data_type;
+
+    if (node.has_fused_primitives()) {
+        output_type = node.get_fused_output_layout().data_type;
+    }
+
+    if (indices_size > static_cast<size_t>(output_shape.sizes()[axis])) {
+        CLDNN_ERROR_MESSAGE(node.id(),
+            "Undefined behavior ScatterUpdate: indices size must not be larger than the output size along the Axis.");
+    }
+    
+    if (nonempty_indices_dims + static_cast<size_t>(axis) > updates_number_of_dims) {
+        CLDNN_ERROR_MESSAGE(node.id(),
+            "Undefined behavior ScatterUpdate: indices dimention must not be larger than the updates[:Axis] dimentional size.");
+    }
+    
+    if (static_cast<size_t>(axis) < 0 || static_cast<size_t>(axis) >= input_number_of_dims)
+        CLDNN_ERROR_MESSAGE(node.id(), "Incorrect axis value for ScatterUpdate: Axis must be positive and less than the input tensor dimension.");
+
+    return layout{output_type, input_format, output_shape};
+}
+
+std::string scatter_update_inst::to_string(scatter_update_node const& node) {
+    auto desc = node.get_primitive();
+    auto node_info = node.desc_to_json();
+    auto& input = node.input();
+
+    std::stringstream primitive_description;
+
+    json_composite scatter_update_info;
+    scatter_update_info.add("input id", input.id());
+    scatter_update_info.add("axis", desc->axis);
+    scatter_update_info.add("output shape", node.input(0).get_output_layout().size.to_string());
+
+    node_info->add("scatter_update info", scatter_update_info);
+    node_info->dump(primitive_description);
+
+    return primitive_description.str();
+}
+
+scatter_update_inst::typed_primitive_inst(network_impl& network, scatter_update_node const& node) : parent(network, node) {}
+
+}  // namespace cldnn
index 1a62fab..1cb500e 100644 (file)
@@ -34,6 +34,7 @@
 #include "api/deconvolution.hpp"
 #include "api/permute.hpp"
 #include "api/gather.hpp"
+#include "api/scatter_update.hpp"
 #include "api/depth_to_space.hpp"
 #include "api/space_to_depth.hpp"
 #include "api/batch_to_space.hpp"
@@ -236,6 +237,27 @@ public:
         return prim;
     }
 
+    cldnn::memory get_repeatless_mem(cldnn::layout l, int min, int max) {
+        auto prim = memory::allocate(engine, l);
+        tensor s = l.size;
+        if (l.data_type == data_types::f32) {
+            VF<float> rnd_vec = generate_random_norepetitions_1d<float>(s.count(), min, max);
+            set_values(prim, rnd_vec);
+        } else if (l.data_type == data_types::f16) {
+            VF<FLOAT16> rnd_vec = generate_random_norepetitions_1d<FLOAT16>(s.count(), min, max);
+            set_values(prim, rnd_vec);
+        } else if (l.data_type == data_types::i8) {
+            VF<int8_t> rnd_vec = generate_random_norepetitions_1d<int8_t>(s.count(), min, max);
+            set_values(prim, rnd_vec);
+        }
+        else if (l.data_type == data_types::bin) {
+            VF<int32_t> rnd_vec = generate_random_norepetitions_1d<int32_t>(s.count(), min, max);
+            set_values(prim, rnd_vec);
+        }
+
+        return prim;
+    }
+
     cldnn::memory get_mem(cldnn::layout l, int min, int max) {
         auto prim = memory::allocate(engine, l);
         tensor s = l.size;
@@ -5209,6 +5231,184 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_scale_activation,
                         gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 4 },
 }), );
 
+/* ----------------------------------------------------------------------------------------------------- */
+/* ------------------------------------------ ScatterUpdate cases --------------------------------------------- */
+/* ----------------------------------------------------------------------------------------------------- */
+
+struct scatter_update_test_params {
+    tensor dictionary_shape;
+    tensor indices_shape;
+    tensor updates_shape;
+    cldnn::scatter_update::scatter_update_axis axis;
+    data_types data_type;
+    format input_format;
+    data_types default_type;
+    format default_format;
+    size_t expected_fused_primitives;
+    size_t expected_not_fused_primitives;
+};
+
+#define CASE_SCATTER_UPDATE_FP32_1 {2, 4, 1, 1}, {2, 1, 1, 1}, {2, 4, 1, 1}, cldnn::scatter_update::scatter_update_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_SCATTER_UPDATE_FP32_2 {8, 1, 1, 1}, {4, 1, 1, 1}, {4, 1, 1, 1}, cldnn::scatter_update::scatter_update_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_SCATTER_UPDATE_FP32_3 {4, 3, 1, 1}, {2, 2, 1, 1}, {2, 2, 1, 3}, cldnn::scatter_update::scatter_update_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_SCATTER_UPDATE_FP32_4 {2, 5, 1, 2}, {2, 2, 1, 1}, {2, 2, 2, 2}, cldnn::scatter_update::scatter_update_axis::along_f, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_SCATTER_UPDATE_FP32_5 {2, 2, 1, 4}, {2, 2, 1, 1}, {2, 2, 2, 2}, cldnn::scatter_update::scatter_update_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_SCATTER_UPDATE_FP16_1 {2, 4, 1, 1}, {1, 1, 1, 2}, {2, 1, 2, 1}, cldnn::scatter_update::scatter_update_axis::along_f, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_SCATTER_UPDATE_FP16_2 {8, 2, 1, 20}, {2, 3, 1, 1}, {2, 3, 20, 2}, cldnn::scatter_update::scatter_update_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_SCATTER_UPDATE_FP16_3 {2, 2, 4, 1}, {3, 1, 1, 1}, {2, 2, 3, 1}, cldnn::scatter_update::scatter_update_axis::along_x, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_SCATTER_UPDATE_FP16_4 {6, 2, 1, 1}, {1, 2, 1, 2}, {1, 2, 2, 2}, cldnn::scatter_update::scatter_update_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_SCATTER_UPDATE_FP16_5 {3, 1, 1, 5}, {2, 2, 1, 1}, {3, 1, 2, 2}, cldnn::scatter_update::scatter_update_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+
+#define CASE_SCATTER_UPDATE_5D_FP32_1 {4, 3, 1, 4, 1}, {4, 1, 1, 1}, {4, 3, 1, 4, 1}, cldnn::scatter_update::scatter_update_axis::along_b, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_SCATTER_UPDATE_5D_FP32_2 {2, 3, 2, 2, 2}, {2, 1, 1, 1}, {2, 2, 2, 2, 2}, cldnn::scatter_update::scatter_update_axis::along_f, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_SCATTER_UPDATE_5D_FP32_3 {5, 3, 2, 4, 2}, {3, 1, 1, 1}, {5, 3, 2, 3, 2}, cldnn::scatter_update::scatter_update_axis::along_y, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_SCATTER_UPDATE_5D_FP32_4 {2, 3, 1, 4, 4}, {2, 1, 1, 1}, {2, 3, 1, 4, 2}, cldnn::scatter_update::scatter_update_axis::along_z, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_SCATTER_UPDATE_5D_FP32_5 {3, 1, 5, 2, 1}, {2, 1, 1, 1}, {3, 1, 2, 2, 1}, cldnn::scatter_update::scatter_update_axis::along_x, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+
+#define CASE_SCATTER_UPDATE_5D_FP16_1 {3, 2, 1, 2, 1}, {2, 1, 1, 1}, {2, 2, 2, 2, 1}, cldnn::scatter_update::scatter_update_axis::along_b, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_SCATTER_UPDATE_5D_FP16_2 {1, 3, 1, 2, 1}, {2, 1, 1, 1}, {1, 2, 1, 2, 1}, cldnn::scatter_update::scatter_update_axis::along_f, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_SCATTER_UPDATE_5D_FP16_3 {2, 3, 1, 3, 3}, {1, 2, 1, 1}, {2, 3, 1, 2, 3}, cldnn::scatter_update::scatter_update_axis::along_y, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_SCATTER_UPDATE_5D_FP16_4 {3, 2, 2, 2, 2}, {2, 1, 1, 1}, {3, 2, 2, 2, 2}, cldnn::scatter_update::scatter_update_axis::along_z, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_SCATTER_UPDATE_5D_FP16_5 {1, 1, 4, 1, 1}, {3, 1, 1, 1}, {1, 1, 3, 1, 1}, cldnn::scatter_update::scatter_update_axis::along_x, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+
+class ScatterUpdatePrimitiveFusingTest : public ::BaseFusingTest<scatter_update_test_params> {
+public:
+    void execute(scatter_update_test_params& p) {
+        auto input_prim = get_mem(get_input_layout(p));
+        network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
+        network network_fused(this->engine, this->topology_fused, bo_fused);
+        network_fused.set_input_data("input", input_prim);
+        network_not_fused.set_input_data("input", input_prim);
+        compare(network_not_fused, network_fused, p);
+    }
+
+    layout get_input_layout(scatter_update_test_params& p) {
+        return layout{ p.data_type, p.input_format, p.dictionary_shape };
+    }
+
+    layout get_indices_layout(scatter_update_test_params& p) {
+        return layout{ p.data_type, format::bfyx, p.indices_shape };
+    }
+
+    layout get_updates_layout(scatter_update_test_params& p) {
+        return layout{ p.data_type, p.input_format, p.updates_shape };
+    }
+
+    size_t get_axis_dim(scatter_update_test_params& p) {
+        switch (p.axis) {
+            case cldnn::scatter_update::scatter_update_axis::along_x:
+                return p.dictionary_shape.spatial[0];
+            case cldnn::scatter_update::scatter_update_axis::along_y:
+                return p.dictionary_shape.spatial[1];
+            case cldnn::scatter_update::scatter_update_axis::along_z:
+                return p.dictionary_shape.spatial[2];
+            case cldnn::scatter_update::scatter_update_axis::along_w:
+                return p.dictionary_shape.spatial[3];
+            case cldnn::scatter_update::scatter_update_axis::along_f:
+                return p.dictionary_shape.feature[0];
+            case cldnn::scatter_update::scatter_update_axis::along_b:
+                return p.dictionary_shape.batch[0];
+            default:
+                return 1;
+        }
+    }
+
+    layout get_per_channel_layout(scatter_update_test_params& p) {
+        return layout{ p.default_type, p.default_format, tensor{1, p.dictionary_shape.feature[0], 1, 1} };
+    }
+};
+
+class scatter_update_quantize : public ScatterUpdatePrimitiveFusingTest {};
+TEST_P(scatter_update_quantize, basic) {
+    auto p = GetParam();
+    create_topologies(input_layout("input", get_input_layout(p)),
+        data("scatter_update_indices", get_repeatless_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)) - 1)),
+        data("scatter_update_updates", get_mem(get_updates_layout(p), 0, 1000)),
+        data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
+        data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
+        data("out_lo", get_mem(get_single_element_layout(p), -127)),
+        data("out_hi", get_mem(get_single_element_layout(p), 127)),
+        scatter_update("scatter_update_prim", "input", "scatter_update_indices", "scatter_update_updates", p.axis),
+        quantize("quantize", "scatter_update_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
+        reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
+    );
+    tolerance = 1.f;
+    execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_update_quantize,
+    ::testing::ValuesIn(std::vector<scatter_update_test_params>{
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_1, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_2, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_3, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_4, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_5, 2, 3 },
+
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_1, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_2, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_3, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_4, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_5, 2, 3 },
+
+                        
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_1, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_2, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_3, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_4, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_5, 2, 3 },
+
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_1, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_2, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_3, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_4, 2, 3 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_5, 2, 3 },
+}), );
+
+class scatter_update_scale_activation : public ScatterUpdatePrimitiveFusingTest {};
+TEST_P(scatter_update_scale_activation, basic) {
+    auto p = GetParam();
+    create_topologies(input_layout("input", get_input_layout(p)),
+        data("scatter_update_indices", get_repeatless_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)) - 1)),
+        data("scatter_update_updates", get_mem(get_updates_layout(p), 0, 1000)),
+        data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
+        scatter_update("scatter_update_prim", "input", "scatter_update_indices", "scatter_update_updates", p.axis),
+        activation("activation", "scatter_update_prim", activation_func::abs),
+        scale("scale", "activation", "scale_data"),
+        reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)
+    );
+    tolerance = 1e-5f;
+    execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_update_scale_activation,
+    ::testing::ValuesIn(std::vector<scatter_update_test_params>{
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_1, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_2, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_3, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_4, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP32_5, 2, 4 },
+
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_1, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_2, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_3, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_4, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_FP16_5, 2, 4 },
+
+                        
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_1, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_2, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_3, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_4, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP32_5, 2, 4 },
+
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_1, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_2, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_3, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_4, 2, 4 },
+                        scatter_update_test_params{ CASE_SCATTER_UPDATE_5D_FP16_5, 2, 4 },
+}), );
+
 /* ------------------------------------------------------------------------------------------------------------ */
 /* ---------------------------------------- PERMUTE FUSE cases -------------------------------------------------- */
 /* ------------------------------------------------------------------------------------------------------------ */
diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/scatter_update_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/scatter_update_gpu_test.cpp
new file mode 100644 (file)
index 0000000..c5132ec
--- /dev/null
@@ -0,0 +1,1389 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+#include <gtest/gtest.h>
+
+#include <api/input_layout.hpp>
+#include <api/memory.hpp>
+#include <api/scatter_update.hpp>
+#include <api/topology.hpp>
+#include <api/network.hpp>
+
+#include <cstddef>
+#include <tests/test_utils/test_utils.h>
+
+using namespace cldnn;
+using namespace ::tests;
+
+
+TEST(scatter_update_gpu_fp16, d2411_axisB) {
+    //  Dictionary : 2x4x1x1
+    //  Indexes : 2x1x1x1
+    //  Updates : 2x4x1x1
+    //  Axis : 0
+    //  Output : 2x4x1x1
+    //  Input values in fp16
+
+    //  Indexes:
+    //  1.f, 0.f
+    //
+    //  Updates:
+    //  1.f, 7.f, 2.f, 9.f,
+    //  3.f, 6.f, 5.f, 4.f
+    //
+    //  Dictionary:
+    //  0.f, 0.f, 0.f, 0.f,
+    //  0.f, 0.f, 0.f, 0.f
+    //
+    //  Output:
+    //  3.f, 6.f, 5.f, 4.f, 
+    //  1.f, 7.f, 2.f, 9.f
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 4, 1, 1 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 1, 1, 1 } }); // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 4, 1, 1 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_b;
+
+    set_values(input1, {
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f)
+    });
+
+    set_values(input2, {
+        1.f, 0.f
+    });
+
+    set_values(input3, {
+        FLOAT16(1.0f), FLOAT16(7.0f), FLOAT16(2.0f), FLOAT16(9.0f),
+        FLOAT16(3.0f), FLOAT16(6.0f), FLOAT16(5.0f), FLOAT16(4.0f)
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+    
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<uint16_t>();
+
+    std::vector<float> expected_results = {
+        3.f, 6.f, 5.f, 4.f, 
+        1.f, 7.f, 2.f, 9.f
+    };
+
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
+    }    
+}
+
+TEST(scatter_update_gpu_fp32, d8111_axisB) {
+    //  Dictionary : 8x1x1x1
+    //  Indexes : 4x1x1x1
+    //  Updates : 4x1x1x1
+    //  Axis : 0
+    //  Output : 8x1x1x1
+    //  Input values in fp32
+
+    //  Indexes:
+    //  4.f, 3.f, 1.f, 7.f
+    //
+    //  Updates:
+    //  9.f, 10.f, 11.f, 12.f
+    //
+    //  Dictionary:
+    //  1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f
+    //
+    //  Output:
+    //  1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f
+
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 8, 1, 1, 1 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 4, 1, 1, 1 } }); // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f32, format::bfyx, { 4, 1, 1, 1 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_b;
+
+    set_values(input1, {
+        1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f
+    });
+
+    set_values(input2, {
+        4.f, 3.f, 1.f, 7.f
+    });
+
+    set_values(input3, {
+        9.0f, 10.0f, 11.0f, 12.0f
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();    
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<float>();
+
+    std::vector<float> expected_results = {
+        1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f
+    };
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }    
+}
+
+TEST(scatter_update_gpu_fp16, d4311_axisB) {
+    //  Dictionary : 4x3x1x1
+    //  Indexes : 2x2x1x1
+    //  Updates : 2x2x3x1
+    //  Axis : 0
+    //  Output : 4x3x1x1
+    //  Input values in fp16
+
+    //  Indexes:
+    //  3.f, 1.f,
+    //  2.f, 0.f
+    //
+    //  Updates:
+    //  7.f, 7.f, 7.f,
+    //  8.f, 8.f, 8.f,
+    //
+    //  6.f, 6.f, 6.f,
+    //  9.f, 10.f, 11.f
+    //
+    //  Dictionary:
+    //  1.f, 1.f, 1.f,
+    //  2.f, 2.f, 2.f,
+    //  0.f, 0.f, 0.f,
+    //  3.f, 3.f, 3.f
+    //
+    //  Output:
+    //  9.f, 10.f, 11.f,
+    //  8.f, 8.f, 8.f,
+    //  6.f, 6.f, 6.f,
+    //  7.f, 7.f, 7.f
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 4, 3, 1, 1 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 2, 1, 1 } }); // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 3 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_b;
+
+    set_values(input1, {
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),
+        FLOAT16(2.0f), FLOAT16(2.0f), FLOAT16(2.0f),
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(3.0f), FLOAT16(3.0f), FLOAT16(3.0f)
+    });
+
+    set_values(input2, {
+        3.f, 1.f,
+        2.f, 0.f
+    });
+
+    set_values(input3, {
+        FLOAT16(7.0f), FLOAT16(7.0f), FLOAT16(7.0f),
+        FLOAT16(8.0f), FLOAT16(8.0f), FLOAT16(8.0f),
+
+        FLOAT16(6.0f), FLOAT16(6.0f), FLOAT16(6.0f),
+        FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f)
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology);     
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();   
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<uint16_t>();
+
+    std::vector<float> expected_results = {
+        9.f, 10.f, 11.f, 
+        8.f, 8.f, 8.f,
+        6.f, 6.f, 6.f, 
+        7.f, 7.f, 7.f
+    };
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
+    } 
+}
+
+TEST(scatter_update_gpu_fp16, d2521_axisF) {
+    //  Dictionary : 2x5x2x1
+    //  Indexes : 2x2x1x1
+    //  Updates : 2x2x2x2
+    //  Axis : 1
+    //  Output : 2x5x2x1
+    //  Input values in fp16
+
+    //  Indexes:
+    //  0.f, 2.f,
+    //  4.f, 1.f
+    //
+    //  Updates:
+    //  21.f, 31.f,
+    //  41.f, 51.f,
+    //
+    //  61.f, 71.f,
+    //  81.f, 91.f,
+    //
+    //  101.f, 111.f,
+    //  121.f, 131.f,
+    //
+    //  141.f, 151.f,
+    //  161.f, 171.f
+    //
+    //  Dictionary:
+    //  0.f, 1.f,
+    //  2.f, 3.f,
+    //  4.f, 5.f,
+    //  6.f, 7.f,
+    //  8.f, 9.f,
+    //
+    //  10.f, 11.f,
+    //  12.f, 13.f,
+    //  14.f, 15.f,
+    //  16.f, 17.f,
+    //  18.f, 19.f
+    //
+    //  Output:
+    //  21.f, 31.f,
+    //  81.f, 91.f,
+    //  41.f, 51.f,
+    //  6.f, 7.f,
+    //  61.f, 71.f,
+    //
+    //  101.f, 111.f,
+    //  161.f, 171.f,
+    //  121.f, 131.f,
+    //  16.f, 17.f,
+    //  141.f, 151.f
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 5, 1, 2 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 2, 1, 1 } }); // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 2, 2 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_f;
+
+    set_values(input1, {
+        FLOAT16(0.0f), FLOAT16(1.0f), 
+        FLOAT16(2.0f), FLOAT16(3.0f), 
+        FLOAT16(4.0f), FLOAT16(5.0f),
+        FLOAT16(6.0f), FLOAT16(7.0f), 
+        FLOAT16(8.0f), FLOAT16(9.0f),
+
+        FLOAT16(10.0f), FLOAT16(11.0f),
+        FLOAT16(12.0f), FLOAT16(13.0f),
+        FLOAT16(14.0f), FLOAT16(15.0f), 
+        FLOAT16(16.0f), FLOAT16(17.0f), 
+        FLOAT16(18.0f), FLOAT16(19.0f)
+    });
+
+    set_values(input2, {
+        0.f, 2.f,
+        4.f, 1.f
+    });
+
+    set_values(input3, {
+        FLOAT16(21.0f), FLOAT16(31.0f), 
+        FLOAT16(41.0f), FLOAT16(51.0f),
+        FLOAT16(61.0f), FLOAT16(71.0f), 
+        FLOAT16(81.0f), FLOAT16(91.0f),
+
+        FLOAT16(101.0f), FLOAT16(111.0f),
+        FLOAT16(121.0f), FLOAT16(131.0f),
+        FLOAT16(141.0f), FLOAT16(151.0f), 
+        FLOAT16(161.0f), FLOAT16(171.0f)
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+      
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<uint16_t>();
+
+    std::vector<float> expected_results = {
+        21.f, 31.f,
+        81.f, 91.f,
+        41.f, 51.f,
+        6.f, 7.f,
+        61.f, 71.f,
+
+        101.f, 111.f,
+        161.f, 171.f,
+        121.f, 131.f,
+        16.f, 17.f,
+        141.f, 151.f
+    };
+  
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
+    }
+}
+
+TEST(scatter_update_gpu_fp16, d2241_axisY) {
+    //  Dictionary : 2x2x4x1
+    //  Indexes : 2x2x1x1
+    //  Updates : 2x2x2x2
+    //  Axis : 2
+    //  Output : 2x2x4x1
+    //  Input values in fp16
+
+    //  Indexes:
+    //  0.f, 2.f,
+    //  3.f, 1.f
+    //
+    //  Updates:
+    //  0.f, 20.f,
+    //  30.f, 40.f,
+    //
+    //  50.f, 60.f,
+    //  70.f, 80.f,
+    //
+    //  90.f, 100.f,
+    //  110.f, 120.f,
+    //
+    //  130.f, 140.f,
+    //  150.f, 160.f
+    //
+    //  Dictionary:
+    //  1.f, 2.f, 3.f, 4.f,
+    //  5.f, 6.f, 7.f, 8.f,
+    //  11.f, 10.f, 11.f, 12.f,
+    //  13.f, 14.f, 15.f, 16.f
+    //
+    //  Output:
+    //  0.f, 40.f, 20.f, 30.f,
+    //  50.f, 80.f, 60.f, 70.f,
+    //  90.f, 120.f, 100.f, 110.f,
+    //  130.f, 160.f, 140.f, 150.f
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 4 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 2, 1, 1 } }); // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 2, 2 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_y;
+
+    set_values(input1, {
+        FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f),
+        FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
+        FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f),
+        FLOAT16(13.0f), FLOAT16(14.0f), FLOAT16(15.0f), FLOAT16(16.0f)
+    });
+
+    set_values(input2, {
+        0.f, 2.f,
+        3.f, 1.f
+    });
+
+    set_values(input3, {
+        FLOAT16(0.0f), FLOAT16(20.0f), 
+        FLOAT16(30.0f), FLOAT16(40.0f),
+        FLOAT16(50.0f), FLOAT16(60.0f), 
+        FLOAT16(70.0f), FLOAT16(80.0f),
+
+        FLOAT16(90.0f), FLOAT16(100.0f),
+        FLOAT16(110.0f), FLOAT16(120.0f),
+        FLOAT16(130.0f), FLOAT16(140.0f), 
+        FLOAT16(150.0f), FLOAT16(160.0f)
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<uint16_t>();
+
+    std::vector<float> expected_results = {
+        0.f, 40.f, 20.f, 30.f,
+        50.f, 80.f, 60.f, 70.f,
+        90.f, 120.f, 100.f, 110.f,
+        130.f, 160.f, 140.f, 150.f
+    };
+
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
+    }
+}
+
+TEST(scatter_update_gpu_fp16, d8x2x20x1_axisB) {
+    //  Dictionary : 8x2x20x1
+    //  Indexes : 2x3x1x1
+    //  Updates : 2x3x2x20
+    //  Axis : 0
+    //  Output : 8x2x20x1
+    //  Input values in fp16
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 8, 2, 1, 20 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 3, 1, 1 } });  // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 20, 2 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_b;
+
+    set_values(input1, {
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),
+
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),
+
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),
+
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),
+
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),
+
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), 
+        
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),
+
+        FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),FLOAT16(0.0f), FLOAT16(0.0f),
+        FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f),FLOAT16(1.0f), FLOAT16(1.0f)
+    });
+
+    set_values(input2, {
+        3.f, 1.f, 6.f,
+        2.f, 7.f, 4.f
+    });
+
+    set_values(input3, {
+        FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(8), FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), 
+        FLOAT16(20), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39),
+        
+        FLOAT16(40), FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46), FLOAT16(47), FLOAT16(48), FLOAT16(49), FLOAT16(50), FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56), FLOAT16(57), FLOAT16(58), FLOAT16(59),
+        FLOAT16(60), FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66), FLOAT16(67), FLOAT16(68), FLOAT16(69), FLOAT16(70), FLOAT16(71), FLOAT16(72), FLOAT16(73), FLOAT16(74), FLOAT16(75), FLOAT16(76), FLOAT16(77), FLOAT16(78), FLOAT16(79),
+        
+        FLOAT16(80), FLOAT16(81), FLOAT16(82), FLOAT16(83), FLOAT16(84), FLOAT16(85), FLOAT16(86), FLOAT16(87), FLOAT16(88), FLOAT16(89), FLOAT16(90), FLOAT16(91), FLOAT16(92), FLOAT16(93), FLOAT16(94), FLOAT16(95), FLOAT16(96), FLOAT16(97), FLOAT16(98), FLOAT16(99),
+        FLOAT16(100), FLOAT16(101), FLOAT16(102), FLOAT16(103), FLOAT16(104), FLOAT16(105), FLOAT16(106), FLOAT16(107), FLOAT16(108), FLOAT16(109), FLOAT16(110), FLOAT16(111), FLOAT16(112), FLOAT16(113), FLOAT16(114), FLOAT16(115), FLOAT16(116), FLOAT16(117), FLOAT16(118), FLOAT16(119),
+        
+        FLOAT16(120), FLOAT16(121), FLOAT16(122), FLOAT16(123), FLOAT16(124), FLOAT16(125), FLOAT16(126), FLOAT16(127), FLOAT16(128), FLOAT16(129), FLOAT16(130), FLOAT16(131), FLOAT16(132), FLOAT16(133), FLOAT16(134), FLOAT16(135), FLOAT16(136), FLOAT16(137), FLOAT16(138), FLOAT16(139),
+        FLOAT16(140), FLOAT16(141), FLOAT16(142), FLOAT16(143), FLOAT16(144), FLOAT16(145), FLOAT16(146), FLOAT16(147), FLOAT16(148), FLOAT16(149), FLOAT16(150), FLOAT16(151), FLOAT16(152), FLOAT16(153), FLOAT16(154), FLOAT16(155), FLOAT16(156), FLOAT16(157), FLOAT16(158), FLOAT16(159),
+        FLOAT16(160), FLOAT16(161), FLOAT16(162), FLOAT16(163), FLOAT16(164), FLOAT16(165), FLOAT16(166), FLOAT16(167), FLOAT16(168), FLOAT16(169), FLOAT16(170), FLOAT16(171), FLOAT16(172), FLOAT16(173), FLOAT16(174), FLOAT16(175), FLOAT16(176), FLOAT16(177), FLOAT16(178), FLOAT16(179),
+        FLOAT16(180), FLOAT16(181), FLOAT16(182), FLOAT16(183), FLOAT16(184), FLOAT16(185), FLOAT16(186), FLOAT16(187), FLOAT16(188), FLOAT16(189), FLOAT16(190), FLOAT16(191), FLOAT16(192), FLOAT16(193), FLOAT16(194), FLOAT16(195), FLOAT16(196), FLOAT16(197), FLOAT16(198), FLOAT16(199), 
+
+        FLOAT16(200), FLOAT16(201), FLOAT16(202), FLOAT16(203), FLOAT16(204), FLOAT16(205), FLOAT16(206), FLOAT16(207), FLOAT16(208), FLOAT16(209), FLOAT16(210), FLOAT16(211), FLOAT16(212), FLOAT16(213), FLOAT16(214), FLOAT16(215), FLOAT16(216), FLOAT16(217), FLOAT16(218), FLOAT16(219),
+        FLOAT16(220), FLOAT16(221), FLOAT16(222), FLOAT16(223), FLOAT16(224), FLOAT16(225), FLOAT16(226), FLOAT16(227), FLOAT16(228), FLOAT16(229), FLOAT16(230), FLOAT16(231), FLOAT16(232), FLOAT16(233), FLOAT16(234), FLOAT16(235), FLOAT16(236), FLOAT16(237), FLOAT16(238), FLOAT16(239)
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology);
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+    
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<uint16_t>();
+
+    std::vector<float> expected_results = {
+        0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+        1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f , 1.f, 1.f, 1.f, 1.f, 1.f,
+
+        40.f, 41.f,  42.f,  43.f,  44.f,  45.f,  46.f,  47.f,  48.f,  49.f,  50.f,  51.f,  52.f,  53.f,  54.f,  55.f,  56.f, 57.f,  58.f,  59.f,
+        60.f,  61.f,  62.f,  63.f,  64.f,  65.f,  66.f, 67.f,  68.f,  69.f, 70.f,  71.f,  72.f,  73.f,  74.f,  75.f,  76.f,  77.f,  78.f,  79.f,
+        
+        120.f, 121.f, 122.f, 123.f, 124.f, 125.f, 126.f, 127.f, 128.f, 129.f,130.f, 131.f, 132.f, 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 139.f,
+        140.f, 141.f, 142.f, 143.f, 144.f, 145.f, 146.f, 147.f, 148.f, 149.f,150.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 157.f, 158.f, 159.f,
+
+        0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f,
+        20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f , 35.f, 36.f, 37.f, 38.f, 39.f,
+        
+        200.f, 201.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 208.f, 209.f,210.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 217.f, 218.f, 219.f, 
+        220.f, 221.f, 222.f, 223.f, 224.f, 225.f, 226.f,227.f, 228.f, 229.f, 230.f, 231.f, 232.f, 233.f, 234.f, 235.f, 236.f, 237.f, 238.f, 239.f,
+
+        0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+        1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f , 1.f, 1.f, 1.f, 1.f, 1.f,
+
+        80.f,  81.f,  82.f,  83.f,  84.f,  85.f,  86.f, 87.f,  88.f,  89.f, 90.f,  91.f,  92.f,  93.f,  94.f,  95.f,  96.f,  97.f,  98.f,  99.f, 
+        100.f, 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f,110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f,
+
+        160.f, 161.f, 162.f, 163.f, 164.f, 165.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 175.f, 176.f, 177.f, 178.f, 179.f, 
+        180.f, 181.f, 182.f, 183.f, 184.f, 185.f, 186.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 199.f
+    };
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
+    }
+}
+
+TEST(scatter_update_gpu_fp32, d2214_axisX) {
+    //  Dictionary : 2x2x1x4
+    //  Indexes : 3x1x1x1
+    //  Updates : 2x2x1x3
+    //  Axis : 3
+    //  Output : 2x2x1x4
+    //  Input values in fp32
+
+    //  Indexes:
+    //  2.f, 0.f, 3.f
+    //
+    //  Updates:
+    //  20.f, 30.f, 40.f,
+    //  50.f, 60.f, 70.f,
+    //
+    //  80.f, 90.f, 100.f,
+    //  110.f, 120.f, 130.f
+    //
+    //  Dictionary:
+    //  0.f, 1.f, 2.f, 3.f,
+    //  4.f, 5.f, 6.f, 7.f,
+    //
+    //  8.f, 9.f, 10.f, 11.f,
+    //  12.f, 13.f, 14.f, 15.f
+    //
+    //  Output:
+    //  30.f, 1.f, 20.f, 40.f,
+    //  60.f, 5.f, 50.f, 70.f,
+    //
+    //  90.f, 9.f, 80.f, 100.f,
+    //  120.f, 13.f, 110.f, 130.f
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 2, 4, 1 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 3, 1, 1, 1 } }); // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 2, 3, 1 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_x;
+
+    set_values(input1, {
+        0.f, 1.f, 2.f, 3.f,
+        4.f, 5.f, 6.f, 7.f,
+        8.f, 9.f, 10.f, 11.f,
+        12.f, 13.f, 14.f, 15.f
+    });
+
+    set_values(input2, {
+        2.f, 0.f, 3.f
+    });
+
+    set_values(input3, {
+        20.f, 30.f, 40.f,
+        50.f, 60.f, 70.f,
+        80.f, 90.f, 100.f,
+        110.f, 120.f, 130.f
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<float>();
+
+    std::vector<float> expected_results = {
+        30.f, 1.f, 20.f, 40.f,
+        60.f, 5.f, 50.f, 70.f,
+        90.f, 9.f, 80.f, 100.f,
+        120.f, 13.f, 110.f, 130.f
+    };
+
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }
+}
+
+TEST(scatter_update_gpu_int32, d6211_axisB) {
+    //  Dictionary : 6x2x1x1
+    //  Indexes : 1x2x2x1
+    //  Updates : 1x2x2x2
+    //  Axis : 0
+    //  Output : 6x2x1x1
+    //  Input values in int32
+
+    //  Indexes:
+    //  3,   1,
+    //  5,   2
+    //
+    //  Updates:
+    //  20,  30,
+    //  40,  50
+    //
+    //  60,  70,
+    //  80,  90
+    //
+    //  Dictionary:
+    //  1,   2,
+    //  3,   4,
+    //  5,   6,
+    //  7,   8,
+    //  9,   10,
+    //  11,  12
+    //
+    //  Output:
+    //   1,  2,
+    //  40,  50,
+    //  80,  90,
+    //  20,  30,
+    //   9,  10,
+    //  60,  70
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::i32, format::bfyx, { 6, 2, 1, 1 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::i32, format::bfyx, { 1, 2, 1, 2 } }); // Indexes
+    auto input3 = memory::allocate(engine, { data_types::i32, format::bfyx, { 1, 2, 2, 2 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_b;
+
+    set_values(input1, {
+        1, 2,
+        3, 4,
+        5, 6,
+        7, 8,
+        9, 10,
+        11, 12
+    });
+
+    set_values(input2, {
+        3, 1,
+        5, 2
+    });
+
+    set_values(input3, {
+        20, 30,
+        40, 50,
+        60, 70,
+        80, 90
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology);  
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<int>();
+
+    std::vector<int> expected_results = {
+        1, 2,
+        40, 50,
+        80, 90,
+        20, 30,
+        9, 10,
+        60, 70
+    };
+
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }
+}
+
+TEST(scatter_update_gpu_int32, d3151_axisY) {
+    //  Dictionary : 3x1x5x1
+    //  Indexes : 2x2x1x1
+    //  Updates : 3x1x2x2
+    //  Axis : 2
+    //  Output : 3x1x5x1
+    //  Input values in int32
+
+    //  Indexes:
+    //  3,   2,
+    //  0,   4
+    //
+    //  Updates:
+    //  200,  20,
+    //  30,  40
+    //
+    //  50,  60,
+    //  70,  80
+    //
+    //  90,  100,
+    //  110,  120
+    //
+    //  Dictionary:
+    //  1,  2,  3,  4,  5,
+    //  6,  7,  8,  9,  10,
+    //  11, 12, 13, 14, 15
+    //
+    //  Output:
+    //   30,  1,  20, 200, 40,
+    //   70,  6,  60,  50, 80,
+    //   110, 11, 100, 90, 120
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::i32, format::bfyx, { 3, 1, 1, 5 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::i32, format::bfyx, { 2, 2, 1, 1 } }); // Indexes
+    auto input3 = memory::allocate(engine, { data_types::i32, format::bfyx, { 3, 1, 2, 2 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_y;
+
+    set_values(input1, {
+        0, 1, 2, 3, 4,
+        5, 6, 7, 8, 9,
+        10, 11, 12, 13, 14
+    });
+
+    set_values(input2, {
+        3, 2,
+        0, 4
+    });
+
+    set_values(input3, {
+        200,  20,
+        30,   40,
+        50,   60,
+        70,   80,
+        90,  100,
+        110,  120
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<int>();
+
+    std::vector<int> expected_results = {
+        30, 1, 20, 200, 40,
+        70, 6, 60, 50, 80,
+        110, 11, 100, 90, 120
+    };
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }
+}
+
+TEST(scatter_update_gpu_fp32, d24111_axisF_bfzyx) {
+    //  Dictionary : 2x4x1x1
+    //  Indexes : 1x1x1x2
+    //  Updates : 2x1x1x1x2
+    //  Axis : 1
+    //  Output : 2x4x1x1x1
+    //  Input values in fp32
+
+    //  Indexes:
+    //  2.f, 0.f
+    //
+    //  Updates:
+    //  1.f, 2.f, 
+    //  3.f, 4.f
+    //
+    //  Dictionary:
+    //  0.f, 0.f, 0.f, 0.f,
+    //  0.f, 0.f, 0.f, 0.f
+    //
+    //  Output:
+    //  2.f, 0.f, 1.f, 0.f, 
+    //  4.f, 0.f, 3.f, 0.f
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 4, 1, 1 } });      // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 1, 2, 1 } });      // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f32, format::bfzyx, { 2, 1, 1, 2, 1 } });  // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_f;
+
+    set_values(input1, {
+        0.0f, 0.0f, 0.0f, 0.0f,
+        0.0f, 0.0f, 0.0f, 0.0f
+    });
+
+    set_values(input2, {
+        2.f, 0.f
+    });
+
+    set_values(input3, {
+        1.0f, 2.0f,
+        3.0f, 4.0f
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute(); 
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<float>();
+
+    std::vector<float> expected_results = {
+        2.f, 0.f, 1.f, 0.f, 
+        4.f, 0.f, 3.f, 0.f
+    };
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }
+}
+
+TEST(scatter_update_gpu_int32, d121251_bfwzyx_axisB) {
+    //  Dictionary : 1x2x1x2x5x1
+    //  Indexes : 1x2x2x1
+    //  Updates : 1x2x1x2x2x2
+    //  Axis : 4
+    //  Output : 1x2x1x2x5x1
+    //  Input values in int32
+
+    //  Indexes:
+    //  2,   1,
+    //  0,   4
+    //
+    //  Updates:
+    //  20,  30,
+    //  40,  50
+    //
+    //  60,  70,
+    //  80,  90,
+    //
+    //  100,  110,
+    //  120,  130,
+    //
+    //  140,  150,
+    //  160,  170
+    //
+    //  Dictionary:
+    //  0, 1, 2, 3, 4,
+    //  5, 6, 7, 8, 9,
+    //  10, 11, 12, 13, 14,
+    //  15, 16, 17, 18, 19 
+    //
+    //  Output:
+    //  40,  30,   20,  3, 50,
+    //  80,  70,   60,  8, 90,
+    //  120, 110, 100, 13, 130,
+    //  160, 150, 140, 18, 170 
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::i32, format::bfwzyx, tensor{ batch(1), feature(2), spatial(1, 5, 2, 1) }}); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::i32, format::bfyx, { 2, 2, 1, 1 } });                                       // Indexes
+    auto input3 = memory::allocate(engine, { data_types::i32, format::bfwzyx, tensor{ batch(1), feature(2), spatial(2, 2, 2, 1) }}); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_y;
+
+    set_values(input1, {
+        0, 1, 2, 3, 4,
+        5, 6, 7, 8, 9,
+        10, 11, 12, 13, 14,
+        15, 16, 17, 18, 19
+    });
+
+    set_values(input2, {
+        2, 1,
+        0, 4
+    });
+
+    set_values(input3, {
+        20, 30,
+        40, 50,
+        60, 70,
+        80, 90,
+        100,  110,
+        120,  130,
+        140,  150,
+        160,  170
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+    
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<int>();
+
+    std::vector<int> expected_results = {
+        40,  30,   20,  3, 50,
+        80,  70,   60,  8, 90,
+        120, 110, 100, 13, 130,
+        160, 150, 140, 18, 170
+    };
+
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }
+}
+
+TEST(scatter_update_gpu_fp32, d21511_bfzyx_axisX) {
+    //  Dictionary : 2x1x5x1x1
+    //  Indexes : 2x1x2x1
+    //  Updates : 2x1x2x1x2
+    //  Axis : 2
+    //  Output : 2x1x5x1x1
+    //  Input values in fp32
+
+    //  Indexes:
+    //  3.f, 4.f
+    //  0.f, 1.f
+    //
+    //  Updates:
+    //  10.f, 20.f, 
+    //  30.f, 40.f,
+    //  50.f, 60.f,
+    //  70.f, 80.f
+    //
+    //  Dictionary:
+    //  0.f, 1.f, 2.f, 3.f, 4.f
+    //  5.f, 6.f, 7.f, 8.f, 9.f
+    //
+    //  Output:
+    //  30.f, 40.f, 2.f, 10.f, 20.f,
+    //  70.f, 80.f, 7.f, 50.f, 60.f
+    //
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f32, format::bfzyx, { 2, 1, 1, 1, 5 } }); // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 2, 1, 1 } });     // Indices
+    auto input3 = memory::allocate(engine, { data_types::f32, format::bfzyx, { 2, 1, 1, 2, 2 } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_z;
+
+    set_values(input1, {
+        0.f, 1.f, 2.f, 3.f, 4.f,
+        5.f, 6.f, 7.f, 8.f, 9.f
+    });
+
+    set_values(input2, {
+        3.f, 4.f,
+        0.f, 1.f
+    });
+
+    set_values(input3, {
+        10.f, 20.f, 
+        30.f, 40.f,
+        50.f, 60.f,
+        70.f, 80.f
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology);
+    
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<float>();
+
+    std::vector<float> expected_results = {
+        30.f, 40.f, 2.f, 10.f, 20.f,
+        70.f, 80.f, 7.f, 50.f, 60.f
+    };
+
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }
+}
+
+TEST(scatter_update_gpu_fp32, d1252_axisY_bfwzyx) {
+    //  Dictionary : 1x2x5x2
+    //  Indexes : 2x1x2x1
+    //  Updates : 1x2x2x1x2x2
+    //  Axis : 2
+    //  Output : 1x2x5x2
+    //  Input values in fp32
+
+    //  Indexes:
+    //  2.f, 0.f,
+    //  3.f, 4.f
+    //
+    //  Updates:
+    //  20.f, 30.f, 
+    //  40.f, 50.f
+    //
+    //  60.f, 70.f, 
+    //  80.f, 90.f
+    //
+    //  100.f, 110.f, 
+    //  120.f, 130.f
+    //
+    //  140.f, 150.f, 
+    //  160.f, 170.f
+    //
+    //  Dictionary:
+    //  0.f, 1.f,     2.f, 3.f,     4.f, 5.f,     6.f, 7.f,     8.f, 9.f,
+    //  10.f, 11.f,   12.f, 13.f,   14.f, 15.f,   16.f, 17.f,   18.f, 19.f
+    //
+    //  Output:
+    //  40.f, 50.f,     2.f, 3.f,     20.f, 30.f,     60.f, 70.f,     80.f, 90.f,
+    //  120.f, 130.f,   12.f, 13.f,   100.f, 110.f,   140.f, 150.f,   160.f, 170.f
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 2, 2, 5 } });                                         // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 2, 1, 1, 2 } });                                         // Indices
+    auto input3 = memory::allocate(engine, { data_types::f32, format::bfwzyx, tensor{ batch(1), feature(2), spatial(2, 2, 1, 2) } });  // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_y;
+
+    set_values(input1, {
+        0.f, 1.f,     2.f, 3.f,     4.f, 5.f,     6.f, 7.f,     8.f, 9.f,
+        10.f, 11.f,   12.f, 13.f,   14.f, 15.f,   16.f, 17.f,   18.f, 19.f
+    });
+
+    set_values(input2, {
+        2.f, 0.f,
+        3.f, 4.f
+    });
+
+    set_values(input3, {
+        20.f, 30.f, 
+        40.f, 50.f,
+
+        60.f, 70.f, 
+        80.f, 90.f,
+      
+        100.f, 110.f, 
+        120.f, 130.f,
+    
+        140.f, 150.f, 
+        160.f, 170.f
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<float>();
+
+    std::vector<float> expected_results = {
+        40.f, 50.f,     2.f, 3.f,     20.f, 30.f,     60.f, 70.f,     80.f, 90.f,
+        120.f, 130.f,   12.f, 13.f,   100.f, 110.f,   140.f, 150.f,   160.f, 170.f
+    };
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }
+}
+
+TEST(scatter_update_gpu_int32, d2115_axisX_bfwzyx) {
+    //  Dictionary : 2x1x1x5
+    //  Indexes : 2x2x1x1
+    //  Updates : 2x1x1x2x2x1
+    //  Axis : 3
+    //  Output : 2x1x1x5
+    //  Input values in int32
+
+    //  Indexes:
+    //  2,   1,
+    //  4,   3
+    //
+    //  Updates:
+    //  20,  30,
+    //  40,  50
+    //
+    //  60,  70,
+    //  80,  90
+    //
+    //  Dictionary:
+    //  0, 1, 2, 3, 4,
+    //  5, 6, 7, 8, 9
+    //
+    //  Output:
+    //  0,  30,   20,  50, 40,
+    //  5,  70,   60,  90, 80
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::i32, format::bfyx, { 2, 1, 5, 1 }});                                        // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::i32, format::bfyx, { 2, 2, 1, 1 } });                                       // Indexes
+    auto input3 = memory::allocate(engine, { data_types::i32, format::bfwzyx, tensor{ batch(2), feature(1), spatial(1, 2, 2, 1) }}); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_x;
+
+    set_values(input1, {
+        0, 1, 2, 3, 4,
+        5, 6, 7, 8, 9
+    });
+
+    set_values(input2, {
+        2, 1,
+        4, 3
+    });
+
+    set_values(input3, {
+        20, 30,
+        40, 50,
+        60, 70,
+        80, 90
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<int>();
+
+    std::vector<int> expected_results = {
+        0,  30,  20,  50, 40,
+        5,  70,  60,  90, 80
+    };
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], output_ptr[i]);
+    }
+}
+
+TEST(scatter_update_gpu_fp16, d21214_bfzyx_axisX_bfwzyx) {
+    //  Dictionary : 2x1x2x1x4
+    //  Indexes : 1x3x1x1
+    //  Updates : 2x1x2x1x1x3
+    //  Axis : 4
+    //  Output : 2x1x2x1x4
+    //  Input values in fp16
+
+    //  Indexes:
+    //  3.f, 2.f, 1.f
+    //
+    //  Updates:
+    //  20.f, 30.f, 40.f,
+    //  50.f, 60.f, 70.f,
+    //  80.f, 90.f, 100.f,
+    //  110.f, 120.f, 130.f
+    //
+    //  Dictionary:
+    //  0.f, 1.f, 2.f, 3.f,
+    //  4.f, 5.f, 6.f, 7.f,
+    //  8.f, 9.f, 10.f, 11.f,
+    //  12.f, 13.f, 14.f, 15.f
+    //
+    //  Output:
+    //  0.f, 40.f, 30.f, 20.f,
+    //  4.f, 70.f, 60.f, 50.f,
+    //  8.f, 100.f, 90.f, 80.f,
+    //  12.f, 130.f, 120.f, 110.f
+
+    engine engine;
+
+    auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 1, 4, 1, 2 } });                                    // Dictionary
+    auto input2 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 3, 1, 1 } });                                        // Indexes
+    auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, tensor{ batch(2), feature(1), spatial(3, 1, 1, 2) } }); // Updates
+    auto axis = cldnn::scatter_update::scatter_update_axis::along_x;
+
+    set_values(input1, {
+        FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f),
+        FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f),
+        FLOAT16(8.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f),
+        FLOAT16(12.0f), FLOAT16(13.0f), FLOAT16(14.0f), FLOAT16(15.0f)
+    });
+
+    set_values(input2, {
+        3.f, 2.f, 1.f
+    });
+
+    set_values(input3, {
+        FLOAT16(20.0f), FLOAT16(30.0f), FLOAT16(40.0f), 
+        FLOAT16(50.0f), FLOAT16(60.0f), FLOAT16(70.0f),
+        FLOAT16(80.0f), FLOAT16(90.0f), FLOAT16(100.0f), 
+        FLOAT16(110.0f), FLOAT16(120.0f), FLOAT16(130.0f)
+    });
+
+    topology topology;
+    topology.add(input_layout("InputDictionary", input1.get_layout()));
+    topology.add(input_layout("InputText", input2.get_layout()));
+    topology.add(input_layout("InputUpdates", input3.get_layout()));
+    topology.add(
+        scatter_update("scatter_update", "InputDictionary", "InputText", "InputUpdates", axis)
+    );
+    
+    network network(engine, topology); 
+    
+    network.set_input_data("InputDictionary", input1);
+    network.set_input_data("InputText", input2);
+    network.set_input_data("InputUpdates", input3);
+    
+    auto outputs = network.execute();
+
+    auto output = outputs.at("scatter_update").get_memory();
+    auto output_ptr = output.pointer<uint16_t>();
+
+    std::vector<float> expected_results = {
+        0.f, 40.f, 30.f, 20.f,
+        4.f, 70.f, 60.f, 50.f,
+        8.f, 100.f, 90.f, 80.f,
+        12.f, 130.f, 120.f, 110.f
+    };
+    
+    for (size_t i = 0; i < expected_results.size(); ++i) {
+        EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
+    }
+}
index b59af7c..0757409 100644 (file)
@@ -168,6 +168,36 @@ std::vector<T> generate_random_1d(size_t a, int min, int max, int k = 8) {
     return v;
 }
 
+template<typename Type>
+std::vector<Type> generate_random_norepetitions_1d(size_t size, int min, int max, float bound = 0.45) {
+    // Rerurn repeatless vector with size = size in range(min, max)
+    static std::default_random_engine generator(random_seed);
+    std::uniform_int_distribution<int> distribution(min, max);
+    std::uniform_real_distribution<float> to_bound_dist(0, bound);
+    std::set<int> repeatless;
+    std::vector<float> v(size, 0);
+    std::vector<Type> res(size);
+    int i = 0;
+    int temp;
+    if (max - min >= int(size) - 1){
+        while (repeatless.size() < size) {
+            temp = distribution(generator);
+            if (repeatless.find(temp) == repeatless.end()) {
+                repeatless.insert(temp);
+                v[i] = (float)temp;
+                i++;
+            }
+        }
+        for (size_t k = 0; k < v.size(); k++) {
+            v[k] += to_bound_dist(generator);
+            res[k] = static_cast<Type>(v[k]);
+        }
+    } else {
+        throw "Array size is bigger than size of range(min, max). Unable to generate array of unique integer numbers";
+    }
+    return res;
+}
+
 template<typename T>
 std::vector<std::vector<T>> generate_random_2d(size_t a, size_t b, int min, int max, int k = 8) {
     std::vector<std::vector<T>> v(a);