[ MO ] KSO ON by default (#1730)
[platform/upstream/dldt.git] / ngraph / core / include / ngraph / op / scatter_add.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include "ngraph/op/op.hpp"
20 #include "ngraph/op/util/scatter_base.hpp"
21
22 namespace ngraph
23 {
24     namespace op
25     {
26         namespace v0
27         {
28             /// \brief Add updates to slices from inputs addressed by indices
29             class NGRAPH_API ScatterAdd : public Op
30             {
31             public:
32                 static constexpr NodeTypeInfo type_info{"ScatterAdd", 0};
33                 const NodeTypeInfo& get_type_info() const override { return type_info; }
34                 ScatterAdd() = default;
35                 /// \param inputs Tensor
36                 /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
37                 /// \param updates Tensor: Must have same type as inputs
38                 ScatterAdd(const Output<Node>& inputs,
39                            const Output<Node>& indices,
40                            const Output<Node>& updates)
41                     : Op({inputs, indices, updates})
42                 {
43                     constructor_validate_and_infer_types();
44                 }
45
46                 void validate_and_infer_types() override;
47                 virtual std::shared_ptr<Node>
48                     clone_with_new_inputs(const OutputVector& new_args) const override;
49             };
50         }
51
52         namespace v3
53         {
54             ///
55             /// \brief      Add updates to slices from inputs addressed by indices
56             ///
57             class NGRAPH_API ScatterAdd : public util::ScatterBase
58             {
59             public:
60                 static constexpr NodeTypeInfo type_info{"ScatterAdd", 3};
61                 const NodeTypeInfo& get_type_info() const override { return type_info; }
62                 ScatterAdd() = default;
63
64                 ///
65                 /// \brief      Constructs ScatterAdd object.
66                 ///
67                 /// \param      data     The input tensor to be updated.
68                 /// \param      indices  The tensor with indexes which will be updated.
69                 /// \param      updates  The tensor with update values.
70                 /// \param[in]  axis     The axis at which elements will be updated.
71                 ///
72                 ScatterAdd(const Output<Node>& data,
73                            const Output<Node>& indices,
74                            const Output<Node>& updates,
75                            const Output<Node>& axis);
76
77                 virtual std::shared_ptr<Node>
78                     clone_with_new_inputs(const OutputVector& inputs) const override;
79             };
80         }
81         using v0::ScatterAdd;
82     }
83 }