2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include "kernels/Utils.h"
20 #include "PadCommon.h"
23 namespace luci_interpreter
25 void configure_kernel_CirclePadCommon(const circle::Operator *cur_op,
26 BaseRuntimeGraph *runtime_graph)
28 const auto num_inputs = cur_op->inputs()->size();
30 const auto input1_index = cur_op->inputs()->operator[](0);
31 const auto input2_index = cur_op->inputs()->operator[](1);
32 const auto input3_index = num_inputs == 3 ? cur_op->inputs()->operator[](2) : -1;
33 const auto output_index = cur_op->outputs()->operator[](0);
35 assert(input1_index != -1);
36 assert(input2_index != -1);
37 assert(input3_index != -1 or num_inputs == 2);
38 assert(output_index != -1);
40 const auto input1_tensor = runtime_graph->getCircleTensorByIndex(input1_index);
41 const auto input2_tensor = runtime_graph->getCircleTensorByIndex(input2_index);
42 const auto input3_tensor =
43 num_inputs == 3 ? runtime_graph->getCircleTensorByIndex(input3_index) : nullptr;
44 const auto output_tensor = runtime_graph->getCircleTensorByIndex(output_index);
46 assert(input1_tensor != nullptr);
47 assert(input2_tensor != nullptr);
48 assert(input3_tensor != nullptr or num_inputs == 2);
49 assert(output_tensor != nullptr);
51 LUCI_INTERPRETER_CHECK(Tensor::element_type(input2_tensor) == DataType::S32);
52 LUCI_INTERPRETER_CHECK(Tensor::element_type(input1_tensor) ==
53 Tensor::element_type(output_tensor));
54 if (input3_tensor != nullptr)
56 LUCI_INTERPRETER_CHECK(Tensor::element_type(input3_tensor) ==
57 Tensor::element_type(input1_tensor));
59 LUCI_INTERPRETER_CHECK(Tensor::num_elements(input3_tensor) == 1);
63 const int32_t *paddings_data =
64 kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(input2_tensor));
65 for (int i = 0; i < Tensor::num_dims(output_tensor); i++)
67 int output_dim = Tensor::dim(output_tensor, i);
69 Tensor::dim(input1_tensor, i) + paddings_data[i * 2] + paddings_data[i * 2 + 1];
70 LUCI_INTERPRETER_CHECK(output_dim == expected_dim);
74 void execute_kernel_CirclePadCommon(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
76 const auto num_inputs = cur_op->inputs()->size();
78 const auto input1_index = cur_op->inputs()->operator[](0);
79 const auto input2_index = cur_op->inputs()->operator[](1);
80 const auto input3_index = num_inputs == 3 ? cur_op->inputs()->operator[](2) : -1;
81 const auto output_index = cur_op->outputs()->operator[](0);
83 assert(input1_index != -1);
84 assert(input2_index != -1);
85 assert(input3_index != -1 or num_inputs == 2);
86 assert(output_index != -1);
88 const auto input1_tensor = runtime_graph->getCircleTensorByIndex(input1_index);
89 const auto input2_tensor = runtime_graph->getCircleTensorByIndex(input2_index);
90 const auto input3_tensor =
91 num_inputs == 3 ? runtime_graph->getCircleTensorByIndex(input3_index) : nullptr;
92 const auto output_tensor = runtime_graph->getCircleTensorByIndex(output_index);
94 assert(input1_tensor != nullptr);
95 assert(input2_tensor != nullptr);
96 assert(input3_tensor != nullptr or num_inputs == 2);
97 assert(output_tensor != nullptr);
99 luci_interpreter_pal::PadParams pad_params;
100 const int num_input_dimensions = Tensor::num_dims(input1_tensor);
101 pad_params.left_padding_count = num_input_dimensions;
102 pad_params.right_padding_count = num_input_dimensions;
104 const int32_t *paddings_data =
105 kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(input2_tensor));
106 for (int idx = num_input_dimensions - 1; idx >= 0; --idx)
108 pad_params.left_padding[idx] = paddings_data[idx * 2];
109 pad_params.right_padding[idx] = paddings_data[idx * 2 + 1];
112 auto *input1_data = runtime_graph->getDataByTensor(input1_tensor);
113 if (input1_data == nullptr)
114 input1_data = runtime_graph->getConstDataByTensor(input1_tensor);
117 auto *input2_data = runtime_graph->getConstDataByTensor(input2_tensor);
120 auto *output_data = runtime_graph->getDataByTensor(output_tensor);
123 switch (Tensor::element_type(input1_tensor))
126 case DataType::FLOAT32:
129 input3_tensor == nullptr
131 : *kernels::getTensorData<float>(runtime_graph->getConstDataByTensor(input3_tensor));
132 luci_interpreter_pal::Pad(pad_params, kernels::getTensorShape(input1_tensor),
133 kernels::getTensorData<float>(input1_data), &pad_value,
134 kernels::getTensorShape(output_tensor),
135 kernels::getTensorData<float>(output_data));
140 assert(false && "Unsupported type");
144 } // namespace luci_interpreter