2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "fused_conv_eltwise_kernel_gemm.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector {
22 ParamsKey fused_conv_eltwise_kernel_gemm::GetSupportedKey() const
25 k.EnableInputDataType(Datatype::F16);
26 k.EnableInputDataType(Datatype::F32);
27 k.EnableInputWeightsType(WeightsType::F16);
28 k.EnableInputWeightsType(WeightsType::F32);
29 k.EnableOutputDataType(Datatype::F16);
30 k.EnableOutputDataType(Datatype::F32);
31 k.EnableInputLayout(DataLayout::bfyx);
32 k.EnableOutputLayout(DataLayout::bfyx);
33 k.EnableTensorOffset();
34 k.EnableTensorPitches();
36 //k.EnableSubGroupShort(); // we need it for FP16 only. we check it on the Validate phase
37 k.EnableBiasPerFeature();
38 k.EnableNonBiasTerm();
40 k.EnableFusedConvEltwSplitSupport();
44 std::string fused_conv_eltwise_kernel_gemm::GetKernelName(const fused_conv_eltwise_params& params) const
46 if (params.inputs[0].GetDType() == Datatype::F32)
48 return kernelName + "_fp32";
52 return kernelName + "_fp16";
56 bool fused_conv_eltwise_kernel_gemm::Validate(const Params& p, const optional_params& o) const
58 if (!fused_conv_eltwise_kernel_base::Validate(p, o) ||
59 !FusedConvolutionEltwiseCheckInput(p, o))
64 const convolution_params& cp = static_cast<const convolution_params&>(p);
66 // make sure it's 1x1 conv
67 if (cp.filterSize.x != 1 || cp.filterSize.y != 1)
70 // make sure stride is 1x1
71 if (cp.stride.x != 1 || cp.stride.y != 1)
74 // input padding not supported
75 if (cp.inputs[0].X().pad.Total() != 0 ||
76 cp.inputs[0].Y().pad.Total() != 0 ||
77 cp.inputs[0].Feature().pad.Total() != 0 ||
78 cp.inputs[0].Batch().pad.Total() != 0)
81 // input and output spatial sizes must match
82 if (!(cp.output.X().v == cp.inputs[0].X().v) || !(cp.output.Y().v == cp.inputs[0].Y().v))
88 std::vector<WeightsLayout> fused_conv_eltwise_kernel_gemm::GetSupportedWeightLayouts(const fused_conv_eltwise_params& params) const
90 if (params.inputs[0].GetDType() == Datatype::F16)
92 return{ WeightsLayout::iy_xs_os_xsv2_osv16__ao32 };
96 return{ WeightsLayout::iy_xs_os_xsv2_osv8__ao32 };
100 fused_conv_eltwise_kernel_base::DispatchData fused_conv_eltwise_kernel_gemm::SetDefault(const fused_conv_eltwise_params& arg, int) const
102 DispatchData runInfo = Parent::SetDefault(arg);
107 if (arg.inputs[0].GetDType() == Datatype::F16)
109 runInfo.gemmStyle = { 1, arg.conv.filterSize.x, 32, 32, 1, 1 };
111 runInfo.effiency = FORCE_PRIORITY_6;
115 runInfo.gemmStyle = { 2, arg.conv.filterSize.x, 32, 32, 2, 1 };
117 runInfo.effiency = FORCE_PRIORITY_8;
120 size_t sgemm_m = RoundUp(arg.output.X().v * arg.output.Y().v, runInfo.gemmStyle.subBlockDimM);
121 size_t sgemm_n = RoundUp(arg.output.Feature().v, runInfo.gemmStyle.subBlockDimN);
123 runInfo.gws0 = RoundUp(CeilDiv(sgemm_n, runInfo.gemmStyle.globalWorkSizeDX), runInfo.lws0);
124 runInfo.gws1 = RoundUp(CeilDiv(sgemm_m, runInfo.gemmStyle.globalWorkSizeDY), runInfo.lws1);
125 runInfo.gws2 = arg.output.Batch().v;
130 JitConstants fused_conv_eltwise_kernel_gemm::GetJitConstants(const fused_conv_eltwise_params& params, const DispatchData& runInfo) const
132 auto jit = Parent::GetJitConstants(params, runInfo);
135 MakeJitConstant("ALIGNED_OFM", RoundUp(params.output.Feature().v, runInfo.gemmStyle.subBlockDimN)),
136 MakeJitConstant("DX", runInfo.gemmStyle.globalWorkSizeDX),
137 MakeJitConstant("DY", runInfo.gemmStyle.globalWorkSizeDY),
138 MakeJitConstant("FILTER_SIZE_X_DIV2", params.conv.filterSize.x / 2),
139 MakeJitConstant("INPUT_BUFFER_WIDTH_PADDED", ""), // TODO: enable non padding path again
140 MakeJitConstant("INPUT_BUFFER_HEIGHT_PADDED", ""),
143 if (CeilDiv(RoundUp(params.output.X().v * params.output.Y().v, runInfo.gemmStyle.subBlockDimM), runInfo.gemmStyle.globalWorkSizeDY) % runInfo.lws1 != 0)
144 jit.AddConstant(MakeJitConstant("LEFTOVERS", 1));
146 if (!params.eltw.stride.empty())
148 jit.AddConstant(MakeJitConstant("ELTW_STRIDE_X", params.eltw.stride[0].x));
149 jit.AddConstant(MakeJitConstant("ELTW_STRIDE_Y", params.eltw.stride[0].y));
153 jit.AddConstant(MakeJitConstant("ELTW_STRIDE_X", 1));
154 jit.AddConstant(MakeJitConstant("ELTW_STRIDE_Y", 1));
160 KernelsData fused_conv_eltwise_kernel_gemm::GetKernelsData(const Params& params, const optional_params& options) const
162 return GetTunedKernelsDataByIndex(params, options);