2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "kernel_selector_common.h"
18 #include "reorder_kernel_base.h"
19 #include "common_tools.h"
20 #include "kernel_selector_utils.h"
22 namespace kernel_selector
24 inline uint32_t SubGroupSize(WeightsLayout l)
28 case WeightsLayout::os_iyx_osv16:
29 case WeightsLayout::os_iyx_osv16_rotate_180:
30 case WeightsLayout::os_i_osv16:
31 case WeightsLayout::os_i_osv16__ai8:
32 case WeightsLayout::i_yxs_os_yxsv2_osv16:
33 case WeightsLayout::iy_xs_os_xsv2_osv16__ao32:
35 case WeightsLayout::os_i_osv8__ai8:
36 case WeightsLayout::iy_xs_os_xsv2_osv8__ao32:
43 inline uint32_t SubGroupSize(DataLayout l)
47 case DataLayout::bs_f_bsv16__af8:
49 case DataLayout::bs_f_bsv8__af8:
56 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
57 // MakeReorderWeightsJitConstants
58 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
59 inline JitConstants MakeReorderWeightsJitConstants(const reorder_weights_params& params)
61 const auto& input = params.input;
62 const auto& output = params.output;
63 const bool fp16Supported = output.GetDType() == WeightsType::F16 || input.GetDType() == WeightsType::F16;
66 MakeJitConstant("FP16_SUPPORTED", fp16Supported), // TODO: use engine
67 MakeJitConstant("FP16_UNIT_USED", fp16Supported),
68 MakeJitConstant("INPUT0", input),
69 MakeJitConstant("OUTPUT", output),
75 JitConstants ReorderKernelBase::GetJitConstants(const reorder_weights_params& params) const
77 JitConstants mem_consts = MakeReorderWeightsJitConstants(params);
79 mem_consts.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", SubGroupSize(params.output.GetLayout())));
84 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
85 // MakeReorderJitConstants
86 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
87 inline JitConstants MakeReorderJitConstants(const reorder_params& params)
89 JitConstants jit = MakeBaseParamsJitConstants(params);
91 jit.AddConstant(MakeJitConstant("MEAN_SUBTRACT_" + toString(params.mode), 1));
93 if (params.mode == MeanSubtractMode::INSIDE_PARAMS)
95 jit.AddConstant(MakeJitConstant("VALUE_TO_SUBTRACT", params.meanValues));
96 jit.AddConstant(MakeJitConstant("TO_MEAN_TYPE", "convert_float"));
98 else if (params.mode == MeanSubtractMode::IN_BUFFER)
100 jit.AddConstant(MakeJitConstant("MEAN_SUBTRACT", params.mean));
101 jit.AddConstant(MakeJitConstant("TO_MEAN_TYPE", "convert_" + toCLType(params.mean.GetDType())));
104 //half->half without subtraction (so plain reorder) can be done on shorts without explicit fp16 support
105 bool useUshort = (params.inputs[0].GetDType() == Datatype::F16 && params.output.GetDType() == Datatype::F16 &&
106 params.mode == MeanSubtractMode::NONE);
108 Datatype calc_type = useUshort ? Datatype::UINT16 : params.inputs[0].GetDType();
111 MakeJitConstant("CALC_TYPE", toCLType(calc_type)),
112 MakeJitConstant("TO_CALC_TYPE", "convert_" + toCLType(calc_type)),
113 MakeJitConstant("INPUT_REORDER_TYPE", useUshort ? toCLType(Datatype::UINT16) : "INPUT0_TYPE"),
114 MakeJitConstant("OUTPUT_REORDER_TYPE", useUshort ? toCLType(Datatype::UINT16) : "OUTPUT_TYPE"),
115 MakeJitConstant("TO_OUTPUT_REORDER_TYPE", useUshort ? "" : "TO_OUTPUT_TYPE"),
116 MakeJitConstant("MEAN_OP(val,mean_val)", getMeanOpString(params.mean_op))
122 JitConstants ReorderKernelBase::GetJitConstants(const reorder_params& params) const
124 JitConstants mem_consts = MakeReorderJitConstants(params);
126 mem_consts.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", SubGroupSize(params.output.GetLayout())));
131 ReorderKernelBase::DispatchData ReorderKernelBase::SetDefault(const reorder_weights_params& params) const
133 const auto& out = params.output;
137 std::vector<size_t> global(3);
139 global = { out.OFM().v, out.IFM().v, out.X().v*out.Y().v };
140 auto local = GetOptimalLocalWorkGroupSizes(global);
153 ReorderKernelBase::DispatchData ReorderKernelBase::SetDefault(const reorder_params& params) const
157 auto global = GetTensorFriendlyWorkGroups(params.inputs[0]);
158 auto local = GetOptimalLocalWorkGroupSizes(global);
171 KernelsData ReorderKernelBase::GetCommonKernelsData(const reorder_weights_params& params, const optional_params& options, float estimated_time) const
173 assert(params.GetType() == KernelType::REORDER);
175 KernelData kd = KernelData::Default<reorder_weights_params>(params);
176 reorder_weights_params& newParams = *static_cast<reorder_weights_params*>(kd.params.get());
178 DispatchData runInfo;
180 runInfo = SetDefault(newParams);
182 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
183 auto cldnn_jit = GetJitConstants(newParams);
184 std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
186 auto& kernel = kd.kernels[0];
188 FillCLKernelData(kernel, runInfo, kernelName, jit, entry_point);
190 kernel.arguments = GetArgsDesc(1, false, false);
192 kd.estimatedTime = estimated_time;
197 KernelsData ReorderKernelBase::GetCommonKernelsData(const reorder_params& params, const optional_params& options, float estimated_time) const
199 if (!Validate(params, options))
203 assert(params.GetType() == KernelType::REORDER);
205 KernelData kd = KernelData::Default<reorder_params>(params);
206 reorder_params& newParams = *static_cast<reorder_params*>(kd.params.get());
208 DispatchData runInfo;
210 runInfo = SetDefault(newParams);
212 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
213 auto cldnn_jit = GetJitConstants(newParams);
214 std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
216 auto& kernel = kd.kernels[0];
218 FillCLKernelData(kernel, runInfo, kernelName, jit, entry_point);
220 kernel.arguments = GetArgsDesc(1, false, false);
221 if (newParams.mode == MeanSubtractMode::IN_BUFFER)
223 kernel.arguments.push_back({ ArgumentDescriptor::Types::BIAS, 0 });
226 kd.estimatedTime = estimated_time;