2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
23 namespace luci_interpreter
29 void packImpl(const circle::Tensor *input0, const circle::Tensor *output,
30 const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
31 uint8_t *output_data_raw)
33 const auto *options = cur_op->builtin_options_as_PackOptions();
35 const int values_count = options->values_count();
36 int axis = options->axis();
37 const int dimensions = Tensor::num_dims(output);
39 const auto input_dims = wrap(input0->shape());
40 const auto output_dims = wrap(output->shape());
48 for (int i = 0; i < axis; ++i)
49 outer_size *= output_dims[i];
52 for (int i = axis + 1; i < dimensions; ++i)
53 copy_size *= output_dims[i];
56 for (int i = 0; i < input_dims.size(); ++i)
57 input_size *= input_dims[i];
59 assert(input_size == copy_size * outer_size);
61 T *output_data = kernels::getTensorData<T>(output_data_raw);
62 assert(output_data != nullptr);
64 for (int i = 0; i < values_count; ++i)
66 const auto input_index = cur_op->inputs()->operator[](i);
67 assert(input_index != -1);
68 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
70 auto input_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(input));
71 assert(input_data != nullptr);
72 for (int k = 0; k < outer_size; ++k)
74 const T *input_ptr = input_data + copy_size * k;
75 int loc = k * values_count * copy_size + i * copy_size;
76 T *output_ptr = output_data + loc;
77 for (int j = 0; j < copy_size; ++j)
78 output_ptr[j] = input_ptr[j];
85 void configure_kernel_CirclePack(const circle::Operator *, BaseRuntimeGraph *)
90 void execute_kernel_CirclePack(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
92 const auto input_index = cur_op->inputs()->operator[](0);
93 const auto output_index = cur_op->outputs()->operator[](0);
94 assert(output_index != -1);
95 assert(input_index != -1);
96 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
97 const auto output = runtime_graph->getCircleTensorByIndex(output_index);
99 auto output_data = runtime_graph->getDataByTensor(output);
100 assert(output_data != nullptr);
102 switch (Tensor::element_type(output))
105 case DataType::FLOAT32:
106 packImpl<float>(input, output, cur_op, runtime_graph, output_data);
111 packImpl<int8_t>(input, output, cur_op, runtime_graph, output_data);
114 packImpl<uint8_t>(input, output, cur_op, runtime_graph, output_data);
118 packImpl<int32_t>(input, output, cur_op, runtime_graph, output_data);
121 packImpl<int64_t>(input, output, cur_op, runtime_graph, output_data);
124 assert(false && "Unsupported types");
128 } // namespace luci_interpreter