Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Add.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "kernels/Utils.h"
20
21 #include "kernels/BinaryOpCommon.h"
22
23 #include "PALAdd.h"
24
25 namespace luci_interpreter
26 {
27
28 void configure_kernel_CircleAdd(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
29 {
30   kernels::TISOKernel kernel(cur_op, runtime_graph);
31
32   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
33                          Tensor::element_type(kernel.input2()));
34   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
35                          Tensor::element_type(kernel.input2()));
36
37 #ifndef DIS_QUANT
38   if (Tensor::element_type(kernel.input1()) == DataType::S16)
39   {
40     LUCI_INTERPRETER_CHECK(Tensor::zero_points(kernel.input1()).size() == 1 &&
41                            Tensor::zero_points(kernel.input2()).size() == 1);
42     LUCI_INTERPRETER_CHECK(Tensor::zero_point(kernel.input1()) == 0 &&
43                            Tensor::zero_point(kernel.input2()) == 0 &&
44                            Tensor::zero_point(kernel.output()) == 0);
45   }
46 #endif // DIS_QUANT
47 }
48
49 void execute_kernel_CircleAdd(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
50 {
51   kernels::TISOKernel kernel(cur_op, runtime_graph);
52
53   const auto *options = cur_op->builtin_options_as_AddOptions();
54
55   luci_interpreter::RuntimeShape input_shape1 =
56     kernels::getTensorRuntimeShape(kernel.input1(), runtime_graph);
57   luci_interpreter::RuntimeShape input_shape2 =
58     kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph);
59
60   bool is_inplace = runtime_graph->is_inplace_op(cur_op);
61
62   switch (Tensor::element_type(kernel.input1()))
63   {
64 #ifndef DIS_FLOAT
65     case DataType::FLOAT32:
66     {
67       auto tiso_func = luci_interpreter_pal::Add<float>;
68       auto broadcast_tiso_func = luci_interpreter_pal::BroadcastAdd4DSlow<float>;
69       if (is_inplace)
70       {
71         kernels::evalTISOInplaceKernel<float>(tiso_func, broadcast_tiso_func, &kernel, options,
72                                               std::move(input_shape1), std::move(input_shape2));
73       }
74       else
75       {
76         kernels::TISOData kernel_data = kernel.readData();
77         kernels::evalTISOKernel<float>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
78                                        options, std::move(input_shape1), std::move(input_shape2));
79       }
80     }
81     break;
82 #endif // DIS_FLOAT
83     case DataType::S64:
84     {
85       auto tiso_func = luci_interpreter_pal::Add<int64_t>;
86       auto broadcast_tiso_func = luci_interpreter_pal::BroadcastAdd4DSlow<int64_t>;
87       if (is_inplace)
88       {
89         kernels::evalTISOInplaceKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, options,
90                                                 std::move(input_shape1), std::move(input_shape2));
91       }
92       else
93       {
94         kernels::TISOData kernel_data = kernel.readData();
95         kernels::evalTISOKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
96                                          options, std::move(input_shape1), std::move(input_shape2));
97       }
98     }
99     break;
100     case DataType::S32:
101     {
102       auto tiso_func = luci_interpreter_pal::Add<int32_t>;
103       auto broadcast_tiso_func = luci_interpreter_pal::BroadcastAdd4DSlow<int32_t>;
104       if (is_inplace)
105       {
106         kernels::evalTISOInplaceKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, options,
107                                                 std::move(input_shape1), std::move(input_shape2));
108       }
109       else
110       {
111         kernels::TISOData kernel_data = kernel.readData();
112         kernels::evalTISOKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
113                                          options, std::move(input_shape1), std::move(input_shape2));
114       }
115     }
116     break;
117     default:
118       assert(false && "Unsupported type.");
119   }
120 }
121
122 } // namespace luci_interpreter