Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / fused_conv_eltwise / fused_conv_eltwise_kernel_gemm.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "fused_conv_eltwise_kernel_gemm.h"
18 #include "kernel_selector_utils.h"
19
20 namespace kernel_selector {
21
22     ParamsKey fused_conv_eltwise_kernel_gemm::GetSupportedKey() const
23         {
24         ParamsKey k;
25         k.EnableInputDataType(Datatype::F16);
26         k.EnableInputDataType(Datatype::F32);
27         k.EnableInputWeightsType(WeightsType::F16);
28         k.EnableInputWeightsType(WeightsType::F32);
29         k.EnableOutputDataType(Datatype::F16);
30         k.EnableOutputDataType(Datatype::F32);
31         k.EnableInputLayout(DataLayout::bfyx);
32         k.EnableOutputLayout(DataLayout::bfyx);
33         k.EnableTensorOffset();
34         k.EnableTensorPitches();
35         k.EnableSubGroup();
36         //k.EnableSubGroupShort(); // we need it for FP16 only. we check it on the Validate phase
37         k.EnableBiasPerFeature();
38         k.EnableNonBiasTerm();
39         k.EnableBatching();
40         k.EnableFusedConvEltwSplitSupport();
41         return k;
42         }
43
44     std::string fused_conv_eltwise_kernel_gemm::GetKernelName(const fused_conv_eltwise_params& params) const
45     {
46         if (params.inputs[0].GetDType() == Datatype::F32)
47         {
48             return kernelName + "_fp32";
49         }
50         else
51         {
52             return kernelName + "_fp16";
53         }
54     }
55
56         bool fused_conv_eltwise_kernel_gemm::Validate(const Params& p, const optional_params& o) const
57         {
58                 if (!fused_conv_eltwise_kernel_base::Validate(p, o) ||
59                         !FusedConvolutionEltwiseCheckInput(p, o))
60                 {
61                         return false;
62                 }
63
64                 const convolution_params& cp = static_cast<const convolution_params&>(p);
65                 
66         // make sure it's 1x1 conv
67         if (cp.filterSize.x != 1 || cp.filterSize.y != 1)
68             return false;
69
70         // make sure stride is 1x1
71         if (cp.stride.x != 1 || cp.stride.y != 1)
72             return false;
73
74         // input padding not supported
75         if (cp.inputs[0].X().pad.Total() != 0 ||
76             cp.inputs[0].Y().pad.Total() != 0 ||
77             cp.inputs[0].Feature().pad.Total() != 0 ||
78             cp.inputs[0].Batch().pad.Total() != 0)
79             return false;
80
81         // input and output spatial sizes must match
82         if (!(cp.output.X().v == cp.inputs[0].X().v) || !(cp.output.Y().v == cp.inputs[0].Y().v))
83             return false;
84
85                 return true;
86         }
87
88     std::vector<WeightsLayout> fused_conv_eltwise_kernel_gemm::GetSupportedWeightLayouts(const fused_conv_eltwise_params& params) const
89     {
90         if (params.inputs[0].GetDType() == Datatype::F16)
91         {
92             return{ WeightsLayout::iy_xs_os_xsv2_osv16__ao32 };
93         }
94         else
95         {
96             return{ WeightsLayout::iy_xs_os_xsv2_osv8__ao32 };
97         }
98     }
99
100     fused_conv_eltwise_kernel_base::DispatchData fused_conv_eltwise_kernel_gemm::SetDefault(const fused_conv_eltwise_params& arg, int) const
101         {
102         DispatchData runInfo = Parent::SetDefault(arg);
103
104         runInfo.lws0 = 1;
105         runInfo.lws2 = 1;
106
107         if (arg.inputs[0].GetDType() == Datatype::F16)
108         {
109             runInfo.gemmStyle = { 1, arg.conv.filterSize.x, 32, 32, 1, 1 };
110             runInfo.lws1 = 16;
111             runInfo.effiency = FORCE_PRIORITY_6;
112         }
113         else
114         {
115             runInfo.gemmStyle = { 2, arg.conv.filterSize.x, 32, 32, 2, 1 };
116             runInfo.lws1 = 8;
117             runInfo.effiency = FORCE_PRIORITY_8;
118         }
119
120         size_t sgemm_m = RoundUp(arg.output.X().v * arg.output.Y().v, runInfo.gemmStyle.subBlockDimM);
121         size_t sgemm_n = RoundUp(arg.output.Feature().v, runInfo.gemmStyle.subBlockDimN);
122
123         runInfo.gws0 = RoundUp(CeilDiv(sgemm_n, runInfo.gemmStyle.globalWorkSizeDX), runInfo.lws0);
124         runInfo.gws1 = RoundUp(CeilDiv(sgemm_m, runInfo.gemmStyle.globalWorkSizeDY), runInfo.lws1);
125         runInfo.gws2 = arg.output.Batch().v;
126
127         return runInfo;
128         }
129
130         JitConstants fused_conv_eltwise_kernel_gemm::GetJitConstants(const fused_conv_eltwise_params& params, const DispatchData& runInfo) const
131         {
132                 auto jit = Parent::GetJitConstants(params, runInfo);
133
134         jit.AddConstants({
135             MakeJitConstant("ALIGNED_OFM",                  RoundUp(params.output.Feature().v, runInfo.gemmStyle.subBlockDimN)),
136             MakeJitConstant("DX",                           runInfo.gemmStyle.globalWorkSizeDX),
137             MakeJitConstant("DY",                           runInfo.gemmStyle.globalWorkSizeDY),
138             MakeJitConstant("FILTER_SIZE_X_DIV2",           params.conv.filterSize.x / 2),
139             MakeJitConstant("INPUT_BUFFER_WIDTH_PADDED",    ""),    // TODO: enable non padding path again
140             MakeJitConstant("INPUT_BUFFER_HEIGHT_PADDED",   ""),
141             });
142
143         if (CeilDiv(RoundUp(params.output.X().v * params.output.Y().v, runInfo.gemmStyle.subBlockDimM), runInfo.gemmStyle.globalWorkSizeDY) % runInfo.lws1 != 0)
144             jit.AddConstant(MakeJitConstant("LEFTOVERS", 1));
145
146         if (!params.eltw.stride.empty())
147         {
148             jit.AddConstant(MakeJitConstant("ELTW_STRIDE_X", params.eltw.stride[0].x));
149             jit.AddConstant(MakeJitConstant("ELTW_STRIDE_Y", params.eltw.stride[0].y));
150         }
151         else
152         {
153             jit.AddConstant(MakeJitConstant("ELTW_STRIDE_X", 1));
154             jit.AddConstant(MakeJitConstant("ELTW_STRIDE_Y", 1));
155         }
156
157                 return jit;
158         }
159
160     KernelsData fused_conv_eltwise_kernel_gemm::GetKernelsData(const Params& params, const optional_params& options) const
161     {
162         return GetTunedKernelsDataByIndex(params, options);
163     }
164 }