2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "fused_conv_bn_scale_kernel_base.h"
18 #include "kernel_selector_utils.h"
19 #include "common_tools.h"
21 namespace kernel_selector
23 bool fused_conv_bn_scale_kernel_base::Validate(const Params& p, const optional_params& o) const
25 if (p.GetType() != KernelType::FUSED_CONV_BN_SCALE ||
26 o.GetType() != KernelType::FUSED_CONV_BN_SCALE)
31 const fused_conv_bn_scale_params& params = static_cast<const fused_conv_bn_scale_params&>(p);
32 const fused_conv_bn_scale_optional_params& optParams = static_cast<const fused_conv_bn_scale_optional_params&>(o);
34 bool bSupportedWeightsLayout = false;
36 for (WeightsLayout l : GetSupportedWeightLayouts(params))
38 bSupportedWeightsLayout |= params.weights.GetLayout() == l;
41 const bool bWeightsOK = bSupportedWeightsLayout || optParams.allowStaticInputReordering;
46 JitConstants fused_conv_bn_scale_kernel_base::GetJitConstants(const fused_conv_bn_scale_params& params, const DispatchData&) const
48 JitConstants mem_consts = WeightBiasKernelBase::GetJitConstants(params);
49 const auto& padding = params.padding;
50 const auto& input = params.inputs[0];
52 int64_t input_offset_with_padding = (int64_t)input.GetFirstElementOffset() - padding.x*input.X().pitch - input.Y().pitch*padding.y;
53 input_offset_with_padding = std::max(input_offset_with_padding, (int64_t)0);
55 mem_consts.AddConstants({
56 MakeJitConstant("STRIDE", params.stride),
57 MakeJitConstant("PADDING", params.padding),
58 MakeJitConstant("FILTER_ARRAY_NUM", params.split),
59 MakeJitConstant("DILATION", params.dilation),
60 MakeJitConstant("INPUT0_OFFSET_WITH_PADDING", input_offset_with_padding),
61 MakeJitConstant("EPSILON", params.epsilon)
64 if (params.fused_in_training)
65 mem_consts.AddConstant(MakeJitConstant("FUSED_TRAINING", 1));
66 if (params.scale_bias)
67 mem_consts.AddConstant(MakeJitConstant("SCALE_BIAS_TERM", 1));
72 bool fused_conv_bn_scale_kernel_base::CheckWorkGroups(const DispatchData& kd)
84 if ((kd.gws0 % kd.lws0) != 0 ||
85 (kd.gws1 % kd.lws1) != 0 ||
86 (kd.gws2 % kd.lws2) != 0)
94 fused_conv_bn_scale_kernel_base::DispatchData fused_conv_bn_scale_kernel_base::SetDefault(const fused_conv_bn_scale_params& params) const
98 const auto& out = params.output;
99 kd.fp16UnitUsed = out.GetDType() == Datatype::F16;
100 std::vector<size_t> global;
101 if (params.output.GetLayout() == DataLayout::bfyx || params.output.GetLayout() == DataLayout::byxf)
103 global = { out.X().v, out.Y().v, out.Feature().v*out.Batch().v };
107 global = { out.Feature().v*out.Batch().v, out.X().v, out.Y().v };
110 auto local = GetOptimalLocalWorkGroupSizes(global);
120 kd.effiency = DONT_USE_IF_HAVE_SOMETHING_ELSE;
124 KernelsData fused_conv_bn_scale_kernel_base::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const
126 if (!Validate(params, options))
131 KernelData kd = KernelData::Default<fused_conv_bn_scale_params>(params);
132 fused_conv_bn_scale_params& newParams = *static_cast<fused_conv_bn_scale_params*>(kd.params.get());
134 DispatchData runInfo = SetDefault(newParams);
136 if (!CheckWorkGroups(runInfo))
138 // Internal Error - wrong calculation of global/local work group sizes
142 bool succeed = UpdateWeightsParams(
145 GetSupportedWeightLayouts(newParams),
146 kd.weightsReorderParams);
153 auto finalKernelName = GetKernelName(newParams);
154 auto cldnnJit = GetJitConstants(newParams, runInfo);
155 auto entryPoint = GetEntryPoint(finalKernelName, newParams.layerID, options);
156 auto jit = CreateJit(finalKernelName, cldnnJit, entryPoint);
158 auto& kernel = kd.kernels[0];
159 FillCLKernelData(kernel, runInfo, params.engineInfo, finalKernelName, jit, entryPoint, "", true, !newParams.bias.empty(), 1);
160 kernel.arguments.push_back({ ArgumentDescriptor::Types::SPLIT, 0 });
162 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, idx++ });
163 if (newParams.scale_bias)
164 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, idx++ });
165 if (newParams.fused_in_training)
167 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, idx++ });
168 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, idx++ });
169 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, idx });
172 kd.estimatedTime = estimated_time;