Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / ngraph / core / include / ngraph / op / broadcast.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include "ngraph/axis_set.hpp"
20 #include "ngraph/op/op.hpp"
21 #include "ngraph/op/util/attr_types.hpp"
22 #include "ngraph/op/util/broadcast_base.hpp"
23
24 namespace ngraph
25 {
26     namespace op
27     {
28         namespace v3
29         {
30             /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
31             ///        input as needed along the new axes.
32             class NGRAPH_API Broadcast : public util::BroadcastBase
33             {
34             public:
35                 static constexpr NodeTypeInfo type_info{"Broadcast", 3};
36                 const NodeTypeInfo& get_type_info() const override { return type_info; }
37                 /// \brief Constructs a broadcast operation.
38                 Broadcast() = default;
39                 /// \brief Constructs a broadcast operation.
40                 ///
41                 /// \param arg            The input tensor to be broadcast.
42                 /// \param target_shape   The shape of the output tensor.
43                 /// \param axes_mapping   The axis positions (0-based) in the result that correspond
44                 ///                       to input axes. 'Arg' tensor is broadcast along the
45                 ///                       remaining axes.
46                 ///                       E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
47                 ///                       axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
48                 ///                       axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
49                 /// \param broadcast_spec Broadcast specification to use for determining broadcast
50                 ///                       axes. 'axes_mapping' should not be provided if mode other
51                 ///                       than explicit (none) is used.
52                 Broadcast(const Output<Node>& arg,
53                           const Output<Node>& target_shape,
54                           const Output<Node>& axes_mapping,
55                           const BroadcastModeSpec& broadcast_spec = BroadcastType::EXPLICIT);
56
57                 /// \brief Constructs a broadcast operation.
58                 ///
59                 /// \param arg            The input tensor to be broadcast.
60                 /// \param target_shape   The shape of the output tensor.
61                 /// \param broadcast_spec Broadcast specification to use for determining broadcast
62                 ///                       axes
63                 Broadcast(const Output<Node>& arg,
64                           const Output<Node>& target_shape,
65                           const BroadcastModeSpec& broadcast_spec = BroadcastType::NUMPY);
66
67                 bool visit_attributes(AttributeVisitor& visitor) override;
68
69                 std::shared_ptr<Node>
70                     clone_with_new_inputs(const OutputVector& new_args) const override;
71
72                 // \return Broadcast Specification.
73                 const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; }
74                 void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec)
75                 {
76                     m_mode = broadcast_spec;
77                 }
78
79                 void validate_and_infer_types() override;
80
81                 /// \return true and the AxisSet if broadcast axes can be fully determined.
82                 std::pair<bool, AxisSet> get_broadcast_axes() const override;
83                 bool evaluate(const HostTensorVector& outputs,
84                               const HostTensorVector& inputs) const override;
85             };
86         } // namespace v3
87
88         namespace v1
89         {
90             /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
91             ///        input as needed along the new axes.
92             class NGRAPH_API Broadcast : public util::BroadcastBase
93             {
94             public:
95                 static constexpr NodeTypeInfo type_info{"Broadcast", 1};
96                 const NodeTypeInfo& get_type_info() const override { return type_info; }
97                 /// \brief Constructs a broadcast operation.
98                 Broadcast() = default;
99                 /// \brief Constructs a broadcast operation.
100                 ///
101                 /// \param arg            The input tensor to be broadcast.
102                 /// \param target_shape   The shape of the output tensor.
103                 /// \param axes_mapping   The axis positions (0-based) in the result that correspond
104                 ///                       to input axes. 'Arg' tensor is broadcast along the
105                 ///                       remaining axes.
106                 ///                       E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
107                 ///                       axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
108                 ///                       axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
109                 /// \param broadcast_spec Broadcast specification to use for determining broadcast
110                 ///                       axes. 'axes_mapping' is ignored if broadcast_spec is not
111                 ///                       NONE
112                 Broadcast(const Output<Node>& arg,
113                           const Output<Node>& target_shape,
114                           const Output<Node>& axes_mapping,
115                           const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
116
117                 /// \brief Constructs a broadcast operation.
118                 ///
119                 /// \param arg            The input tensor to be broadcast.
120                 /// \param target_shape   The shape of the output tensor.
121                 /// \param broadcast_spec Broadcast specification to use for determining broadcast
122                 ///                       axes
123                 Broadcast(const Output<Node>& arg,
124                           const Output<Node>& target_shape,
125                           const AutoBroadcastSpec& broadcast_spec =
126                               AutoBroadcastSpec(AutoBroadcastType::NUMPY));
127
128                 bool visit_attributes(AttributeVisitor& visitor) override;
129
130                 std::shared_ptr<Node>
131                     clone_with_new_inputs(const OutputVector& new_args) const override;
132
133                 /// \return Broadcast Specification.
134                 const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
135                 void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
136                 {
137                     m_broadcast_spec = broadcast_spec;
138                 }
139
140                 void validate_and_infer_types() override;
141                 bool evaluate(const HostTensorVector& outputs,
142                               const HostTensorVector& inputs) const override;
143
144             protected:
145                 AutoBroadcastSpec m_broadcast_spec;
146             };
147         } // namespace v1
148     }
149 }