1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "ngraph/op/op.hpp"
27 /// \brief Softmax operation.
29 class NGRAPH_DEPRECATED(
30 "This operation is deprecated and will be removed soon. "
31 "Use v1::Softmax instead of it.") NGRAPH_API Softmax : public Op
33 NGRAPH_SUPPRESS_DEPRECATED_START
35 static constexpr NodeTypeInfo type_info{"Softmax", 0};
36 const NodeTypeInfo& get_type_info() const override { return type_info; }
38 /// \brief Constructs a softmax operation.
40 /// \param arg Node that produces the first input tensor.<br>
42 /// \param axes The axis positions (0-based) on which to calculate the softmax.
44 /// Output `[d0, ...]`
46 Softmax(const Output<Node>& arg, const AxisSet& axes);
47 /// \brief Constructs a softmax operation.
49 /// \param arg Node that produces the first input tensor.<br>
51 /// \param axes node produces the axis positions (0-based) on which to calculate the
54 /// Output `[d0, ...]`
56 Softmax(const Output<Node>& arg, const Output<Node>& axes);
58 void validate_and_infer_types() override;
60 virtual std::shared_ptr<Node>
61 clone_with_new_inputs(const OutputVector& new_args) const override;
63 bool are_axes_constant() const;
64 const AxisSet get_axes() const;
65 void set_axes(const AxisSet& axes);
67 bool evaluate(const HostTensorVector& outputs,
68 const HostTensorVector& inputs) const override;
69 NGRAPH_SUPPRESS_DEPRECATED_END
75 class NGRAPH_API Softmax : public Op
78 static constexpr NodeTypeInfo type_info{"Softmax", 1};
79 const NodeTypeInfo& get_type_info() const override { return type_info; }
84 /// \brief Constructs a softmax operation.
86 /// \param arg Node that produces the first input tensor.<br>
88 /// \param axis The axis position (0-based) on which to calculate the softmax.
90 /// Output `[d0, ...]`
92 Softmax(const Output<Node>& arg, const size_t axis);
94 bool visit_attributes(AttributeVisitor& visitor) override;
95 void validate_and_infer_types() override;
97 size_t get_version() const override { return 1; }
98 virtual std::shared_ptr<Node>
99 clone_with_new_inputs(const OutputVector& new_args) const override;
101 size_t get_axis() const { return m_axis; }
102 void set_axis(const size_t axis) { m_axis = axis; }
103 bool evaluate(const HostTensorVector& outputs,
104 const HostTensorVector& inputs) const override;
111 // default opset version
112 NGRAPH_SUPPRESS_DEPRECATED_START
114 NGRAPH_SUPPRESS_DEPRECATED_END