Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / op / one_hot.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "ngraph/op/one_hot.hpp"
18 #include "ngraph/attribute_visitor.hpp"
19 #include "ngraph/op/util/op_types.hpp"
20 #include "ngraph/validation_util.hpp"
21
22 NGRAPH_SUPPRESS_DEPRECATED_START
23
24 using namespace std;
25 using namespace ngraph;
26
27 constexpr NodeTypeInfo op::v0::OneHot::type_info;
28
29 op::v0::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
30     : Op({arg})
31     , m_shape(shape)
32     , m_one_hot_axis(one_hot_axis)
33 {
34     constructor_validate_and_infer_types();
35 }
36
37 void op::v0::OneHot::validate_and_infer_types()
38 {
39     element::Type arg_et = get_input_element_type(0);
40     PartialShape arg_shape = get_input_partial_shape(0);
41     Rank arg_rank = arg_shape.rank();
42
43     NODE_VALIDATION_CHECK(this,
44                           arg_et.is_dynamic() || arg_et.is_integral(),
45                           "Argument does not have integral element type.");
46
47     NODE_VALIDATION_CHECK(
48         this, m_shape.rank().is_static(), "Requested result shape has dynamic rank.");
49
50     NODE_VALIDATION_CHECK(this,
51                           m_one_hot_axis < m_shape.rank().get_length(),
52                           "One-hot axis (",
53                           m_one_hot_axis,
54                           ") is out of bounds (requested result shape: ",
55                           m_shape,
56                           ").");
57
58     NODE_VALIDATION_CHECK(this,
59                           m_shape[m_one_hot_axis].is_static(),
60                           "Requested result shape (",
61                           m_shape,
62                           ") has dynamic dimension at the one-hot axis ",
63                           "(",
64                           m_one_hot_axis,
65                           ").");
66
67     PartialShape result_shape{m_shape};
68
69     if (arg_rank.is_static())
70     {
71         std::vector<Dimension> expected_input_dims(m_shape.rank().get_length());
72         for (size_t i = 0; i < m_shape.rank().get_length(); i++)
73         {
74             expected_input_dims[i] = m_shape[i];
75         }
76         expected_input_dims.erase(expected_input_dims.begin() + m_one_hot_axis);
77         PartialShape expected_input_shape{expected_input_dims};
78
79         PartialShape merged_input_shape{expected_input_shape};
80         NODE_VALIDATION_CHECK(this,
81                               PartialShape::merge_into(merged_input_shape, arg_shape),
82                               "Argument shape ",
83                               arg_shape,
84                               " does not match the expected shape of ",
85                               expected_input_shape,
86                               ".");
87
88         std::vector<Dimension> output_dims(merged_input_shape.rank().get_length());
89         for (size_t i = 0; i < merged_input_shape.rank().get_length(); i++)
90         {
91             output_dims[i] = merged_input_shape[i];
92         }
93         output_dims.insert(output_dims.begin() + m_one_hot_axis, m_shape[m_one_hot_axis]);
94         result_shape = PartialShape{output_dims};
95     }
96
97     set_output_type(0, arg_et, result_shape);
98 }
99
100 shared_ptr<Node> op::v0::OneHot::clone_with_new_inputs(const OutputVector& new_args) const
101 {
102     check_new_args_count(this, new_args);
103     return make_shared<v0::OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
104 }
105
106 constexpr NodeTypeInfo op::v1::OneHot::type_info;
107
108 op::v1::OneHot::OneHot(const Output<Node>& indices,
109                        const Output<Node>& depth,
110                        const Output<Node>& on_value,
111                        const Output<Node>& off_value,
112                        int64_t axis)
113     : Op({indices, depth, on_value, off_value})
114     , m_axis(axis)
115 {
116     constructor_validate_and_infer_types();
117 }
118
119 void op::v1::OneHot::validate_and_infer_types()
120 {
121     const auto& indices_et = get_input_element_type(0);
122     const auto& depth_et = get_input_element_type(1);
123     const auto& on_value_et = get_input_element_type(2);
124     const auto& off_value_et = get_input_element_type(3);
125
126     NODE_VALIDATION_CHECK(this,
127                           indices_et.is_dynamic() || indices_et.is_integral(),
128                           "Indices must be integral element type.");
129
130     NODE_VALIDATION_CHECK(this,
131                           depth_et.is_dynamic() || depth_et.is_integral(),
132                           "Depth must be integral element type.");
133
134     NODE_VALIDATION_CHECK(this,
135                           on_value_et.compatible(off_value_et),
136                           "on_value element type must be compatible with off_value element type.");
137
138     const auto& indices_shape = get_input_partial_shape(0);
139     const auto& depth_shape = get_input_partial_shape(1);
140     const auto& on_value_shape = get_input_partial_shape(2);
141     const auto& off_value_shape = get_input_partial_shape(3);
142
143     NODE_VALIDATION_CHECK(this,
144                           depth_shape.is_dynamic() || is_scalar(depth_shape.to_shape()),
145                           "depth input must be scalar.");
146
147     NODE_VALIDATION_CHECK(this,
148                           on_value_shape.is_dynamic() || is_scalar(on_value_shape.to_shape()),
149                           "on_value input must be scalar.");
150
151     NODE_VALIDATION_CHECK(this,
152                           off_value_shape.is_dynamic() || is_scalar(off_value_shape.to_shape()),
153                           "off_value input must be scalar.");
154
155     const auto& depth = input_value(1).get_node_shared_ptr();
156     PartialShape result_shape{PartialShape::dynamic()};
157
158     if (indices_shape.is_static() && indices_shape.rank().is_static() && op::is_constant(depth))
159     {
160         const auto indices_rank = indices_shape.rank().get_length();
161
162         std::vector<Dimension> out_dims(indices_rank);
163         for (auto i = 0; i < indices_rank; i++)
164         {
165             out_dims[i] = indices_shape[i];
166         }
167         m_axis =
168             ngraph::normalize_axis(this, m_axis, indices_rank + 1, -indices_rank - 1, indices_rank);
169
170         auto depth_element_type = depth->get_output_element_type(0);
171         NODE_VALIDATION_CHECK(this,
172                               depth_element_type.is_integral(),
173                               "'depth' input element type must be an integer (got ",
174                               depth_element_type,
175                               ").");
176
177         NODE_VALIDATION_CHECK(this,
178                               is_scalar(depth->get_shape()),
179                               "A scalar input should be provided as 'depth' to OneHot",
180                               " (got ",
181                               depth->get_shape(),
182                               " elements).");
183
184         const auto depth_constant = as_type_ptr<op::Constant>(depth);
185         int64_t depth_val = depth_constant->cast_vector<int64_t>()[0];
186
187         NODE_VALIDATION_CHECK(this,
188                               depth_val > 0,
189                               "The value of 'depth' must be a positive number.",
190                               " (got ",
191                               depth_val,
192                               ").");
193
194         out_dims.insert(out_dims.begin() + m_axis, Dimension(depth_val));
195         result_shape = out_dims;
196     }
197
198     set_output_type(0, on_value_et, result_shape);
199 }
200
201 bool ngraph::op::v1::OneHot::visit_attributes(AttributeVisitor& visitor)
202 {
203     visitor.on_attribute("axis", m_axis);
204     return true;
205 }
206
207 shared_ptr<Node> op::v1::OneHot::clone_with_new_inputs(const OutputVector& new_args) const
208 {
209     check_new_args_count(this, new_args);
210     return make_shared<v1::OneHot>(
211         new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_axis);
212 }