Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / kernel_selector_params.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "kernel_selector_params.h"
18 #include "kernel_selector_common.h"
19 #include <sstream>
20
21 namespace kernel_selector {
22
23     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
24     // ParamsKey
25     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
26     void ParamsKey::EnableInputDataType(Datatype dt)
27     {
28         switch (dt)
29         {
30         case Datatype::INT8:
31             key.inputType.val.int8 = 1;
32             break;
33         case Datatype::UINT8:
34             key.inputType.val.uint8 = 1;
35             break;
36         case Datatype::INT16:
37             key.inputType.val.int16 = 1;
38             break;
39         case Datatype::UINT16:
40             key.inputType.val.uint16 = 1;
41             break;
42         case Datatype::INT32:
43             key.inputType.val.int32 = 1;
44             break;
45         case Datatype::UINT32:
46             key.inputType.val.uint32 = 1;
47             break;
48         case Datatype::INT64:
49             key.inputType.val.int64 = 1;
50             break;
51         case Datatype::F16:
52             key.inputType.val.F16 = 1;
53             break;
54         case Datatype::F32:
55             key.inputType.val.F32 = 1;
56             break;
57         default:
58             break;
59         }
60     }
61
62     void ParamsKey::EnableAllInputDataType()
63     {
64         key.inputType.raw = 0xffffffff;
65     }
66
67     void ParamsKey::EnableOutputDataType(Datatype dt)
68     {
69         switch (dt)
70         {
71         case Datatype::INT8:
72             key.outputType.val.int8 = 1;
73             break;
74         case Datatype::UINT8:
75             key.outputType.val.uint8 = 1;
76             break;
77         case Datatype::INT16:
78             key.outputType.val.int16 = 1;
79             break;
80         case Datatype::UINT16:
81             key.outputType.val.uint16 = 1;
82             break;
83         case Datatype::INT32:
84             key.outputType.val.int32 = 1;
85             break;
86         case Datatype::UINT32:
87             key.outputType.val.uint32 = 1;
88             break;
89         case Datatype::INT64:
90             key.outputType.val.int64 = 1;
91             break;
92         case Datatype::F16:
93             key.outputType.val.F16 = 1;
94             break;
95         case Datatype::F32:
96             key.outputType.val.F32 = 1;
97             break;
98         default:
99             break;
100         }
101     }
102
103     void ParamsKey::EnableAllOutputDataType()
104     {
105         key.outputType.raw = 0xffffffff;
106     }
107
108     void ParamsKey::EnableInputWeightsType(WeightsType wt)
109     {
110         switch (wt)
111         {
112         case WeightsType::F16:
113             key.inputWeightsType.val.F16 = 1;
114             break;
115         case WeightsType::F32:
116             key.inputWeightsType.val.F32 = 1;
117             break;
118         case WeightsType::INT8:
119             key.inputWeightsType.val.int8 = 1;
120             break;
121         default:
122             break;
123         }
124     }
125
126     void ParamsKey::EnableAllInputWeightsType()
127     {
128         key.inputWeightsType.raw = 0xffffffff;
129     }
130
131     void ParamsKey::EnableOutputWeightsType(WeightsType wt)
132     {
133         switch (wt)
134         {
135         case WeightsType::F16:
136             key.outputWeightsType.val.F16 = 1;
137             break;
138         case WeightsType::F32:
139             key.outputWeightsType.val.F32 = 1;
140             break;
141         case WeightsType::INT8:
142             key.outputWeightsType.val.int8 = 1;
143             break;
144         default:
145             break;
146         }
147     }
148
149     void ParamsKey::EnableAllOutputWeightsType()
150     {
151         key.outputWeightsType.raw = 0xffffffff;
152     }
153
154     void ParamsKey::EnableLRNMode(LRNMode m)
155     {
156         switch (m)
157         {
158         case LRNMode::ACROSS_CHANNEL:
159             key.restrict.val.dedicated.norm.across = 1;
160             break;
161         case LRNMode::WITHIN_CHANNEL:
162             key.restrict.val.dedicated.norm.within = 1;
163             break;
164         default:
165             break;
166         }
167     }
168
169     void ParamsKey::EnableLookUpTableAxis(LookUpTableAxis m)
170     {
171         switch (m)
172         {
173         case kernel_selector::LookUpTableAxis::BATCH:
174             key.restrict.val.dedicated.lookt.axisBatch = 1;
175             break;
176         case kernel_selector::LookUpTableAxis::FEATURE:
177             key.restrict.val.dedicated.lookt.axisFeature = 1;
178             break;
179         case kernel_selector::LookUpTableAxis::X:
180             key.restrict.val.dedicated.lookt.axisX = 1;
181             break;
182         case kernel_selector::LookUpTableAxis::Y:
183             key.restrict.val.dedicated.lookt.axisY = 1;
184             break;
185         case kernel_selector::LookUpTableAxis::XYF:
186             key.restrict.val.dedicated.lookt.axisXYF = 1;
187             break;
188         default:
189             break;
190         }
191     }
192
193     void ParamsKey::EnableNormalizeMode(NormalizeMode m)
194     {
195         switch (m)
196         {
197         case NormalizeMode::ACROSS_SPATIAL:
198             key.restrict.val.dedicated.norm.across = 1;
199             break;
200         case NormalizeMode::WITHIN_SPATIAL:
201             key.restrict.val.dedicated.norm.within = 1;
202             break;
203         default:
204             break;
205         }
206     }
207
208     void ParamsKey::EnableMVNMode(MVNMode m)
209     {
210         switch (m)
211         {
212         case MVNMode::ACROSS_CHANNELS:
213             key.restrict.val.dedicated.mvn.across = 1;
214             break;
215         case MVNMode::WITHIN_CHANNELS:
216             key.restrict.val.dedicated.mvn.within = 1;
217             break;
218         default:
219             break;
220         }
221     }
222
223     void ParamsKey::EnableMVNNormalizeVariance()
224     {
225         key.restrict.val.dedicated.mvn.normalize_variance = 1;
226     }
227
228     void ParamsKey::EnableLRNKernelDividerMode(KernelDividerMode m)
229     {
230         switch (m)
231         {
232         case KernelDividerMode::FIXED:
233             key.restrict.val.dedicated.norm.fixedKenrelDivider = 1;
234             break;
235         case KernelDividerMode::DYNAMIC:
236             key.restrict.val.dedicated.norm.dynamicKenrelDivider = 1;
237             break;
238         default:
239             break;
240         }
241     }
242
243     void ParamsKey::EnablePoolKernelDividerMode(KernelDividerMode m)
244     {
245         switch (m)
246         {
247         case KernelDividerMode::FIXED:
248             key.restrict.val.dedicated.pooling.fixedKenrelDivider = 1;
249             break;
250         case KernelDividerMode::DYNAMIC:
251             key.restrict.val.dedicated.pooling.dynamicKenrelDivider = 1;
252             break;
253         case KernelDividerMode::DYNAMIC_WITH_PADDING:
254             key.restrict.val.dedicated.pooling.dynamicKenrelDividerWithPadding = 1;
255             break;
256         default:
257             break;
258         }
259     }
260
261     void ParamsKey::EnablePoolType(PoolType t)
262     {
263         switch (t)
264         {
265         case PoolType::MAX:
266             key.restrict.val.dedicated.pooling.max = 1;
267             break;
268         case PoolType::AVG:
269             key.restrict.val.dedicated.pooling.avg = 1;
270             break;
271         case PoolType::MAX_WITH_ARGMAX:
272             key.restrict.val.dedicated.pooling.max_with_argmax = 1;
273             break;
274         case PoolType::BILINEAR:
275             key.restrict.val.dedicated.pooling.bilinear = 1;
276         default:
277             break;
278         }
279     }
280
281     void ParamsKey::EnablePoolRemainder(PoolRemainder r)
282     {
283         switch (r)
284         {
285         case PoolRemainder::FLOOR:
286             key.restrict.val.dedicated.pooling.floor = 1;
287             break;
288         case PoolRemainder::CEIL:
289             key.restrict.val.dedicated.pooling.ceil = 1;
290             break;
291         default:
292             break;
293         }
294     }
295
296     void ParamsKey::EnableSoftmaxDim(SoftmaxDim d)
297     {
298         switch (d)
299         {
300         case SoftmaxDim::X:
301             key.restrict.val.dedicated.softmax.dimX = 1;
302             break;
303         case SoftmaxDim::Y:
304             key.restrict.val.dedicated.softmax.dimY = 1;
305             break;
306         case SoftmaxDim::FEATURE:
307             key.restrict.val.dedicated.softmax.dimFeature = 1;
308             break;
309         default:
310             break;
311         }
312     }
313
314     void ParamsKey::EnableConcatAxis(ConcatAxis a)
315     {
316         switch (a)
317         {
318         case ConcatAxis::X:
319             key.restrict.val.dedicated.concat.axisX = 1;
320             break;
321         case ConcatAxis::Y:
322             key.restrict.val.dedicated.concat.axisY = 1;
323             break;
324         case ConcatAxis::FEATURE:
325             key.restrict.val.dedicated.concat.axisFeature = 1;
326             break;
327         case ConcatAxis::BATCH:
328             key.restrict.val.dedicated.concat.axisBatch = 1;
329             break;
330         default:
331             break;
332         }
333     }
334
335     void ParamsKey::EnableUpSamplingSampleType(SampleType a)
336     {
337         switch (a)
338         {
339         case SampleType::NEAREST:
340             key.restrict.val.dedicated.upsample.nearest = 1;
341             break;
342         case SampleType::BILINEAR:
343             key.restrict.val.dedicated.upsample.bilinear = 1;
344             break;
345         default:
346             break;
347         }
348     }
349
350     void ParamsKey::EnableFusedConvEltwEltwiseStride()
351     {
352         key.restrict.val.dedicated.fused_conv_eltw.stride = 1;
353     }
354
355     void ParamsKey::EnableEltwiseStride()
356     {
357         key.restrict.val.dedicated.eltwise.stride = 1;
358     }
359
360     void ParamsKey::EnableArgMaxMinAxis(ArgMaxMinAxis a)
361     {
362         switch (a)
363         {
364         case ArgMaxMinAxis::X:
365             key.restrict.val.dedicated.argm.axisX = 1;
366             break;
367         case ArgMaxMinAxis::Y:
368             key.restrict.val.dedicated.argm.axisY = 1;
369             break;
370         case ArgMaxMinAxis::FEATURE:
371             key.restrict.val.dedicated.argm.axisFeature = 1;
372             break;
373         case ArgMaxMinAxis::BATCH:
374             key.restrict.val.dedicated.argm.axisBatch = 1;
375             break;
376         case ArgMaxMinAxis::XYF:
377             key.restrict.val.dedicated.argm.axisXYF = 1;
378             break;
379         default:
380             break;
381         }
382     }
383
384     void ParamsKey::EnableIndexSelectAxis(IndexSelectAxis a)
385     {
386         switch (a)
387         {
388         case IndexSelectAxis::X:
389             key.restrict.val.dedicated.idxsel.axisX = 1;
390             break;
391         case IndexSelectAxis::Y:
392             key.restrict.val.dedicated.idxsel.axisY = 1;
393             break;
394         case IndexSelectAxis::FEATURE:
395             key.restrict.val.dedicated.idxsel.axisFeature = 1;
396             break;
397         case IndexSelectAxis::BATCH:
398             key.restrict.val.dedicated.idxsel.axisBatch = 1;
399             break;
400         default:
401             break;
402         }
403     }
404
405     void ParamsKey::EnableLookUpTableIndicesFormat(Datatype a)
406     {
407         if (a == Datatype::F32)
408             key.restrict.val.dedicated.lookt.indicesF32 = 1;
409         else
410             key.restrict.val.dedicated.lookt.indicesOther = 1;
411     }
412
413     void ParamsKey::EnableFusedConvEltwiseRWOutOpt()
414     {
415         key.restrict.val.dedicated.fused_conv_eltw.rw_out_opt = 1;
416     }
417
418     bool ParamsKey::Support(const ParamsKey& k) const
419     {
420         if (!((key.restrict.raw & k.key.restrict.raw) == k.key.restrict.raw)) // check if this kernel supports this params
421             return false;
422         if (!((key.machineInfo.raw & k.key.machineInfo.raw) == key.machineInfo.raw)) // check if machine supports this kernel
423             return false;
424         if (!((key.inputType.raw & k.key.inputType.raw) == k.key.inputType.raw))
425             return false;
426         if (!((key.outputType.raw & k.key.outputType.raw) == k.key.outputType.raw))
427             return false;
428         if (!((key.inputWeightsType.raw & k.key.inputWeightsType.raw) == k.key.inputWeightsType.raw))
429             return false;
430         if (!((key.outputWeightsType.raw & k.key.outputWeightsType.raw) == k.key.outputWeightsType.raw))
431             return false;
432         if (!((key.inputLayout & k.key.inputLayout) != 0 || key.inputLayout == k.key.inputLayout))
433             return false;
434         if (!((key.outputLayout & k.key.outputLayout) != 0 || key.outputLayout == k.key.outputLayout))
435             return false;
436         if (!((key.weightsInputLayout & k.key.weightsInputLayout) != 0 || key.weightsInputLayout == k.key.weightsInputLayout))
437             return false;
438         if (!((key.weightsOutputLayout & k.key.weightsOutputLayout) != 0 || key.weightsOutputLayout == k.key.weightsOutputLayout))
439             return false;
440
441         return true;
442     }
443
444     ParamsKey ParamsKey::Merge(const ParamsKey& k) const
445     {
446         ParamsKey ret;
447         ret.key.restrict.raw = key.restrict.raw | k.key.restrict.raw;
448         ret.key.machineInfo.raw = key.machineInfo.raw | k.key.machineInfo.raw;
449         ret.key.inputType.raw = key.inputType.raw | k.key.inputType.raw;
450         ret.key.outputType.raw = key.outputType.raw | k.key.outputType.raw;
451         ret.key.inputWeightsType.raw = key.inputWeightsType.raw | k.key.inputWeightsType.raw;
452         ret.key.outputWeightsType.raw = key.outputWeightsType.raw | k.key.outputWeightsType.raw;
453         ret.key.inputLayout = key.inputLayout | k.key.inputLayout;
454         ret.key.outputLayout = key.outputLayout | k.key.outputLayout;
455         ret.key.weightsInputLayout = key.weightsInputLayout | k.key.weightsInputLayout;
456         ret.key.weightsOutputLayout = key.weightsOutputLayout | k.key.weightsOutputLayout;
457         return ret;
458     }
459     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
460     // Params
461     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
462     ParamsKey Params::GetParamsKey() const
463     {
464         ParamsKey k;
465
466         if (engineInfo.bSubGroupSupport)
467         {
468             k.EnableSubGroup();
469         }
470
471         if (engineInfo.bSubGroupShortSupport)
472         {
473             k.EnableSubGroupShort();
474         }
475
476         return k;
477     }
478
479     std::string Params::to_string() const
480     {
481         std::stringstream s;
482         s << toString(kType);
483         return s.str();
484     }
485
486     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
487     // optional_params
488     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
489     ParamsKey optional_params::GetSupportedKey() const
490     {
491         ParamsKey k;
492
493         for (auto l : inputLayouts)
494         {
495             k.EnableInputLayout(l);
496         }
497
498         for (auto l : outputLayouts)
499         {
500             k.EnableOutputLayout(l);
501         }
502
503         return k;
504     }
505
506     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
507     // base_params
508     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
509     ParamsKey base_params::GetParamsKey() const
510     {
511         ParamsKey k = Params::GetParamsKey();
512
513         bool bBatching = false;
514         bool bPitches = false;
515         bool bOffests = false;
516         bool bDifferentTypes = false;
517         bool bFP16Used = (output.GetDType() == Datatype::F16);
518
519         for (const auto& i : inputs)
520         {
521             k.EnableInputDataType(i.GetDType());
522             k.EnableInputLayout(i.GetLayout());
523
524             bBatching |= (i.Batch().v > 1);
525             bPitches |= (i.PitchesDifferFromLogicalDims());
526             bOffests |= (i.GetFirstElementOffset() != 0);
527             bDifferentTypes |= (i.GetDType() != output.GetDType());
528             bFP16Used |= (i.GetDType() == Datatype::F16);
529         }
530
531         k.EnableOutputDataType(output.GetDType());
532         k.EnableOutputLayout(output.GetLayout());
533
534         if (bBatching)
535         {
536             k.EnableBatching();
537         }
538
539         if (bPitches ||
540             output.PitchesDifferFromLogicalDims())
541         {
542             k.EnableTensorPitches();
543         }
544
545         if (bDifferentTypes)
546         {
547             k.EnableDifferentTypes();
548         }
549
550         if (bOffests ||
551             output.GetFirstElementOffset() != 0)
552         {
553             k.EnableTensorOffset();
554         }
555
556         if (!engineInfo.bFP16Support &&
557             bFP16Used)
558         {
559             // I'm not sure it's the best idea, but we can live with it right now
560             k.EnableFP16Emulation();
561         }
562
563         if (gradient)
564         {
565             k.EnableGradient();
566         }
567
568         return k;
569     }
570
571     std::string base_activation_params::to_string() const
572     {
573         std::stringstream s;
574         s << "m" << m << "_n" << n << "_" << toString(function);
575         return s.str();
576     }
577
578     std::string base_params::to_string() const
579     {
580         std::stringstream s;
581         s << Params::to_string() << "_";
582         s << activation.to_string() << "_";
583
584         for (auto input : inputs)
585         {
586             s << toString(input) << "_";
587         }
588         s << toString(output);
589
590         return s.str();
591     }
592 }