2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "convolution_kernel_bfyx_1x1_gemm_buf.h"
19 namespace kernel_selector {
21 ParamsKey ConvolutionKernel_bfyx_1x1_gemm_buf::GetSupportedKey() const
24 k.EnableInputDataType(Datatype::F16);
25 k.EnableOutputDataType(Datatype::F16);
26 k.EnableInputWeightsType(WeightsType::F16);
27 k.EnableInputWeightsType(WeightsType::F32);
28 k.EnableInputLayout(DataLayout::byxf);
29 k.EnableOutputLayout(DataLayout::byxf);
30 k.EnableBiasPerFeature();
31 k.EnableNonBiasTerm();
36 ConvolutionKernelBase::DispatchData ConvolutionKernel_bfyx_1x1_gemm_buf::SetDefault(const convolution_params& params, int) const
38 DispatchData kd = ConvolutionKernelBase::SetDefault(params);
40 const auto& out = params.output;
44 auto f = out.Feature().v;
45 auto b = out.Batch().v;
47 kd.gws0 = Align(f, 16);
48 kd.gws1 = static_cast<size_t>(std::ceil(x*y / 16.0f));
55 kd.effiency = FORCE_PRIORITY_1;
60 bool ConvolutionKernel_bfyx_1x1_gemm_buf::Validate(const Params& p, const optional_params& o) const
62 if (!ConvolutionKernelBase::Validate(p, o))
67 const auto& params = static_cast<const convolution_params&>(p);
69 const auto &input = params.inputs[0];
71 const bool bPad = input.X().pad.Total() != 0 || input.Y().pad.Total() != 0 || input.Feature().pad.Total() != 0 || input.Batch().pad.Total() != 0;
72 const bool bFilterSize = params.filterSize.x != 1 || params.filterSize.y != 1;
73 const bool bStride = params.stride.x != 1 || params.stride.y != 1;
75 if(bPad || bFilterSize || bStride)
83 JitConstants ConvolutionKernel_bfyx_1x1_gemm_buf::GetJitConstants(const convolution_params& params, const DispatchData& runInfo) const
85 auto jit = Parent::GetJitConstants(params, runInfo);
87 const auto& out = params.output;
88 const auto& input = params.inputs[0];
93 auto num_whole_groups_y = x*y / (16);
94 auto num_whole_subgroups_y = (x*y - num_whole_groups_y*16) / 16;
95 auto last_local_y = x*y - (num_whole_groups_y + num_whole_subgroups_y)*16;
97 jit.AddConstant(MakeJitConstant("TX", 16));
98 jit.AddConstant(MakeJitConstant("TY", 1));
99 jit.AddConstant(MakeJitConstant("M", x*y));
100 jit.AddConstant(MakeJitConstant("K", input.Feature().v));
101 jit.AddConstant(MakeJitConstant("N", out.Feature().v));
102 jit.AddConstant(MakeJitConstant("TILE_M", 16));
103 jit.AddConstant(MakeJitConstant("TILE_N", 16));
104 jit.AddConstant(MakeJitConstant("K8", (input.Feature().v >> 3)));
105 jit.AddConstant(MakeJitConstant("NUM_WHOLE_GROUPS_Y", num_whole_groups_y));
106 jit.AddConstant(MakeJitConstant("NUM_WHOLE_SUBGROUPS_Y", num_whole_subgroups_y));
107 jit.AddConstant(MakeJitConstant("LAST_LOCAL_Y", last_local_y));
112 KernelsData ConvolutionKernel_bfyx_1x1_gemm_buf::GetKernelsData(const Params& params, const optional_params& options) const
114 return GetTunedKernelsDataByIndex(params, options);