Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / ElementwiseUnaryLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ElementwiseUnaryLayer.h"
18
19 #include "OperationUtils.h"
20
21 #include <cker/operation/Dequantize.h>
22 #include <cker/operation/Elementwise.h>
23 #include <cker/operation/Erf.h>
24 #include <cker/operation/Exp.h>
25 #include <cker/operation/LogicalNot.h>
26 #include <cker/operation/Quantize.h>
27 #include <cker/operation/Round.h>
28
29 namespace onert
30 {
31 namespace backend
32 {
33 namespace cpu
34 {
35 namespace ops
36 {
37
38 namespace
39 {
40 void absFloat32(const IPortableTensor *input, IPortableTensor *output)
41 {
42   nnfw::cker::Abs(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
43                   getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
44 }
45
46 template <typename FromT>
47 void castPtr(const FromT *in, DataPtr out, int num_elements, ir::DataType data_type_out)
48 {
49   switch (data_type_out)
50   {
51     case ir::DataType::FLOAT32:
52       std::transform(in, in + num_elements, out.f, [](FromT a) { return static_cast<float>(a); });
53       return;
54     case ir::DataType::INT32:
55       std::transform(in, in + num_elements, out.i32,
56                      [](FromT a) { return static_cast<int32_t>(a); });
57       return;
58     case ir::DataType::UINT32:
59       std::transform(in, in + num_elements, out.u32,
60                      [](FromT a) { return static_cast<uint32_t>(a); });
61       return;
62     case ir::DataType::UINT8:
63       std::transform(in, in + num_elements, out.u8,
64                      [](FromT a) { return static_cast<uint8_t>(a); });
65       return;
66     case ir::DataType::BOOL8:
67       std::transform(in, in + num_elements, out.b, [](FromT a) { return static_cast<bool>(a); });
68       return;
69     case ir::DataType::INT64:
70       std::transform(in, in + num_elements, out.i64,
71                      [](FromT a) { return static_cast<int64_t>(a); });
72       return;
73     default:
74       throw std::runtime_error("Cast: Not supported output type" +
75                                std::to_string((int)data_type_out));
76   }
77 }
78
79 void cast(const IPortableTensor *input, IPortableTensor *output)
80 {
81   auto input_buf = input->buffer();
82   auto output_buf = output->buffer();
83   const auto in = *reinterpret_cast<const DataPtr *>(&input_buf);
84   auto out = *reinterpret_cast<DataPtr *>(&output_buf);
85
86   auto input_shape = getTensorShape(input);
87   auto output_shape = getTensorShape(output);
88   const auto num_elements = MatchingFlatSize(input_shape, output_shape);
89
90   switch (input->data_type())
91   {
92     case ir::DataType::FLOAT32:
93       castPtr(in.f, out, num_elements, output->data_type());
94       return;
95     case ir::DataType::INT32:
96       castPtr(in.i32, out, num_elements, output->data_type());
97       return;
98     case ir::DataType::UINT32:
99       castPtr(in.u32, out, num_elements, output->data_type());
100       return;
101     case ir::DataType::UINT8:
102       castPtr(in.u8, out, num_elements, output->data_type());
103       return;
104     case ir::DataType::BOOL8:
105       castPtr(in.b, out, num_elements, output->data_type());
106       return;
107     case ir::DataType::INT64:
108       castPtr(in.i64, out, num_elements, output->data_type());
109       return;
110     default:
111       throw std::runtime_error("Cast: unsupported data type" +
112                                std::to_string((int)input->data_type()));
113   }
114 }
115
116 void cosFloat32(const IPortableTensor *input, IPortableTensor *output)
117 {
118   nnfw::cker::Cos(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
119                   getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
120 }
121
122 void dequantizeInt8(const IPortableTensor *input, IPortableTensor *output)
123 {
124   nnfw::cker::Dequantize(getTensorShape(input), reinterpret_cast<const int8_t *>(input->buffer()),
125                          getTensorShape(output), reinterpret_cast<float *>(output->buffer()),
126                          input->data_scale(), input->data_offset());
127 }
128
129 void dequantizeUint8(const IPortableTensor *input, IPortableTensor *output)
130 {
131   nnfw::cker::Dequantize(getTensorShape(input), reinterpret_cast<const uint8_t *>(input->buffer()),
132                          getTensorShape(output), reinterpret_cast<float *>(output->buffer()),
133                          input->data_scale(), input->data_offset());
134 }
135
136 void expFloat32(const IPortableTensor *input, IPortableTensor *output)
137 {
138   nnfw::cker::Exp(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
139                   getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
140 }
141
142 void erfFloat32(const IPortableTensor *input, IPortableTensor *output)
143 {
144   nnfw::cker::Erf(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
145                   getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
146 }
147
148 void floorFloat32(const IPortableTensor *input, IPortableTensor *output)
149 {
150   nnfw::cker::Floor(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
151                     getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
152 }
153
154 void logFloat32(const IPortableTensor *input, IPortableTensor *output)
155 {
156   nnfw::cker::Log(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
157                   getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
158 }
159
160 void logicalNot(const IPortableTensor *input, IPortableTensor *output)
161 {
162   nnfw::cker::LogicalNot(getTensorShape(input), reinterpret_cast<const bool *>(input->buffer()),
163                          getTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
164 }
165
166 template <typename T> void neg(const IPortableTensor *input, IPortableTensor *output)
167 {
168   nnfw::cker::Neg<T>(getTensorShape(input), reinterpret_cast<const T *>(input->buffer()),
169                      getTensorShape(output), reinterpret_cast<T *>(output->buffer()));
170 }
171
172 template <typename InputT, typename OutputT>
173 void affineQuantize(const IPortableTensor *input, IPortableTensor *output)
174 {
175   nnfw::cker::Quantize(getTensorShape(input), reinterpret_cast<const InputT *>(input->buffer()),
176                        getTensorShape(output), reinterpret_cast<OutputT *>(output->buffer()),
177                        output->data_scale(), output->data_offset());
178 }
179
180 void roundFloat32(const IPortableTensor *input, IPortableTensor *output)
181 {
182   nnfw::cker::Round(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
183                     getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
184 }
185
186 void rsqrtFloat32(const IPortableTensor *input, IPortableTensor *output)
187 {
188   nnfw::cker::Rsqrt(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
189                     getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
190 }
191
192 void sinFloat32(const IPortableTensor *input, IPortableTensor *output)
193 {
194   nnfw::cker::Sin(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
195                   getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
196 }
197
198 void sqrtFloat32(const IPortableTensor *input, IPortableTensor *output)
199 {
200   nnfw::cker::Sqrt(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
201                    getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
202 }
203
204 void squareFloat32(const IPortableTensor *input, IPortableTensor *output)
205 {
206   nnfw::cker::Square(getTensorShape(input), reinterpret_cast<const float *>(input->buffer()),
207                      getTensorShape(output), reinterpret_cast<float *>(output->buffer()));
208 }
209
210 template <typename T> void zerosLikeFloat32(const IPortableTensor *input, IPortableTensor *output)
211 {
212   if (!HaveSameShapes(input, output))
213     throw std::runtime_error{"ZerosLike: input and output shape don't match."};
214
215   auto element_size = getTensorShape(input).FlatSize();
216
217   memset(reinterpret_cast<T *>(output->buffer()), 0, element_size * sizeof(T));
218 }
219 } // namespace
220
221 void ElementwiseUnaryLayer::configure(const IPortableTensor *input, IPortableTensor *output,
222                                       const ElementwiseUnaryType op_type)
223 {
224   assert(input != nullptr);
225   assert(output != nullptr);
226
227   _input = input;
228   _output = output;
229
230   switch (op_type)
231   {
232     case ElementwiseUnaryType::kAbs:
233       if ((input->data_type() == OperandType::FLOAT32))
234       {
235         _kernel = absFloat32;
236       }
237       else
238       {
239         throw std::runtime_error{"Abs: Unsupported data type"};
240       }
241       break;
242     case ElementwiseUnaryType::kCast:
243       _kernel = cast;
244       break;
245     case ElementwiseUnaryType::kCos:
246       if ((input->data_type() == OperandType::FLOAT32))
247       {
248         _kernel = cosFloat32;
249       }
250       else
251       {
252         throw std::runtime_error{"Cos: Unsupported data type"};
253       }
254       break;
255     case ElementwiseUnaryType::kDequantize:
256       if ((input->data_type() == OperandType::QUANT_UINT8_ASYMM))
257       {
258         _kernel = dequantizeUint8;
259       }
260       else if ((input->data_type() == OperandType::QUANT_INT8_ASYMM) ||
261                (input->data_type() == OperandType::QUANT_INT8_SYMM))
262       {
263         _kernel = dequantizeInt8;
264       }
265       else
266       {
267         throw std::runtime_error{"Dequantize: Unsupported data type"};
268       }
269       break;
270     case ElementwiseUnaryType::kExp:
271       if ((input->data_type() == OperandType::FLOAT32))
272       {
273         _kernel = expFloat32;
274       }
275       else
276       {
277         throw std::runtime_error{"Exp: Unsupported data type"};
278       }
279       break;
280     case ElementwiseUnaryType::kErf:
281       if ((input->data_type() == OperandType::FLOAT32))
282       {
283         _kernel = erfFloat32;
284       }
285       else
286       {
287         throw std::runtime_error{"Exp: Unsupported data type"};
288       }
289       break;
290     case ElementwiseUnaryType::kFloor:
291       if ((input->data_type() == OperandType::FLOAT32))
292       {
293         _kernel = floorFloat32;
294       }
295       else
296       {
297         throw std::runtime_error{"Floor: Unsupported data type"};
298       }
299       break;
300     case ElementwiseUnaryType::kLog:
301       if ((input->data_type() == OperandType::FLOAT32))
302       {
303         _kernel = logFloat32;
304       }
305       else
306       {
307         throw std::runtime_error{"Log: Unsupported  data type"};
308       }
309       break;
310     case ElementwiseUnaryType::kLogicalNot:
311       if ((input->data_type() == OperandType::BOOL8))
312       {
313         _kernel = logicalNot;
314       }
315       else
316       {
317         throw std::runtime_error{"LogicalNot: Unsupported  data type"};
318       }
319       break;
320     case ElementwiseUnaryType::kNeg:
321       if ((input->data_type() == OperandType::FLOAT32))
322       {
323         _kernel = neg<float>;
324       }
325       else if ((input->data_type() == OperandType::INT64))
326       {
327         _kernel = neg<int64_t>;
328       }
329       else if ((input->data_type() == OperandType::INT32))
330       {
331         _kernel = neg<int32_t>;
332       }
333       else
334       {
335         throw std::runtime_error{"Neg: Unsupported  data type"};
336       }
337       break;
338     case ElementwiseUnaryType::kQuantize:
339       if ((input->data_type() == OperandType::FLOAT32))
340       {
341         _kernel = affineQuantize<float, uint8_t>;
342       }
343       else
344       {
345         throw std::runtime_error{"Quantize: Unsupported  data type"};
346       }
347       break;
348     case ElementwiseUnaryType::kRound:
349       if ((input->data_type() == OperandType::FLOAT32))
350       {
351         _kernel = roundFloat32;
352       }
353       else
354       {
355         throw std::runtime_error{"Round: Unsupported  data type"};
356       }
357       break;
358     case ElementwiseUnaryType::kRSqrt:
359       if ((input->data_type() == OperandType::FLOAT32))
360       {
361         _kernel = rsqrtFloat32;
362       }
363       else
364       {
365         throw std::runtime_error{"RSqrt: Unsupported  data type"};
366       }
367       break;
368     case ElementwiseUnaryType::kSin:
369       if ((input->data_type() == OperandType::FLOAT32))
370       {
371         _kernel = sinFloat32;
372       }
373       else
374       {
375         throw std::runtime_error{"Sin: Unsupported  data type"};
376       }
377       break;
378     case ElementwiseUnaryType::kSqrt:
379       if ((input->data_type() == OperandType::FLOAT32))
380       {
381         _kernel = sqrtFloat32;
382       }
383       else
384       {
385         throw std::runtime_error{"Sqrt: Unsupported  data type"};
386       }
387       break;
388     case ElementwiseUnaryType::kSquare:
389       if ((input->data_type() == OperandType::FLOAT32))
390       {
391         _kernel = squareFloat32;
392       }
393       else
394       {
395         throw std::runtime_error{"Square: Unsupported  data type"};
396       }
397       break;
398     case ElementwiseUnaryType::kZerosLike:
399       if (input->data_type() == OperandType::FLOAT32)
400       {
401         _kernel = zerosLikeFloat32<float>;
402       }
403       else if (input->data_type() == OperandType::INT32)
404       {
405         _kernel = zerosLikeFloat32<int32_t>;
406       }
407       else
408       {
409         throw std::runtime_error{"ZerosLike: Unsupported data type"};
410       }
411       break;
412     default:
413       throw std::runtime_error{"ElementwiseBinary: Unsupported ElementwiseBinary type"};
414   }
415 }
416
417 void ElementwiseUnaryLayer::run() { _kernel(_input, _output); }
418
419 } // namespace ops
420 } // namespace cpu
421 } // namespace backend
422 } // namespace onert