2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <cker/operation/Concatenation.h>
19 #include "OperationUtil.h"
21 #include "interp/Registration.h"
22 #include "ir/operation/Concat.h"
23 #include "misc/polymorphic_downcast.h"
32 void prepareConcat(ExecEnv *env, const ir::Operation &node)
34 const auto &concat_node = nnfw::misc::polymorphic_downcast<const ir::operation::Concat &>(node);
36 const auto first_index = node.getInputs().at(0);
37 const auto out_index = node.getOutputs().at(0);
39 const auto first_tensor = env->tensorAt(first_index);
40 uint32_t out_axis_dimension = 0;
41 const int32_t axis_raw = concat_node.param().axis;
42 const uint32_t axis = (axis_raw < 0) ? (axis_raw + first_tensor->num_dimensions()) : axis_raw;
44 // All inputs shape should be same except axis dimension
45 // All inputs type should be same
46 for (auto input : node.getInputs())
48 assert(first_tensor->num_dimensions() == env->tensorAt(input)->num_dimensions());
49 assert(first_tensor->data_type() == env->tensorAt(input)->data_type());
50 for (uint32_t i = 0; i < first_tensor->num_dimensions(); i++)
54 out_axis_dimension += env->tensorAt(input)->dimension(i);
57 assert(first_tensor->dimension(i) == env->tensorAt(input)->dimension(i));
61 // Make output tensor info using first input tensor info, and accumulated axis dimension value
62 auto out_shape = first_tensor->tensorInfo().shape();
63 out_shape.dim(axis) = out_axis_dimension;
64 env->allocateIfNeeded(out_index, ir::OperandInfo::createStaticInfo(
65 out_shape, first_tensor->tensorInfo().typeInfo()));
67 auto out_tensor = env->tensorAt(out_index);
68 UNUSED_RELEASE(out_tensor);
70 // Output shape should be same with input except axis dimension
71 // Output type should be same with input
72 assert(first_tensor->data_type() == out_tensor->data_type());
73 for (uint32_t i = 0; i < first_tensor->num_dimensions(); i++)
79 assert(first_tensor->dimension(i) == out_tensor->dimension(i));
83 void invoke(const std::vector<const ITensor *> in_tensors, const ITensor *out_tensor, uint32_t axis)
85 const uint32_t count = in_tensors.size();
88 nnfw::cker::ConcatenationParams cker_param;
89 cker_param.axis = (int8_t)axis;
90 cker_param.inputs_count = count;
92 const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
94 std::vector<nnfw::cker::Shape> in_shapes;
95 std::vector<const nnfw::cker::Shape *> in_shape_ptrs;
96 in_shapes.reserve(count);
97 in_shape_ptrs.reserve(count);
98 std::vector<const float *> in_ptrs;
99 for (uint32_t i = 0; i < count; i++)
101 in_shapes.push_back(convertShape(in_tensors[i]->tensorInfo().shape()));
102 in_shape_ptrs.push_back(&in_shapes[i]);
103 in_ptrs.push_back(reinterpret_cast<const float *>(in_tensors[i]->bufferRO()));
106 auto out_buffer = out_tensor->buffer();
107 float *out_ptr = reinterpret_cast<float *>(out_buffer);
109 nnfw::cker::Concatenation<float>(cker_param, in_shape_ptrs.data(), in_ptrs.data(), out_shape,
113 void invokeConcat(const ExecEnv *env, const ir::Operation &node)
115 const auto &concat_node = nnfw::misc::polymorphic_downcast<const ir::operation::Concat &>(node);
116 const int32_t axis_raw = concat_node.param().axis;
118 std::vector<const ITensor *> in_tensors;
119 for (const auto &e : concat_node.getInputs())
121 in_tensors.emplace_back(env->tensorAt(e));
124 const auto out_index = node.getOutputs().at(0);
125 const auto out_tensor = env->tensorAt(out_index);
126 const uint32_t axis = (axis_raw < 0) ? (axis_raw + out_tensor->num_dimensions()) : axis_raw;
128 const auto data_type = in_tensors[0]->data_type();
129 if (data_type == ir::DataType::FLOAT32)
131 invoke(in_tensors, out_tensor, axis);
135 throw std::runtime_error{"NYI: Support float32 only"};
138 } // namespace concat
140 OpKernel *getConcat()
142 static OpKernel kernel = {concat::prepareConcat, concat::invokeConcat};
146 } // namespace interp