2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ElementwiseUnaryLayer.h"
19 #include "OperationUtils.h"
21 #include <cker/operation/Elementwise.h>
22 #include <cker/operation/Erf.h>
23 #include <cker/operation/Exp.h>
24 #include <cker/operation/LogicalNot.h>
25 #include <cker/operation/Quantize.h>
26 #include <cker/operation/Round.h>
39 void absFloat32(const IPortableTensor *input, IPortableTensor *output)
41 nnfw::cker::Abs(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
42 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
45 template <typename FromT>
46 void castPtr(const FromT *in, DataPtr out, int num_elements, ir::DataType data_type_out)
48 switch (data_type_out)
50 case ir::DataType::FLOAT32:
51 std::transform(in, in + num_elements, out.f, [](FromT a) { return static_cast<float>(a); });
53 case ir::DataType::INT32:
54 std::transform(in, in + num_elements, out.i32,
55 [](FromT a) { return static_cast<int32_t>(a); });
57 case ir::DataType::UINT32:
58 std::transform(in, in + num_elements, out.u32,
59 [](FromT a) { return static_cast<uint32_t>(a); });
61 case ir::DataType::UINT8:
62 std::transform(in, in + num_elements, out.u8,
63 [](FromT a) { return static_cast<uint8_t>(a); });
65 case ir::DataType::BOOL8:
66 std::transform(in, in + num_elements, out.b, [](FromT a) { return static_cast<bool>(a); });
68 case ir::DataType::INT64:
69 std::transform(in, in + num_elements, out.i64,
70 [](FromT a) { return static_cast<int64_t>(a); });
73 throw std::runtime_error("Cast: Not supported output type" +
74 std::to_string((int)data_type_out));
78 void cast(const IPortableTensor *input, IPortableTensor *output)
80 auto input_buf = input->buffer();
81 auto output_buf = output->buffer();
82 const auto in = *reinterpret_cast<const DataPtr *>(&input_buf);
83 auto out = *reinterpret_cast<DataPtr *>(&output_buf);
85 auto input_shape = getTensorShape(input);
86 auto output_shape = getTensorShape(output);
87 const auto num_elements = MatchingFlatSize(input_shape, output_shape);
89 switch (input->data_type())
91 case ir::DataType::FLOAT32:
92 castPtr(in.f, out, num_elements, output->data_type());
94 case ir::DataType::INT32:
95 castPtr(in.i32, out, num_elements, output->data_type());
97 case ir::DataType::UINT32:
98 castPtr(in.u32, out, num_elements, output->data_type());
100 case ir::DataType::UINT8:
101 castPtr(in.u8, out, num_elements, output->data_type());
103 case ir::DataType::BOOL8:
104 castPtr(in.b, out, num_elements, output->data_type());
106 case ir::DataType::INT64:
107 castPtr(in.i64, out, num_elements, output->data_type());
110 throw std::runtime_error("Cast: unsupported data type" +
111 std::to_string((int)input->data_type()));
115 void cosFloat32(const IPortableTensor *input, IPortableTensor *output)
117 nnfw::cker::Cos(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
118 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
121 void expFloat32(const IPortableTensor *input, IPortableTensor *output)
123 nnfw::cker::Exp(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
124 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
127 void erfFloat32(const IPortableTensor *input, IPortableTensor *output)
129 nnfw::cker::Erf(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
130 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
133 void logFloat32(const IPortableTensor *input, IPortableTensor *output)
135 nnfw::cker::Log(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
136 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
139 void logicalNot(const IPortableTensor *input, IPortableTensor *output)
141 nnfw::cker::LogicalNot(getTensorShape(input), reinterpret_cast<const bool *>(input->buffer()),
142 getTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
145 void negFloat32(const IPortableTensor *input, IPortableTensor *output)
147 nnfw::cker::Neg(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
148 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
151 template <typename InputT, typename OutputT>
152 void affineQuantize(const IPortableTensor *input, IPortableTensor *output)
154 nnfw::cker::Quantize(getTensorShape(input), reinterpret_cast<const InputT *>(input->buffer()),
155 getTensorShape(output), reinterpret_cast<OutputT *>(output->buffer()),
156 output->data_scale(), output->data_offset());
159 void roundFloat32(const IPortableTensor *input, IPortableTensor *output)
161 nnfw::cker::Round(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
162 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
165 void rsqrtFloat32(const IPortableTensor *input, IPortableTensor *output)
167 nnfw::cker::Rsqrt(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
168 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
171 void sinFloat32(const IPortableTensor *input, IPortableTensor *output)
173 nnfw::cker::Sin(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
174 getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
177 template <typename T> void zerosLikeFloat32(const IPortableTensor *input, IPortableTensor *output)
179 if (!HaveSameShapes(input, output))
180 throw std::runtime_error{"ZerosLike: input and output shape don't match."};
182 auto element_size = getTensorShape(input).FlatSize();
184 memset(reinterpret_cast<T *>(output->buffer()), 0, element_size * sizeof(T));
188 void ElementwiseUnaryLayer::configure(const IPortableTensor *input, IPortableTensor *output,
189 const ElementwiseUnaryType op_type)
191 assert(input != nullptr);
192 assert(output != nullptr);
199 case ElementwiseUnaryType::kAbs:
200 if ((input->data_type() == OperandType::FLOAT32))
202 _kernel = absFloat32;
206 throw std::runtime_error{"Abs: Unsupported data type"};
209 case ElementwiseUnaryType::kCast:
212 case ElementwiseUnaryType::kCos:
213 if ((input->data_type() == OperandType::FLOAT32))
215 _kernel = cosFloat32;
219 throw std::runtime_error{"Cos: Unsupported data type"};
222 case ElementwiseUnaryType::kExp:
223 if ((input->data_type() == OperandType::FLOAT32))
225 _kernel = expFloat32;
229 throw std::runtime_error{"Exp: Unsupported data type"};
232 case ElementwiseUnaryType::kErf:
233 if ((input->data_type() == OperandType::FLOAT32))
235 _kernel = erfFloat32;
239 throw std::runtime_error{"Exp: Unsupported data type"};
242 case ElementwiseUnaryType::kLog:
243 if ((input->data_type() == OperandType::FLOAT32))
245 _kernel = logFloat32;
249 throw std::runtime_error{"Log: Unsupported data type"};
252 case ElementwiseUnaryType::kLogicalNot:
253 if ((input->data_type() == OperandType::BOOL8))
255 _kernel = logicalNot;
259 throw std::runtime_error{"LogicalNot: Unsupported data type"};
262 case ElementwiseUnaryType::kNeg:
263 if ((input->data_type() == OperandType::FLOAT32))
265 _kernel = negFloat32;
269 throw std::runtime_error{"Neg: Unsupported data type"};
272 case ElementwiseUnaryType::kQuantize:
273 if ((input->data_type() == OperandType::FLOAT32))
275 _kernel = affineQuantize<float, uint8_t>;
279 throw std::runtime_error{"Quantize: Unsupported data type"};
282 case ElementwiseUnaryType::kRound:
283 if ((input->data_type() == OperandType::FLOAT32))
285 _kernel = roundFloat32;
289 throw std::runtime_error{"Round: Unsupported data type"};
292 case ElementwiseUnaryType::kRSqrt:
293 if ((input->data_type() == OperandType::FLOAT32))
295 _kernel = rsqrtFloat32;
299 throw std::runtime_error{"RSqrt: Unsupported data type"};
302 case ElementwiseUnaryType::kSin:
303 if ((input->data_type() == OperandType::FLOAT32))
305 _kernel = sinFloat32;
309 throw std::runtime_error{"Sin: Unsupported data type"};
312 case ElementwiseUnaryType::kZerosLike:
313 if (input->data_type() == OperandType::FLOAT32)
315 _kernel = zerosLikeFloat32<float>;
317 else if (input->data_type() == OperandType::INT32)
319 _kernel = zerosLikeFloat32<int32_t>;
323 throw std::runtime_error{"ZerosLike: Unsupported data type"};
327 throw std::runtime_error{"ElementwiseBinary: Unsupported ElementwiseBinary type"};
331 void ElementwiseUnaryLayer::run() { _kernel(_input, _output); }
335 } // namespace backend