Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / convolution / convolution_kernel_bfyx_1x1_gemm_buf.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "convolution_kernel_bfyx_1x1_gemm_buf.h"
18
19 namespace kernel_selector {
20     
21     ParamsKey ConvolutionKernel_bfyx_1x1_gemm_buf::GetSupportedKey() const
22     {
23         ParamsKey k;
24         k.EnableInputDataType(Datatype::F16);
25         k.EnableOutputDataType(Datatype::F16);
26         k.EnableInputWeightsType(WeightsType::F16);
27         k.EnableInputWeightsType(WeightsType::F32);
28         k.EnableInputLayout(DataLayout::byxf);
29         k.EnableOutputLayout(DataLayout::byxf);
30         k.EnableBiasPerFeature();
31         k.EnableNonBiasTerm();
32         k.EnableBatching();
33         return k;
34     }
35
36     ConvolutionKernelBase::DispatchData ConvolutionKernel_bfyx_1x1_gemm_buf::SetDefault(const convolution_params& params, int) const
37     {
38         DispatchData kd = ConvolutionKernelBase::SetDefault(params);
39
40         const auto& out = params.output;
41
42         auto x = out.X().v;
43         auto y = out.Y().v;
44         auto f = out.Feature().v;
45         auto b = out.Batch().v;
46
47         kd.gws0 = Align(f, 16);
48         kd.gws1 = static_cast<size_t>(std::ceil(x*y / 16.0f));
49         kd.gws2 = b;
50
51         kd.lws0 = 16;
52         kd.lws1 = 1;
53         kd.lws2 = 1;
54
55         kd.effiency = FORCE_PRIORITY_1;
56
57         return kd;
58     }
59
60     bool ConvolutionKernel_bfyx_1x1_gemm_buf::Validate(const Params& p, const optional_params& o) const
61     {
62         if (!ConvolutionKernelBase::Validate(p, o))
63         {
64             return false;
65         }
66
67         const auto& params = static_cast<const convolution_params&>(p);
68
69         const auto &input = params.inputs[0];
70
71         const bool bPad = input.X().pad.Total() != 0 || input.Y().pad.Total() != 0 || input.Feature().pad.Total() != 0 || input.Batch().pad.Total() != 0;
72         const bool bFilterSize = params.filterSize.x != 1 || params.filterSize.y != 1;
73         const bool bStride = params.stride.x != 1 || params.stride.y != 1;
74
75         if(bPad || bFilterSize || bStride)
76         {
77             return false;
78         }
79
80         return true;
81     }
82
83     JitConstants ConvolutionKernel_bfyx_1x1_gemm_buf::GetJitConstants(const convolution_params& params, const DispatchData& runInfo) const
84     {
85         auto jit = Parent::GetJitConstants(params, runInfo);
86
87         const auto& out = params.output;
88         const auto& input = params.inputs[0];
89
90         auto x = out.X().v;
91         auto y = out.Y().v;
92
93         auto num_whole_groups_y = x*y / (16);
94         auto num_whole_subgroups_y = (x*y - num_whole_groups_y*16) / 16;
95         auto last_local_y = x*y - (num_whole_groups_y + num_whole_subgroups_y)*16;
96
97         jit.AddConstant(MakeJitConstant("TX", 16));
98         jit.AddConstant(MakeJitConstant("TY", 1));
99         jit.AddConstant(MakeJitConstant("M", x*y));
100         jit.AddConstant(MakeJitConstant("K", input.Feature().v));
101         jit.AddConstant(MakeJitConstant("N", out.Feature().v));
102         jit.AddConstant(MakeJitConstant("TILE_M", 16));
103         jit.AddConstant(MakeJitConstant("TILE_N", 16));
104         jit.AddConstant(MakeJitConstant("K8", (input.Feature().v >> 3)));
105         jit.AddConstant(MakeJitConstant("NUM_WHOLE_GROUPS_Y", num_whole_groups_y));
106         jit.AddConstant(MakeJitConstant("NUM_WHOLE_SUBGROUPS_Y", num_whole_subgroups_y));
107         jit.AddConstant(MakeJitConstant("LAST_LOCAL_Y", last_local_y));
108
109         return jit;
110     }
111
112     KernelsData ConvolutionKernel_bfyx_1x1_gemm_buf::GetKernelsData(const Params& params, const optional_params& options) const
113     {
114         return GetTunedKernelsDataByIndex(params, options);
115     }
116 }