13e28071440e31311ec8c5e02a5d85bedd8e4360
[platform/upstream/dldt.git] / ngraph / core / src / op / convert.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <memory>
18
19 #include "itt.hpp"
20 #include "ngraph/op/convert.hpp"
21 #include "ngraph/runtime/reference/convert.hpp"
22
23 using namespace std;
24 using namespace ngraph;
25
26 constexpr NodeTypeInfo op::Convert::type_info;
27
28 op::Convert::Convert(const Output<Node>& arg, const element::Type& destination_type)
29     : Op({arg})
30     , m_destination_type(destination_type)
31 {
32     constructor_validate_and_infer_types();
33 }
34
35 void op::Convert::validate_and_infer_types()
36 {
37     set_output_type(0, m_destination_type, get_input_partial_shape(0));
38 }
39
40 bool op::Convert::visit_attributes(AttributeVisitor& visitor)
41 {
42     visitor.on_attribute("destination_type", m_destination_type);
43     return true;
44 }
45
46 shared_ptr<Node> op::Convert::clone_with_new_inputs(const OutputVector& new_args) const
47 {
48     check_new_args_count(this, new_args);
49     return make_shared<Convert>(new_args.at(0), m_destination_type);
50 }
51
52 namespace
53 {
54     template <element::Type_t INPUT_ET, element::Type_t OUTPUT_ET>
55     bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out)
56
57     {
58         out->set_shape(arg->get_shape());
59         size_t element_count = shape_size(out->get_shape());
60         return (INPUT_ET == arg->get_element_type()) && OUTPUT_ET == out->get_element_type() &&
61                (runtime::reference::convert(
62                     arg->get_data_ptr<INPUT_ET>(), out->get_data_ptr<OUTPUT_ET>(), element_count),
63                 true);
64     }
65
66 #define TYPE_OUT_CASE(a)                                                                           \
67     case element::Type_t::a: rc = evaluate<INPUT_ET, element::Type_t::a>
68
69     template <element::Type_t INPUT_ET>
70     bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out)
71     {
72         bool rc = true;
73
74         switch (out->get_element_type())
75         {
76             TYPE_OUT_CASE(i8)(arg, out);
77             break;
78             TYPE_OUT_CASE(i16)(arg, out);
79             break;
80             TYPE_OUT_CASE(i32)(arg, out);
81             break;
82             TYPE_OUT_CASE(i64)(arg, out);
83             break;
84             TYPE_OUT_CASE(u8)(arg, out);
85             break;
86             TYPE_OUT_CASE(u16)(arg, out);
87             break;
88             TYPE_OUT_CASE(u32)(arg, out);
89             break;
90             TYPE_OUT_CASE(u64)(arg, out);
91             break;
92             TYPE_OUT_CASE(bf16)(arg, out);
93             break;
94             TYPE_OUT_CASE(f16)(arg, out);
95             break;
96             TYPE_OUT_CASE(f32)(arg, out);
97             break;
98             TYPE_OUT_CASE(f64)(arg, out);
99             break;
100         default: rc = false; break;
101         }
102         return rc;
103     }
104
105     bool evaluate_convert(const HostTensorPtr& arg, const HostTensorPtr& out)
106     {
107         bool rc = true;
108
109         switch (arg->get_element_type())
110         {
111             TYPE_CASE(i32)(arg, out);
112             break;
113             TYPE_CASE(i64)(arg, out);
114             break;
115             TYPE_CASE(u32)(arg, out);
116             break;
117             TYPE_CASE(u64)(arg, out);
118             break;
119             TYPE_CASE(f16)(arg, out);
120             break;
121             TYPE_CASE(f32)(arg, out);
122             break;
123         default: rc = false; break;
124         }
125         return rc;
126     }
127 }
128 bool op::v0::Convert::evaluate(const HostTensorVector& output_values,
129                                const HostTensorVector& input_values) const
130 {
131     OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Convert::evaluate");
132     return evaluate_convert(input_values[0], output_values[0]);
133 }