2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "convolution_kernel_bfzyx_f16.h"
18 #include "kernel_selector_utils.h"
21 namespace kernel_selector {
23 static const size_t sub_group_size = 16;
24 static const size_t feature_block_size = 16;
26 ParamsKey ConvolutionKernel_bfzyx_f16::GetSupportedKey() const {
28 k.EnableInputDataType(Datatype::F32);
29 k.EnableOutputDataType(Datatype::F32);
30 k.EnableInputDataType(Datatype::F16);
31 k.EnableOutputDataType(Datatype::F16);
32 k.EnableInputWeightsType(WeightsType::F32);
33 k.EnableInputWeightsType(WeightsType::F16);
34 k.EnableInputLayout(DataLayout::bfzyx_f16);
35 k.EnableOutputLayout(DataLayout::bfzyx_f16);
36 k.EnableTensorOffset();
37 k.EnableTensorPitches();
38 k.EnableBiasPerFeature();
39 k.EnableNonBiasTerm();
40 k.EnableSplitSupport();
43 k.EnableSubGroupShort();
47 ConvolutionKernelBase::DispatchData ConvolutionKernel_bfzyx_f16::SetDefault(const convolution_params& params,
48 int autoTuneIndex) const {
49 DispatchData kd = ConvolutionKernelBase::SetDefault(params, autoTuneIndex);
51 const auto& out = params.output;
56 auto f = out.Feature().v;
57 auto b = out.Batch().v;
67 auto ow_block = std::max(8, div);
77 kd.cldnnStyle.blockWidth = ow_block;
80 kd.gws1 = CeilDiv(y, oh_block) * CeilDiv(x, ow_block) * z;
81 kd.gws2 = b * (f / ocb);
83 kd.lws0 = sub_group_size;
88 kd.effiency = FORCE_PRIORITY_2;
90 kd.effiency = FORCE_PRIORITY_7;
95 bool ConvolutionKernel_bfzyx_f16::Validate(const Params& p, const optional_params& o) const {
96 if (!ConvolutionKernelBase::Validate(p, o) || !CovolutionCheckInput(p, o)) {
100 const auto& params = static_cast<const convolution_params&>(p);
102 const auto& input = params.inputs[0];
103 const auto& output = params.output;
105 if (output.GetDType() != use_data_type)
108 if (output.Feature().v % feature_block_size != 0)
111 if (input.Feature().v % feature_block_size != 0)
114 // Check that padding before features doesn't miss-align the blocks
115 if (input.Feature().pad.before % feature_block_size != 0 || output.Feature().pad.before % feature_block_size != 0) {
122 JitConstants ConvolutionKernel_bfzyx_f16::GetJitConstants(const convolution_params& params,
123 const DispatchData& runInfo) const {
124 auto input = params.inputs[0];
125 auto output = params.output;
126 auto jit = Parent::GetJitConstants(params, runInfo);
128 jit.AddConstant(MakeJitConstant("VER_8OW16C", 1));
129 jit.AddConstant(MakeJitConstant("OC_BLOCK", 16));
130 jit.AddConstant(MakeJitConstant("NCHW", 1));
131 jit.AddConstant(MakeJitConstant("CASE_3D", 1));
133 jit.AddConstant(MakeJitConstant("LWS_0", runInfo.lws0));
134 jit.AddConstant(MakeJitConstant("LWS_1", runInfo.lws1));
135 jit.AddConstant(MakeJitConstant("LWS_2", runInfo.lws2));
137 jit.AddConstant(MakeJitConstant("OCB", runInfo.gws0));
139 jit.AddConstant(MakeJitConstant("SUM_SCALE", 1));
141 auto blockWidth = runInfo.cldnnStyle.blockWidth;
142 // the conditional code below was replaced to fix security issue
143 // auto is_1stconv = false;
144 // auto mb_block =(is_1stconv && output.Batch().v % 16 == 0) ? 16 : 1;
145 // auto ic_block = (is_1stconv) ? 1 : 16;
149 jit.AddConstant(MakeJitConstant("MB_BLOCK", mb_block));
150 jit.AddConstant(MakeJitConstant("MB_LAST", (output.Batch().v / 16) * 16));
151 jit.AddConstant(MakeJitConstant("IC_BLOCK", ic_block));
152 jit.AddConstant(MakeJitConstant("OH_BLOCK", 1));
153 jit.AddConstant(MakeJitConstant("OW_BLOCK", blockWidth));
154 jit.AddConstant(MakeJitConstant("OW_LAST", (output.X().v / blockWidth) * blockWidth));
155 jit.AddConstant(MakeJitConstant("OWB", CeilDiv(output.X().v, blockWidth)));
156 jit.AddConstant(MakeJitConstant("OHB", CeilDiv(output.Y().v, 1)));
157 jit.AddConstant(MakeJitConstant("G", params.split));
158 jit.AddConstant(MakeJitConstant("DD", params.dilation.z - 1));
159 jit.AddConstant(MakeJitConstant("DH", params.dilation.y - 1));
160 jit.AddConstant(MakeJitConstant("DW", params.dilation.x - 1));
161 jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", sub_group_size));
162 jit.AddConstant(MakeJitConstant("FWD_DATA", 1));
163 jit.AddConstant(MakeJitConstant("IS_DW", "DEPTHWISE_SEPARABLE_OPT"));
164 jit.AddConstant(MakeJitConstant("WITH_BIAS", "BIAS_TERM"));
166 jit.AddConstant(MakeJitConstant("MB", "OUTPUT_BATCH_NUM"));
167 jit.AddConstant(MakeJitConstant("OC", "OUTPUT_FEATURE_NUM"));
168 jit.AddConstant(MakeJitConstant("OD", "OUTPUT_SIZE_Z"));
169 jit.AddConstant(MakeJitConstant("OH", "OUTPUT_SIZE_Y"));
170 jit.AddConstant(MakeJitConstant("OW", "OUTPUT_SIZE_X"));
171 jit.AddConstant(MakeJitConstant("IC", "INPUT0_FEATURE_NUM"));
172 jit.AddConstant(MakeJitConstant("ID", "INPUT0_SIZE_Z"));
173 jit.AddConstant(MakeJitConstant("IH", "INPUT0_SIZE_Y"));
174 jit.AddConstant(MakeJitConstant("IW", "INPUT0_SIZE_X"));
175 jit.AddConstant(MakeJitConstant("KD", "FILTER_SIZE_Z"));
176 jit.AddConstant(MakeJitConstant("KH", "FILTER_SIZE_Y"));
177 jit.AddConstant(MakeJitConstant("KW", "(FILTER_SIZE_X)"));
178 jit.AddConstant(MakeJitConstant("SD", "STRIDE_SIZE_Z"));
179 jit.AddConstant(MakeJitConstant("SH", "STRIDE_SIZE_Y"));
180 jit.AddConstant(MakeJitConstant("SW", "STRIDE_SIZE_X"));
181 jit.AddConstant(MakeJitConstant("PD", "PADDING_SIZE_Z"));
182 jit.AddConstant(MakeJitConstant("PH", "PADDING_SIZE_Y"));
183 jit.AddConstant(MakeJitConstant("PW", "PADDING_SIZE_X"));
184 jit.AddConstant(MakeJitConstant("PD_R", "PADDING_SIZE_Z"));
185 jit.AddConstant(MakeJitConstant("PH_R", "PADDING_SIZE_Y"));
186 jit.AddConstant(MakeJitConstant("PW_R", "PADDING_SIZE_X"));
191 KernelsData ConvolutionKernel_bfzyx_f16::GetKernelsData(const Params& params, const optional_params& options) const {
192 return GetTunedKernelsDataByIndex(params, options);
194 } // namespace kernel_selector