2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <cker/operation/FullyConnected.h>
19 #include "OperationUtil.h"
21 #include "interp/Registration.h"
22 #include "ir/operation/FullyConnected.h"
23 #include "misc/polymorphic_downcast.h"
32 void prepareFC(ExecEnv *env, const ir::Operation &node)
34 const auto in_index = node.getInputs().at(ir::operation::FullyConnected::INPUT);
35 const auto kernel_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT);
36 const auto bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS);
37 const auto out_index = node.getOutputs().at(0);
39 const auto in_tensor = env->tensorAt(in_index);
40 const auto kernel_tensor = env->tensorAt(kernel_index);
41 const auto bias_tensor = env->tensorAt(bias_index);
43 UNUSED_RELEASE(in_tensor);
44 UNUSED_RELEASE(kernel_tensor);
45 UNUSED_RELEASE(bias_tensor);
47 assert(in_tensor->num_dimensions() >= 2);
48 assert(kernel_tensor->num_dimensions() == 2);
49 assert(bias_tensor->num_dimensions() == 1);
51 const auto input_size_with_batch = in_tensor->num_elements();
52 const auto num_units = kernel_tensor->dimension(0);
53 const auto input_size = kernel_tensor->dimension(1);
54 const auto batch_size = input_size_with_batch / input_size;
55 assert(input_size_with_batch % input_size == 0);
56 assert(num_units == bias_tensor->dimension(0));
58 // Make output tensor info
59 ir::Shape output_shape(2);
60 output_shape.dim(0) = batch_size;
61 output_shape.dim(1) = num_units;
63 ir::OperandInfo::createStaticInfo(output_shape, in_tensor->tensorInfo().typeInfo());
64 env->allocateIfNeeded(out_index, out_info);
66 auto out_tensor = env->tensorAt(out_index);
67 UNUSED_RELEASE(out_tensor);
69 // Handle same ifm & ofm data type only
70 assert(in_tensor->data_type() == out_tensor->data_type());
71 assert(out_tensor->num_dimensions() == 2);
72 assert(out_tensor->dimension(0) == batch_size);
73 assert(out_tensor->dimension(1) == num_units);
76 void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *bias_tensor,
77 const ITensor *ofm_tensor, const ir::operation::FullyConnected::Param ¶m)
79 const auto ifm_buffer = ifm_tensor->bufferRO();
80 const auto ker_buffer = ker_tensor->bufferRO();
81 const auto bias_buffer = bias_tensor->bufferRO();
82 auto ofm_buffer = ofm_tensor->buffer();
85 nnfw::cker::FullyConnectedParams cker_param;
86 cker_param.activation = convertActivationType(param.activation);
87 calculateActivationRange(param.activation, &cker_param.float_activation_min,
88 &cker_param.float_activation_max);
89 const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape());
90 const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape());
91 const auto cker_bias_shape = convertShape(bias_tensor->tensorInfo().shape());
92 const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape());
93 const float *ifm_ptr = reinterpret_cast<const float *>(ifm_buffer);
94 const float *ker_ptr = reinterpret_cast<const float *>(ker_buffer);
95 const float *bias_ptr = reinterpret_cast<const float *>(bias_buffer);
96 float *ofm_ptr = reinterpret_cast<float *>(ofm_buffer);
98 nnfw::cker::FullyConnected(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr,
99 cker_bias_shape, bias_ptr, cker_ofm_shape, ofm_ptr);
102 void invokeFC(const ExecEnv *env, const ir::Operation &node)
104 const auto &conv_node =
105 nnfw::misc::polymorphic_downcast<const ir::operation::FullyConnected &>(node);
107 const auto ifm_index = node.getInputs().at(ir::operation::FullyConnected::INPUT);
108 const auto ker_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT);
109 const auto bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS);
110 const auto ofm_index = node.getOutputs().at(0);
112 const auto ifm_tensor = env->tensorAt(ifm_index);
113 const auto ker_tensor = env->tensorAt(ker_index);
114 const auto bias_tensor = env->tensorAt(bias_index);
115 const auto ofm_tensor = env->tensorAt(ofm_index);
117 const auto data_type = ifm_tensor->data_type();
118 if (data_type == ir::DataType::FLOAT32)
120 invoke(ifm_tensor, ker_tensor, bias_tensor, ofm_tensor, conv_node.param());
124 throw std::runtime_error{"NYI: Support float only"};
129 OpKernel *getFullyConnected()
131 static OpKernel kernel = {fc::prepareFC, fc::invokeFC};
135 } // namespace interp