1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
20 #include "ngraph/op/convert.hpp"
21 #include "ngraph/runtime/reference/convert.hpp"
24 using namespace ngraph;
26 constexpr NodeTypeInfo op::Convert::type_info;
28 op::Convert::Convert(const Output<Node>& arg, const element::Type& destination_type)
30 , m_destination_type(destination_type)
32 constructor_validate_and_infer_types();
35 void op::Convert::validate_and_infer_types()
37 set_output_type(0, m_destination_type, get_input_partial_shape(0));
40 bool op::Convert::visit_attributes(AttributeVisitor& visitor)
42 visitor.on_attribute("destination_type", m_destination_type);
46 shared_ptr<Node> op::Convert::clone_with_new_inputs(const OutputVector& new_args) const
48 check_new_args_count(this, new_args);
49 return make_shared<Convert>(new_args.at(0), m_destination_type);
54 template <element::Type_t INPUT_ET, element::Type_t OUTPUT_ET>
55 bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out)
58 out->set_shape(arg->get_shape());
59 size_t element_count = shape_size(out->get_shape());
60 return (INPUT_ET == arg->get_element_type()) && OUTPUT_ET == out->get_element_type() &&
61 (runtime::reference::convert(
62 arg->get_data_ptr<INPUT_ET>(), out->get_data_ptr<OUTPUT_ET>(), element_count),
66 #define TYPE_OUT_CASE(a) \
67 case element::Type_t::a: rc = evaluate<INPUT_ET, element::Type_t::a>
69 template <element::Type_t INPUT_ET>
70 bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out)
74 switch (out->get_element_type())
76 TYPE_OUT_CASE(i8)(arg, out);
78 TYPE_OUT_CASE(i16)(arg, out);
80 TYPE_OUT_CASE(i32)(arg, out);
82 TYPE_OUT_CASE(i64)(arg, out);
84 TYPE_OUT_CASE(u8)(arg, out);
86 TYPE_OUT_CASE(u16)(arg, out);
88 TYPE_OUT_CASE(u32)(arg, out);
90 TYPE_OUT_CASE(u64)(arg, out);
92 TYPE_OUT_CASE(bf16)(arg, out);
94 TYPE_OUT_CASE(f16)(arg, out);
96 TYPE_OUT_CASE(f32)(arg, out);
98 TYPE_OUT_CASE(f64)(arg, out);
100 default: rc = false; break;
105 bool evaluate_convert(const HostTensorPtr& arg, const HostTensorPtr& out)
109 switch (arg->get_element_type())
111 TYPE_CASE(i32)(arg, out);
113 TYPE_CASE(i64)(arg, out);
115 TYPE_CASE(u32)(arg, out);
117 TYPE_CASE(u64)(arg, out);
119 TYPE_CASE(f16)(arg, out);
121 TYPE_CASE(f32)(arg, out);
123 default: rc = false; break;
128 bool op::v0::Convert::evaluate(const HostTensorVector& output_values,
129 const HostTensorVector& input_values) const
131 OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Convert::evaluate");
132 return evaluate_convert(input_values[0], output_values[0]);