2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "kernel_selector_utils.h"
18 #include "reorder/reorder_weights_kernel_selector.h"
19 #include "reorder/reorder_kernel_base.h"
20 #include "convolution/convolution_params.h"
22 namespace kernel_selector {
24 static WeightsType DataTypeToWeightsType(Datatype t)
28 case Datatype::UINT8: return WeightsType::UINT8;
29 case Datatype::INT8: return WeightsType::INT8;
30 case Datatype::F16: return WeightsType::F16;
31 case Datatype::F32: return WeightsType::F32;
33 return WeightsType::UNSUPPORTED;
37 static bool CheckWeights(const WeightsTensor& tensor, WeightsType reqType, std::vector<WeightsLayout> reqLayouts, const ParamsKey& paramsKey)
39 if ((reqType != tensor.GetDType()) &&
40 !(paramsKey.isEnabledDifferentInputWeightsTypes()))
45 bool bProperWeightsLayout = std::find(reqLayouts.begin(), reqLayouts.end(), tensor.GetLayout()) != reqLayouts.end();
46 if (!bProperWeightsLayout && tensor.PitchesDifferFromLogicalDims() == false)
48 bProperWeightsLayout =
49 (std::find(reqLayouts.begin(), reqLayouts.end(), WeightsLayout::io) != reqLayouts.end() && tensor.GetLayout() == WeightsLayout::iyxo) ||
50 (std::find(reqLayouts.begin(), reqLayouts.end(), WeightsLayout::oi) != reqLayouts.end() && tensor.GetLayout() == WeightsLayout::oiyx);
53 return bProperWeightsLayout;
56 std::vector<size_t> GetImageSizes(const kernel_selector::WeightsTensor& dimensions, const WeightsLayout layout)
58 auto ofm = dimensions.OFM().v;
59 auto ifm = dimensions.IFM().v;
60 auto x = dimensions.X().v;
61 auto y = dimensions.Y().v;
65 case WeightsLayout::image_2d_weights_c1_b_fyx:
66 case WeightsLayout::image_2d_weights_c4_fyx_b:
67 return { ofm, ifm * x * y };
68 case WeightsLayout::image_2d_weights_winograd_6x3_s1_fbxyb:
69 return { ofm * x * y * 8 / 3, ifm };
70 case WeightsLayout::image_2d_weights_winograd_6x3_s1_xfbyb:
71 return { ofm * y, ifm * x * 8 / 3 };
77 bool CheckImageSize(const weight_bias_params& newParams, const WeightsLayout layout)
79 if (!newParams.engineInfo.bImageSupport)
82 auto image_sizes = GetImageSizes(newParams.weights, layout);
83 if (image_sizes[0] == 0 ||
84 image_sizes[1] == 0 ||
85 image_sizes[0] > newParams.engineInfo.maxImage2dWidth ||
86 image_sizes[1] > newParams.engineInfo.maxImage2dHeight)
92 bool UpdateWeightsParams(weight_bias_params& newParams, const optional_params& options, std::vector<WeightsLayout> layouts, WeightsReorderParams& weightsReorderParams, const ParamsKey& paramsKey)
94 //validate if weights type is image and if device supports requested sizes
95 for (auto& requested_layout : layouts)
97 if (Tensor::IsImageType(requested_layout))
99 if (!CheckImageSize(newParams, requested_layout))
103 const weight_bias_optional_params& optParams = static_cast<const weight_bias_optional_params&>(options);
105 const auto dtype = DataTypeToWeightsType(newParams.inputs[0].GetDType());
106 bool bProperWeights = CheckWeights(
107 newParams.weights, dtype, layouts, paramsKey);
110 if (!optParams.allowStaticInputReordering)
115 auto& reorderKS = ReorderWeightsKernelSelctor::Instance();
116 reorder_weights_params r_params;
118 r_params.layerID = newParams.layerID + "_reorder_";
119 r_params.input = newParams.weights;
120 r_params.output = newParams.weights.TransformIgnorePadding(layouts[0], dtype);
122 reorder_optional_params op;
123 KernelsData kernels_data = reorderKS.GetBestKernels(r_params, op);
125 if (kernels_data.empty())
130 weightsReorderParams.engine = WeightsReorderParams::Engine::GPU;
131 weightsReorderParams.clKernel = std::make_shared<clKernelData>(kernels_data[0].kernels[0]);
132 weightsReorderParams.newBufferSize = r_params.output.PhysicalSizeInBytes();
133 weightsReorderParams.dtype = dtype;
134 weightsReorderParams.destLayout = r_params.output.GetLayout();
135 weightsReorderParams.toImageType = Tensor::IsImageType(r_params.output.GetLayout());
137 newParams.weights = r_params.output;
143 JitConstants GetTensorFriendlyWorkGroupsJit(const DataTensor& t)
145 auto b = DataTensor::Channelndex(t.GetLayout(), Tensor::DataChannelName::BATCH);
146 auto f = DataTensor::Channelndex(t.GetLayout(), Tensor::DataChannelName::FEATURE);
147 auto x = DataTensor::Channelndex(t.GetLayout(), Tensor::DataChannelName::X);
155 b = (b < x) ? b : b - 1;
156 f = (f < x) ? f : f - 1;
160 MakeJitConstant("GWS_BATCH", b),
161 MakeJitConstant("GWS_FEATURE", f),
162 MakeJitConstant("GWS_YX", x),
168 std::vector<size_t> GetTensorFriendlyWorkGroups(const DataTensor& t)
170 std::vector<size_t> sizes;
171 auto y = DataTensor::Channelndex(t.GetLayout(), Tensor::DataChannelName::Y);
172 for (size_t i = 0; i < t.GetDims().size(); i++)
174 const auto& o = t.GetDims()[i];
181 sizes.push_back(o.v);
185 for (size_t i = sizes.size(); i < 3; i++)
193 std::vector<size_t> GetOptimalLocalWorkGroupSizes(std::vector<size_t> gws)
195 const size_t lws_max = 256;
196 const size_t optimal_lws_values[] = { 256, 227, 224, 192, 160, 128, 96, 64, 32, 16, 8, 7, 6, 5, 4, 2, 1 };
197 size_t total_lws = 1;
198 std::vector<size_t> lws;
199 for (size_t i = 0; i < gws.size(); ++i)
201 auto rest_lws = lws_max / total_lws;
203 while (rest_lws < optimal_lws_values[lws_idx]) lws_idx++;
205 while (gws[i] % optimal_lws_values[lws_idx]) lws_idx++;
207 lws.push_back(optimal_lws_values[lws_idx]);
208 total_lws *= optimal_lws_values[lws_idx];
214 bool CheckInputsOutputNoPitchSameDims(const base_params& params)
216 bool no_pitch_same_dims = true;
218 if (params.inputs.size())
220 no_pitch_same_dims = !params.inputs[0].PitchesDifferFromLogicalDims();
222 for (size_t i = 1; i < params.inputs.size(); i++)
224 no_pitch_same_dims = no_pitch_same_dims && (params.inputs[0] == params.inputs[i]);
227 no_pitch_same_dims = no_pitch_same_dims && (params.inputs[0] == params.output);
230 return no_pitch_same_dims;