2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "kernel_selector_params.h"
18 #include "kernel_selector_common.h"
21 namespace kernel_selector {
23 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
25 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
26 void ParamsKey::EnableInputDataType(Datatype dt)
31 key.inputType.val.int8 = 1;
34 key.inputType.val.uint8 = 1;
37 key.inputType.val.int16 = 1;
39 case Datatype::UINT16:
40 key.inputType.val.uint16 = 1;
43 key.inputType.val.int32 = 1;
45 case Datatype::UINT32:
46 key.inputType.val.uint32 = 1;
49 key.inputType.val.int64 = 1;
52 key.inputType.val.F16 = 1;
55 key.inputType.val.F32 = 1;
62 void ParamsKey::EnableAllInputDataType()
64 key.inputType.raw = 0xffffffff;
67 void ParamsKey::EnableOutputDataType(Datatype dt)
72 key.outputType.val.int8 = 1;
75 key.outputType.val.uint8 = 1;
78 key.outputType.val.int16 = 1;
80 case Datatype::UINT16:
81 key.outputType.val.uint16 = 1;
84 key.outputType.val.int32 = 1;
86 case Datatype::UINT32:
87 key.outputType.val.uint32 = 1;
90 key.outputType.val.int64 = 1;
93 key.outputType.val.F16 = 1;
96 key.outputType.val.F32 = 1;
103 void ParamsKey::EnableAllOutputDataType()
105 key.outputType.raw = 0xffffffff;
108 void ParamsKey::EnableInputWeightsType(WeightsType wt)
112 case WeightsType::F16:
113 key.inputWeightsType.val.F16 = 1;
115 case WeightsType::F32:
116 key.inputWeightsType.val.F32 = 1;
118 case WeightsType::INT8:
119 key.inputWeightsType.val.int8 = 1;
126 void ParamsKey::EnableAllInputWeightsType()
128 key.inputWeightsType.raw = 0xffffffff;
131 void ParamsKey::EnableOutputWeightsType(WeightsType wt)
135 case WeightsType::F16:
136 key.outputWeightsType.val.F16 = 1;
138 case WeightsType::F32:
139 key.outputWeightsType.val.F32 = 1;
141 case WeightsType::INT8:
142 key.outputWeightsType.val.int8 = 1;
149 void ParamsKey::EnableAllOutputWeightsType()
151 key.outputWeightsType.raw = 0xffffffff;
154 void ParamsKey::EnableLRNMode(LRNMode m)
158 case LRNMode::ACROSS_CHANNEL:
159 key.restrict.val.dedicated.norm.across = 1;
161 case LRNMode::WITHIN_CHANNEL:
162 key.restrict.val.dedicated.norm.within = 1;
169 void ParamsKey::EnableLookUpTableAxis(LookUpTableAxis m)
173 case kernel_selector::LookUpTableAxis::BATCH:
174 key.restrict.val.dedicated.lookt.axisBatch = 1;
176 case kernel_selector::LookUpTableAxis::FEATURE:
177 key.restrict.val.dedicated.lookt.axisFeature = 1;
179 case kernel_selector::LookUpTableAxis::X:
180 key.restrict.val.dedicated.lookt.axisX = 1;
182 case kernel_selector::LookUpTableAxis::Y:
183 key.restrict.val.dedicated.lookt.axisY = 1;
185 case kernel_selector::LookUpTableAxis::XYF:
186 key.restrict.val.dedicated.lookt.axisXYF = 1;
193 void ParamsKey::EnableNormalizeMode(NormalizeMode m)
197 case NormalizeMode::ACROSS_SPATIAL:
198 key.restrict.val.dedicated.norm.across = 1;
200 case NormalizeMode::WITHIN_SPATIAL:
201 key.restrict.val.dedicated.norm.within = 1;
208 void ParamsKey::EnableMVNMode(MVNMode m)
212 case MVNMode::ACROSS_CHANNELS:
213 key.restrict.val.dedicated.mvn.across = 1;
215 case MVNMode::WITHIN_CHANNELS:
216 key.restrict.val.dedicated.mvn.within = 1;
223 void ParamsKey::EnableMVNNormalizeVariance()
225 key.restrict.val.dedicated.mvn.normalize_variance = 1;
228 void ParamsKey::EnableLRNKernelDividerMode(KernelDividerMode m)
232 case KernelDividerMode::FIXED:
233 key.restrict.val.dedicated.norm.fixedKenrelDivider = 1;
235 case KernelDividerMode::DYNAMIC:
236 key.restrict.val.dedicated.norm.dynamicKenrelDivider = 1;
243 void ParamsKey::EnablePoolKernelDividerMode(KernelDividerMode m)
247 case KernelDividerMode::FIXED:
248 key.restrict.val.dedicated.pooling.fixedKenrelDivider = 1;
250 case KernelDividerMode::DYNAMIC:
251 key.restrict.val.dedicated.pooling.dynamicKenrelDivider = 1;
253 case KernelDividerMode::DYNAMIC_WITH_PADDING:
254 key.restrict.val.dedicated.pooling.dynamicKenrelDividerWithPadding = 1;
261 void ParamsKey::EnablePoolType(PoolType t)
266 key.restrict.val.dedicated.pooling.max = 1;
269 key.restrict.val.dedicated.pooling.avg = 1;
271 case PoolType::MAX_WITH_ARGMAX:
272 key.restrict.val.dedicated.pooling.max_with_argmax = 1;
274 case PoolType::BILINEAR:
275 key.restrict.val.dedicated.pooling.bilinear = 1;
281 void ParamsKey::EnablePoolRemainder(PoolRemainder r)
285 case PoolRemainder::FLOOR:
286 key.restrict.val.dedicated.pooling.floor = 1;
288 case PoolRemainder::CEIL:
289 key.restrict.val.dedicated.pooling.ceil = 1;
296 void ParamsKey::EnableSoftmaxDim(SoftmaxDim d)
301 key.restrict.val.dedicated.softmax.dimX = 1;
304 key.restrict.val.dedicated.softmax.dimY = 1;
306 case SoftmaxDim::FEATURE:
307 key.restrict.val.dedicated.softmax.dimFeature = 1;
314 void ParamsKey::EnableConcatAxis(ConcatAxis a)
319 key.restrict.val.dedicated.concat.axisX = 1;
322 key.restrict.val.dedicated.concat.axisY = 1;
324 case ConcatAxis::FEATURE:
325 key.restrict.val.dedicated.concat.axisFeature = 1;
327 case ConcatAxis::BATCH:
328 key.restrict.val.dedicated.concat.axisBatch = 1;
335 void ParamsKey::EnableUpSamplingSampleType(SampleType a)
339 case SampleType::NEAREST:
340 key.restrict.val.dedicated.upsample.nearest = 1;
342 case SampleType::BILINEAR:
343 key.restrict.val.dedicated.upsample.bilinear = 1;
350 void ParamsKey::EnableFusedConvEltwEltwiseStride()
352 key.restrict.val.dedicated.fused_conv_eltw.stride = 1;
355 void ParamsKey::EnableEltwiseStride()
357 key.restrict.val.dedicated.eltwise.stride = 1;
360 void ParamsKey::EnableArgMaxMinAxis(ArgMaxMinAxis a)
364 case ArgMaxMinAxis::X:
365 key.restrict.val.dedicated.argm.axisX = 1;
367 case ArgMaxMinAxis::Y:
368 key.restrict.val.dedicated.argm.axisY = 1;
370 case ArgMaxMinAxis::FEATURE:
371 key.restrict.val.dedicated.argm.axisFeature = 1;
373 case ArgMaxMinAxis::BATCH:
374 key.restrict.val.dedicated.argm.axisBatch = 1;
376 case ArgMaxMinAxis::XYF:
377 key.restrict.val.dedicated.argm.axisXYF = 1;
384 void ParamsKey::EnableIndexSelectAxis(IndexSelectAxis a)
388 case IndexSelectAxis::X:
389 key.restrict.val.dedicated.idxsel.axisX = 1;
391 case IndexSelectAxis::Y:
392 key.restrict.val.dedicated.idxsel.axisY = 1;
394 case IndexSelectAxis::FEATURE:
395 key.restrict.val.dedicated.idxsel.axisFeature = 1;
397 case IndexSelectAxis::BATCH:
398 key.restrict.val.dedicated.idxsel.axisBatch = 1;
405 void ParamsKey::EnableLookUpTableIndicesFormat(Datatype a)
407 if (a == Datatype::F32)
408 key.restrict.val.dedicated.lookt.indicesF32 = 1;
410 key.restrict.val.dedicated.lookt.indicesOther = 1;
413 void ParamsKey::EnableFusedConvEltwiseRWOutOpt()
415 key.restrict.val.dedicated.fused_conv_eltw.rw_out_opt = 1;
418 bool ParamsKey::Support(const ParamsKey& k) const
420 if (!((key.restrict.raw & k.key.restrict.raw) == k.key.restrict.raw)) // check if this kernel supports this params
422 if (!((key.machineInfo.raw & k.key.machineInfo.raw) == key.machineInfo.raw)) // check if machine supports this kernel
424 if (!((key.inputType.raw & k.key.inputType.raw) == k.key.inputType.raw))
426 if (!((key.outputType.raw & k.key.outputType.raw) == k.key.outputType.raw))
428 if (!((key.inputWeightsType.raw & k.key.inputWeightsType.raw) == k.key.inputWeightsType.raw))
430 if (!((key.outputWeightsType.raw & k.key.outputWeightsType.raw) == k.key.outputWeightsType.raw))
432 if (!((key.inputLayout & k.key.inputLayout) != 0 || key.inputLayout == k.key.inputLayout))
434 if (!((key.outputLayout & k.key.outputLayout) != 0 || key.outputLayout == k.key.outputLayout))
436 if (!((key.weightsInputLayout & k.key.weightsInputLayout) != 0 || key.weightsInputLayout == k.key.weightsInputLayout))
438 if (!((key.weightsOutputLayout & k.key.weightsOutputLayout) != 0 || key.weightsOutputLayout == k.key.weightsOutputLayout))
444 ParamsKey ParamsKey::Merge(const ParamsKey& k) const
447 ret.key.restrict.raw = key.restrict.raw | k.key.restrict.raw;
448 ret.key.machineInfo.raw = key.machineInfo.raw | k.key.machineInfo.raw;
449 ret.key.inputType.raw = key.inputType.raw | k.key.inputType.raw;
450 ret.key.outputType.raw = key.outputType.raw | k.key.outputType.raw;
451 ret.key.inputWeightsType.raw = key.inputWeightsType.raw | k.key.inputWeightsType.raw;
452 ret.key.outputWeightsType.raw = key.outputWeightsType.raw | k.key.outputWeightsType.raw;
453 ret.key.inputLayout = key.inputLayout | k.key.inputLayout;
454 ret.key.outputLayout = key.outputLayout | k.key.outputLayout;
455 ret.key.weightsInputLayout = key.weightsInputLayout | k.key.weightsInputLayout;
456 ret.key.weightsOutputLayout = key.weightsOutputLayout | k.key.weightsOutputLayout;
459 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
461 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
462 ParamsKey Params::GetParamsKey() const
466 if (engineInfo.bSubGroupSupport)
471 if (engineInfo.bSubGroupShortSupport)
473 k.EnableSubGroupShort();
479 std::string Params::to_string() const
482 s << toString(kType);
486 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
488 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
489 ParamsKey optional_params::GetSupportedKey() const
493 for (auto l : inputLayouts)
495 k.EnableInputLayout(l);
498 for (auto l : outputLayouts)
500 k.EnableOutputLayout(l);
506 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
508 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
509 ParamsKey base_params::GetParamsKey() const
511 ParamsKey k = Params::GetParamsKey();
513 bool bBatching = false;
514 bool bPitches = false;
515 bool bOffests = false;
516 bool bDifferentTypes = false;
517 bool bFP16Used = (output.GetDType() == Datatype::F16);
519 for (const auto& i : inputs)
521 k.EnableInputDataType(i.GetDType());
522 k.EnableInputLayout(i.GetLayout());
524 bBatching |= (i.Batch().v > 1);
525 bPitches |= (i.PitchesDifferFromLogicalDims());
526 bOffests |= (i.GetFirstElementOffset() != 0);
527 bDifferentTypes |= (i.GetDType() != output.GetDType());
528 bFP16Used |= (i.GetDType() == Datatype::F16);
531 k.EnableOutputDataType(output.GetDType());
532 k.EnableOutputLayout(output.GetLayout());
540 output.PitchesDifferFromLogicalDims())
542 k.EnableTensorPitches();
547 k.EnableDifferentTypes();
551 output.GetFirstElementOffset() != 0)
553 k.EnableTensorOffset();
556 if (!engineInfo.bFP16Support &&
559 // I'm not sure it's the best idea, but we can live with it right now
560 k.EnableFP16Emulation();
571 std::string base_activation_params::to_string() const
574 s << "m" << m << "_n" << n << "_" << toString(function);
578 std::string base_params::to_string() const
581 s << Params::to_string() << "_";
582 s << activation.to_string() << "_";
584 for (auto input : inputs)
586 s << toString(input) << "_";
588 s << toString(output);