Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / kernel / CastLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "CastLayer.h"
18
19 namespace onert
20 {
21 namespace backend
22 {
23 namespace cpu
24 {
25 namespace kernel
26 {
27
28 CastLayer::CastLayer() : _input(nullptr), _output(nullptr)
29 {
30   // DO NOTHING
31 }
32
33 void CastLayer::configure(const operand::Tensor *input, operand::Tensor *output)
34 {
35   _input = input;
36   _output = output;
37 }
38
39 template <typename FromT, typename ToT> void CastLayer::castTensor(const FromT *in, ToT *out)
40 {
41   auto input_shape = convertTensorToCkerShape(_input);
42   auto output_shape = convertTensorToCkerShape(_output);
43   const auto num_elements = MatchingFlatSize(input_shape, output_shape);
44
45   std::transform(in, in + num_elements, out, [](FromT a) { return static_cast<ToT>(a); });
46 }
47
48 template <typename FromT> void CastLayer::castPtr(const FromT *in, DataPtr out)
49 {
50   switch (_output->data_type())
51   {
52     case ir::DataType::FLOAT32:
53       castTensor(in, out.f);
54       return;
55     case ir::DataType::INT32:
56       castTensor(in, out.i32);
57       return;
58     case ir::DataType::UINT32:
59       castTensor(in, out.u32);
60       return;
61     case ir::DataType::UINT8:
62       castTensor(in, out.u8);
63       return;
64     case ir::DataType::BOOL8:
65       castTensor(in, out.b);
66       return;
67     case ir::DataType::QUANT8_ASYMM:
68     case ir::DataType::QUANT8_SYMM:
69       throw std::runtime_error("Not supported output type" +
70                                std::to_string((int)_output->data_type()));
71   }
72 }
73
74 void CastLayer::run()
75 {
76   auto input_buf = _input->buffer();
77   auto output_buf = _output->buffer();
78   const auto in = *reinterpret_cast<const DataPtr *>(&input_buf);
79   auto out = *reinterpret_cast<DataPtr *>(&output_buf);
80
81   switch (_input->data_type())
82   {
83     case ir::DataType::FLOAT32:
84       castPtr(in.f, out);
85       return;
86     case ir::DataType::INT32:
87       castPtr(in.i32, out);
88       return;
89     case ir::DataType::UINT32:
90       castPtr(in.u32, out);
91       return;
92     case ir::DataType::UINT8:
93       castPtr(in.u8, out);
94       return;
95     case ir::DataType::BOOL8:
96       castPtr(in.b, out);
97       return;
98     case ir::DataType::QUANT8_ASYMM:
99     case ir::DataType::QUANT8_SYMM:
100       throw std::runtime_error("Not supported input type" +
101                                std::to_string((int)_input->data_type()));
102   }
103 }
104
105 } // namespace kernel
106 } // namespace cpu
107 } // namespace backend
108 } // namespace onert