1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "ngraph/op/op.hpp"
20 #include "ngraph/op/util/scatter_base.hpp"
28 /// \brief Add updates to slices from inputs addressed by indices
29 class NGRAPH_API ScatterAdd : public Op
32 static constexpr NodeTypeInfo type_info{"ScatterAdd", 0};
33 const NodeTypeInfo& get_type_info() const override { return type_info; }
34 ScatterAdd() = default;
35 /// \param inputs Tensor
36 /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
37 /// \param updates Tensor: Must have same type as inputs
38 ScatterAdd(const Output<Node>& inputs,
39 const Output<Node>& indices,
40 const Output<Node>& updates)
41 : Op({inputs, indices, updates})
43 constructor_validate_and_infer_types();
46 void validate_and_infer_types() override;
47 virtual std::shared_ptr<Node>
48 clone_with_new_inputs(const OutputVector& new_args) const override;
55 /// \brief Add updates to slices from inputs addressed by indices
57 class NGRAPH_API ScatterAdd : public util::ScatterBase
60 static constexpr NodeTypeInfo type_info{"ScatterAdd", 3};
61 const NodeTypeInfo& get_type_info() const override { return type_info; }
62 ScatterAdd() = default;
65 /// \brief Constructs ScatterAdd object.
67 /// \param data The input tensor to be updated.
68 /// \param indices The tensor with indexes which will be updated.
69 /// \param updates The tensor with update values.
70 /// \param[in] axis The axis at which elements will be updated.
72 ScatterAdd(const Output<Node>& data,
73 const Output<Node>& indices,
74 const Output<Node>& updates,
75 const Output<Node>& axis);
77 virtual std::shared_ptr<Node>
78 clone_with_new_inputs(const OutputVector& inputs) const override;