497515606b9c04bf19b2b9c689102d0ef0b705d3
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / CastLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "CastLayer.h"
18
19 namespace onert
20 {
21 namespace backend
22 {
23 namespace cpu
24 {
25 namespace ops
26 {
27
28 CastLayer::CastLayer() : _input(nullptr), _output(nullptr)
29 {
30   // DO NOTHING
31 }
32
33 void CastLayer::configure(const IPortableTensor *input, IPortableTensor *output)
34 {
35   _input = input;
36   _output = output;
37 }
38
39 template <typename FromT, typename ToT> void CastLayer::castTensor(const FromT *in, ToT *out)
40 {
41   auto input_shape = getTensorShape(_input);
42   auto output_shape = getTensorShape(_output);
43   const auto num_elements = MatchingFlatSize(input_shape, output_shape);
44
45   std::transform(in, in + num_elements, out, [](FromT a) { return static_cast<ToT>(a); });
46 }
47
48 template <typename FromT> void CastLayer::castPtr(const FromT *in, DataPtr out)
49 {
50   switch (_output->data_type())
51   {
52     case ir::DataType::FLOAT32:
53       castTensor(in, out.f);
54       return;
55     case ir::DataType::INT32:
56       castTensor(in, out.i32);
57       return;
58     case ir::DataType::UINT32:
59       castTensor(in, out.u32);
60       return;
61     case ir::DataType::UINT8:
62       castTensor(in, out.u8);
63       return;
64     case ir::DataType::BOOL8:
65       castTensor(in, out.b);
66       return;
67     case ir::DataType::INT64:
68       castTensor(in, out.i64);
69       return;
70     default:
71       throw std::runtime_error("Not supported output type" +
72                                std::to_string((int)_output->data_type()));
73   }
74 }
75
76 void CastLayer::run()
77 {
78   auto input_buf = _input->buffer();
79   auto output_buf = _output->buffer();
80   const auto in = *reinterpret_cast<const DataPtr *>(&input_buf);
81   auto out = *reinterpret_cast<DataPtr *>(&output_buf);
82
83   switch (_input->data_type())
84   {
85     case ir::DataType::FLOAT32:
86       castPtr(in.f, out);
87       return;
88     case ir::DataType::INT32:
89       castPtr(in.i32, out);
90       return;
91     case ir::DataType::UINT32:
92       castPtr(in.u32, out);
93       return;
94     case ir::DataType::UINT8:
95       castPtr(in.u8, out);
96       return;
97     case ir::DataType::BOOL8:
98       castPtr(in.b, out);
99       return;
100     case ir::DataType::INT64:
101       castPtr(in.i64, out);
102       return;
103     default:
104       throw std::runtime_error("Cast: unsupported data type" +
105                                std::to_string((int)_input->data_type()));
106   }
107 }
108
109 } // namespace ops
110 } // namespace cpu
111 } // namespace backend
112 } // namespace onert