1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "ngraph/op/round.hpp"
19 #include "ngraph/op/util/eval_copy.hpp"
20 #include "ngraph/runtime/host_tensor.hpp"
21 #include "ngraph/runtime/reference/copy.hpp"
22 #include "ngraph/runtime/reference/round.hpp"
24 NGRAPH_SUPPRESS_DEPRECATED_START
27 using namespace ngraph;
29 constexpr NodeTypeInfo op::Round::type_info;
31 op::Round::Round(const Output<Node>& arg)
32 : UnaryElementwiseArithmetic(arg)
34 constructor_validate_and_infer_types();
37 shared_ptr<Node> op::Round::clone_with_new_inputs(const OutputVector& new_args) const
39 check_new_args_count(this, new_args);
40 return make_shared<Round>(new_args.at(0));
45 // function used by TYPE_CASE
46 template <element::Type_t ET>
47 inline bool evaluate(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
49 using T = typename element_type_traits<ET>::value_type;
50 runtime::reference::round<T>(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
54 // function used by COPY_TENSOR
55 template <element::Type_t ET>
56 inline bool copy_tensor(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
58 runtime::reference::copy(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
62 bool evaluate_round(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
67 switch (arg0->get_element_type())
69 COPY_TENSOR(boolean)(arg0, out, count);
71 COPY_TENSOR(i8)(arg0, out, count);
73 COPY_TENSOR(i16)(arg0, out, count);
75 COPY_TENSOR(i32)(arg0, out, count);
77 COPY_TENSOR(i64)(arg0, out, count);
79 COPY_TENSOR(u8)(arg0, out, count);
81 COPY_TENSOR(u16)(arg0, out, count);
83 COPY_TENSOR(u32)(arg0, out, count);
85 COPY_TENSOR(u64)(arg0, out, count);
87 TYPE_CASE(f16)(arg0, out, count);
89 TYPE_CASE(f32)(arg0, out, count);
91 default: rc = false; break;
97 bool op::Round::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
99 OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::Round::evaluate");
100 return evaluate_round(inputs[0], outputs[0], shape_size(get_output_shape(0)));