Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Div.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Builders.h"
18 #include "kernels/Utils.h"
19
20 #include "kernels/BinaryOpCommon.h"
21
22 #include "PALDiv.h"
23
24 namespace luci_interpreter
25 {
26
27 // TODO: reduce code duplication with Mul
28 void configure_kernel_CircleDiv(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
29 {
30   kernels::TISOKernel kernel(cur_op, runtime_graph);
31
32   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
33                          Tensor::element_type(kernel.input2()));
34   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
35                          Tensor::element_type(kernel.input2()));
36 }
37
38 void execute_kernel_CircleDiv(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
39 {
40   kernels::TISOKernel kernel(cur_op, runtime_graph);
41
42   const auto *options = cur_op->builtin_options_as_DivOptions();
43
44   luci_interpreter::RuntimeShape input_shape1 =
45     kernels::getTensorRuntimeShape(kernel.input1(), runtime_graph);
46   luci_interpreter::RuntimeShape input_shape2 =
47     kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph);
48
49   bool is_inplace = runtime_graph->is_inplace_op(cur_op);
50
51   switch (Tensor::element_type(kernel.input1()))
52   {
53 #ifndef DIS_FLOAT
54     case DataType::FLOAT32:
55     {
56       auto tiso_func = luci_interpreter_pal::Div<float>;
57       auto broadcast_tiso_func = luci_interpreter_pal::BroadcastDiv4DSlow<float>;
58       if (is_inplace)
59       {
60         kernels::evalTISOInplaceKernel<float>(tiso_func, broadcast_tiso_func, &kernel, options,
61                                               std::move(input_shape1), std::move(input_shape2));
62       }
63       else
64       {
65         kernels::TISOData kernel_data = kernel.readData();
66         kernels::evalTISOKernel<float>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
67                                        options, std::move(input_shape1), std::move(input_shape2));
68       }
69     }
70     break;
71 #endif // DIS_FLOAT
72     default:
73       assert(false && "Unsupported type.");
74   }
75 }
76
77 } // namespace luci_interpreter