Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / op / dot.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <functional>
18 #include <memory>
19
20 #include "ngraph/axis_vector.hpp"
21 #include "ngraph/graph_util.hpp"
22 #include "ngraph/op/broadcast.hpp"
23 #include "ngraph/op/dot.hpp"
24 #include "ngraph/op/reshape.hpp"
25 #include "ngraph/shape.hpp"
26
27 NGRAPH_SUPPRESS_DEPRECATED_START
28
29 using namespace std;
30 using namespace ngraph;
31
32 constexpr NodeTypeInfo op::Dot::type_info;
33
34 op::Dot::Dot(const Output<Node>& arg0, const Output<Node>& arg1)
35     : Dot(arg0, arg1, 0, false)
36 {
37 }
38
39 op::Dot::Dot(const Output<Node>& arg0,
40              const Output<Node>& arg1,
41              size_t reduction_axes_count,
42              bool has_reduction_axes_count)
43     : Op({arg0, arg1})
44     , m_reduction_axes_count(reduction_axes_count)
45     , m_has_reduction_axes_count(has_reduction_axes_count)
46 {
47     constructor_validate_and_infer_types();
48 }
49
50 void op::Dot::validate_and_infer_types()
51 {
52     element::Type result_et;
53
54     NODE_VALIDATION_CHECK(
55         this,
56         element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
57         "Arguments do not have the same element type (arg0 element type: ",
58         get_input_element_type(0),
59         ", arg1 element type: ",
60         get_input_element_type(1),
61         ").");
62
63     const PartialShape& arg0_shape = get_input_partial_shape(0);
64     const PartialShape& arg1_shape = get_input_partial_shape(1);
65
66     // If an explicit value was not passed for reduction axis count at construction time, we have
67     // some extra work to do.
68     //
69     // - If one of the arguments is known to be scalar, the count is 0.
70     // - If both of the arguments are known to be nonscalar, the count is 1.
71     // - Otherwise, the count is unknown.
72     bool reduction_axes_ambiguous = !m_has_reduction_axes_count;
73
74     if (reduction_axes_ambiguous)
75     {
76         if (arg0_shape.rank().same_scheme(0) || arg1_shape.rank().same_scheme(0))
77         {
78             m_reduction_axes_count = 0;
79             reduction_axes_ambiguous = false;
80         }
81         else if (arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
82         {
83             m_reduction_axes_count = 1;
84             reduction_axes_ambiguous = false;
85         }
86     }
87
88     PartialShape result_shape;
89
90     NODE_VALIDATION_CHECK(this,
91                           reduction_axes_ambiguous || arg0_shape.rank().is_dynamic() ||
92                               m_reduction_axes_count <= arg0_shape.rank().get_length(),
93                           "Reduction axes count (",
94                           m_reduction_axes_count,
95                           ") is too large (arg0 shape: ",
96                           arg0_shape,
97                           ", arg1 shape: ",
98                           arg1_shape,
99                           ").");
100
101     NODE_VALIDATION_CHECK(this,
102                           reduction_axes_ambiguous || arg1_shape.rank().is_dynamic() ||
103                               m_reduction_axes_count <= arg1_shape.rank().get_length(),
104                           "Reduction axes count (",
105                           m_reduction_axes_count,
106                           ") is too large (arg0 shape: ",
107                           arg0_shape,
108                           ", arg1 shape: ",
109                           arg1_shape,
110                           ").");
111
112     if (!reduction_axes_ambiguous && arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
113     {
114         for (size_t i = 0; i < m_reduction_axes_count; i++)
115         {
116             size_t axis_index_arg0 = arg0_shape.rank().get_length() - m_reduction_axes_count + i;
117             size_t axis_index_arg1 = i;
118
119             NODE_VALIDATION_CHECK(
120                 this,
121                 arg0_shape[axis_index_arg0].compatible(arg1_shape[axis_index_arg1]),
122                 "Paired axes (axis ",
123                 axis_index_arg0,
124                 " from arg0, axis ",
125                 axis_index_arg1,
126                 " from arg1) do not have same length (arg0 shape: ",
127                 arg0_shape,
128                 ", arg1 shape: ",
129                 arg1_shape,
130                 ", reduction axes count: ",
131                 m_reduction_axes_count,
132                 ").");
133         }
134
135         std::vector<Dimension> result_dims(arg0_shape.rank().get_length() +
136                                            arg1_shape.rank().get_length() -
137                                            2 * m_reduction_axes_count);
138
139         size_t i = 0;
140
141         for (size_t j = 0; j < arg0_shape.rank().get_length() - m_reduction_axes_count; j++)
142         {
143             result_dims[i++] = arg0_shape[j];
144         }
145         for (size_t j = m_reduction_axes_count; j < arg1_shape.rank().get_length(); j++)
146         {
147             result_dims[i++] = arg1_shape[j];
148         }
149
150         result_shape = PartialShape(result_dims);
151     }
152     else
153     {
154         result_shape = PartialShape::dynamic();
155     }
156
157     set_output_type(0, result_et, result_shape);
158 }
159
160 shared_ptr<op::Reshape> make_reshape_axes_to_front(const Output<Node>& n,
161                                                    const Shape& front_shape,
162                                                    const Shape& back_shape)
163 {
164     AxisVector input_order;
165     Shape output_shape;
166
167     for (size_t i = 0; i < back_shape.size(); i++)
168     {
169         input_order.push_back(front_shape.size() + i);
170         output_shape.push_back(back_shape[i]);
171     }
172
173     for (size_t i = 0; i < front_shape.size(); i++)
174     {
175         input_order.push_back(i);
176         output_shape.push_back(front_shape[i]);
177     }
178
179     return make_shared<op::Reshape>(n, input_order, output_shape);
180 }
181
182 shared_ptr<Node> op::Dot::get_default_value() const
183 {
184     return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
185 }