f701a14b381dd5a964747c186c11825a35e4db1c
[platform/upstream/dldt.git] / ngraph / core / include / ngraph / op / broadcast.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include "ngraph/axis_set.hpp"
20 #include "ngraph/op/op.hpp"
21 #include "ngraph/op/util/attr_types.hpp"
22 #include "ngraph/op/util/broadcast_base.hpp"
23
24 namespace ngraph
25 {
26     namespace op
27     {
28         namespace v3
29         {
30             /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
31             ///        input as needed along the new axes.
32             class NGRAPH_API Broadcast : public util::BroadcastBase
33             {
34             public:
35                 static constexpr NodeTypeInfo type_info{"Broadcast", 3};
36                 const NodeTypeInfo& get_type_info() const override { return type_info; }
37                 /// \brief Constructs a broadcast operation.
38                 Broadcast() = default;
39                 /// \brief Constructs a broadcast operation.
40                 ///
41                 /// \param arg            The input tensor to be broadcast.
42                 /// \param target_shape   The shape of the output tensor.
43                 /// \param axes_mapping   The axis positions (0-based) in the result that correspond
44                 ///                       to input axes. 'Arg' tensor is broadcast along the
45                 ///                       remaining axes.
46                 ///                       E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
47                 ///                       axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
48                 ///                       axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
49                 /// \param broadcast_spec Broadcast specification to use for determining broadcast
50                 ///                       axes. 'axes_mapping' should not be provided if mode other
51                 ///                       than explicit (none) is used.
52                 Broadcast(const Output<Node>& arg,
53                           const Output<Node>& target_shape,
54                           const Output<Node>& axes_mapping,
55                           const BroadcastModeSpec& broadcast_spec = BroadcastType::EXPLICIT);
56
57                 /// \brief Constructs a broadcast operation.
58                 ///
59                 /// \param arg            The input tensor to be broadcast.
60                 /// \param target_shape   The shape of the output tensor.
61                 /// \param broadcast_spec Broadcast specification to use for determining broadcast
62                 ///                       axes
63                 Broadcast(const Output<Node>& arg,
64                           const Output<Node>& target_shape,
65                           const BroadcastModeSpec& broadcast_spec = BroadcastType::NUMPY);
66
67                 bool visit_attributes(AttributeVisitor& visitor) override;
68
69                 std::shared_ptr<Node>
70                     clone_with_new_inputs(const OutputVector& new_args) const override;
71
72                 // \return Broadcast Specification.
73                 const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; }
74                 void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec)
75                 {
76                     m_mode = broadcast_spec;
77                 }
78
79                 void validate_and_infer_types() override;
80
81                 /// \return true and the AxisSet if broadcast axes can be fully determined.
82                 std::pair<bool, AxisSet> get_broadcast_axes() const override;
83                 bool evaluate(const HostTensorVector& outputs,
84                               const HostTensorVector& inputs) const override;
85             };
86         } // namespace v3
87
88         namespace v1
89         {
90             /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
91             ///        input as needed along the new axes.
92             class NGRAPH_API Broadcast : public util::BroadcastBase
93             {
94             public:
95                 static constexpr NodeTypeInfo type_info{"Broadcast", 1};
96                 const NodeTypeInfo& get_type_info() const override { return type_info; }
97                 /// \brief Constructs a broadcast operation.
98                 Broadcast() = default;
99                 /// \brief Constructs a broadcast operation.
100                 ///
101                 /// \param arg            The input tensor to be broadcast.
102                 /// \param target_shape   The shape of the output tensor.
103                 /// \param axes_mapping   The axis positions (0-based) in the result that correspond
104                 ///                       to input axes. 'Arg' tensor is broadcast along the
105                 ///                       remaining axes.
106                 ///                       E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
107                 ///                       axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
108                 ///                       axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
109                 /// \param broadcast_spec Broadcast specification to use for determining broadcast
110                 ///                       axes. 'axes_mapping' is ignored if broadcast_spec is not
111                 ///                       NONE
112                 Broadcast(const Output<Node>& arg,
113                           const Output<Node>& target_shape,
114                           const Output<Node>& axes_mapping,
115                           const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
116
117                 /// \brief Constructs a broadcast operation.
118                 ///
119                 /// \param arg            The input tensor to be broadcast.
120                 /// \param target_shape   The shape of the output tensor.
121                 /// \param broadcast_spec Broadcast specification to use for determining broadcast
122                 ///                       axes
123                 Broadcast(const Output<Node>& arg,
124                           const Output<Node>& target_shape,
125                           const AutoBroadcastSpec& broadcast_spec =
126                               AutoBroadcastSpec(AutoBroadcastType::NUMPY));
127
128                 bool visit_attributes(AttributeVisitor& visitor) override;
129
130                 std::shared_ptr<Node>
131                     clone_with_new_inputs(const OutputVector& new_args) const override;
132
133                 /// \return Broadcast Specification.
134                 const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
135                 void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
136                 {
137                     m_broadcast_spec = broadcast_spec;
138                 }
139
140                 void validate_and_infer_types() override;
141                 bool evaluate(const HostTensorVector& outputs,
142                               const HostTensorVector& inputs) const override;
143
144             protected:
145                 AutoBroadcastSpec m_broadcast_spec;
146             };
147         } // namespace v1
148
149         namespace v0
150         {
151             NGRAPH_SUPPRESS_DEPRECATED_START
152             /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
153             ///        input as needed along the new axes.
154             class NGRAPH_DEPRECATED(
155                 "This operation is deprecated and will be removed soon. "
156                 "Use v1::Broadcast instead of it.") NGRAPH_API Broadcast : public Op
157             {
158             public:
159                 static constexpr NodeTypeInfo type_info{"Broadcast", 0};
160                 const NodeTypeInfo& get_type_info() const override { return type_info; }
161                 /// \brief Constructs a broadcast operation.
162                 Broadcast() = default;
163                 /// \brief Constructs a broadcast operation.
164                 ///
165                 /// \param arg            The input tensor to be broadcast.
166                 /// \param shape          The shape of the output tensor.
167                 /// \param broadcast_axes The axis positions (0-based) in the result that are being
168                 ///                       broadcast. The remaining axes in shape must be the same as
169                 ///                       the shape of arg.
170                 Broadcast(const Output<Node>& arg,
171                           const Shape& shape,
172                           const AxisSet& broadcast_axes);
173                 bool visit_attributes(AttributeVisitor& visitor) override;
174                 void validate_and_infer_types() override;
175
176                 std::shared_ptr<Node>
177                     clone_with_new_inputs(const OutputVector& new_args) const override;
178
179                 /// \return A set containing the indices of the broadcast axes (0-based).
180                 const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
181                 void set_broadcast_axes(const AxisSet& broadcast_axes)
182                 {
183                     m_broadcast_axes = broadcast_axes;
184                 }
185                 const Shape& get_broadcast_shape() const { return m_shape; }
186                 void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
187                 bool evaluate(const HostTensorVector& outputs,
188                               const HostTensorVector& inputs) const override;
189
190             protected:
191                 Broadcast(const OutputVector& args,
192                           const Shape& shape,
193                           const AxisSet& broadcast_axes);
194
195                 virtual void infer_shape() {}
196                 Shape m_shape;
197                 AxisSet m_broadcast_axes;
198             };
199
200             /// \brief Broadcast arg to the same shape as like_arg.
201             class NGRAPH_DEPRECATED(
202                 "This operation is deprecated and will be removed soon. Please don't use it.")
203                 NGRAPH_API BroadcastLike : public v0::Broadcast
204             {
205             public:
206                 static constexpr NodeTypeInfo type_info{"BroadcastLike", 0};
207                 const NodeTypeInfo& get_type_info() const override { return type_info; }
208                 /// \brief Broadcast arg to the same shape as like_arg.
209                 BroadcastLike() = default;
210                 /// \brief Broadcast arg to the same shape as like_arg.
211                 ///
212                 /// Once the shape of like_arg is known, this op will be replaced with an equivalent
213                 /// Broadcast op.
214                 ///
215                 /// \param arg The argument to be broadcast.
216                 /// \param like_arg Provides the shape for the result.
217                 /// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
218                 ///        arg must be scalar and all axes are broadcast.
219                 BroadcastLike(const Output<Node>& arg,
220                               const Output<Node>& like_arg,
221                               const AxisSet& initial_broadcast_axes);
222                 bool visit_attributes(AttributeVisitor& visitor) override;
223                 std::shared_ptr<Node>
224                     clone_with_new_inputs(const OutputVector& new_args) const override;
225
226                 void infer_shape() override;
227                 const AxisSet& get_initial_broadcast_axes() const
228                 {
229                     return m_initial_broadcast_axes;
230                 }
231                 void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
232                 {
233                     m_initial_broadcast_axes = initial_broadcast_axes;
234                 }
235
236             protected:
237                 AxisSet m_initial_broadcast_axes;
238             };
239             NGRAPH_SUPPRESS_DEPRECATED_END
240         } // namespace v0
241
242         NGRAPH_SUPPRESS_DEPRECATED_START
243         using v0::Broadcast;
244         using v0::BroadcastLike;
245         NGRAPH_SUPPRESS_DEPRECATED_END
246     }
247 }