2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
18 #include "convolution_kernel_bfyx_gemm_like.h"
19 #include "kernel_selector_utils.h"
20 #include "common_tools.h"
22 namespace kernel_selector
25 ParamsKey ConvolutionKernel_bfyx_GEMMLike::GetSupportedKey() const
28 k.EnableInputDataType(Datatype::F16);
29 k.EnableInputDataType(Datatype::F32);
30 k.EnableInputWeightsType(WeightsType::F16);
31 k.EnableInputWeightsType(WeightsType::F32);
32 k.EnableOutputDataType(Datatype::F16);
33 k.EnableOutputDataType(Datatype::F32);
34 k.EnableInputLayout(DataLayout::bfyx);
35 k.EnableOutputLayout(DataLayout::bfyx);
36 k.EnableTensorOffset();
37 k.EnableTensorPitches();
39 //k.EnableSubGroupShort(); // we need it for FP16 only. we check it on the Validate phase
40 k.EnableBiasPerFeature();
41 k.EnableNonBiasTerm();
43 k.EnableSplitSupport();
47 std::string ConvolutionKernel_bfyx_GEMMLike::GetKernelName(const convolution_params& params) const
49 if (params.inputs[0].GetDType() == Datatype::F32)
51 return kernelName + "_fp32";
55 return kernelName + "_fp16";
59 JitConstants ConvolutionKernel_bfyx_GEMMLike::GetJitConstants(const convolution_params& params, const DispatchData& runInfo) const
61 JitConstants jit = Parent::GetJitConstants(params, runInfo);
64 MakeJitConstant("ALIGNED_OFM", RoundUp(params.output.Feature().v, runInfo.gemmStyle.subBlockDimN)),
65 MakeJitConstant("DX", runInfo.gemmStyle.globalWorkSizeDX),
66 MakeJitConstant("DY", runInfo.gemmStyle.globalWorkSizeDY),
67 MakeJitConstant("FILTER_SIZE_X_DIV2", params.filterSize.x / 2),
68 MakeJitConstant("INPUT_BUFFER_WIDTH_PADDED", ""), // TODO: enable non padding path again
69 MakeJitConstant("INPUT_BUFFER_HEIGHT_PADDED", ""),
72 if (CeilDiv(RoundUp(params.output.X().v * params.output.Y().v, runInfo.gemmStyle.subBlockDimM), runInfo.gemmStyle.globalWorkSizeDY) % runInfo.lws1 != 0)
73 jit.AddConstant(MakeJitConstant("LEFTOVERS", 1));
78 ConvolutionKernel_bfyx_GEMMLike::Parent::DispatchData ConvolutionKernel_bfyx_GEMMLike::SetDefault(const convolution_params& arg, int autoTuneIndex) const
80 DispatchData runInfo = Parent::SetDefault(arg, autoTuneIndex);
85 if (arg.inputs[0].GetDType() == Datatype::F16)
87 runInfo.gemmStyle = { 1, arg.filterSize.x, 32, 32, 1, 1 };
89 runInfo.effiency = FORCE_PRIORITY_6;
93 runInfo.gemmStyle = { 2, arg.filterSize.x, 32, 32, 2, 1 };
95 runInfo.effiency = FORCE_PRIORITY_8;
98 size_t sgemm_m = RoundUp(arg.output.X().v * arg.output.Y().v, runInfo.gemmStyle.subBlockDimM);
99 size_t sgemm_n = RoundUp(arg.output.Feature().v, runInfo.gemmStyle.subBlockDimN);
101 runInfo.gws0 = RoundUp(CeilDiv(sgemm_n, runInfo.gemmStyle.globalWorkSizeDX), runInfo.lws0);
102 runInfo.gws1 = RoundUp(CeilDiv(sgemm_m, runInfo.gemmStyle.globalWorkSizeDY), runInfo.lws1);
103 runInfo.gws2 = arg.output.Batch().v;
108 bool ConvolutionKernel_bfyx_GEMMLike::Validate(const Params& p, const optional_params& o) const
110 if (!Parent::Validate(p, o) ||
111 !CovolutionCheckInput(p, o))
116 const auto& params = static_cast<const convolution_params&>(p);
118 if (!params.engineInfo.bSubGroupShortSupport && params.inputs[0].GetDType() == Datatype::F16)
126 std::vector<WeightsLayout> ConvolutionKernel_bfyx_GEMMLike::GetSupportedWeightLayouts(const convolution_params& params) const
128 if (params.inputs[0].GetDType() == Datatype::F16)
130 return{ WeightsLayout::iy_xs_os_xsv2_osv16__ao32 };
134 return{ WeightsLayout::iy_xs_os_xsv2_osv8__ao32 };
138 KernelsData ConvolutionKernel_bfyx_GEMMLike::GetKernelsData(const Params& params, const optional_params& options) const
140 return GetCommonKernelsData(params, options, AGE_BASED);