2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "arm_compute/runtime/CL/functions/CLFullyConnectedReshapingLayer.h"
19 #include <arm_compute/runtime/CL/functions/CLFullyConnectedHybridLayer.h>
20 #include <arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h>
21 #include <arm_compute/runtime/CL/functions/CLFullyConnectedLayerEx.h>
23 using namespace arm_compute;
25 void CLFullyConnectedReshapingLayer::configure(const arm_compute::ICLTensor *input,
26 const arm_compute::ICLTensor *weights,
27 const arm_compute::ICLTensor *biases,
28 arm_compute::ICLTensor *output, bool needs_reshape,
29 const arm_compute::TensorShape &reshape,
30 KernelType kernel_type)
36 _needs_reshape = needs_reshape;
38 const ICLTensor *input_to_use = input;
42 auto_init_if_empty(*_cl_buffer.info(),
43 _input->info()->clone()->set_tensor_shape(reshape).set_data_layout(
44 _input->info()->data_layout()));
45 _cl_reshape.configure(_input, &_cl_buffer);
46 input_to_use = &_cl_buffer;
50 if (kernel_type == KernelType::GENERAL)
52 auto fc = new arm_compute::CLFullyConnectedLayerEx{_memory_manager};
53 fc->configure(input_to_use, _weights, _biases, _output);
54 return std::unique_ptr<arm_compute::IFunction>(fc);
56 else if (kernel_type == KernelType::PREPROCESSED_WEIGHTS)
58 bool is_hybrid = (input->info()->data_type() == DataType::F32 ||
59 input->info()->data_type() == DataType::F16) &&
60 (weights->info()->data_type() == DataType::S8 ||
61 weights->info()->data_type() == DataType::QASYMM8_SIGNED);
65 auto fc = new arm_compute::CLFullyConnectedHybridLayer{_memory_manager};
66 ITensorInfo *weights_info = const_cast<ITensorInfo *>(_weights->info());
67 const auto orgin_weights_data_type = weights_info->data_type();
68 weights_info->set_data_type(DataType::QASYMM8_SIGNED);
69 fc->configure(input_to_use, _weights, _biases, _output);
70 weights_info->set_data_type(orgin_weights_data_type);
71 return std::unique_ptr<arm_compute::IFunction>(fc);
75 auto fc = new arm_compute::CLFullyConnectedLayer{_memory_manager};
76 fc->configure(input_to_use, _weights, _biases, _output);
77 return std::unique_ptr<arm_compute::IFunction>(fc);
82 throw std::runtime_error("CLFullyConnectedReshapingLayer: Unsupported kernel type");
89 // NOTE _cl_buffer is inaccessible from outside, and thus it is safe to invoke allocate here.
90 _cl_buffer.allocator()->allocate();
94 void CLFullyConnectedReshapingLayer::run(void)
102 void CLFullyConnectedReshapingLayer::prepare(void) { _cl_fc->prepare(); }