2 // Copyright (c) 2016-2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
22 #include "common_types.h"
23 #include "tensor_type.h"
26 namespace kernel_selector
28 using DataTensor = Tensor::DataTensor;
29 using WeightsTensor = Tensor::WeightsTensor;
30 using DataLayout = Tensor::DataLayout;
31 using WeightsLayout = Tensor::WeightsLayout;
32 using MultiDataTensor = std::vector<DataTensor>;
33 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
35 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
43 key.machineInfo.raw = 0;
44 key.inputType.raw = 0;
45 key.outputType.raw = 0;
46 key.inputWeightsType.raw = 0;
47 key.outputWeightsType.raw = 0;
50 key.weightsInputLayout = 0;
51 key.weightsOutputLayout = 0;
60 uint32_t different_types : 1;
61 uint32_t different_input_weights_types : 1;
64 uint32_t batching : 1;
65 uint32_t biasPerFeatureMap : 1;
66 uint32_t biasPerOutput : 1;
68 uint32_t activationAdditionalParamsAsInput : 1;
69 uint32_t FP16Emulation : 1;
70 uint32_t gradient : 1;
71 uint32_t momentum : 1;
79 uint32_t axisFeature : 1;
80 uint32_t axisBatch : 1;
82 uint32_t indicesF32 : 1;
83 uint32_t indicesOther : 1;
89 uint32_t axisFeature : 1;
90 uint32_t axisBatch : 1;
97 uint32_t axisFeature : 1;
98 uint32_t axisBatch : 1;
104 uint32_t fixedKenrelDivider : 1;
105 uint32_t dynamicKenrelDivider : 1;
111 uint32_t normalize_variance : 1;
118 uint32_t max_with_argmax : 1;
120 uint32_t bilinear : 1;
121 uint32_t fixedKenrelDivider : 1;
122 uint32_t dynamicKenrelDivider : 1;
123 uint32_t dynamicKenrelDividerWithPadding : 1;
124 uint32_t position_sensitive : 1;
129 uint32_t dilation : 1;
130 uint32_t depthwise_separable_opt : 1;
131 uint32_t transposed : 1;
132 uint32_t quantization : 1;
133 uint32_t calibration : 1;
135 uint32_t grouped : 1;
142 uint32_t dimFeature : 1;
148 uint32_t dimFeature : 1;
150 uint32_t classes : 1;
157 uint32_t dimFeature : 1;
164 uint32_t axisFeature : 1;
165 uint32_t axisBatch : 1;
166 uint32_t kernelPerInput : 1;
167 uint32_t oneKernel : 1;
171 uint32_t nearest : 1;
172 uint32_t bilinear : 1;
176 uint32_t winograd : 1;
181 uint32_t broadcast : 1;
190 struct fused_conv_eltw_t {
193 uint32_t dilation : 1;
194 uint32_t depthwise_separable_opt : 1;
195 uint32_t transposed : 1;
196 uint32_t quantization : 1;
197 uint32_t calibration : 1;
199 uint32_t grouped : 1;
203 uint32_t rw_out_opt : 1;
214 uint32_t subgroup : 1;
215 uint32_t subgroupShort : 1;
220 static_assert(sizeof(restrict_t) == sizeof(uint64_t), "problem with union");
222 typedef union DataTypesKey_t
239 uint32_t enableTuning;
240 DataTypesKey inputType;
241 DataTypesKey outputType;
242 DataTypesKey inputWeightsType;
243 DataTypesKey outputWeightsType;
244 uint32_t inputLayout;
245 uint32_t outputLayout;
246 uint32_t weightsInputLayout;
247 uint32_t weightsOutputLayout;
250 void EnableInputDataType(Datatype dt);
251 void EnableAllInputDataType();
252 void EnableOutputDataType(Datatype dt);
253 void EnableAllOutputDataType();
254 void EnableInputWeightsType(WeightsType wt);
255 void EnableAllInputWeightsType();
256 void EnableOutputWeightsType(WeightsType wt);
257 void EnableAllOutputWeightsType();
258 void EnableFP16Emulation() { key.restrict.val.FP16Emulation = 1; }
259 void EnableDifferentTypes() { key.restrict.val.different_types = 1; }
260 void EnableDifferentInputWeightsTypes() {
261 key.restrict.val.different_input_weights_types = 1; }
262 void EnableInputLayout(DataLayout l) { key.inputLayout |= (1 << l); }
263 void EnableAllInputLayout() { key.inputLayout = 0xffffffff; }
264 void EnableOutputLayout(DataLayout l) { key.outputLayout |= (1 << l); }
265 void EnableAllOutputLayout() { key.outputLayout = 0xffffffff; }
266 void EnableInputWeightsLayout(WeightsLayout l) { key.weightsInputLayout |= (1 << l); }
267 void EnableAllInputWeightsLayout() { key.weightsInputLayout = 0xffffffff; }
268 void EnableOutputWeightsLayout(WeightsLayout l) { key.weightsOutputLayout |= (1 << l); }
269 void EnableAllOutputWeightsLayout() { key.weightsOutputLayout = 0xffffffff; }
270 void EnableTensorOffset() { key.restrict.val.offset = 1; }
271 void EnableTensorPitches() { key.restrict.val.pitches = 1; }
272 void EnableBatching() { key.restrict.val.batching = 1; }
273 void EnableGradient() { key.restrict.val.gradient = 1; }
274 void EnableSubGroup() { key.machineInfo.val.subgroup = 1; }
275 void EnableSubGroupShort() { key.machineInfo.val.subgroupShort = 1; }
276 void EnableNonBiasTerm() { key.restrict.val.nonBias = 1; }
277 void EnableBiasPerFeature() { key.restrict.val.biasPerFeatureMap = 1; }
278 void EnableBiasPerOutput() { key.restrict.val.biasPerOutput = 1; }
279 void EnableActivationAdditionalParamsAsInput() { key.restrict.val.activationAdditionalParamsAsInput = 1; }
280 void EnableMomentum() { key.restrict.val.momentum = 1; }
281 void EnableLRNMode(LRNMode m);
282 void EnableLookUpTableAxis(LookUpTableAxis m);
283 void EnableNormalizeMode(NormalizeMode m);
284 void EnableMVNMode(MVNMode m);
285 void EnableMVNNormalizeVariance();
286 void EnableLRNKernelDividerMode(KernelDividerMode m);
287 void EnablePoolKernelDividerMode(KernelDividerMode m);
288 void EnablePoolType(PoolType t);
289 void EnablePoolRemainder(PoolRemainder r);
290 void EnablePositionSensitivePooling() { key.restrict.val.dedicated.pooling.position_sensitive = 1; }
291 void EnableSplitSupport() { key.restrict.val.dedicated.conv.split = 1; }
292 void EnableDilation() { key.restrict.val.dedicated.conv.dilation = 1; }
293 void EnableDepthwiseSeparableOpt() { key.restrict.val.dedicated.conv.depthwise_separable_opt = 1; }
294 void EnableLocalConvolution() { key.restrict.val.dedicated.conv.local = 1; }
295 void EnableGroupedConvolution() { key.restrict.val.dedicated.conv.grouped = 1; }
296 void EnableTranspose() { key.restrict.val.dedicated.conv.transposed = 1; }
297 void EnableInt8Quantization() { key.restrict.val.dedicated.conv.quantization = 1; }
298 void EnableOutputCalibration() { key.restrict.val.dedicated.conv.calibration = 1; }
300 void EnableFusedConvEltwSplitSupport() { key.restrict.val.dedicated.fused_conv_eltw.split = 1; }
301 void EnableFusedConvEltwDilation() { key.restrict.val.dedicated.fused_conv_eltw.dilation = 1; }
302 void EnableFusedConvEltwDepthwiseSeparableOpt() { key.restrict.val.dedicated.fused_conv_eltw.depthwise_separable_opt = 1; }
303 void EnableFusedConvEltwLocalConvolution() { key.restrict.val.dedicated.fused_conv_eltw.local = 1; }
304 void EnableFusedConvEltwGroupedConvolution() { key.restrict.val.dedicated.fused_conv_eltw.grouped = 1; }
305 void EnableFusedConvEltwTranspose() { key.restrict.val.dedicated.fused_conv_eltw.transposed = 1; }
306 void EnableFusedConvEltwInt8Quantization() { key.restrict.val.dedicated.fused_conv_eltw.quantization = 1; }
307 void EnableFusedConvEltwOutputCalibration() { key.restrict.val.dedicated.fused_conv_eltw.calibration = 1; }
308 void EnableFusedConvEltwEltwiseStride();
310 void EnableWinogradReorder() { key.restrict.val.dedicated.reorder.winograd = 1; }
311 void EnableSoftmaxDim(SoftmaxDim d);
312 void EnableConcatAxis(ConcatAxis a);
313 void EnableUpSamplingSampleType(SampleType a);
314 void EnableEltwiseStride();
315 void EnableEltwiseBroadcast() { key.restrict.val.dedicated.eltwise.broadcast = 1; }
316 void EnableLSTMGEMMBias() { key.restrict.val.dedicated.lstm_gemm.bias = 1; }
317 void EnableLSTMGEMMHidden() { key.restrict.val.dedicated.lstm_gemm.hidden = 1; }
318 void EnableLSTMEltCell() { key.restrict.val.dedicated.lstm_elt.cell = 1; }
319 void EnableConcatKernelPerInput() { key.restrict.val.dedicated.concat.kernelPerInput = 1; }
320 void DisableTuning() { key.enableTuning = 0; }
321 void EnableConcatOneKernel() { key.restrict.val.dedicated.concat.oneKernel = 1; }
322 void EnableArgMaxMinAxis(ArgMaxMinAxis a);
323 void EnableLookUpTableIndicesFormat(Datatype a);
324 void EnableIndexSelectAxis(IndexSelectAxis a);
325 void EnableFusedConvEltwiseRWOutOpt();
326 bool Support(const ParamsKey& k) const;
327 bool TuningSupport() const
329 if (key.enableTuning == 1)
333 bool isEnabledDifferentInputWeightsTypes() const {
334 return key.restrict.val.different_input_weights_types ? true : false;
336 ParamsKey Merge(const ParamsKey& k) const;
342 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
344 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
347 bool bSubGroupSupport = false;
348 bool bSubGroupShortSupport = false;
349 bool bFP16Support = false;
350 bool bFP64Support = false;
351 bool bImageSupport = false;
352 bool bIMADSupport = false;
353 bool bIMMADSupport = false;
354 uint32_t computeUnitsCount = 0;
355 uint64_t maxWorkGroupSize = 0;
356 uint64_t maxLocalMemSize = 0;
357 uint64_t maxImage2dWidth = 0;
358 uint64_t maxImage2dHeight = 0;
359 std::string deviceId = "";
360 std::string driverVersion = "";
361 std::string hostVersion = "";
362 std::shared_ptr<rapidjson::Document> deviceCache;
365 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
367 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
372 KernelType GetType() const { return kType; }
373 virtual ParamsKey GetParamsKey() const;
376 Params(KernelType kt, const std::string& id) : kType(kt), layerID(id) {}
381 EngineInfo engineInfo;
383 virtual std::string to_string() const;
386 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
387 // base_activation_params
388 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
389 struct base_activation_params
391 ActivationFunction function = ActivationFunction::NONE;
395 base_activation_params() = default;
396 base_activation_params(const float m, const float n) : m(m), n(n) {}
398 virtual std::string to_string() const;
401 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
403 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
404 struct base_params : public Params
406 virtual ~base_params() {}
408 base_activation_params activation;
409 MultiDataTensor inputs;
411 bool gradient = false;
413 virtual std::string to_string() const;
414 virtual ParamsKey GetParamsKey() const;
417 base_params(KernelType kt) : Params(kt, ""), inputs(1){}
420 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
421 // Auto tuner parameters
422 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
423 class KernelRunnerInterface;
427 std::string cacheFilePath;
428 std::shared_ptr<KernelRunnerInterface> runner;
430 TuningParams() : mode(TuningMode::TUNING_DISABLED), cacheFilePath(""), runner(nullptr) {}
433 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
435 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
436 struct optional_params
438 virtual ~optional_params() {}
440 KernelType GetType() const { return kType; }
442 std::vector<DataLayout> inputLayouts;
443 std::vector<DataLayout> outputLayouts;
445 bool meaningfulKernelsNames = false; // use layer name instead of internal kernel name
446 bool allowStaticInputReordering = true; // allow kernel to provide a kernel which reorder static data like weights/bias/tables...
447 bool allowInputReordering = false; // allow kernel to ask graph compiler to reorder the input data before executing its
448 bool allowOutputReordering = false; // allow kernel to ask graph compiler to reorder the output data before executing the next kernel
450 TuningParams tuningParams;
452 virtual ParamsKey GetSupportedKey() const;
454 optional_params(KernelType kt) : kType(kt) {}