Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / convolution / convolution_kernel_bfyx_gemm_like.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include <cmath>
18 #include "convolution_kernel_bfyx_gemm_like.h"
19 #include "kernel_selector_utils.h"
20 #include "common_tools.h"
21
22 namespace kernel_selector 
23 {
24     
25     ParamsKey ConvolutionKernel_bfyx_GEMMLike::GetSupportedKey() const
26     {
27         ParamsKey k;
28         k.EnableInputDataType(Datatype::F16);
29         k.EnableInputDataType(Datatype::F32);
30         k.EnableInputWeightsType(WeightsType::F16);
31         k.EnableInputWeightsType(WeightsType::F32);
32         k.EnableOutputDataType(Datatype::F16);
33         k.EnableOutputDataType(Datatype::F32);
34         k.EnableInputLayout(DataLayout::bfyx);
35         k.EnableOutputLayout(DataLayout::bfyx);
36         k.EnableTensorOffset();
37         k.EnableTensorPitches();
38         k.EnableSubGroup();
39         //k.EnableSubGroupShort(); // we need it for FP16 only. we check it on the Validate phase
40         k.EnableBiasPerFeature();
41         k.EnableNonBiasTerm();
42         k.EnableBatching();
43         k.EnableSplitSupport();
44         return k;
45     }
46
47     std::string ConvolutionKernel_bfyx_GEMMLike::GetKernelName(const convolution_params& params) const
48     {
49         if (params.inputs[0].GetDType() == Datatype::F32)
50         {
51             return kernelName + "_fp32";
52         }
53         else
54         {
55             return kernelName + "_fp16";
56         }
57     }
58
59     JitConstants ConvolutionKernel_bfyx_GEMMLike::GetJitConstants(const convolution_params& params, const DispatchData& runInfo) const
60     {
61         JitConstants jit = Parent::GetJitConstants(params, runInfo);
62         
63         jit.AddConstants({
64             MakeJitConstant("ALIGNED_OFM",                  RoundUp(params.output.Feature().v, runInfo.gemmStyle.subBlockDimN)),
65             MakeJitConstant("DX",                           runInfo.gemmStyle.globalWorkSizeDX),
66             MakeJitConstant("DY",                           runInfo.gemmStyle.globalWorkSizeDY),
67             MakeJitConstant("FILTER_SIZE_X_DIV2",           params.filterSize.x / 2),
68             MakeJitConstant("INPUT_BUFFER_WIDTH_PADDED",    ""),    // TODO: enable non padding path again
69             MakeJitConstant("INPUT_BUFFER_HEIGHT_PADDED",   ""),
70         });
71
72         if (CeilDiv(RoundUp(params.output.X().v * params.output.Y().v, runInfo.gemmStyle.subBlockDimM), runInfo.gemmStyle.globalWorkSizeDY) % runInfo.lws1 != 0)
73             jit.AddConstant(MakeJitConstant("LEFTOVERS", 1));
74
75         return jit;
76     }
77
78     ConvolutionKernel_bfyx_GEMMLike::Parent::DispatchData ConvolutionKernel_bfyx_GEMMLike::SetDefault(const convolution_params& arg, int autoTuneIndex) const
79     {
80         DispatchData runInfo = Parent::SetDefault(arg, autoTuneIndex);
81
82         runInfo.lws0 = 1;
83         runInfo.lws2 = 1;
84
85         if (arg.inputs[0].GetDType() == Datatype::F16)
86         {
87             runInfo.gemmStyle = { 1, arg.filterSize.x, 32, 32, 1, 1 };
88             runInfo.lws1 = 16;
89             runInfo.effiency = FORCE_PRIORITY_6;
90         }
91         else
92         {
93             runInfo.gemmStyle = { 2, arg.filterSize.x, 32, 32, 2, 1 };
94             runInfo.lws1 = 8;
95             runInfo.effiency = FORCE_PRIORITY_8;
96         }
97
98         size_t sgemm_m = RoundUp(arg.output.X().v * arg.output.Y().v, runInfo.gemmStyle.subBlockDimM);
99         size_t sgemm_n = RoundUp(arg.output.Feature().v, runInfo.gemmStyle.subBlockDimN);
100
101         runInfo.gws0 = RoundUp(CeilDiv(sgemm_n, runInfo.gemmStyle.globalWorkSizeDX), runInfo.lws0);
102         runInfo.gws1 = RoundUp(CeilDiv(sgemm_m, runInfo.gemmStyle.globalWorkSizeDY), runInfo.lws1);
103         runInfo.gws2 = arg.output.Batch().v;
104
105         return runInfo;
106     }
107
108     bool ConvolutionKernel_bfyx_GEMMLike::Validate(const Params& p, const optional_params& o) const
109     {
110         if (!Parent::Validate(p, o) ||
111             !CovolutionCheckInput(p, o))
112         {
113             return false;
114         }
115
116         const auto& params = static_cast<const convolution_params&>(p);
117
118         if (!params.engineInfo.bSubGroupShortSupport && params.inputs[0].GetDType() == Datatype::F16)
119         {
120             return false;
121         }
122
123         return true;
124     }
125
126     std::vector<WeightsLayout> ConvolutionKernel_bfyx_GEMMLike::GetSupportedWeightLayouts(const convolution_params& params) const
127     {
128         if (params.inputs[0].GetDType() == Datatype::F16)
129         {
130             return{ WeightsLayout::iy_xs_os_xsv2_osv16__ao32 };
131         }
132         else
133         {
134             return{ WeightsLayout::iy_xs_os_xsv2_osv8__ao32 };
135         }
136     }
137
138     KernelsData ConvolutionKernel_bfyx_GEMMLike::GetKernelsData(const Params& params, const optional_params& options) const
139     {
140         return GetCommonKernelsData(params, options, AGE_BASED);
141     }
142 }