2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
19 #include "kernels/Utils.h"
21 namespace luci_interpreter
24 void configure_kernel_CircleExpandDims(const circle::Operator *cur_op,
25 BaseRuntimeGraph *runtime_graph)
27 const auto input_index = cur_op->inputs()->operator[](0);
28 const auto axis_index = cur_op->inputs()->operator[](1);
29 const auto output_index = cur_op->outputs()->operator[](0);
31 assert(input_index != -1);
32 assert(axis_index != -1);
33 assert(output_index != -1);
35 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
36 const auto axis = runtime_graph->getCircleTensorByIndex(axis_index);
37 auto output = runtime_graph->getCircleTensorByIndex(output_index);
39 assert(input != nullptr);
40 assert(axis != nullptr);
41 assert(output != nullptr);
43 auto axis_data = runtime_graph->getConstDataByTensor(axis);
47 switch (Tensor::element_type(axis))
50 axis_value = *reinterpret_cast<int32_t *>(axis_data);
53 axis_value = static_cast<int32_t>(*reinterpret_cast<int64_t *>(axis_data));
56 assert(false && "Unsupported type.");
61 axis_value += Tensor::num_dims(input) + 1;
64 LUCI_INTERPRETER_CHECK(axis_value <= Tensor::num_dims(input) and axis_value >= 0);
67 void execute_kernel_CircleExpandDims(const circle::Operator *cur_op,
68 BaseRuntimeGraph *runtime_graph, bool is_inplace)
70 const auto input_index = cur_op->inputs()->operator[](0);
71 const auto output_index = cur_op->outputs()->operator[](0);
73 assert(input_index != -1);
74 assert(output_index != -1);
76 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
77 const auto output = runtime_graph->getCircleTensorByIndex(output_index);
81 runtime_graph->makeInplaceOperation(input, output);
85 // Just copy input to output
86 const auto input_data = runtime_graph->getDataByTensor(input);
87 auto output_data = runtime_graph->getDataByTensor(output);
89 assert(input_data != nullptr);
90 assert(output_data != nullptr);
92 const size_t element_size = getDataTypeSize(Tensor::element_type(input));
93 const int32_t num_elements = Tensor::num_elements(input);
94 std::memcpy(output_data, input_data, num_elements * element_size);
97 } // namespace luci_interpreter