2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include "kernels/Utils.h"
19 #include "TISOKernel.h"
21 #include "PALReduceCommon.h"
25 namespace luci_interpreter
31 void reduceProdGeneric(kernels::TISOData *tiso_data, const circle::Tensor *input,
32 const circle::Tensor *axis, const circle::Tensor *output, bool keep_dims)
34 const int input_rank = Tensor::num_dims(input);
35 const int num_axis = Tensor::num_elements(axis);
37 auto const input_dims = wrap(input->shape());
38 const auto output_shape = kernels::getTensorShape(output);
40 luci_interpreter_pal::ReduceGeneric<T>(
41 kernels::getTensorData<T>(tiso_data->input1_data),
42 reinterpret_cast<const int *>(input_dims.data()), input_rank,
43 kernels::getTensorData<T>(tiso_data->output_data),
44 kernels::getTensorData<int>(tiso_data->input2_data), num_axis,
45 /*init_value=*/T(1), output_shape.flatSize(),
46 [](const T current, const T in) -> T { return in * current; });
51 void configure_kernel_CircleReduceCommon(const circle::Operator *cur_op,
52 BaseRuntimeGraph *runtime_graph)
54 kernels::TISOKernel kernel(cur_op, runtime_graph);
56 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) == DataType::S32 or
57 Tensor::element_type(kernel.input1()) == DataType::FLOAT32 or
58 Tensor::element_type(kernel.input1()) == DataType::S64);
59 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) == DataType::S32);
62 void execute_kernel_CircleReduceCommon(const circle::Operator *cur_op,
63 BaseRuntimeGraph *runtime_graph)
65 kernels::TISOKernel kernel(cur_op, runtime_graph);
66 kernels::TISOData tiso_data = kernel.readData();
68 const auto *input = kernel.input1();
69 const auto *axis = kernel.input2();
70 const auto *output = kernel.output();
72 const auto *options = cur_op->builtin_options_as_ReducerOptions();
74 switch (Tensor::element_type(kernel.input1()))
77 case DataType::FLOAT32:
78 reduceProdGeneric<float>(&tiso_data, input, axis, output, options->keep_dims());
82 reduceProdGeneric<int32_t>(&tiso_data, input, axis, output, options->keep_dims());
85 reduceProdGeneric<int64_t>(&tiso_data, input, axis, output, options->keep_dims());
88 assert(false && "Unsupported type");
92 } // namespace luci_interpreter