1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
20 #include "ngraph/axis_vector.hpp"
21 #include "ngraph/graph_util.hpp"
22 #include "ngraph/op/broadcast.hpp"
23 #include "ngraph/op/dot.hpp"
24 #include "ngraph/op/reshape.hpp"
25 #include "ngraph/shape.hpp"
27 NGRAPH_SUPPRESS_DEPRECATED_START
30 using namespace ngraph;
32 constexpr NodeTypeInfo op::Dot::type_info;
34 op::Dot::Dot(const Output<Node>& arg0, const Output<Node>& arg1)
35 : Dot(arg0, arg1, 0, false)
39 op::Dot::Dot(const Output<Node>& arg0,
40 const Output<Node>& arg1,
41 size_t reduction_axes_count,
42 bool has_reduction_axes_count)
44 , m_reduction_axes_count(reduction_axes_count)
45 , m_has_reduction_axes_count(has_reduction_axes_count)
47 constructor_validate_and_infer_types();
50 void op::Dot::validate_and_infer_types()
52 element::Type result_et;
54 NODE_VALIDATION_CHECK(
56 element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
57 "Arguments do not have the same element type (arg0 element type: ",
58 get_input_element_type(0),
59 ", arg1 element type: ",
60 get_input_element_type(1),
63 const PartialShape& arg0_shape = get_input_partial_shape(0);
64 const PartialShape& arg1_shape = get_input_partial_shape(1);
66 // If an explicit value was not passed for reduction axis count at construction time, we have
67 // some extra work to do.
69 // - If one of the arguments is known to be scalar, the count is 0.
70 // - If both of the arguments are known to be nonscalar, the count is 1.
71 // - Otherwise, the count is unknown.
72 bool reduction_axes_ambiguous = !m_has_reduction_axes_count;
74 if (reduction_axes_ambiguous)
76 if (arg0_shape.rank().same_scheme(0) || arg1_shape.rank().same_scheme(0))
78 m_reduction_axes_count = 0;
79 reduction_axes_ambiguous = false;
81 else if (arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
83 m_reduction_axes_count = 1;
84 reduction_axes_ambiguous = false;
88 PartialShape result_shape;
90 NODE_VALIDATION_CHECK(this,
91 reduction_axes_ambiguous || arg0_shape.rank().is_dynamic() ||
92 m_reduction_axes_count <= arg0_shape.rank().get_length(),
93 "Reduction axes count (",
94 m_reduction_axes_count,
95 ") is too large (arg0 shape: ",
101 NODE_VALIDATION_CHECK(this,
102 reduction_axes_ambiguous || arg1_shape.rank().is_dynamic() ||
103 m_reduction_axes_count <= arg1_shape.rank().get_length(),
104 "Reduction axes count (",
105 m_reduction_axes_count,
106 ") is too large (arg0 shape: ",
112 if (!reduction_axes_ambiguous && arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
114 for (size_t i = 0; i < m_reduction_axes_count; i++)
116 size_t axis_index_arg0 = arg0_shape.rank().get_length() - m_reduction_axes_count + i;
117 size_t axis_index_arg1 = i;
119 NODE_VALIDATION_CHECK(
121 arg0_shape[axis_index_arg0].compatible(arg1_shape[axis_index_arg1]),
122 "Paired axes (axis ",
126 " from arg1) do not have same length (arg0 shape: ",
130 ", reduction axes count: ",
131 m_reduction_axes_count,
135 std::vector<Dimension> result_dims(arg0_shape.rank().get_length() +
136 arg1_shape.rank().get_length() -
137 2 * m_reduction_axes_count);
141 for (size_t j = 0; j < arg0_shape.rank().get_length() - m_reduction_axes_count; j++)
143 result_dims[i++] = arg0_shape[j];
145 for (size_t j = m_reduction_axes_count; j < arg1_shape.rank().get_length(); j++)
147 result_dims[i++] = arg1_shape[j];
150 result_shape = PartialShape(result_dims);
154 result_shape = PartialShape::dynamic();
157 set_output_type(0, result_et, result_shape);
160 shared_ptr<op::Reshape> make_reshape_axes_to_front(const Output<Node>& n,
161 const Shape& front_shape,
162 const Shape& back_shape)
164 AxisVector input_order;
167 for (size_t i = 0; i < back_shape.size(); i++)
169 input_order.push_back(front_shape.size() + i);
170 output_shape.push_back(back_shape[i]);
173 for (size_t i = 0; i < front_shape.size(); i++)
175 input_order.push_back(i);
176 output_shape.push_back(front_shape[i]);
179 return make_shared<op::Reshape>(n, input_order, output_shape);
182 shared_ptr<Node> op::Dot::get_default_value() const
184 return ngraph::make_constant_from_string("0", get_element_type(), get_shape());