Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / ElementwiseUnaryLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ElementwiseUnaryLayer.h"
18
19 #include "OperationUtils.h"
20
21 #include <cker/operation/Dequantize.h>
22 #include <cker/operation/Elementwise.h>
23 #include <cker/operation/Erf.h>
24 #include <cker/operation/Exp.h>
25 #include <cker/operation/LogicalNot.h>
26 #include <cker/operation/Round.h>
27
28 namespace onert
29 {
30 namespace backend
31 {
32 namespace cpu
33 {
34 namespace ops
35 {
36
37 namespace
38 {
39 void absFloat32(const IPortableTensor *input, IPortableTensor *output)
40 {
41   nnfw::cker::Abs(getShape(input), getBuffer<float>(input), getShape(output),
42                   getBuffer<float>(output));
43 }
44
45 template <typename FromT>
46 void castPtr(const FromT *in, DataPtr out, int num_elements, ir::DataType data_type_out)
47 {
48   switch (data_type_out)
49   {
50     case ir::DataType::FLOAT32:
51       std::transform(in, in + num_elements, out.f, [](FromT a) { return static_cast<float>(a); });
52       return;
53     case ir::DataType::INT32:
54       std::transform(in, in + num_elements, out.i32,
55                      [](FromT a) { return static_cast<int32_t>(a); });
56       return;
57     case ir::DataType::UINT32:
58       std::transform(in, in + num_elements, out.u32,
59                      [](FromT a) { return static_cast<uint32_t>(a); });
60       return;
61     case ir::DataType::UINT8:
62       std::transform(in, in + num_elements, out.u8,
63                      [](FromT a) { return static_cast<uint8_t>(a); });
64       return;
65     case ir::DataType::BOOL8:
66       std::transform(in, in + num_elements, out.b, [](FromT a) { return static_cast<bool>(a); });
67       return;
68     case ir::DataType::INT64:
69       std::transform(in, in + num_elements, out.i64,
70                      [](FromT a) { return static_cast<int64_t>(a); });
71       return;
72     default:
73       throw std::runtime_error("Cast: Not supported output type" +
74                                std::to_string((int)data_type_out));
75   }
76 }
77
78 void cast(const IPortableTensor *input, IPortableTensor *output)
79 {
80   auto input_buf = input->buffer();
81   auto output_buf = output->buffer();
82   const auto in = *reinterpret_cast<const DataPtr *>(&input_buf);
83   auto out = *reinterpret_cast<DataPtr *>(&output_buf);
84
85   auto input_shape = getShape(input);
86   auto output_shape = getShape(output);
87   const auto num_elements = MatchingFlatSize(input_shape, output_shape);
88
89   switch (input->data_type())
90   {
91     case ir::DataType::FLOAT32:
92       castPtr(in.f, out, num_elements, output->data_type());
93       return;
94     case ir::DataType::INT32:
95       castPtr(in.i32, out, num_elements, output->data_type());
96       return;
97     case ir::DataType::UINT32:
98       castPtr(in.u32, out, num_elements, output->data_type());
99       return;
100     case ir::DataType::UINT8:
101       castPtr(in.u8, out, num_elements, output->data_type());
102       return;
103     case ir::DataType::BOOL8:
104       castPtr(in.b, out, num_elements, output->data_type());
105       return;
106     case ir::DataType::INT64:
107       castPtr(in.i64, out, num_elements, output->data_type());
108       return;
109     default:
110       throw std::runtime_error("Cast: unsupported data type" +
111                                std::to_string((int)input->data_type()));
112   }
113 }
114
115 void cosFloat32(const IPortableTensor *input, IPortableTensor *output)
116 {
117   nnfw::cker::Cos(getShape(input), getBuffer<float>(input), getShape(output),
118                   getBuffer<float>(output));
119 }
120
121 void dequantizeInt8(const IPortableTensor *input, IPortableTensor *output)
122 {
123   nnfw::cker::Dequantize(getShape(input), getBuffer<int8_t>(input), getShape(output),
124                          getBuffer<float>(output), input->data_scale(), input->data_zero_point());
125 }
126
127 void dequantizeUint8(const IPortableTensor *input, IPortableTensor *output)
128 {
129   nnfw::cker::Dequantize(getShape(input), getBuffer<uint8_t>(input), getShape(output),
130                          getBuffer<float>(output), input->data_scale(), input->data_zero_point());
131 }
132
133 void expFloat32(const IPortableTensor *input, IPortableTensor *output)
134 {
135   nnfw::cker::Exp(getShape(input), getBuffer<float>(input), getShape(output),
136                   getBuffer<float>(output));
137 }
138
139 void erfFloat32(const IPortableTensor *input, IPortableTensor *output)
140 {
141   nnfw::cker::Erf(getShape(input), getBuffer<float>(input), getShape(output),
142                   getBuffer<float>(output));
143 }
144
145 void floorFloat32(const IPortableTensor *input, IPortableTensor *output)
146 {
147   nnfw::cker::Floor(getShape(input), getBuffer<float>(input), getShape(output),
148                     getBuffer<float>(output));
149 }
150
151 void logFloat32(const IPortableTensor *input, IPortableTensor *output)
152 {
153   nnfw::cker::Log(getShape(input), getBuffer<float>(input), getShape(output),
154                   getBuffer<float>(output));
155 }
156
157 void logicalNot(const IPortableTensor *input, IPortableTensor *output)
158 {
159   nnfw::cker::LogicalNot(getShape(input), getBuffer<bool>(input), getShape(output),
160                          getBuffer<bool>(output));
161 }
162
163 template <typename T> void neg(const IPortableTensor *input, IPortableTensor *output)
164 {
165   nnfw::cker::Neg<T>(getShape(input), getBuffer<T>(input), getShape(output), getBuffer<T>(output));
166 }
167
168 void roundFloat32(const IPortableTensor *input, IPortableTensor *output)
169 {
170   nnfw::cker::Round(getShape(input), getBuffer<float>(input), getShape(output),
171                     getBuffer<float>(output));
172 }
173
174 void rsqrtFloat32(const IPortableTensor *input, IPortableTensor *output)
175 {
176   nnfw::cker::Rsqrt(getShape(input), getBuffer<float>(input), getShape(output),
177                     getBuffer<float>(output));
178 }
179
180 void sinFloat32(const IPortableTensor *input, IPortableTensor *output)
181 {
182   nnfw::cker::Sin(getShape(input), getBuffer<float>(input), getShape(output),
183                   getBuffer<float>(output));
184 }
185
186 void sqrtFloat32(const IPortableTensor *input, IPortableTensor *output)
187 {
188   nnfw::cker::Sqrt(getShape(input), getBuffer<float>(input), getShape(output),
189                    getBuffer<float>(output));
190 }
191
192 void squareFloat32(const IPortableTensor *input, IPortableTensor *output)
193 {
194   nnfw::cker::Square(getShape(input), getBuffer<float>(input), getShape(output),
195                      getBuffer<float>(output));
196 }
197
198 template <typename T> void zerosLikeFloat32(const IPortableTensor *input, IPortableTensor *output)
199 {
200   if (!HaveSameShapes(input, output))
201     throw std::runtime_error{"ZerosLike: input and output shape don't match."};
202
203   auto element_size = getShape(input).FlatSize();
204
205   memset(getBuffer<T>(output), 0, element_size * sizeof(T));
206 }
207 } // namespace
208
209 void ElementwiseUnaryLayer::configure(const IPortableTensor *input, IPortableTensor *output,
210                                       const ElementwiseUnaryType op_type)
211 {
212   assert(input != nullptr);
213   assert(output != nullptr);
214
215   _input = input;
216   _output = output;
217
218   switch (op_type)
219   {
220     case ElementwiseUnaryType::kAbs:
221       if ((input->data_type() == OperandType::FLOAT32))
222       {
223         _kernel = absFloat32;
224       }
225       else
226       {
227         throw std::runtime_error{"Abs: Unsupported data type"};
228       }
229       break;
230     case ElementwiseUnaryType::kCast:
231       _kernel = cast;
232       break;
233     case ElementwiseUnaryType::kCos:
234       if ((input->data_type() == OperandType::FLOAT32))
235       {
236         _kernel = cosFloat32;
237       }
238       else
239       {
240         throw std::runtime_error{"Cos: Unsupported data type"};
241       }
242       break;
243     case ElementwiseUnaryType::kDequantize:
244       if ((input->data_type() == OperandType::QUANT_UINT8_ASYMM))
245       {
246         _kernel = dequantizeUint8;
247       }
248       else if ((input->data_type() == OperandType::QUANT_INT8_ASYMM) ||
249                (input->data_type() == OperandType::QUANT_INT8_SYMM))
250       {
251         _kernel = dequantizeInt8;
252       }
253       else
254       {
255         throw std::runtime_error{"Dequantize: Unsupported data type"};
256       }
257       break;
258     case ElementwiseUnaryType::kExp:
259       if ((input->data_type() == OperandType::FLOAT32))
260       {
261         _kernel = expFloat32;
262       }
263       else
264       {
265         throw std::runtime_error{"Exp: Unsupported data type"};
266       }
267       break;
268     case ElementwiseUnaryType::kErf:
269       if ((input->data_type() == OperandType::FLOAT32))
270       {
271         _kernel = erfFloat32;
272       }
273       else
274       {
275         throw std::runtime_error{"Exp: Unsupported data type"};
276       }
277       break;
278     case ElementwiseUnaryType::kFloor:
279       if ((input->data_type() == OperandType::FLOAT32))
280       {
281         _kernel = floorFloat32;
282       }
283       else
284       {
285         throw std::runtime_error{"Floor: Unsupported data type"};
286       }
287       break;
288     case ElementwiseUnaryType::kLog:
289       if ((input->data_type() == OperandType::FLOAT32))
290       {
291         _kernel = logFloat32;
292       }
293       else
294       {
295         throw std::runtime_error{"Log: Unsupported  data type"};
296       }
297       break;
298     case ElementwiseUnaryType::kLogicalNot:
299       if ((input->data_type() == OperandType::BOOL8))
300       {
301         _kernel = logicalNot;
302       }
303       else
304       {
305         throw std::runtime_error{"LogicalNot: Unsupported  data type"};
306       }
307       break;
308     case ElementwiseUnaryType::kNeg:
309       if ((input->data_type() == OperandType::FLOAT32))
310       {
311         _kernel = neg<float>;
312       }
313       else if ((input->data_type() == OperandType::INT64))
314       {
315         _kernel = neg<int64_t>;
316       }
317       else if ((input->data_type() == OperandType::INT32))
318       {
319         _kernel = neg<int32_t>;
320       }
321       else
322       {
323         throw std::runtime_error{"Neg: Unsupported  data type"};
324       }
325       break;
326     case ElementwiseUnaryType::kRound:
327       if ((input->data_type() == OperandType::FLOAT32))
328       {
329         _kernel = roundFloat32;
330       }
331       else
332       {
333         throw std::runtime_error{"Round: Unsupported  data type"};
334       }
335       break;
336     case ElementwiseUnaryType::kRSqrt:
337       if ((input->data_type() == OperandType::FLOAT32))
338       {
339         _kernel = rsqrtFloat32;
340       }
341       else
342       {
343         throw std::runtime_error{"RSqrt: Unsupported  data type"};
344       }
345       break;
346     case ElementwiseUnaryType::kSin:
347       if ((input->data_type() == OperandType::FLOAT32))
348       {
349         _kernel = sinFloat32;
350       }
351       else
352       {
353         throw std::runtime_error{"Sin: Unsupported  data type"};
354       }
355       break;
356     case ElementwiseUnaryType::kSqrt:
357       if ((input->data_type() == OperandType::FLOAT32))
358       {
359         _kernel = sqrtFloat32;
360       }
361       else
362       {
363         throw std::runtime_error{"Sqrt: Unsupported  data type"};
364       }
365       break;
366     case ElementwiseUnaryType::kSquare:
367       if ((input->data_type() == OperandType::FLOAT32))
368       {
369         _kernel = squareFloat32;
370       }
371       else
372       {
373         throw std::runtime_error{"Square: Unsupported  data type"};
374       }
375       break;
376     case ElementwiseUnaryType::kZerosLike:
377       if (input->data_type() == OperandType::FLOAT32)
378       {
379         _kernel = zerosLikeFloat32<float>;
380       }
381       else if (input->data_type() == OperandType::INT32)
382       {
383         _kernel = zerosLikeFloat32<int32_t>;
384       }
385       else
386       {
387         throw std::runtime_error{"ZerosLike: Unsupported data type"};
388       }
389       break;
390     default:
391       throw std::runtime_error{"ElementwiseUnary: Unsupported ElementwiseUnary type"};
392   }
393 }
394
395 void ElementwiseUnaryLayer::run() { _kernel(_input, _output); }
396
397 } // namespace ops
398 } // namespace cpu
399 } // namespace backend
400 } // namespace onert