[IE CLDNN] Fix linear_onnx Interpolate selection (#2769)
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / kernel_selector_params.cpp
1 // Copyright (c) 2016-2020 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "kernel_selector_params.h"
17 #include "kernel_selector_common.h"
18 #include <sstream>
19 #include <string>
20
21 #include <activation/activation_kernel_base.h>
22 #include "jitter.h"
23
24 namespace kernel_selector {
25
26 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
27 // ParamsKey
28 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
29 void ParamsKey::EnableInputDataType(Datatype dt) {
30     switch (dt) {
31         case Datatype::INT8:
32             key.inputType.val.int8 = 1;
33             break;
34         case Datatype::UINT8:
35             key.inputType.val.uint8 = 1;
36             break;
37         case Datatype::INT16:
38             key.inputType.val.int16 = 1;
39             break;
40         case Datatype::UINT16:
41             key.inputType.val.uint16 = 1;
42             break;
43         case Datatype::INT32:
44             key.inputType.val.int32 = 1;
45             break;
46         case Datatype::UINT32:
47             key.inputType.val.uint32 = 1;
48             break;
49         case Datatype::INT64:
50             key.inputType.val.int64 = 1;
51             break;
52         case Datatype::F16:
53             key.inputType.val.F16 = 1;
54             break;
55         case Datatype::F32:
56             key.inputType.val.F32 = 1;
57             break;
58         case Datatype::BINARY:
59             key.inputType.val.binary = 1;
60             break;
61         default:
62             break;
63     }
64 }
65
66 void ParamsKey::EnableAllInputDataType() { key.inputType.raw = 0xffffffff; }
67
68 void ParamsKey::EnableOutputDataType(Datatype dt) {
69     switch (dt) {
70         case Datatype::INT8:
71             key.outputType.val.int8 = 1;
72             break;
73         case Datatype::UINT8:
74             key.outputType.val.uint8 = 1;
75             break;
76         case Datatype::INT16:
77             key.outputType.val.int16 = 1;
78             break;
79         case Datatype::UINT16:
80             key.outputType.val.uint16 = 1;
81             break;
82         case Datatype::INT32:
83             key.outputType.val.int32 = 1;
84             break;
85         case Datatype::UINT32:
86             key.outputType.val.uint32 = 1;
87             break;
88         case Datatype::INT64:
89             key.outputType.val.int64 = 1;
90             break;
91         case Datatype::F16:
92             key.outputType.val.F16 = 1;
93             break;
94         case Datatype::F32:
95             key.outputType.val.F32 = 1;
96             break;
97         case Datatype::BINARY:
98             key.outputType.val.binary = 1;
99             break;
100         default:
101             break;
102     }
103 }
104
105 void ParamsKey::EnableAllOutputDataType() { key.outputType.raw = 0xffffffff; }
106
107 void ParamsKey::EnableInputWeightsType(WeightsType wt) {
108     switch (wt) {
109         case WeightsType::F16:
110             key.inputWeightsType.val.F16 = 1;
111             break;
112         case WeightsType::F32:
113             key.inputWeightsType.val.F32 = 1;
114             break;
115         case WeightsType::INT8:
116             key.inputWeightsType.val.int8 = 1;
117             break;
118         case WeightsType::BINARY:
119             key.inputWeightsType.val.binary = 1;
120             break;
121         default:
122             break;
123     }
124 }
125
126 void ParamsKey::EnableAllInputWeightsType() { key.inputWeightsType.raw = 0xffffffff; }
127
128 void ParamsKey::EnableOutputWeightsType(WeightsType wt) {
129     switch (wt) {
130         case WeightsType::F16:
131             key.outputWeightsType.val.F16 = 1;
132             break;
133         case WeightsType::F32:
134             key.outputWeightsType.val.F32 = 1;
135             break;
136         case WeightsType::INT8:
137             key.outputWeightsType.val.int8 = 1;
138             break;
139         case WeightsType::BINARY:
140             key.outputWeightsType.val.binary = 1;
141             break;
142         default:
143             break;
144     }
145 }
146
147 void ParamsKey::EnableAllOutputWeightsType() { key.outputWeightsType.raw = 0xffffffff; }
148
149 void ParamsKey::EnableLRNMode(LRNMode m) {
150     switch (m) {
151         case LRNMode::ACROSS_CHANNEL:
152             key.restrict.val.dedicated.norm.across = 1;
153             break;
154         case LRNMode::WITHIN_CHANNEL:
155             key.restrict.val.dedicated.norm.within = 1;
156             break;
157         default:
158             break;
159     }
160 }
161
162 void ParamsKey::EnableLookUpTableAxis(LookUpTableAxis m) {
163     switch (m) {
164         case kernel_selector::LookUpTableAxis::BATCH:
165             key.restrict.val.dedicated.lookt.axisBatch = 1;
166             break;
167         case kernel_selector::LookUpTableAxis::FEATURE:
168             key.restrict.val.dedicated.lookt.axisFeature = 1;
169             break;
170         case kernel_selector::LookUpTableAxis::X:
171             key.restrict.val.dedicated.lookt.axisX = 1;
172             break;
173         case kernel_selector::LookUpTableAxis::Y:
174             key.restrict.val.dedicated.lookt.axisY = 1;
175             break;
176         case kernel_selector::LookUpTableAxis::XYF:
177             key.restrict.val.dedicated.lookt.axisXYF = 1;
178             break;
179         default:
180             break;
181     }
182 }
183
184 void ParamsKey::EnableNormalizeMode(NormalizeMode m) {
185     switch (m) {
186         case NormalizeMode::ACROSS_SPATIAL:
187             key.restrict.val.dedicated.norm.across = 1;
188             break;
189         case NormalizeMode::WITHIN_SPATIAL:
190             key.restrict.val.dedicated.norm.within = 1;
191             break;
192         default:
193             break;
194     }
195 }
196
197 void ParamsKey::EnableMVNMode(MVNMode m) {
198     switch (m) {
199         case MVNMode::ACROSS_CHANNELS:
200             key.restrict.val.dedicated.mvn.across = 1;
201             break;
202         case MVNMode::WITHIN_CHANNELS:
203             key.restrict.val.dedicated.mvn.within = 1;
204             break;
205         default:
206             break;
207     }
208 }
209
210 void ParamsKey::EnableMVNNormalizeVariance() { key.restrict.val.dedicated.mvn.normalize_variance = 1; }
211
212 void ParamsKey::EnableLRNKernelDividerMode(KernelDividerMode m) {
213     switch (m) {
214         case KernelDividerMode::FIXED:
215             key.restrict.val.dedicated.norm.fixedKenrelDivider = 1;
216             break;
217         case KernelDividerMode::DYNAMIC:
218             key.restrict.val.dedicated.norm.dynamicKenrelDivider = 1;
219             break;
220         default:
221             break;
222     }
223 }
224
225 void ParamsKey::EnablePoolKernelDividerMode(KernelDividerMode m) {
226     switch (m) {
227         case KernelDividerMode::FIXED:
228             key.restrict.val.dedicated.pooling.fixedKenrelDivider = 1;
229             break;
230         case KernelDividerMode::DYNAMIC:
231             key.restrict.val.dedicated.pooling.dynamicKenrelDivider = 1;
232             break;
233         case KernelDividerMode::DYNAMIC_WITH_PADDING:
234             key.restrict.val.dedicated.pooling.dynamicKenrelDividerWithPadding = 1;
235             break;
236         default:
237             break;
238     }
239 }
240
241 void ParamsKey::EnablePoolType(PoolType t) {
242     switch (t) {
243         case PoolType::MAX:
244             key.restrict.val.dedicated.pooling.max = 1;
245             break;
246         case PoolType::AVG:
247             key.restrict.val.dedicated.pooling.avg = 1;
248             break;
249         case PoolType::MAX_WITH_ARGMAX:
250             key.restrict.val.dedicated.pooling.max_with_argmax = 1;
251             break;
252         case PoolType::BILINEAR:
253             key.restrict.val.dedicated.pooling.bilinear = 1;
254             break;
255         case PoolType::DEFORMABLE_BILINEAR:
256             key.restrict.val.dedicated.pooling.deformable_bilinear = 1;
257             break;
258         default:
259             break;
260     }
261 }
262
263 void ParamsKey::EnablePoolRemainder(PoolRemainder r) {
264     switch (r) {
265         case PoolRemainder::FLOOR:
266             key.restrict.val.dedicated.pooling.floor = 1;
267             break;
268         case PoolRemainder::CEIL:
269             key.restrict.val.dedicated.pooling.ceil = 1;
270             break;
271         default:
272             break;
273     }
274 }
275
276 void ParamsKey::EnableSoftmaxDim(SoftmaxDim d) {
277     switch (d) {
278         case SoftmaxDim::X:
279             key.restrict.val.dedicated.softmax.dimX = 1;
280             break;
281         case SoftmaxDim::Y:
282             key.restrict.val.dedicated.softmax.dimY = 1;
283             break;
284         case SoftmaxDim::FEATURE:
285             key.restrict.val.dedicated.softmax.dimFeature = 1;
286             break;
287         default:
288             break;
289     }
290 }
291
292 void ParamsKey::EnableConcatAxis(ConcatAxis a) {
293     switch (a) {
294         case ConcatAxis::X:
295             key.restrict.val.dedicated.concat.axisX = 1;
296             break;
297         case ConcatAxis::Y:
298             key.restrict.val.dedicated.concat.axisY = 1;
299             break;
300         case ConcatAxis::Z:
301             key.restrict.val.dedicated.concat.axisZ = 1;
302             break;
303         case ConcatAxis::W:
304             key.restrict.val.dedicated.concat.axisW = 1;
305             break;
306         case ConcatAxis::FEATURE:
307             key.restrict.val.dedicated.concat.axisFeature = 1;
308             break;
309         case ConcatAxis::BATCH:
310             key.restrict.val.dedicated.concat.axisBatch = 1;
311             break;
312         default:
313             break;
314     }
315 }
316
317 void ParamsKey::EnableReampleType(ResampleType a) {
318     switch (a) {
319         case ResampleType::NEAREST_NEIGHBOR:
320             key.restrict.val.dedicated.resample.nearest_neighbor = 1;
321             break;
322         case ResampleType::CAFFE_BILINEAR_INTERP:
323             key.restrict.val.dedicated.resample.caffe_bilinear_interp = 1;
324             break;
325         case ResampleType::BILINEAR_INTERP:
326             key.restrict.val.dedicated.resample.bilinear_interp = 1;
327             break;
328         case ResampleType::CUBIC:
329             key.restrict.val.dedicated.resample.cubic = 1;
330             break;
331         case ResampleType::LINEAR_ONNX:
332             key.restrict.val.dedicated.resample.linear_onnx = 1;
333             break;
334         default:
335             break;
336     }
337 }
338
339 void ParamsKey::EnableFusedConvEltwEltwiseStride() { key.restrict.val.dedicated.fused_conv_eltw.stride = 1; }
340
341 void ParamsKey::EnableEltwiseStride() { key.restrict.val.dedicated.eltwise.stride = 1; }
342
343 void ParamsKey::EnableArgMaxMinAxis(ArgMaxMinAxis a) {
344     switch (a) {
345         case ArgMaxMinAxis::X:
346             key.restrict.val.dedicated.argm.axisX = 1;
347             break;
348         case ArgMaxMinAxis::Y:
349             key.restrict.val.dedicated.argm.axisY = 1;
350             break;
351         case ArgMaxMinAxis::Z:
352             key.restrict.val.dedicated.argm.axisZ = 1;
353             break;
354         case ArgMaxMinAxis::FEATURE:
355             key.restrict.val.dedicated.argm.axisFeature = 1;
356             break;
357         case ArgMaxMinAxis::BATCH:
358             key.restrict.val.dedicated.argm.axisBatch = 1;
359             break;
360         case ArgMaxMinAxis::XYF:
361             key.restrict.val.dedicated.argm.axisXYF = 1;
362             break;
363         default:
364             break;
365     }
366 }
367
368 void ParamsKey::EnableIndexSelectAxis(IndexSelectAxis a) {
369     switch (a) {
370         case IndexSelectAxis::X:
371             key.restrict.val.dedicated.idxsel.axisX = 1;
372             break;
373         case IndexSelectAxis::Y:
374             key.restrict.val.dedicated.idxsel.axisY = 1;
375             break;
376         case IndexSelectAxis::FEATURE:
377             key.restrict.val.dedicated.idxsel.axisFeature = 1;
378             break;
379         case IndexSelectAxis::BATCH:
380             key.restrict.val.dedicated.idxsel.axisBatch = 1;
381             break;
382         default:
383             break;
384     }
385 }
386
387 void ParamsKey::EnableLookUpTableIndicesFormat(Datatype a) {
388     if (a == Datatype::F32)
389         key.restrict.val.dedicated.lookt.indicesF32 = 1;
390     else
391         key.restrict.val.dedicated.lookt.indicesOther = 1;
392 }
393
394 void ParamsKey::EnableFusedConvEltwiseRWOutOpt() { key.restrict.val.dedicated.fused_conv_eltw.rw_out_opt = 1; }
395 void ParamsKey::EnableFusedConvEltwDepthToSpaceFusing() { key.restrict.val.dedicated.fused_conv_eltw.depth_to_space_fused = 1; }
396
397
398 void ParamsKey::EnableQuantization(QuantizationType q) {
399     switch (q) {
400         case QuantizationType::NONE:
401             break;
402         case QuantizationType::SYMMETRIC:
403             key.restrict.val.sym_quantization = 1;
404             break;
405         case QuantizationType::ASYMMETRIC_DATA:
406             key.restrict.val.asym_d_quantization = 1;
407             break;
408         case QuantizationType::ASYMMETRIC_WEIGHTS:
409             key.restrict.val.asym_w_quantization = 1;
410             break;
411         case QuantizationType::ASYMMETRIC_DATA_AND_WEIGHTS:
412             key.restrict.val.asym_d_quantization = 1;
413             key.restrict.val.asym_w_quantization = 1;
414             break;
415         default:
416             break;
417     }
418 }
419
420 bool ParamsKey::Support(const ParamsKey& k) const {
421     if (!((key.restrict.raw & k.key.restrict.raw) == k.key.restrict.raw))  // check if this kernel supports this params
422         return false;
423     if (!((key.machineInfo.raw & k.key.machineInfo.raw) ==
424           key.machineInfo.raw))  // check if machine supports this kernel
425         return false;
426     if (!((key.inputType.raw & k.key.inputType.raw) == k.key.inputType.raw))
427         return false;
428     if (!((key.outputType.raw & k.key.outputType.raw) == k.key.outputType.raw))
429         return false;
430     if (!((key.inputWeightsType.raw & k.key.inputWeightsType.raw) == k.key.inputWeightsType.raw))
431         return false;
432     if (!((key.outputWeightsType.raw & k.key.outputWeightsType.raw) == k.key.outputWeightsType.raw))
433         return false;
434     if (!((key.inputLayout & k.key.inputLayout) != 0 || key.inputLayout == k.key.inputLayout))
435         return false;
436     if (!((key.outputLayout & k.key.outputLayout) != 0 || key.outputLayout == k.key.outputLayout))
437         return false;
438     if (!((key.weightsInputLayout & k.key.weightsInputLayout) != 0 ||
439           key.weightsInputLayout == k.key.weightsInputLayout))
440         return false;
441     if (!((key.weightsOutputLayout & k.key.weightsOutputLayout) != 0 ||
442           key.weightsOutputLayout == k.key.weightsOutputLayout))
443         return false;
444
445     return true;
446 }
447
448 ParamsKey ParamsKey::Merge(const ParamsKey& k) const {
449     ParamsKey ret;
450     ret.key.restrict.raw = key.restrict.raw | k.key.restrict.raw;
451     ret.key.machineInfo.raw = key.machineInfo.raw | k.key.machineInfo.raw;
452     ret.key.inputType.raw = key.inputType.raw | k.key.inputType.raw;
453     ret.key.outputType.raw = key.outputType.raw | k.key.outputType.raw;
454     ret.key.inputWeightsType.raw = key.inputWeightsType.raw | k.key.inputWeightsType.raw;
455     ret.key.outputWeightsType.raw = key.outputWeightsType.raw | k.key.outputWeightsType.raw;
456     ret.key.inputLayout = key.inputLayout | k.key.inputLayout;
457     ret.key.outputLayout = key.outputLayout | k.key.outputLayout;
458     ret.key.weightsInputLayout = key.weightsInputLayout | k.key.weightsInputLayout;
459     ret.key.weightsOutputLayout = key.weightsOutputLayout | k.key.weightsOutputLayout;
460     return ret;
461 }
462 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
463 // Params
464 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
465 ParamsKey Params::GetParamsKey() const {
466     ParamsKey k;
467
468     if (engineInfo.bSubGroupSupport) {
469         k.EnableSubGroup();
470     }
471
472     if (engineInfo.bSubGroupShortSupport) {
473         k.EnableSubGroupShort();
474     }
475
476     if (engineInfo.bSubGroupCharSupport) {
477         k.EnableSubGroupChar();
478     }
479
480     return k;
481 }
482
483 std::string Params::to_string() const {
484     std::stringstream s;
485     s << toString(kType);
486     return s.str();
487 }
488
489 std::string Params::to_cache_string_v2() const {
490     return "";
491 }
492
493 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
494 // optional_params
495 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
496 ParamsKey optional_params::GetSupportedKey() const {
497     ParamsKey k;
498
499     for (auto l : inputLayouts) {
500         k.EnableInputLayout(l);
501     }
502
503     for (auto l : outputLayouts) {
504         k.EnableOutputLayout(l);
505     }
506
507     return k;
508 }
509
510 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
511 // base_params
512 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
513 ParamsKey base_params::GetParamsKey() const {
514     ParamsKey k = Params::GetParamsKey();
515
516     bool bBatching = false;
517     bool bPitches = false;
518     bool bOffests = false;
519     bool bDifferentTypes = false;
520     bool bFP16Used = (output.GetDType() == Datatype::F16);
521
522     for (const auto& i : inputs) {
523         k.EnableInputDataType(i.GetDType());
524         k.EnableInputLayout(i.GetLayout());
525
526         bBatching |= (i.Batch().v > 1);
527         bPitches |= (i.PitchesDifferFromLogicalDims());
528         bOffests |= (i.GetFirstElementOffset() != 0);
529         bDifferentTypes |= (i.GetDType() != output.GetDType());
530         bFP16Used |= (i.GetDType() == Datatype::F16);
531     }
532
533     k.EnableOutputDataType(output.GetDType());
534     k.EnableOutputLayout(output.GetLayout());
535
536     if (bBatching) {
537         k.EnableBatching();
538     }
539
540     if (bPitches || output.PitchesDifferFromLogicalDims()) {
541         k.EnableTensorPitches();
542     }
543
544     if (bDifferentTypes) {
545         k.EnableDifferentTypes();
546     }
547
548     if (bOffests || output.GetFirstElementOffset() != 0) {
549         k.EnableTensorOffset();
550     }
551
552     if (!engineInfo.bFP16Support && bFP16Used) {
553         // I'm not sure it's the best idea, but we can live with it right now
554         k.EnableFP16Emulation();
555     }
556
557     return k;
558 }
559
560 std::string base_activation_params::to_string() const {
561     std::stringstream s;
562     s << "m" << m << "_n" << n << "_" << toString(function);
563     return s.str();
564 }
565
566 std::string base_params::to_string() const {
567     std::stringstream s;
568
569     // WA to reuse old tuning cache. Code below must be replace with the following line once new cache file is merged.
570     // s << Params::to_string() << "_";
571     auto type_string = toString(kType);
572     if (kType == KernelType::FUSED_CONV_ELTWISE) {
573         type_string = "";
574     }
575     s << type_string << "_";
576
577     // TODO: Remove activation from the string and recollect cache file
578     bool found_fused_activation = false;
579     if (!activations.empty()) {
580         s << activations[0].to_string() << "_";
581         found_fused_activation = true;
582     }
583
584     if (activations.empty() && !fused_ops.empty()) {
585         if (fused_ops[0].GetType() == KernelType::ACTIVATION) {
586             auto activation_params = fused_ops[0].GetOpParams<activation_fuse_params>()->param;
587             s << activation_params.to_string() << "_";
588             found_fused_activation = true;
589         }
590     }
591
592     if (!found_fused_activation) {
593         s << "m" << 0.f << "_n" << 0.f << "_" << toString(ActivationFunction::NONE) << "_";
594     }
595
596     for (auto input : inputs) {
597         s << toString(input) << "_";
598     }
599     s << toString(output);
600
601     return s.str();
602 }
603
604 std::string base_params::to_cache_string_v2() const {
605     std::stringstream s;
606
607     for (auto input : inputs) {
608         s << toString_v2(input) << ";";
609     }
610     s << toString_v2(output);
611
612     return s.str();
613 }
614
615 }  // namespace kernel_selector