Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compute / ARMComputeEx / src / runtime / CL / functions / CLFullyConnectedReshapingLayer.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "arm_compute/runtime/CL/functions/CLFullyConnectedReshapingLayer.h"
18
19 #include <arm_compute/runtime/CL/functions/CLFullyConnectedHybridLayer.h>
20 #include <arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h>
21 #include <arm_compute/runtime/CL/functions/CLFullyConnectedLayerEx.h>
22
23 using namespace arm_compute;
24
25 void CLFullyConnectedReshapingLayer::configure(const arm_compute::ICLTensor *input,
26                                                const arm_compute::ICLTensor *weights,
27                                                const arm_compute::ICLTensor *biases,
28                                                arm_compute::ICLTensor *output, bool needs_reshape,
29                                                const arm_compute::TensorShape &reshape,
30                                                KernelType kernel_type)
31 {
32   _input = input;
33   _weights = weights;
34   _biases = biases;
35   _output = output;
36   _needs_reshape = needs_reshape;
37
38   const ICLTensor *input_to_use = input;
39   if (_needs_reshape)
40   {
41     // reshape
42     auto_init_if_empty(*_cl_buffer.info(),
43                        _input->info()->clone()->set_tensor_shape(reshape).set_data_layout(
44                            _input->info()->data_layout()));
45     _cl_reshape.configure(_input, &_cl_buffer);
46     input_to_use = &_cl_buffer;
47   }
48
49   _cl_fc = [&]() {
50     if (kernel_type == KernelType::GENERAL)
51     {
52       auto fc = new arm_compute::CLFullyConnectedLayerEx{_memory_manager};
53       fc->configure(input_to_use, _weights, _biases, _output);
54       return std::unique_ptr<arm_compute::IFunction>(fc);
55     }
56     else if (kernel_type == KernelType::PREPROCESSED_WEIGHTS)
57     {
58       bool is_hybrid = (input->info()->data_type() == DataType::F32 ||
59                         input->info()->data_type() == DataType::F16) &&
60                        (weights->info()->data_type() == DataType::S8 ||
61                         weights->info()->data_type() == DataType::QASYMM8_SIGNED);
62
63       if (is_hybrid)
64       {
65         auto fc = new arm_compute::CLFullyConnectedHybridLayer{_memory_manager};
66         ITensorInfo *weights_info = const_cast<ITensorInfo *>(_weights->info());
67         const auto orgin_weights_data_type = weights_info->data_type();
68         weights_info->set_data_type(DataType::QASYMM8_SIGNED);
69         fc->configure(input_to_use, _weights, _biases, _output);
70         weights_info->set_data_type(orgin_weights_data_type);
71         return std::unique_ptr<arm_compute::IFunction>(fc);
72       }
73       else
74       {
75         auto fc = new arm_compute::CLFullyConnectedLayer{_memory_manager};
76         fc->configure(input_to_use, _weights, _biases, _output);
77         return std::unique_ptr<arm_compute::IFunction>(fc);
78       }
79     }
80     else
81     {
82       throw std::runtime_error("CLFullyConnectedReshapingLayer: Unsupported kernel type");
83     }
84
85   }();
86
87   if (_needs_reshape)
88   {
89     // NOTE _cl_buffer is inaccessible from outside, and thus it is safe to invoke allocate here.
90     _cl_buffer.allocator()->allocate();
91   }
92 }
93
94 void CLFullyConnectedReshapingLayer::run(void)
95 {
96   if (_needs_reshape)
97     _cl_reshape.run();
98
99   _cl_fc->run();
100 }
101
102 void CLFullyConnectedReshapingLayer::prepare(void) { _cl_fc->prepare(); }