867af01bb8b7535f6ca0c1523886944c4b524fa5
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / ExpandDims.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "kernels/Utils.h"
20
21 namespace luci_interpreter
22 {
23
24 void configure_kernel_CircleExpandDims(const circle::Operator *cur_op,
25                                        BaseRuntimeGraph *runtime_graph)
26 {
27   const auto input_index = cur_op->inputs()->operator[](0);
28   const auto axis_index = cur_op->inputs()->operator[](1);
29   const auto output_index = cur_op->outputs()->operator[](0);
30
31   assert(input_index != -1);
32   assert(axis_index != -1);
33   assert(output_index != -1);
34
35   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
36   const auto axis = runtime_graph->getCircleTensorByIndex(axis_index);
37   auto output = runtime_graph->getCircleTensorByIndex(output_index);
38
39   assert(input != nullptr);
40   assert(axis != nullptr);
41   assert(output != nullptr);
42
43   auto axis_data = runtime_graph->getConstDataByTensor(axis);
44
45   int32_t axis_value;
46
47   switch (Tensor::element_type(axis))
48   {
49     case DataType::S32:
50       axis_value = *reinterpret_cast<int32_t *>(axis_data);
51       break;
52     case DataType::S64:
53       axis_value = static_cast<int32_t>(*reinterpret_cast<int64_t *>(axis_data));
54       break;
55     default:
56       assert(false && "Unsupported type.");
57   }
58
59   if (axis_value < 0)
60   {
61     axis_value += Tensor::num_dims(input) + 1;
62   }
63
64   LUCI_INTERPRETER_CHECK(axis_value <= Tensor::num_dims(input) and axis_value >= 0);
65 }
66
67 void execute_kernel_CircleExpandDims(const circle::Operator *cur_op,
68                                      BaseRuntimeGraph *runtime_graph, bool is_inplace)
69 {
70   const auto input_index = cur_op->inputs()->operator[](0);
71   const auto output_index = cur_op->outputs()->operator[](0);
72
73   assert(input_index != -1);
74   assert(output_index != -1);
75
76   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
77   const auto output = runtime_graph->getCircleTensorByIndex(output_index);
78
79   if (is_inplace)
80   {
81     runtime_graph->makeInplaceOperation(input, output);
82     return;
83   }
84
85   // Just copy input to output
86   const auto input_data = runtime_graph->getDataByTensor(input);
87   auto output_data = runtime_graph->getDataByTensor(output);
88
89   assert(input_data != nullptr);
90   assert(output_data != nullptr);
91
92   const size_t element_size = getDataTypeSize(Tensor::element_type(input));
93   const int32_t num_elements = Tensor::num_elements(input);
94   std::memcpy(output_data, input_data, num_elements * element_size);
95 }
96
97 } // namespace luci_interpreter