2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "convolution_kernel_bfyx_gemm_like.h"
19 namespace kernel_selector
22 ParamsKey ConvolutionKernel_bfyx_GEMMLike::GetSupportedKey() const
25 k.EnableInputDataType(Datatype::F16);
26 k.EnableInputDataType(Datatype::F32);
27 k.EnableInputWeightsType(WeightsType::F16);
28 k.EnableInputWeightsType(WeightsType::F32);
29 k.EnableOutputDataType(Datatype::F16);
30 k.EnableOutputDataType(Datatype::F32);
31 k.EnableInputLayout(DataLayout::bfyx);
32 k.EnableOutputLayout(DataLayout::bfyx);
33 k.EnableTensorOffset();
34 k.EnableTensorPitches();
36 //k.EnableSubGroupShort(); // we need it for FP16 only. we check it on the Validate phase
37 k.EnableBiasPerFeature();
38 k.EnableNonBiasTerm();
40 k.EnableSplitSupport();
44 std::string ConvolutionKernel_bfyx_GEMMLike::GetKernelName(const convolution_params& params) const
46 if (params.inputs[0].GetDType() == Datatype::F32)
48 return kernelName + "_fp32";
52 return kernelName + "_fp16";
56 JitConstants ConvolutionKernel_bfyx_GEMMLike::GetJitConstants(const convolution_params& params, const DispatchData& runInfo) const
58 JitConstants jit = Parent::GetJitConstants(params, runInfo);
61 MakeJitConstant("ALIGNED_OFM", RoundUp(params.output.Feature().v, runInfo.gemmStyle.subBlockDimN)),
62 MakeJitConstant("DX", runInfo.gemmStyle.globalWorkSizeDX),
63 MakeJitConstant("DY", runInfo.gemmStyle.globalWorkSizeDY),
64 MakeJitConstant("FILTER_SIZE_X_DIV2", params.filterSize.x / 2),
65 MakeJitConstant("INPUT_BUFFER_WIDTH_PADDED", ""), // TODO: enable non padding path again
66 MakeJitConstant("INPUT_BUFFER_HEIGHT_PADDED", ""),
69 if (CeilDiv(RoundUp(params.output.X().v * params.output.Y().v, runInfo.gemmStyle.subBlockDimM), runInfo.gemmStyle.globalWorkSizeDY) % runInfo.lws1 != 0)
70 jit.AddConstant(MakeJitConstant("LEFTOVERS", 1));
75 ConvolutionKernel_bfyx_GEMMLike::Parent::DispatchData ConvolutionKernel_bfyx_GEMMLike::SetDefault(const convolution_params& arg, int autoTuneIndex) const
77 DispatchData runInfo = Parent::SetDefault(arg, autoTuneIndex);
82 if (arg.inputs[0].GetDType() == Datatype::F16)
84 runInfo.gemmStyle = { 1, arg.filterSize.x, 32, 32, 1, 1 };
86 runInfo.effiency = FORCE_PRIORITY_6;
90 runInfo.gemmStyle = { 2, arg.filterSize.x, 32, 32, 2, 1 };
92 runInfo.effiency = FORCE_PRIORITY_8;
95 size_t sgemm_m = RoundUp(arg.output.X().v * arg.output.Y().v, runInfo.gemmStyle.subBlockDimM);
96 size_t sgemm_n = RoundUp(arg.output.Feature().v, runInfo.gemmStyle.subBlockDimN);
98 runInfo.gws0 = RoundUp(CeilDiv(sgemm_n, runInfo.gemmStyle.globalWorkSizeDX), runInfo.lws0);
99 runInfo.gws1 = RoundUp(CeilDiv(sgemm_m, runInfo.gemmStyle.globalWorkSizeDY), runInfo.lws1);
100 runInfo.gws2 = arg.output.Batch().v;
105 bool ConvolutionKernel_bfyx_GEMMLike::Validate(const Params& p, const optional_params& o) const
107 if (!Parent::Validate(p, o) ||
108 !CovolutionCheckInput(p, o))
113 const auto& params = static_cast<const convolution_params&>(p);
115 if (!params.engineInfo.bSubGroupShortSupport && params.inputs[0].GetDType() == Datatype::F16)
123 std::vector<WeightsLayout> ConvolutionKernel_bfyx_GEMMLike::GetSupportedWeightLayouts(const convolution_params& params) const
125 if (params.inputs[0].GetDType() == Datatype::F16)
127 return{ WeightsLayout::iy_xs_os_xsv2_osv16__ao32 };
131 return{ WeightsLayout::iy_xs_os_xsv2_osv8__ao32 };
135 KernelsData ConvolutionKernel_bfyx_GEMMLike::GetKernelsData(const Params& params, const optional_params& options) const
137 return GetTunedKernelsDataByIndex(params, options);