2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ElementwiseUnaryLayer.h"
19 #include "OperationUtils.h"
21 #include <cker/operation/Dequantize.h>
22 #include <cker/operation/Elementwise.h>
23 #include <cker/operation/Erf.h>
24 #include <cker/operation/Exp.h>
25 #include <cker/operation/LogicalNot.h>
26 #include <cker/operation/Round.h>
39 void absFloat32(const IPortableTensor *input, IPortableTensor *output)
41 nnfw::cker::Abs(getShape(input), getBuffer<float>(input), getShape(output),
42 getBuffer<float>(output));
45 template <typename FromT>
46 void castPtr(const FromT *in, DataPtr out, int num_elements, ir::DataType data_type_out)
48 switch (data_type_out)
50 case ir::DataType::FLOAT32:
51 std::transform(in, in + num_elements, out.f, [](FromT a) { return static_cast<float>(a); });
53 case ir::DataType::INT32:
54 std::transform(in, in + num_elements, out.i32,
55 [](FromT a) { return static_cast<int32_t>(a); });
57 case ir::DataType::UINT32:
58 std::transform(in, in + num_elements, out.u32,
59 [](FromT a) { return static_cast<uint32_t>(a); });
61 case ir::DataType::UINT8:
62 std::transform(in, in + num_elements, out.u8,
63 [](FromT a) { return static_cast<uint8_t>(a); });
65 case ir::DataType::BOOL8:
66 std::transform(in, in + num_elements, out.b, [](FromT a) { return static_cast<bool>(a); });
68 case ir::DataType::INT64:
69 std::transform(in, in + num_elements, out.i64,
70 [](FromT a) { return static_cast<int64_t>(a); });
73 throw std::runtime_error("Cast: Not supported output type" +
74 std::to_string((int)data_type_out));
78 void cast(const IPortableTensor *input, IPortableTensor *output)
80 auto input_buf = input->buffer();
81 auto output_buf = output->buffer();
82 const auto in = *reinterpret_cast<const DataPtr *>(&input_buf);
83 auto out = *reinterpret_cast<DataPtr *>(&output_buf);
85 auto input_shape = getShape(input);
86 auto output_shape = getShape(output);
87 const auto num_elements = MatchingFlatSize(input_shape, output_shape);
89 switch (input->data_type())
91 case ir::DataType::FLOAT32:
92 castPtr(in.f, out, num_elements, output->data_type());
94 case ir::DataType::INT32:
95 castPtr(in.i32, out, num_elements, output->data_type());
97 case ir::DataType::UINT32:
98 castPtr(in.u32, out, num_elements, output->data_type());
100 case ir::DataType::UINT8:
101 castPtr(in.u8, out, num_elements, output->data_type());
103 case ir::DataType::BOOL8:
104 castPtr(in.b, out, num_elements, output->data_type());
106 case ir::DataType::INT64:
107 castPtr(in.i64, out, num_elements, output->data_type());
110 throw std::runtime_error("Cast: unsupported data type" +
111 std::to_string((int)input->data_type()));
115 void cosFloat32(const IPortableTensor *input, IPortableTensor *output)
117 nnfw::cker::Cos(getShape(input), getBuffer<float>(input), getShape(output),
118 getBuffer<float>(output));
121 void dequantizeInt8(const IPortableTensor *input, IPortableTensor *output)
123 nnfw::cker::Dequantize(getShape(input), getBuffer<int8_t>(input), getShape(output),
124 getBuffer<float>(output), input->data_scale(), input->data_zero_point());
127 void dequantizeUint8(const IPortableTensor *input, IPortableTensor *output)
129 nnfw::cker::Dequantize(getShape(input), getBuffer<uint8_t>(input), getShape(output),
130 getBuffer<float>(output), input->data_scale(), input->data_zero_point());
133 void expFloat32(const IPortableTensor *input, IPortableTensor *output)
135 nnfw::cker::Exp(getShape(input), getBuffer<float>(input), getShape(output),
136 getBuffer<float>(output));
139 void erfFloat32(const IPortableTensor *input, IPortableTensor *output)
141 nnfw::cker::Erf(getShape(input), getBuffer<float>(input), getShape(output),
142 getBuffer<float>(output));
145 void floorFloat32(const IPortableTensor *input, IPortableTensor *output)
147 nnfw::cker::Floor(getShape(input), getBuffer<float>(input), getShape(output),
148 getBuffer<float>(output));
151 void logFloat32(const IPortableTensor *input, IPortableTensor *output)
153 nnfw::cker::Log(getShape(input), getBuffer<float>(input), getShape(output),
154 getBuffer<float>(output));
157 void logicalNot(const IPortableTensor *input, IPortableTensor *output)
159 nnfw::cker::LogicalNot(getShape(input), getBuffer<bool>(input), getShape(output),
160 getBuffer<bool>(output));
163 template <typename T> void neg(const IPortableTensor *input, IPortableTensor *output)
165 nnfw::cker::Neg<T>(getShape(input), getBuffer<T>(input), getShape(output), getBuffer<T>(output));
168 void roundFloat32(const IPortableTensor *input, IPortableTensor *output)
170 nnfw::cker::Round(getShape(input), getBuffer<float>(input), getShape(output),
171 getBuffer<float>(output));
174 void rsqrtFloat32(const IPortableTensor *input, IPortableTensor *output)
176 nnfw::cker::Rsqrt(getShape(input), getBuffer<float>(input), getShape(output),
177 getBuffer<float>(output));
180 void sinFloat32(const IPortableTensor *input, IPortableTensor *output)
182 nnfw::cker::Sin(getShape(input), getBuffer<float>(input), getShape(output),
183 getBuffer<float>(output));
186 void sqrtFloat32(const IPortableTensor *input, IPortableTensor *output)
188 nnfw::cker::Sqrt(getShape(input), getBuffer<float>(input), getShape(output),
189 getBuffer<float>(output));
192 void squareFloat32(const IPortableTensor *input, IPortableTensor *output)
194 nnfw::cker::Square(getShape(input), getBuffer<float>(input), getShape(output),
195 getBuffer<float>(output));
198 template <typename T> void zerosLikeFloat32(const IPortableTensor *input, IPortableTensor *output)
200 if (!HaveSameShapes(input, output))
201 throw std::runtime_error{"ZerosLike: input and output shape don't match."};
203 auto element_size = getShape(input).FlatSize();
205 memset(getBuffer<T>(output), 0, element_size * sizeof(T));
209 void ElementwiseUnaryLayer::configure(const IPortableTensor *input, IPortableTensor *output,
210 const ElementwiseUnaryType op_type)
212 assert(input != nullptr);
213 assert(output != nullptr);
220 case ElementwiseUnaryType::kAbs:
221 if ((input->data_type() == OperandType::FLOAT32))
223 _kernel = absFloat32;
227 throw std::runtime_error{"Abs: Unsupported data type"};
230 case ElementwiseUnaryType::kCast:
233 case ElementwiseUnaryType::kCos:
234 if ((input->data_type() == OperandType::FLOAT32))
236 _kernel = cosFloat32;
240 throw std::runtime_error{"Cos: Unsupported data type"};
243 case ElementwiseUnaryType::kDequantize:
244 if ((input->data_type() == OperandType::QUANT_UINT8_ASYMM))
246 _kernel = dequantizeUint8;
248 else if ((input->data_type() == OperandType::QUANT_INT8_ASYMM) ||
249 (input->data_type() == OperandType::QUANT_INT8_SYMM))
251 _kernel = dequantizeInt8;
255 throw std::runtime_error{"Dequantize: Unsupported data type"};
258 case ElementwiseUnaryType::kExp:
259 if ((input->data_type() == OperandType::FLOAT32))
261 _kernel = expFloat32;
265 throw std::runtime_error{"Exp: Unsupported data type"};
268 case ElementwiseUnaryType::kErf:
269 if ((input->data_type() == OperandType::FLOAT32))
271 _kernel = erfFloat32;
275 throw std::runtime_error{"Exp: Unsupported data type"};
278 case ElementwiseUnaryType::kFloor:
279 if ((input->data_type() == OperandType::FLOAT32))
281 _kernel = floorFloat32;
285 throw std::runtime_error{"Floor: Unsupported data type"};
288 case ElementwiseUnaryType::kLog:
289 if ((input->data_type() == OperandType::FLOAT32))
291 _kernel = logFloat32;
295 throw std::runtime_error{"Log: Unsupported data type"};
298 case ElementwiseUnaryType::kLogicalNot:
299 if ((input->data_type() == OperandType::BOOL8))
301 _kernel = logicalNot;
305 throw std::runtime_error{"LogicalNot: Unsupported data type"};
308 case ElementwiseUnaryType::kNeg:
309 if ((input->data_type() == OperandType::FLOAT32))
311 _kernel = neg<float>;
313 else if ((input->data_type() == OperandType::INT64))
315 _kernel = neg<int64_t>;
317 else if ((input->data_type() == OperandType::INT32))
319 _kernel = neg<int32_t>;
323 throw std::runtime_error{"Neg: Unsupported data type"};
326 case ElementwiseUnaryType::kRound:
327 if ((input->data_type() == OperandType::FLOAT32))
329 _kernel = roundFloat32;
333 throw std::runtime_error{"Round: Unsupported data type"};
336 case ElementwiseUnaryType::kRSqrt:
337 if ((input->data_type() == OperandType::FLOAT32))
339 _kernel = rsqrtFloat32;
343 throw std::runtime_error{"RSqrt: Unsupported data type"};
346 case ElementwiseUnaryType::kSin:
347 if ((input->data_type() == OperandType::FLOAT32))
349 _kernel = sinFloat32;
353 throw std::runtime_error{"Sin: Unsupported data type"};
356 case ElementwiseUnaryType::kSqrt:
357 if ((input->data_type() == OperandType::FLOAT32))
359 _kernel = sqrtFloat32;
363 throw std::runtime_error{"Sqrt: Unsupported data type"};
366 case ElementwiseUnaryType::kSquare:
367 if ((input->data_type() == OperandType::FLOAT32))
369 _kernel = squareFloat32;
373 throw std::runtime_error{"Square: Unsupported data type"};
376 case ElementwiseUnaryType::kZerosLike:
377 if (input->data_type() == OperandType::FLOAT32)
379 _kernel = zerosLikeFloat32<float>;
381 else if (input->data_type() == OperandType::INT32)
383 _kernel = zerosLikeFloat32<int32_t>;
387 throw std::runtime_error{"ZerosLike: Unsupported data type"};
391 throw std::runtime_error{"ElementwiseUnary: Unsupported ElementwiseUnary type"};
395 void ElementwiseUnaryLayer::run() { _kernel(_input, _output); }
399 } // namespace backend