1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "ngraph/axis_set.hpp"
20 #include "ngraph/op/op.hpp"
21 #include "ngraph/op/util/attr_types.hpp"
22 #include "ngraph/op/util/broadcast_base.hpp"
30 /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
31 /// input as needed along the new axes.
32 class NGRAPH_API Broadcast : public util::BroadcastBase
35 static constexpr NodeTypeInfo type_info{"Broadcast", 3};
36 const NodeTypeInfo& get_type_info() const override { return type_info; }
37 /// \brief Constructs a broadcast operation.
38 Broadcast() = default;
39 /// \brief Constructs a broadcast operation.
41 /// \param arg The input tensor to be broadcast.
42 /// \param target_shape The shape of the output tensor.
43 /// \param axes_mapping The axis positions (0-based) in the result that correspond
44 /// to input axes. 'Arg' tensor is broadcast along the
46 /// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
47 /// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
48 /// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
49 /// \param broadcast_spec Broadcast specification to use for determining broadcast
50 /// axes. 'axes_mapping' should not be provided if mode other
51 /// than explicit (none) is used.
52 Broadcast(const Output<Node>& arg,
53 const Output<Node>& target_shape,
54 const Output<Node>& axes_mapping,
55 const BroadcastModeSpec& broadcast_spec = BroadcastType::EXPLICIT);
57 /// \brief Constructs a broadcast operation.
59 /// \param arg The input tensor to be broadcast.
60 /// \param target_shape The shape of the output tensor.
61 /// \param broadcast_spec Broadcast specification to use for determining broadcast
63 Broadcast(const Output<Node>& arg,
64 const Output<Node>& target_shape,
65 const BroadcastModeSpec& broadcast_spec = BroadcastType::NUMPY);
67 bool visit_attributes(AttributeVisitor& visitor) override;
70 clone_with_new_inputs(const OutputVector& new_args) const override;
72 // \return Broadcast Specification.
73 const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; }
74 void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec)
76 m_mode = broadcast_spec;
79 void validate_and_infer_types() override;
81 /// \return true and the AxisSet if broadcast axes can be fully determined.
82 std::pair<bool, AxisSet> get_broadcast_axes() const override;
83 bool evaluate(const HostTensorVector& outputs,
84 const HostTensorVector& inputs) const override;
90 /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
91 /// input as needed along the new axes.
92 class NGRAPH_API Broadcast : public util::BroadcastBase
95 static constexpr NodeTypeInfo type_info{"Broadcast", 1};
96 const NodeTypeInfo& get_type_info() const override { return type_info; }
97 /// \brief Constructs a broadcast operation.
98 Broadcast() = default;
99 /// \brief Constructs a broadcast operation.
101 /// \param arg The input tensor to be broadcast.
102 /// \param target_shape The shape of the output tensor.
103 /// \param axes_mapping The axis positions (0-based) in the result that correspond
104 /// to input axes. 'Arg' tensor is broadcast along the
106 /// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
107 /// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
108 /// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
109 /// \param broadcast_spec Broadcast specification to use for determining broadcast
110 /// axes. 'axes_mapping' is ignored if broadcast_spec is not
112 Broadcast(const Output<Node>& arg,
113 const Output<Node>& target_shape,
114 const Output<Node>& axes_mapping,
115 const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
117 /// \brief Constructs a broadcast operation.
119 /// \param arg The input tensor to be broadcast.
120 /// \param target_shape The shape of the output tensor.
121 /// \param broadcast_spec Broadcast specification to use for determining broadcast
123 Broadcast(const Output<Node>& arg,
124 const Output<Node>& target_shape,
125 const AutoBroadcastSpec& broadcast_spec =
126 AutoBroadcastSpec(AutoBroadcastType::NUMPY));
128 bool visit_attributes(AttributeVisitor& visitor) override;
130 std::shared_ptr<Node>
131 clone_with_new_inputs(const OutputVector& new_args) const override;
133 /// \return Broadcast Specification.
134 const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
135 void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
137 m_broadcast_spec = broadcast_spec;
140 void validate_and_infer_types() override;
141 bool evaluate(const HostTensorVector& outputs,
142 const HostTensorVector& inputs) const override;
145 AutoBroadcastSpec m_broadcast_spec;