2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2017-2020 ARM Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 #include "arm_compute/runtime/CL/functions/CLTransposeConvLayer.h"
42 #include "arm_compute/core/Utils.h"
43 #include "arm_compute/core/Validate.h"
44 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
45 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
46 #include "arm_compute/runtime/CL/CLScheduler.h"
52 using namespace arm_compute;
53 using namespace arm_compute::misc::shape_calculator;
55 CLTransposeConvLayer::CLTransposeConvLayer(std::shared_ptr<IMemoryManager> memory_manager)
56 : _memory_manager(std::move(memory_manager)), _function()
60 void CLTransposeConvLayer::configure(ICLTensor *input, ICLTensor *weights, const ICLTensor *bias,
61 ICLTensor *output, const PadStrideInfo &deconv_info,
62 unsigned int invalid_right, unsigned int invalid_bottom,
63 const WeightsInfo &weights_info)
65 configure(CLKernelLibrary::get().get_compile_context(), input, weights, bias, output, deconv_info,
66 invalid_right, invalid_bottom, weights_info);
69 void CLTransposeConvLayer::configure(const CLCompileContext &compile_context, ICLTensor *input,
70 ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
71 const PadStrideInfo &deconv_info, unsigned int invalid_right,
72 unsigned int invalid_bottom, const WeightsInfo &weights_info)
74 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
76 switch (CLTransposeConvLayer::get_deconvolution_method(input->info(), weights->info(), nullptr,
77 output->info(), deconv_info, invalid_right,
78 invalid_bottom, weights_info))
80 case DeconvolutionMethod::DIRECT:
82 auto f = arm_compute::support::cpp14::make_unique<CLDirectTransposeConvLayer>();
83 f->configure(compile_context, input, weights, bias, output, deconv_info, invalid_right,
84 invalid_bottom, weights_info);
85 _function = std::move(f);
88 case DeconvolutionMethod::GEMM:
90 auto f = arm_compute::support::cpp14::make_unique<CLGEMMDeconvolutionLayer>(_memory_manager);
91 f->configure(compile_context, input, weights, bias, output, deconv_info);
92 _function = std::move(f);
96 ARM_COMPUTE_ERROR("Not supported.");
101 Status CLTransposeConvLayer::validate(const ITensorInfo *input, const ITensorInfo *weights,
102 const ITensorInfo *bias, ITensorInfo *output,
103 const PadStrideInfo &deconv_info, unsigned int invalid_right,
104 unsigned int invalid_bottom, const WeightsInfo &weights_info)
106 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
107 switch (CLTransposeConvLayer::get_deconvolution_method(
108 input, weights, bias, output, deconv_info, invalid_right, invalid_bottom, weights_info))
110 case DeconvolutionMethod::DIRECT:
112 // Validate direct convolution layer
113 ARM_COMPUTE_RETURN_ON_ERROR(CLDirectTransposeConvLayer::validate(
114 input, weights, bias, output, deconv_info, invalid_right, invalid_bottom, weights_info));
117 case DeconvolutionMethod::GEMM:
119 // Validate gemm-based convolution layer
120 ARM_COMPUTE_RETURN_ON_ERROR(
121 CLGEMMDeconvolutionLayer::validate(input, weights, bias, output, deconv_info));
125 ARM_COMPUTE_ERROR("Not supported.");
132 DeconvolutionMethod CLTransposeConvLayer::get_deconvolution_method(
133 const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias,
134 ITensorInfo *output, const PadStrideInfo &deconv_info, unsigned int invalid_right,
135 unsigned int invalid_bottom, const WeightsInfo &weights_info)
137 ARM_COMPUTE_UNUSED(output, bias, weights_info);
139 const DataLayout data_layout = input->data_layout();
141 const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
142 const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
144 if (weights->dimension(idx_w) != deconv_info.stride().first ||
145 weights->dimension(idx_h) != deconv_info.stride().second || invalid_right != 0 ||
148 return DeconvolutionMethod::DIRECT;
151 return DeconvolutionMethod::GEMM;
154 void CLTransposeConvLayer::run()
160 void CLTransposeConvLayer::prepare() { _function->prepare(); }