Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Fill.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "TISOKernel.h"
20 #include "kernels/Utils.h"
21
22 namespace luci_interpreter
23 {
24 namespace
25 {
26
27 template <typename T> void fillImpl(const size_t flat_size, const T *value_data, T *output_data)
28 {
29   for (int i = 0; i < flat_size; ++i)
30   {
31     output_data[i] = *value_data;
32   }
33 }
34
35 } // namespace
36
37 void configure_kernel_CircleFill(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
38 {
39   kernels::TISOKernel kernel(cur_op, runtime_graph);
40   // value tensor must be a scalar or has one element
41   LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input2()) == 0 or
42                          Tensor::num_elements(kernel.input2()) == 1);
43   // value and output type must match
44   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) ==
45                          Tensor::element_type(kernel.output()));
46 }
47
48 void execute_kernel_CircleFill(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
49 {
50   kernels::TISOKernel kernel(cur_op, runtime_graph);
51
52   const circle::Tensor *value = kernel.input2();
53   const circle::Tensor *output = kernel.output();
54
55   kernels::TISOData tiso_data = kernel.readData();
56   const uint8_t *value_data = tiso_data.input2_data;
57   uint8_t *output_data = tiso_data.output_data;
58
59   const size_t flat_size = Tensor::num_elements(output);
60
61   switch (Tensor::element_type(value))
62   {
63 #ifndef DIS_FLOAT
64     case DataType::FLOAT32:
65       fillImpl<float>(flat_size, kernels::getTensorData<float>(value_data),
66                       kernels::getTensorData<float>(output_data));
67       break;
68 #endif // DIS_FLOAT
69     case DataType::S32:
70       fillImpl<int32_t>(flat_size, kernels::getTensorData<int32_t>(value_data),
71                         kernels::getTensorData<int32_t>(output_data));
72       break;
73 #ifndef DIS_QUANT
74     case DataType::U8:
75       fillImpl<uint8_t>(flat_size, kernels::getTensorData<uint8_t>(value_data),
76                         kernels::getTensorData<uint8_t>(output_data));
77       break;
78 #endif // DIS_QUANT
79     default:
80       assert(false && "Not impl yet");
81   }
82 }
83
84 } // namespace luci_interpreter