2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "fused_conv_eltwise_kernel_mmad_32x32sg_224x128wg_slm_int8.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector {
22 static const size_t _SG_TILE_M = 32;
23 static const size_t _SG_TILE_N = 32;
24 static const size_t _SG_SIZE = 8; // sub group size
25 static const size_t _TILES_PER_SG_X = 1; // Persistent threads
26 static const size_t _TILES_PER_SG_Y = 1; // Persistent threads
28 ParamsKey fused_conv_eltwise_kernel_mmad_32x32sg_224x128wg_slm_int8::GetSupportedKey() const
31 k.EnableInputDataType(Datatype::INT8);
32 k.EnableOutputDataType(Datatype::INT8);
33 k.EnableInputWeightsType(WeightsType::INT8);
34 k.EnableInputLayout(DataLayout::fs_bs_yx_bsv4_fsv32);
35 k.EnableOutputLayout(DataLayout::fs_bs_yx_bsv4_fsv32);
36 k.EnableTensorOffset();
37 k.EnableTensorPitches();
38 k.EnableBiasPerFeature();
40 k.EnableFusedConvEltwInt8Quantization();
41 k.EnableFusedConvEltwOutputCalibration();
43 k.EnableFusedConvEltwiseRWOutOpt();
47 bool fused_conv_eltwise_kernel_mmad_32x32sg_224x128wg_slm_int8::Validate(const Params& p, const optional_params& o) const
49 if (!fused_conv_eltwise_kernel_base::Validate(p, o) ||
50 !FusedConvolutionEltwiseCheckInput(p, o))
55 const convolution_params& cp = static_cast<const convolution_params&>(p);
57 // make sure it's 1x1 conv
58 if (cp.filterSize.x != 1 || cp.filterSize.y != 1)
61 // make sure stride is 1x1
62 if (cp.stride.x != 1 || cp.stride.y != 1)
65 // input padding not supported
66 if (cp.inputs[0].X().pad.Total() != 0 ||
67 cp.inputs[0].Y().pad.Total() != 0 ||
68 cp.inputs[0].Feature().pad.Total() != 0 ||
69 cp.inputs[0].Batch().pad.Total() != 0)
72 // input and output spatial sizes must match
73 if (!(cp.output.X().v == cp.inputs[0].X().v) || !(cp.output.Y().v == cp.inputs[0].Y().v))
76 const auto m = cp.output.X().v * cp.output.Y().v * cp.output.Batch().v ;
77 const auto k = cp.inputs[0].Feature().v;
78 const auto n = cp.output.Feature().v ;
80 if (m % 32 != 0 && m % 224 != 0) // Matrix size M, Must be mutliple of 32 and multiple of WG_TILE_M=128
83 if (k % 32 != 0) // Matrix size K, Must be mutliple of 32
86 if (n % 32 != 0 && n % 128 != 0) // Matrix size N, Must be mutliple of 32 and multiple of WG_TILE_N=128
93 fused_conv_eltwise_kernel_base::DispatchData fused_conv_eltwise_kernel_mmad_32x32sg_224x128wg_slm_int8::SetDefault(const fused_conv_eltwise_params& arg, int) const
95 DispatchData runInfo = fused_conv_eltwise_kernel_base::SetDefault(arg);
97 runInfo.effiency = FORCE_PRIORITY_1;
99 size_t mat_m = arg.output.X().v * arg.output.Y().v * arg.output.Batch().v;
100 size_t mat_n = arg.output.Feature().v;
102 size_t _MATRIX_M = mat_m;
103 size_t _MATRIX_N = mat_n;
105 size_t _WG_TILE_M = 224;
106 size_t _WG_TILE_N = 128;
108 // Calculate number of threads needed
109 const size_t threadsX = (_MATRIX_N / (_SG_TILE_N / _SG_SIZE)) / _TILES_PER_SG_X;
110 const size_t threadsY = (_MATRIX_M / _SG_TILE_M) / _TILES_PER_SG_Y ;
112 // Define execution setup for kernel:
113 size_t globalWorkSize[3] = { threadsX, threadsY, 1 };
114 size_t localWorkSize[3] = { _SG_SIZE * _WG_TILE_N / _SG_TILE_N, _WG_TILE_M / _SG_TILE_M, 1 };
116 runInfo.gws0 = globalWorkSize[0];
117 runInfo.gws1 = globalWorkSize[1];
118 runInfo.gws2 = globalWorkSize[2];
120 runInfo.lws0 = localWorkSize[0];
121 runInfo.lws1 = localWorkSize[1];
122 runInfo.lws2 = localWorkSize[2];
127 JitConstants fused_conv_eltwise_kernel_mmad_32x32sg_224x128wg_slm_int8::GetJitConstants(const fused_conv_eltwise_params& params, const DispatchData& runInfo) const
129 auto jit = Parent::GetJitConstants(params, runInfo);
131 jit.AddConstant(MakeJitConstant("WG_TILE_M", 224)); // Work-Group tile size M, Must be mutliple of 32
132 jit.AddConstant(MakeJitConstant("WG_TILE_N", 128)); // Work-Group tile size N, Must be mutliple of 32
133 jit.AddConstant(MakeJitConstant("TILES_PER_SG_X", _TILES_PER_SG_X));
134 jit.AddConstant(MakeJitConstant("TILES_PER_SG_Y", _TILES_PER_SG_Y));
136 // Do not change values below
137 jit.AddConstant(MakeJitConstant("DIM_X", 0));
138 jit.AddConstant(MakeJitConstant("DIM_Y", 1));
139 jit.AddConstant(MakeJitConstant("MATRIX_SMALL_K", 32));
140 jit.AddConstant(MakeJitConstant("MATRIX_SMALL_K_BFLOAT", 16));
141 jit.AddConstant(MakeJitConstant("SG_TILE_M", _SG_TILE_M));
142 jit.AddConstant(MakeJitConstant("SG_TILE_N", _SG_TILE_N));
143 jit.AddConstant(MakeJitConstant("SG_SIZE", _SG_SIZE));
144 jit.AddConstant(MakeJitConstant("SIMD_LANE_M", "SG_TILE_M"));
145 jit.AddConstant(MakeJitConstant("SIMD_LANE_N", "(SG_TILE_N / SG_SIZE)"));
146 jit.AddConstant(MakeJitConstant("WG_SIZE", "(SG_SIZE * WG_TILE_N / SG_TILE_N) * (WG_TILE_M / SG_TILE_M)"));
148 jit.AddConstant(MakeJitConstant("COMPILE_KERNELS", ""));
149 jit.AddConstant(MakeJitConstant("TILED_GLOBAL_LAYOUT", ""));
150 jit.AddConstant(MakeJitConstant("OUTPUT_TILED_GLOBAL_LAYOUT", ""));
152 const auto& input = params.inputs[0];
153 const auto& output = params.output;
155 auto m = output.X().v * output.Y().v * output.Batch().v;
156 auto k = input.Feature().v;
157 auto n = output.Feature().v;
159 jit.AddConstant(MakeJitConstant("MATRIX_M", m)); // Matrix size M, Must be mutliple of 32 and multiple of WG_TILE_M
160 jit.AddConstant(MakeJitConstant("MATRIX_K", k)); // Matrix size K, Must be mutliple of 32
161 jit.AddConstant(MakeJitConstant("MATRIX_N", n)); // Matrix size N, Must be mutliple of 32 and multiple of WG_TILE_N
163 const size_t out_x_pitch = 32 * 4;
164 const size_t out_y_pitch = 32 * 4 * params.output.X().LogicalDimPadded();
165 const size_t out_b_block_pitch = out_y_pitch * params.output.Y().LogicalDimPadded();
166 const size_t out_f_block_pitch = out_b_block_pitch * ((params.output.Batch().v + 3) / 4);
167 const size_t out_offset = out_x_pitch * params.output.X().pad.before + out_y_pitch * params.output.Y().pad.before;
169 jit.AddConstant(MakeJitConstant("OUT_X_PITCH", out_x_pitch));
170 jit.AddConstant(MakeJitConstant("OUT_Y_PITCH", out_y_pitch));
171 jit.AddConstant(MakeJitConstant("OUT_B_BLOCK_PITCH", out_b_block_pitch));
172 jit.AddConstant(MakeJitConstant("OUT_F_BLOCK_PITCH", out_f_block_pitch));
173 jit.AddConstant(MakeJitConstant("OUT_OFFSET", out_offset));
175 bool out_padding = output.X().pad.Total() != 0 || output.Y().pad.Total() != 0;
176 jit.AddConstant(MakeJitConstant("OUT_WITH_PADDING", out_padding));
178 bool eltw_padding = false;
179 if (!params.second_input_in_output)
182 const size_t in2_x_pitch = 32 * 4;
183 const size_t in2_y_pitch = 32 * 4 * params.inputs[1].X().LogicalDimPadded();
184 const size_t in2_b_block_pitch = in2_y_pitch * params.inputs[1].Y().LogicalDimPadded();
185 const size_t in2_f_block_pitch = in2_b_block_pitch * ((params.inputs[1].Batch().v + 3) / 4);
186 const size_t in2_offset = in2_x_pitch * params.inputs[1].X().pad.before + in2_y_pitch * params.inputs[1].Y().pad.before;
188 jit.AddConstant(MakeJitConstant("IN2_X_PITCH", in2_x_pitch));
189 jit.AddConstant(MakeJitConstant("IN2_Y_PITCH", in2_y_pitch));
190 jit.AddConstant(MakeJitConstant("IN2_B_BLOCK_PITCH", in2_b_block_pitch));
191 jit.AddConstant(MakeJitConstant("IN2_F_BLOCK_PITCH", in2_f_block_pitch));
192 jit.AddConstant(MakeJitConstant("IN2_OFFSET", in2_offset));
194 eltw_padding = params.inputs[1].X().pad.Total() != 0 || params.inputs[1].Y().pad.Total() != 0;;
198 eltw_padding = out_padding;
201 jit.AddConstant(MakeJitConstant("ELTW_WITH_PADDING", eltw_padding));
203 if (!params.eltw.stride.empty())
205 jit.AddConstant(MakeJitConstant("ELTW_STRIDE_X", params.eltw.stride[0].x));
206 jit.AddConstant(MakeJitConstant("ELTW_STRIDE_Y", params.eltw.stride[0].y));
210 jit.AddConstant(MakeJitConstant("ELTW_STRIDE_X", 1));
211 jit.AddConstant(MakeJitConstant("ELTW_STRIDE_Y", 1));
217 KernelsData fused_conv_eltwise_kernel_mmad_32x32sg_224x128wg_slm_int8::GetKernelsData(const Params& params, const optional_params& options) const
219 KernelsData kd = GetCommonKernelsData(params, options);
221 kd[0].estimatedTime = FORCE_PRIORITY_1; //_3