289ab167f9a685d9f8623718900ed0ebb6222b09
[platform/core/ml/nnfw.git] / compute / ARMComputeEx / arm_compute / runtime / CL / functions / CLFullyConnectedReshapingLayer.h
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * @file        CLFullyConnectedReshapingLayer.h
19  * @brief       This file contains CLFullyConnectedReshapingLayer class
20  * @ingroup     COM_AI_RUNTIME
21  */
22
23 #ifndef __ARM_COMPUTE_CL_FULLY_CONNECTED_RESHAPING_LAYER_H__
24 #define __ARM_COMPUTE_CL_FULLY_CONNECTED_RESHAPING_LAYER_H__
25
26 #include <arm_compute/runtime/CL/CLTensor.h>
27 #include <arm_compute/runtime/CL/functions/CLReshapeLayer.h>
28 #include <arm_compute/runtime/IMemoryManager.h>
29
30 namespace arm_compute
31 {
32 /**
33  * @brief Class to run FullyConnected Layer after reshaping input tensor
34  */
35 class CLFullyConnectedReshapingLayer : public arm_compute::IFunction
36 {
37 public:
38   enum class KernelType
39   {
40     GENERAL,             //< General FC
41     PREPROCESSED_WEIGHTS //< Weights are constants so it can be preprocessed
42   };
43
44 public:
45   CLFullyConnectedReshapingLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr)
46       : _input(nullptr), _weights(nullptr), _biases(nullptr), _output(nullptr), _cl_buffer{},
47         _memory_manager{memory_manager}, _cl_fc{nullptr}, _cl_reshape{}, _needs_reshape(false)
48   {
49     // DO NOTHING
50   }
51
52 public:
53   /**
54    * @brief Configure the layer
55    * @param[in] input The source tensor
56    * @param[in] weights The tensor that is filled with weight values
57    * @param[in] biases The tensor that is filled with biase values
58    * @param[in] output The destination tensor
59    * @param[in] needs_reshape Whether it needs to be reshaped or not
60    * @param[in] reshape The tensor shape to be reshaped. Only valid when needs_reshape is true.
61    * @return N/A
62    */
63   void configure(const arm_compute::ICLTensor *input, const arm_compute::ICLTensor *weights,
64                  const arm_compute::ICLTensor *biases, arm_compute::ICLTensor *output,
65                  bool needs_reshape, const arm_compute::TensorShape &reshape,
66                  KernelType kernel_type);
67
68 public:
69   /**
70    * @brief Run the operation. Must be called after configure().
71    * @return N/A
72    */
73   void run(void) override;
74   /**
75    * @brief Prepare the operation
76    * @return N/A
77    */
78   void prepare(void) override;
79
80 private:
81   const arm_compute::ICLTensor *_input;
82   const arm_compute::ICLTensor *_weights;
83   const arm_compute::ICLTensor *_biases;
84   arm_compute::ICLTensor *_output;
85
86   // buffer for reshaping input tensor
87   arm_compute::CLTensor _cl_buffer;
88
89 private:
90   std::shared_ptr<IMemoryManager> _memory_manager;
91   std::unique_ptr<arm_compute::IFunction> _cl_fc;
92   CLReshapeLayer _cl_reshape;
93   bool _needs_reshape;
94 };
95 } // namespace arm_compute
96
97 #endif // __ARM_COMPUTE_CL_FULLY_CONNECTED_RESHAPING_LAYER_H__