1 // Copyright (c) 2016-2020 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "kernel_selector_params.h"
17 #include "kernel_selector_common.h"
21 #include <activation/activation_kernel_base.h>
24 namespace kernel_selector {
26 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
28 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
29 void ParamsKey::EnableInputDataType(Datatype dt) {
32 key.inputType.val.int8 = 1;
35 key.inputType.val.uint8 = 1;
38 key.inputType.val.int16 = 1;
40 case Datatype::UINT16:
41 key.inputType.val.uint16 = 1;
44 key.inputType.val.int32 = 1;
46 case Datatype::UINT32:
47 key.inputType.val.uint32 = 1;
50 key.inputType.val.int64 = 1;
53 key.inputType.val.F16 = 1;
56 key.inputType.val.F32 = 1;
58 case Datatype::BINARY:
59 key.inputType.val.binary = 1;
66 void ParamsKey::EnableAllInputDataType() { key.inputType.raw = 0xffffffff; }
68 void ParamsKey::EnableOutputDataType(Datatype dt) {
71 key.outputType.val.int8 = 1;
74 key.outputType.val.uint8 = 1;
77 key.outputType.val.int16 = 1;
79 case Datatype::UINT16:
80 key.outputType.val.uint16 = 1;
83 key.outputType.val.int32 = 1;
85 case Datatype::UINT32:
86 key.outputType.val.uint32 = 1;
89 key.outputType.val.int64 = 1;
92 key.outputType.val.F16 = 1;
95 key.outputType.val.F32 = 1;
97 case Datatype::BINARY:
98 key.outputType.val.binary = 1;
105 void ParamsKey::EnableAllOutputDataType() { key.outputType.raw = 0xffffffff; }
107 void ParamsKey::EnableInputWeightsType(WeightsType wt) {
109 case WeightsType::F16:
110 key.inputWeightsType.val.F16 = 1;
112 case WeightsType::F32:
113 key.inputWeightsType.val.F32 = 1;
115 case WeightsType::INT8:
116 key.inputWeightsType.val.int8 = 1;
118 case WeightsType::BINARY:
119 key.inputWeightsType.val.binary = 1;
126 void ParamsKey::EnableAllInputWeightsType() { key.inputWeightsType.raw = 0xffffffff; }
128 void ParamsKey::EnableOutputWeightsType(WeightsType wt) {
130 case WeightsType::F16:
131 key.outputWeightsType.val.F16 = 1;
133 case WeightsType::F32:
134 key.outputWeightsType.val.F32 = 1;
136 case WeightsType::INT8:
137 key.outputWeightsType.val.int8 = 1;
139 case WeightsType::BINARY:
140 key.outputWeightsType.val.binary = 1;
147 void ParamsKey::EnableAllOutputWeightsType() { key.outputWeightsType.raw = 0xffffffff; }
149 void ParamsKey::EnableLRNMode(LRNMode m) {
151 case LRNMode::ACROSS_CHANNEL:
152 key.restrict.val.dedicated.norm.across = 1;
154 case LRNMode::WITHIN_CHANNEL:
155 key.restrict.val.dedicated.norm.within = 1;
162 void ParamsKey::EnableLookUpTableAxis(LookUpTableAxis m) {
164 case kernel_selector::LookUpTableAxis::BATCH:
165 key.restrict.val.dedicated.lookt.axisBatch = 1;
167 case kernel_selector::LookUpTableAxis::FEATURE:
168 key.restrict.val.dedicated.lookt.axisFeature = 1;
170 case kernel_selector::LookUpTableAxis::X:
171 key.restrict.val.dedicated.lookt.axisX = 1;
173 case kernel_selector::LookUpTableAxis::Y:
174 key.restrict.val.dedicated.lookt.axisY = 1;
176 case kernel_selector::LookUpTableAxis::XYF:
177 key.restrict.val.dedicated.lookt.axisXYF = 1;
184 void ParamsKey::EnableNormalizeMode(NormalizeMode m) {
186 case NormalizeMode::ACROSS_SPATIAL:
187 key.restrict.val.dedicated.norm.across = 1;
189 case NormalizeMode::WITHIN_SPATIAL:
190 key.restrict.val.dedicated.norm.within = 1;
197 void ParamsKey::EnableMVNMode(MVNMode m) {
199 case MVNMode::ACROSS_CHANNELS:
200 key.restrict.val.dedicated.mvn.across = 1;
202 case MVNMode::WITHIN_CHANNELS:
203 key.restrict.val.dedicated.mvn.within = 1;
210 void ParamsKey::EnableMVNNormalizeVariance() { key.restrict.val.dedicated.mvn.normalize_variance = 1; }
212 void ParamsKey::EnableLRNKernelDividerMode(KernelDividerMode m) {
214 case KernelDividerMode::FIXED:
215 key.restrict.val.dedicated.norm.fixedKenrelDivider = 1;
217 case KernelDividerMode::DYNAMIC:
218 key.restrict.val.dedicated.norm.dynamicKenrelDivider = 1;
225 void ParamsKey::EnablePoolKernelDividerMode(KernelDividerMode m) {
227 case KernelDividerMode::FIXED:
228 key.restrict.val.dedicated.pooling.fixedKenrelDivider = 1;
230 case KernelDividerMode::DYNAMIC:
231 key.restrict.val.dedicated.pooling.dynamicKenrelDivider = 1;
233 case KernelDividerMode::DYNAMIC_WITH_PADDING:
234 key.restrict.val.dedicated.pooling.dynamicKenrelDividerWithPadding = 1;
241 void ParamsKey::EnablePoolType(PoolType t) {
244 key.restrict.val.dedicated.pooling.max = 1;
247 key.restrict.val.dedicated.pooling.avg = 1;
249 case PoolType::MAX_WITH_ARGMAX:
250 key.restrict.val.dedicated.pooling.max_with_argmax = 1;
252 case PoolType::BILINEAR:
253 key.restrict.val.dedicated.pooling.bilinear = 1;
255 case PoolType::DEFORMABLE_BILINEAR:
256 key.restrict.val.dedicated.pooling.deformable_bilinear = 1;
263 void ParamsKey::EnablePoolRemainder(PoolRemainder r) {
265 case PoolRemainder::FLOOR:
266 key.restrict.val.dedicated.pooling.floor = 1;
268 case PoolRemainder::CEIL:
269 key.restrict.val.dedicated.pooling.ceil = 1;
276 void ParamsKey::EnableSoftmaxDim(SoftmaxDim d) {
279 key.restrict.val.dedicated.softmax.dimX = 1;
282 key.restrict.val.dedicated.softmax.dimY = 1;
284 case SoftmaxDim::FEATURE:
285 key.restrict.val.dedicated.softmax.dimFeature = 1;
292 void ParamsKey::EnableConcatAxis(ConcatAxis a) {
295 key.restrict.val.dedicated.concat.axisX = 1;
298 key.restrict.val.dedicated.concat.axisY = 1;
301 key.restrict.val.dedicated.concat.axisZ = 1;
304 key.restrict.val.dedicated.concat.axisW = 1;
306 case ConcatAxis::FEATURE:
307 key.restrict.val.dedicated.concat.axisFeature = 1;
309 case ConcatAxis::BATCH:
310 key.restrict.val.dedicated.concat.axisBatch = 1;
317 void ParamsKey::EnableReampleType(ResampleType a) {
319 case ResampleType::NEAREST_NEIGHBOR:
320 key.restrict.val.dedicated.resample.nearest_neighbor = 1;
322 case ResampleType::CAFFE_BILINEAR_INTERP:
323 key.restrict.val.dedicated.resample.caffe_bilinear_interp = 1;
325 case ResampleType::BILINEAR_INTERP:
326 key.restrict.val.dedicated.resample.bilinear_interp = 1;
328 case ResampleType::CUBIC:
329 key.restrict.val.dedicated.resample.cubic = 1;
331 case ResampleType::LINEAR_ONNX:
332 key.restrict.val.dedicated.resample.linear_onnx = 1;
339 void ParamsKey::EnableFusedConvEltwEltwiseStride() { key.restrict.val.dedicated.fused_conv_eltw.stride = 1; }
341 void ParamsKey::EnableEltwiseStride() { key.restrict.val.dedicated.eltwise.stride = 1; }
343 void ParamsKey::EnableArgMaxMinAxis(ArgMaxMinAxis a) {
345 case ArgMaxMinAxis::X:
346 key.restrict.val.dedicated.argm.axisX = 1;
348 case ArgMaxMinAxis::Y:
349 key.restrict.val.dedicated.argm.axisY = 1;
351 case ArgMaxMinAxis::Z:
352 key.restrict.val.dedicated.argm.axisZ = 1;
354 case ArgMaxMinAxis::FEATURE:
355 key.restrict.val.dedicated.argm.axisFeature = 1;
357 case ArgMaxMinAxis::BATCH:
358 key.restrict.val.dedicated.argm.axisBatch = 1;
360 case ArgMaxMinAxis::XYF:
361 key.restrict.val.dedicated.argm.axisXYF = 1;
368 void ParamsKey::EnableIndexSelectAxis(IndexSelectAxis a) {
370 case IndexSelectAxis::X:
371 key.restrict.val.dedicated.idxsel.axisX = 1;
373 case IndexSelectAxis::Y:
374 key.restrict.val.dedicated.idxsel.axisY = 1;
376 case IndexSelectAxis::FEATURE:
377 key.restrict.val.dedicated.idxsel.axisFeature = 1;
379 case IndexSelectAxis::BATCH:
380 key.restrict.val.dedicated.idxsel.axisBatch = 1;
387 void ParamsKey::EnableLookUpTableIndicesFormat(Datatype a) {
388 if (a == Datatype::F32)
389 key.restrict.val.dedicated.lookt.indicesF32 = 1;
391 key.restrict.val.dedicated.lookt.indicesOther = 1;
394 void ParamsKey::EnableFusedConvEltwiseRWOutOpt() { key.restrict.val.dedicated.fused_conv_eltw.rw_out_opt = 1; }
395 void ParamsKey::EnableFusedConvEltwDepthToSpaceFusing() { key.restrict.val.dedicated.fused_conv_eltw.depth_to_space_fused = 1; }
398 void ParamsKey::EnableQuantization(QuantizationType q) {
400 case QuantizationType::NONE:
402 case QuantizationType::SYMMETRIC:
403 key.restrict.val.sym_quantization = 1;
405 case QuantizationType::ASYMMETRIC_DATA:
406 key.restrict.val.asym_d_quantization = 1;
408 case QuantizationType::ASYMMETRIC_WEIGHTS:
409 key.restrict.val.asym_w_quantization = 1;
411 case QuantizationType::ASYMMETRIC_DATA_AND_WEIGHTS:
412 key.restrict.val.asym_d_quantization = 1;
413 key.restrict.val.asym_w_quantization = 1;
420 bool ParamsKey::Support(const ParamsKey& k) const {
421 if (!((key.restrict.raw & k.key.restrict.raw) == k.key.restrict.raw)) // check if this kernel supports this params
423 if (!((key.machineInfo.raw & k.key.machineInfo.raw) ==
424 key.machineInfo.raw)) // check if machine supports this kernel
426 if (!((key.inputType.raw & k.key.inputType.raw) == k.key.inputType.raw))
428 if (!((key.outputType.raw & k.key.outputType.raw) == k.key.outputType.raw))
430 if (!((key.inputWeightsType.raw & k.key.inputWeightsType.raw) == k.key.inputWeightsType.raw))
432 if (!((key.outputWeightsType.raw & k.key.outputWeightsType.raw) == k.key.outputWeightsType.raw))
434 if (!((key.inputLayout & k.key.inputLayout) != 0 || key.inputLayout == k.key.inputLayout))
436 if (!((key.outputLayout & k.key.outputLayout) != 0 || key.outputLayout == k.key.outputLayout))
438 if (!((key.weightsInputLayout & k.key.weightsInputLayout) != 0 ||
439 key.weightsInputLayout == k.key.weightsInputLayout))
441 if (!((key.weightsOutputLayout & k.key.weightsOutputLayout) != 0 ||
442 key.weightsOutputLayout == k.key.weightsOutputLayout))
448 ParamsKey ParamsKey::Merge(const ParamsKey& k) const {
450 ret.key.restrict.raw = key.restrict.raw | k.key.restrict.raw;
451 ret.key.machineInfo.raw = key.machineInfo.raw | k.key.machineInfo.raw;
452 ret.key.inputType.raw = key.inputType.raw | k.key.inputType.raw;
453 ret.key.outputType.raw = key.outputType.raw | k.key.outputType.raw;
454 ret.key.inputWeightsType.raw = key.inputWeightsType.raw | k.key.inputWeightsType.raw;
455 ret.key.outputWeightsType.raw = key.outputWeightsType.raw | k.key.outputWeightsType.raw;
456 ret.key.inputLayout = key.inputLayout | k.key.inputLayout;
457 ret.key.outputLayout = key.outputLayout | k.key.outputLayout;
458 ret.key.weightsInputLayout = key.weightsInputLayout | k.key.weightsInputLayout;
459 ret.key.weightsOutputLayout = key.weightsOutputLayout | k.key.weightsOutputLayout;
462 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
464 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
465 ParamsKey Params::GetParamsKey() const {
468 if (engineInfo.bSubGroupSupport) {
472 if (engineInfo.bSubGroupShortSupport) {
473 k.EnableSubGroupShort();
476 if (engineInfo.bSubGroupCharSupport) {
477 k.EnableSubGroupChar();
483 std::string Params::to_string() const {
485 s << toString(kType);
489 std::string Params::to_cache_string_v2() const {
493 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
495 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
496 ParamsKey optional_params::GetSupportedKey() const {
499 for (auto l : inputLayouts) {
500 k.EnableInputLayout(l);
503 for (auto l : outputLayouts) {
504 k.EnableOutputLayout(l);
510 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
512 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
513 ParamsKey base_params::GetParamsKey() const {
514 ParamsKey k = Params::GetParamsKey();
516 bool bBatching = false;
517 bool bPitches = false;
518 bool bOffests = false;
519 bool bDifferentTypes = false;
520 bool bFP16Used = (output.GetDType() == Datatype::F16);
522 for (const auto& i : inputs) {
523 k.EnableInputDataType(i.GetDType());
524 k.EnableInputLayout(i.GetLayout());
526 bBatching |= (i.Batch().v > 1);
527 bPitches |= (i.PitchesDifferFromLogicalDims());
528 bOffests |= (i.GetFirstElementOffset() != 0);
529 bDifferentTypes |= (i.GetDType() != output.GetDType());
530 bFP16Used |= (i.GetDType() == Datatype::F16);
533 k.EnableOutputDataType(output.GetDType());
534 k.EnableOutputLayout(output.GetLayout());
540 if (bPitches || output.PitchesDifferFromLogicalDims()) {
541 k.EnableTensorPitches();
544 if (bDifferentTypes) {
545 k.EnableDifferentTypes();
548 if (bOffests || output.GetFirstElementOffset() != 0) {
549 k.EnableTensorOffset();
552 if (!engineInfo.bFP16Support && bFP16Used) {
553 // I'm not sure it's the best idea, but we can live with it right now
554 k.EnableFP16Emulation();
560 std::string base_activation_params::to_string() const {
562 s << "m" << m << "_n" << n << "_" << toString(function);
566 std::string base_params::to_string() const {
569 // WA to reuse old tuning cache. Code below must be replace with the following line once new cache file is merged.
570 // s << Params::to_string() << "_";
571 auto type_string = toString(kType);
572 if (kType == KernelType::FUSED_CONV_ELTWISE) {
575 s << type_string << "_";
577 // TODO: Remove activation from the string and recollect cache file
578 bool found_fused_activation = false;
579 if (!activations.empty()) {
580 s << activations[0].to_string() << "_";
581 found_fused_activation = true;
584 if (activations.empty() && !fused_ops.empty()) {
585 if (fused_ops[0].GetType() == KernelType::ACTIVATION) {
586 auto activation_params = fused_ops[0].GetOpParams<activation_fuse_params>()->param;
587 s << activation_params.to_string() << "_";
588 found_fused_activation = true;
592 if (!found_fused_activation) {
593 s << "m" << 0.f << "_n" << 0.f << "_" << toString(ActivationFunction::NONE) << "_";
596 for (auto input : inputs) {
597 s << toString(input) << "_";
599 s << toString(output);
604 std::string base_params::to_cache_string_v2() const {
607 for (auto input : inputs) {
608 s << toString_v2(input) << ";";
610 s << toString_v2(output);
615 } // namespace kernel_selector