2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * @file CLFullyConnectedReshapingLayer.h
19 * @brief This file contains CLFullyConnectedReshapingLayer class
20 * @ingroup COM_AI_RUNTIME
23 #ifndef __ARM_COMPUTE_CL_FULLY_CONNECTED_RESHAPING_LAYER_H__
24 #define __ARM_COMPUTE_CL_FULLY_CONNECTED_RESHAPING_LAYER_H__
26 #include <arm_compute/runtime/CL/CLTensor.h>
27 #include <arm_compute/runtime/CL/functions/CLReshapeLayer.h>
28 #include <arm_compute/runtime/IMemoryManager.h>
33 * @brief Class to run FullyConnected Layer after reshaping input tensor
35 class CLFullyConnectedReshapingLayer : public arm_compute::IFunction
40 GENERAL, //< General FC
41 PREPROCESSED_WEIGHTS //< Weights are constants so it can be preprocessed
45 CLFullyConnectedReshapingLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr)
46 : _input(nullptr), _weights(nullptr), _biases(nullptr), _output(nullptr), _cl_buffer{},
47 _memory_manager{memory_manager}, _cl_fc{nullptr}, _cl_reshape{}, _needs_reshape(false)
54 * @brief Configure the layer
55 * @param[in] input The source tensor
56 * @param[in] weights The tensor that is filled with weight values
57 * @param[in] biases The tensor that is filled with biase values
58 * @param[in] output The destination tensor
59 * @param[in] needs_reshape Whether it needs to be reshaped or not
60 * @param[in] reshape The tensor shape to be reshaped. Only valid when needs_reshape is true.
63 void configure(const arm_compute::ICLTensor *input, const arm_compute::ICLTensor *weights,
64 const arm_compute::ICLTensor *biases, arm_compute::ICLTensor *output,
65 bool needs_reshape, const arm_compute::TensorShape &reshape,
66 KernelType kernel_type);
70 * @brief Run the operation. Must be called after configure().
73 void run(void) override;
75 * @brief Prepare the operation
78 void prepare(void) override;
81 const arm_compute::ICLTensor *_input;
82 const arm_compute::ICLTensor *_weights;
83 const arm_compute::ICLTensor *_biases;
84 arm_compute::ICLTensor *_output;
86 // buffer for reshaping input tensor
87 arm_compute::CLTensor _cl_buffer;
90 std::shared_ptr<IMemoryManager> _memory_manager;
91 std::unique_ptr<arm_compute::IFunction> _cl_fc;
92 CLReshapeLayer _cl_reshape;
95 } // namespace arm_compute
97 #endif // __ARM_COMPUTE_CL_FULLY_CONNECTED_RESHAPING_LAYER_H__