Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / include / ngraph / op / softmax.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include "ngraph/op/op.hpp"
20
21 namespace ngraph
22 {
23     namespace op
24     {
25         namespace v0
26         {
27             /// \brief Softmax operation.
28             ///
29             class NGRAPH_DEPRECATED(
30                 "This operation is deprecated and will be removed soon. "
31                 "Use v1::Softmax instead of it.") NGRAPH_API Softmax : public Op
32             {
33                 NGRAPH_SUPPRESS_DEPRECATED_START
34             public:
35                 static constexpr NodeTypeInfo type_info{"Softmax", 0};
36                 const NodeTypeInfo& get_type_info() const override { return type_info; }
37                 Softmax() = default;
38                 /// \brief Constructs a softmax operation.
39                 ///
40                 /// \param arg Node that produces the first input tensor.<br>
41                 /// `[d0, ...]`
42                 /// \param axes The axis positions (0-based) on which to calculate the softmax.
43                 ///
44                 /// Output `[d0, ...]`
45                 ///
46                 Softmax(const Output<Node>& arg, const AxisSet& axes);
47                 /// \brief Constructs a softmax operation.
48                 ///
49                 /// \param arg Node that produces the first input tensor.<br>
50                 /// `[d0, ...]`
51                 /// \param axes node produces the axis positions (0-based) on which to calculate the
52                 /// softmax.
53                 ///
54                 /// Output `[d0, ...]`
55                 ///
56                 Softmax(const Output<Node>& arg, const Output<Node>& axes);
57
58                 void validate_and_infer_types() override;
59
60                 virtual std::shared_ptr<Node>
61                     clone_with_new_inputs(const OutputVector& new_args) const override;
62
63                 bool are_axes_constant() const;
64                 const AxisSet get_axes() const;
65                 void set_axes(const AxisSet& axes);
66
67                 bool evaluate(const HostTensorVector& outputs,
68                               const HostTensorVector& inputs) const override;
69                 NGRAPH_SUPPRESS_DEPRECATED_END
70             };
71         }
72
73         namespace v1
74         {
75             class NGRAPH_API Softmax : public Op
76             {
77             public:
78                 static constexpr NodeTypeInfo type_info{"Softmax", 1};
79                 const NodeTypeInfo& get_type_info() const override { return type_info; }
80                 Softmax()
81                     : m_axis(0)
82                 {
83                 }
84                 /// \brief Constructs a softmax operation.
85                 ///
86                 /// \param arg Node that produces the first input tensor.<br>
87                 /// `[d0, ...]`
88                 /// \param axis The axis position (0-based) on which to calculate the softmax.
89                 ///
90                 /// Output `[d0, ...]`
91                 ///
92                 Softmax(const Output<Node>& arg, const size_t axis);
93
94                 bool visit_attributes(AttributeVisitor& visitor) override;
95                 void validate_and_infer_types() override;
96
97                 size_t get_version() const override { return 1; }
98                 virtual std::shared_ptr<Node>
99                     clone_with_new_inputs(const OutputVector& new_args) const override;
100
101                 size_t get_axis() const { return m_axis; }
102                 void set_axis(const size_t axis) { m_axis = axis; }
103                 bool evaluate(const HostTensorVector& outputs,
104                               const HostTensorVector& inputs) const override;
105
106             private:
107                 size_t m_axis;
108             };
109         }
110
111         // default opset version
112         NGRAPH_SUPPRESS_DEPRECATED_START
113         using v0::Softmax;
114         NGRAPH_SUPPRESS_DEPRECATED_END
115     }
116 }