[IE CLDNN] Fix linear_onnx Interpolate selection (#2769)
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / kernel_selector_params.h
1 /*
2 // Copyright (c) 2016-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #pragma once
18
19 #include <string>
20 #include <memory>
21 #include <cstddef>
22 #include <limits>
23 #include "common_types.h"
24 #include "tensor_type.h"
25 #include "document.h"
26 #include <vector>
27 #include <utility>
28 #include <bitset>
29
30 namespace kernel_selector {
31 using DataTensor = Tensor::DataTensor;
32 using WeightsTensor = Tensor::WeightsTensor;
33 using DataLayout = Tensor::DataLayout;
34 using WeightsLayout = Tensor::WeightsLayout;
35 using MultiDataTensor = std::vector<DataTensor>;
36 using DataBitField = std::bitset<DataLayout::DataLayoutCount>;
37 using WightsBitField = std::bitset<WeightsLayout::WeightsLayoutCount>;
38
39 class JitConstants;
40 class TuningCache;
41
42 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
43 // fuse_params
44 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
45 struct fuse_params {
46     virtual ~fuse_params() {}
47
48     KernelType GetType() const { return kType; }
49 protected:
50     explicit fuse_params(KernelType kt) : kType(kt) {}
51     KernelType kType;
52 };
53
54 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
55 // ParamsKey
56 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
57 class ParamsKey {
58 public:
59     ParamsKey() {
60         key.restrict.raw = 0;
61         key.enableTuning = 1;
62         key.machineInfo.raw = 0;
63         key.inputType.raw = 0;
64         key.outputType.raw = 0;
65         key.inputWeightsType.raw = 0;
66         key.outputWeightsType.raw = 0;
67         key.inputLayout = 0;
68         key.outputLayout = 0;
69         key.weightsInputLayout = 0;
70         key.weightsOutputLayout = 0;
71     }
72
73     struct Key {
74         union restrict_t {
75             struct val_t {
76                 uint32_t different_types : 1;
77                 uint32_t different_input_weights_types : 1;
78                 uint32_t offset : 1;
79                 uint32_t pitches : 1;
80                 uint32_t batching : 1;
81                 uint32_t biasPerFeatureMap : 1;
82                 uint32_t biasPerOutput : 1;
83                 uint32_t nonBias : 1;
84                 uint32_t activationAdditionalParamsAsInput : 1;
85                 uint32_t FP16Emulation : 1;
86                 uint32_t momentum : 1;
87                 uint32_t quantization : 1;
88                 uint32_t sym_quantization : 1;
89                 uint32_t asym_w_quantization : 1;
90                 uint32_t asym_d_quantization : 1;
91
92                 union dedicated_t {
93                     struct lookt_t {
94                         uint32_t axisX : 1;
95                         uint32_t axisY : 1;
96                         uint32_t axisFeature : 1;
97                         uint32_t axisBatch : 1;
98                         uint32_t axisXYF : 1;
99                         uint32_t indicesF32 : 1;
100                         uint32_t indicesOther : 1;
101                     } lookt;
102                     struct argm_t {
103                         uint32_t axisX : 1;
104                         uint32_t axisY : 1;
105                         uint32_t axisZ : 1;
106                         uint32_t axisFeature : 1;
107                         uint32_t axisBatch : 1;
108                         uint32_t axisXYF : 1;
109                     } argm;
110                     struct idxsel_t {
111                         uint32_t axisX : 1;
112                         uint32_t axisY : 1;
113                         uint32_t axisFeature : 1;
114                         uint32_t axisBatch : 1;
115                     } idxsel;
116                     struct norm_t {
117                         uint32_t across : 1;
118                         uint32_t within : 1;
119                         uint32_t fixedKenrelDivider : 1;
120                         uint32_t dynamicKenrelDivider : 1;
121                     } norm;
122                     struct mvn_t {
123                         uint32_t across : 1;
124                         uint32_t within : 1;
125                         uint32_t normalize_variance : 1;
126                     } mvn;
127                     struct pooling_t {
128                         uint32_t max : 1;
129                         uint32_t avg : 1;
130                         uint32_t floor : 1;
131                         uint32_t max_with_argmax : 1;
132                         uint32_t ceil : 1;
133                         uint32_t bilinear : 1;
134                         uint32_t deformable_bilinear : 1;
135                         uint32_t fixedKenrelDivider : 1;
136                         uint32_t dynamicKenrelDivider : 1;
137                         uint32_t dynamicKenrelDividerWithPadding : 1;
138                         uint32_t position_sensitive : 1;
139                     } pooling;
140                     struct conv_t {
141                         uint32_t split : 1;
142                         uint32_t dilation : 1;
143                         uint32_t depthwise_separable_opt : 1;
144                         uint32_t local : 1;
145                         uint32_t grouped : 1;
146                         uint32_t deformable : 1;
147                     } conv;
148                     struct fc_t {
149                     } fc;
150                     struct softmax_t {
151                         uint32_t dimX : 1;
152                         uint32_t dimY : 1;
153                         uint32_t dimFeature : 1;
154                     } softmax;
155                     struct region_yolo_t {
156                         uint32_t dimX : 1;
157                         uint32_t dimY : 1;
158                         uint32_t dimFeature : 1;
159                         uint32_t coords : 1;
160                         uint32_t classes : 1;
161                         uint32_t num : 1;
162                     } region_yolo;
163                     struct reorg_yolo_t {
164                         uint32_t dimX : 1;
165                         uint32_t dimY : 1;
166                         uint32_t dimFeature : 1;
167                         uint32_t stride : 1;
168                     } reorg_yolo;
169                     struct concat_t {
170                         uint32_t axisX : 1;
171                         uint32_t axisY : 1;
172                         uint32_t axisZ : 1;
173                         uint32_t axisW : 1;
174                         uint32_t axisFeature : 1;
175                         uint32_t axisBatch : 1;
176                         uint32_t kernelPerInput : 1;
177                         uint32_t oneKernel : 1;
178                     } concat;
179                     struct upsample_t {
180                         uint32_t nearest_neighbor : 1;
181                         uint32_t caffe_bilinear_interp : 1;
182                         uint32_t bilinear_interp : 1;
183                         uint32_t cubic : 1;
184                         uint32_t linear_onnx : 1;
185                     } resample;
186                     struct reorder_t {
187                         uint32_t winograd : 1;
188                         uint32_t rotate : 1;
189                     } reorder;
190                     struct eltwise_t {
191                         uint32_t stride : 1;
192                         uint32_t broadcast : 1;
193                     } eltwise;
194                     struct lstm_gemm_t {
195                         uint32_t bias : 1;
196                         uint32_t hidden : 1;
197                     } lstm_gemm;
198                     struct lstm_dynamic_t {
199                         uint32_t last_hidden : 1;
200                         uint32_t last_cell : 1;
201                     } lstm_dynamic;
202                     struct lstm_elt_t {
203                         uint32_t cell : 1;
204                     } lstm_elt;
205                     struct fused_conv_eltw_t {
206                         // conv
207                         uint32_t split : 1;
208                         uint32_t dilation : 1;
209                         uint32_t depthwise_separable_opt : 1;
210                         uint32_t transposed : 1;
211                         uint32_t local : 1;
212                         uint32_t grouped : 1;
213                         // eltw
214                         uint32_t stride : 1;
215                         // fused conv eltw
216                         uint32_t rw_out_opt : 1;
217                         uint32_t depth_to_space_fused : 1;
218                     } fused_conv_eltw;
219                     struct quantize_t {
220                         uint32_t packed_binary_output : 1;
221                         uint32_t scale_shift_opt : 1;
222                     } quantize;
223                 } dedicated;
224             } val;
225             uint64_t raw;
226         } restrict;
227
228         union machine_info_t {
229             struct val_t {
230                 uint32_t subgroup : 1;
231                 uint32_t subgroupShort : 1;
232                 uint32_t subgroupChar : 1;
233             } val;
234             uint32_t raw;
235         } machineInfo;
236
237         static_assert(sizeof(restrict_t) == sizeof(uint64_t), "problem with union");
238
239         typedef union DataTypesKey_t {
240             struct val_t {
241                 uint32_t int8 : 1;
242                 uint32_t uint8 : 1;
243                 uint32_t int16 : 1;
244                 uint32_t uint16 : 1;
245                 uint32_t int32 : 1;
246                 uint32_t uint32 : 1;
247                 uint32_t int64 : 1;
248                 uint32_t F16 : 1;
249                 uint32_t F32 : 1;
250                 uint32_t binary : 1;
251             } val;
252             uint32_t raw;
253         } DataTypesKey;
254
255         uint32_t enableTuning;
256         DataTypesKey inputType;
257         DataTypesKey outputType;
258         DataTypesKey inputWeightsType;
259         DataTypesKey outputWeightsType;
260         DataBitField inputLayout;
261         DataBitField outputLayout;
262         WightsBitField weightsInputLayout;
263         WightsBitField weightsOutputLayout;
264     };
265
266     void EnableInputDataType(Datatype dt);
267     void EnableAllInputDataType();
268     void EnableOutputDataType(Datatype dt);
269     void EnableAllOutputDataType();
270     void EnableInputWeightsType(WeightsType wt);
271     void EnableAllInputWeightsType();
272     void EnableOutputWeightsType(WeightsType wt);
273     void EnableAllOutputWeightsType();
274     void EnableFP16Emulation() { key.restrict.val.FP16Emulation = 1; }
275     void EnableDifferentTypes() { key.restrict.val.different_types = 1; }
276     void EnableDifferentInputWeightsTypes() { key.restrict.val.different_input_weights_types = 1; }
277     void EnableInputLayout(DataLayout l) { key.inputLayout.set(static_cast<size_t>(l)); }
278     void EnableAllInputLayout() { key.inputLayout.set(); }
279     void EnableOutputLayout(DataLayout l) { key.outputLayout.set(static_cast<size_t>(l)); }
280     void EnableAllOutputLayout() { key.outputLayout.set(); }
281     void EnableInputWeightsLayout(WeightsLayout l) {
282         key.weightsInputLayout.set(static_cast<size_t>(l));
283     }
284     void EnableAllInputWeightsLayout() { key.weightsInputLayout.set(); }
285     void EnableOutputWeightsLayout(WeightsLayout l) {
286         key.weightsOutputLayout.set(static_cast<size_t>(l));
287     }
288     void EnableAllOutputWeightsLayout() { key.weightsOutputLayout.set(); }
289     void EnableTensorOffset() { key.restrict.val.offset = 1; }
290     void EnableTensorPitches() { key.restrict.val.pitches = 1; }
291     void EnableBatching() { key.restrict.val.batching = 1; }
292     void EnableSubGroup() { key.machineInfo.val.subgroup = 1; }
293     void EnableSubGroupShort() { key.machineInfo.val.subgroupShort = 1; }
294     void EnableSubGroupChar() { key.machineInfo.val.subgroupChar = 1; }
295     void EnableNonBiasTerm() { key.restrict.val.nonBias = 1; }
296     void EnableBiasPerFeature() { key.restrict.val.biasPerFeatureMap = 1; }
297     void EnableBiasPerOutput() { key.restrict.val.biasPerOutput = 1; }
298     void EnableActivationAdditionalParamsAsInput() { key.restrict.val.activationAdditionalParamsAsInput = 1; }
299     void EnableMomentum() { key.restrict.val.momentum = 1; }
300     void EnableLRNMode(LRNMode m);
301     void EnableLookUpTableAxis(LookUpTableAxis m);
302     void EnableNormalizeMode(NormalizeMode m);
303     void EnableMVNMode(MVNMode m);
304     void EnableMVNNormalizeVariance();
305     void EnableLRNKernelDividerMode(KernelDividerMode m);
306     void EnablePoolKernelDividerMode(KernelDividerMode m);
307     void EnablePoolType(PoolType t);
308     void EnablePoolRemainder(PoolRemainder r);
309     void EnableQuantization(QuantizationType q);
310     void EnablePositionSensitivePooling() { key.restrict.val.dedicated.pooling.position_sensitive = 1; }
311     void EnableSplitSupport() { key.restrict.val.dedicated.conv.split = 1; }
312     void EnableDilation() { key.restrict.val.dedicated.conv.dilation = 1; }
313     void EnableDepthwiseSeparableOpt() { key.restrict.val.dedicated.conv.depthwise_separable_opt = 1; }
314     void EnableLocalConvolution() { key.restrict.val.dedicated.conv.local = 1; }
315     void EnableGroupedConvolution() { key.restrict.val.dedicated.conv.grouped = 1; }
316     void EnableDeformableMode() { key.restrict.val.dedicated.conv.deformable = 1; }
317
318     void EnableFusedConvEltwSplitSupport() { key.restrict.val.dedicated.fused_conv_eltw.split = 1; }
319     void EnableFusedConvEltwDilation() { key.restrict.val.dedicated.fused_conv_eltw.dilation = 1; }
320     void EnableFusedConvEltwDepthwiseSeparableOpt() {
321         key.restrict.val.dedicated.fused_conv_eltw.depthwise_separable_opt = 1;
322     }
323     void EnableFusedConvEltwLocalConvolution() { key.restrict.val.dedicated.fused_conv_eltw.local = 1; }
324     void EnableFusedConvEltwGroupedConvolution() { key.restrict.val.dedicated.fused_conv_eltw.grouped = 1; }
325     void EnableFusedConvEltwTranspose() { key.restrict.val.dedicated.fused_conv_eltw.transposed = 1; }
326     void EnableFusedConvEltwEltwiseStride();
327     void EnableFusedConvEltwDepthToSpaceFusing();
328
329     void EnableQuantizePackedBinaryOutput() { key.restrict.val.dedicated.quantize.packed_binary_output = 1; }
330     void EnableQuantizeScaleShiftOpt() { key.restrict.val.dedicated.quantize.scale_shift_opt = 1; }
331
332     void EnableWinogradReorder() { key.restrict.val.dedicated.reorder.winograd = 1; }
333     void EnableRotateReorder() { key.restrict.val.dedicated.reorder.rotate = 1; }
334     void EnableSoftmaxDim(SoftmaxDim d);
335     void EnableConcatAxis(ConcatAxis a);
336     void EnableReampleType(ResampleType a);
337     void EnableEltwiseStride();
338     void EnableEltwiseBroadcast() { key.restrict.val.dedicated.eltwise.broadcast = 1; }
339
340     void EnableLSTMGEMMBias() { key.restrict.val.dedicated.lstm_gemm.bias = 1; }
341     void EnableLSTMGEMMHidden() { key.restrict.val.dedicated.lstm_gemm.hidden = 1; }
342     void EnableLSTMEltCell() { key.restrict.val.dedicated.lstm_elt.cell = 1; }
343     void EnableLSTMDyanmicOptionalHiddenOutput() { key.restrict.val.dedicated.lstm_dynamic.last_hidden = 1; }
344     void EnableLSTMDyanmicOptionalCellOutput() { key.restrict.val.dedicated.lstm_dynamic.last_cell = 1; }
345     void EnableConcatKernelPerInput() { key.restrict.val.dedicated.concat.kernelPerInput = 1; }
346     void DisableTuning() { key.enableTuning = 0; }
347     void EnableConcatOneKernel() { key.restrict.val.dedicated.concat.oneKernel = 1; }
348     void EnableArgMaxMinAxis(ArgMaxMinAxis a);
349     void EnableLookUpTableIndicesFormat(Datatype a);
350     void EnableIndexSelectAxis(IndexSelectAxis a);
351     void EnableFusedConvEltwiseRWOutOpt();
352     bool Support(const ParamsKey& k) const;
353     bool TuningSupport() const {
354         if (key.enableTuning == 1)
355             return true;
356         return false;
357     }
358     bool isEnabledDifferentInputWeightsTypes() const {
359         return key.restrict.val.different_input_weights_types ? true : false;
360     }
361     ParamsKey Merge(const ParamsKey& k) const;
362
363 private:
364     Key key;
365 };
366
367 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
368 // EngineInfo
369 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
370 struct EngineInfo {
371     bool bSubGroupSupport = false;
372     bool bSubGroupShortSupport = false;
373     bool bSubGroupCharSupport = false;
374     bool bFP16Support = false;
375     bool bFP64Support = false;
376     bool bImageSupport = false;
377     bool bIMADSupport = false;
378     bool bIMMADSupport = false;
379     bool bOptHintsSupport = false;
380     bool bLocalBlockIOSupport = false;
381     uint32_t computeUnitsCount = 0;
382     uint64_t maxWorkGroupSize = 0;
383     uint64_t maxLocalMemSize = 0;
384     uint64_t maxImage2dWidth = 0;
385     uint64_t maxImage2dHeight = 0;
386     std::string deviceId = "";
387     std::string driverVersion = "";
388     std::string hostVersion = "";
389     std::shared_ptr<TuningCache> deviceCache;
390 };
391
392 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
393 // Params
394 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
395 struct Params {
396     virtual ~Params() {}
397
398     KernelType GetType() const { return kType; }
399     virtual ParamsKey GetParamsKey() const;
400
401 protected:
402     Params(KernelType kt, const std::string& id) : kType(kt), layerID(id) {}
403     KernelType kType;
404
405 public:
406     std::string layerID;
407     std::string forceImplementation;
408     EngineInfo engineInfo;
409
410     virtual std::string to_string() const;
411     virtual std::string to_cache_string_v2() const;
412 };
413
414 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
415 // base_activation_params
416 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
417 struct base_activation_params {
418     ActivationFunction function = ActivationFunction::NONE;
419     float m = 1.f;
420     float n = 0.f;
421
422     base_activation_params() = default;
423     base_activation_params(const float m, const float n) : m(m), n(n) {}
424     base_activation_params(const ActivationFunction f, const float m, const float n) : function(f),
425                                                                                        m(m),
426                                                                                        n(n) {}
427
428     virtual std::string to_string() const;
429 };
430
431 struct FusedOpsConfiguration {
432     enum class LoadType {
433         LT_UNALIGNED = 0,
434         LT_ALIGNED_READ = 1,
435         FEATURE_SHUFFLE = 2
436     };
437
438     enum class BoundaryCheck {
439         DISABLED = 0,
440         ENABLED = 1
441     };
442
443     enum class IndexType {
444         TENSOR_COORD = 0,
445         LINEAR_OFFSET = 1
446     };
447
448     // Optional suffix that is added to each macro in the configuration.
449     std::string suffix;
450     // Indices to load additional data for a fused op.
451     std::vector<std::string> bfzyx_idx_order;
452     // Name of the input variable for the first fused op.
453     std::string input_var_name;
454     // Data type of the input
455     Datatype input_dt;
456     // Data type vector size of the input
457     size_t vec_size;
458     // Represents a channel in the input tensor that is loaded to the input variable
459     Tensor::DataChannelName vec_axis;
460     // Sets used load type - aligned or unaligned. Aligned load requires specific extensions and adjusted indices.
461     LoadType load_type;
462     // Defines if safe index function should be used for offset calculation
463     BoundaryCheck boundary_check;
464     // Defines how to treat indices array
465     IndexType index_type;
466     // Defines outer loops channels where fused op is called.
467     std::vector<Tensor::DataChannelName> loop_axes;
468     // If allow_for_partial_preload is false, then it's required that all fused_ops can be preloaded.
469     // If allow_for_partial_preload is true, then not preloaded fused_ops will be loaded in FUSED_OPS_CALC.
470     bool allow_for_partial_preload;
471     // Load index for shuffle fused op
472     std::string shuffle_var_name;
473
474     FusedOpsConfiguration(std::string suffix,
475                           std::vector<std::string> bfzyx_idx_order,
476                           std::string input_var_name,
477                           Datatype input_dt,
478                           size_t vec_size = 1,
479                           LoadType load_type = LoadType::LT_UNALIGNED,
480                           BoundaryCheck boundary_check = BoundaryCheck::ENABLED,
481                           IndexType index_type = IndexType::TENSOR_COORD,
482                           Tensor::DataChannelName vec_axis = Tensor::DataChannelName::COUNT,
483                           std::vector<Tensor::DataChannelName> loop_axes = {},
484                           bool allow_for_partial_preload = false,
485                           std::string shuffle_var_name = "")
486       : suffix(suffix)
487       , bfzyx_idx_order(bfzyx_idx_order)
488       , input_var_name(input_var_name)
489       , input_dt(input_dt)
490       , vec_size(vec_size)
491       , vec_axis(vec_axis)
492       , load_type(load_type)
493       , boundary_check(boundary_check)
494       , index_type(index_type)
495       , loop_axes(loop_axes)
496       , allow_for_partial_preload(allow_for_partial_preload)
497       , shuffle_var_name(shuffle_var_name) { }
498
499     FusedOpsConfiguration& SetVectorSize(size_t val) { vec_size = val; return *this; }
500     FusedOpsConfiguration& SetLoadType(LoadType val) { load_type = val; return *this; }
501     FusedOpsConfiguration& SetBoundaryCheck(BoundaryCheck val) { boundary_check = val; return *this; }
502     FusedOpsConfiguration& SetIndexType(IndexType val) { index_type = val; return *this; }
503     FusedOpsConfiguration& SetVectorAxis(Tensor::DataChannelName val) { vec_axis = val; return *this; }
504     FusedOpsConfiguration& SetLoopAxes(std::vector<Tensor::DataChannelName> val, bool partial_preload = false) {
505         loop_axes = std::move(val);
506         allow_for_partial_preload = partial_preload;
507         return *this; }
508     FusedOpsConfiguration& SetShuffleVarName(std::string val) { shuffle_var_name = val; return *this; }
509 };
510
511 // Instance of fused_operation_desc is added to fused_ops vector if a node has been fused to current one using program_impl::fuse_nodes
512 // method. In order to process fused ops following modifications should be done in a kernel:
513 // option 1 - using common generator:
514 //     - create FusedOpsConfiguration object that contains configuration for common code generator.
515 //       Multiple objects can be created if a kernel uses different data types at the same time. E.g. kernels that contains scalar and
516 //       vector branches that are chosen in runtime. To handle this case, create 2 configurations with different suffixes, like
517 //       "_SCALAR" and "_VEC" and then use generated macros accordingly.
518 //     - add jit constants returned by KernelBase::MakeFusedOpsJitConstants method to the kernel's constants.
519 //     - insert generated macros in the ocl code:
520 //       in kernel declaration:
521 //         #if HAS_FUSED_OPS_DECLS
522 //           FUSED_OPS_DECLS,
523 //         #endif
524 //       in kernel body:
525 //         #if HAS_FUSED_OPS
526 //           FUSED_OPS<OPTIONAL_SUFFIX>;
527 //           <SOME_VARIABLE> = FUSED_OPS_RESULT<OPTIONAL_SUFFIX>;
528 //         #endif
529 //   In this case common generator creates set of definitions for each op which are called sequentially in FUSED_OP<OPTIONAL_SUFFIX>
530 //   macro. Example:
531 //     #define FUSED_OPS
532 //       FUSED_OP0_LOAD_VEC
533 //       FUSED_OP0_ACTION_VEC
534 //       FUSED_OP1_LOAD_VEC
535 //       FUSED_OP1_ACTION_VEC
536 //     #define FUSED_OP0_LOAD_VEC
537 //       MAKE_VECTOR_TYPE(FUSED_OP_0_INPUT0_TYPE,2) activation0_data0 = UNIT_BLOCK_READ(activation0_input0,
538 //                                                                      FUSED_OP_0_INPUT0_GET_INDEX_SAFE(0,(f_block*16),0,0));
539 //     #define FUSED_OP0_ACTION_VEC
540 //       float2 dst_0 = dst;
541 //       dst_0 = ACTIVATION_FUSED_OP0_VEC(dst_0, ACTIVATION_PARAMS_FUSED_OP0_VEC);
542 //     #define FUSED_OP1_LOAD_VEC
543 //       MAKE_VECTOR_TYPE(FUSED_OP_1_INPUT0_TYPE,2) eltwise1_data0 = UNIT_BLOCK_READ2(eltwise1_input0,
544 //                                                                   FUSED_OP_1_INPUT0_GET_INDEX_SAFE(0,(f_block*16),y,x));
545 //     #define FUSED_OP1_ACTION_VEC
546 //       float2 dst_0_2 = convert_float2(eltwise1_data0) + convert_float2(dst_0);
547 //     #define FUSED_OPS_RESULT_VEC dst_0_2
548 // option 2 - using custom generator in a kernel. It can be used if performance is not optimal in the common one or to handle
549 //            some difficult cases that can't be unified. Custom processing of fused ops can be written absolutely independently
550 //            in a kernel, but to make it easier set of helper functions exist:
551 //     - KernelBase::MakeFusedOpsDeclsJitConstants that creates arguments for kernel declaration and macro for all tensors used in
552 //       a fused op (requires FusedOpsConfiguration instance).
553 //     - fused_operation_desc contains a bunch of methods to generate variable/pointer names, type conversions, data loads
554 //  If you need an example of custom code generation for fused ops, check BinaryConvolutionKernelGeneric::GetFusedPrimitivesJitConstants
555 //  method in binary_convolution_kernel_generic.cpp.
556 struct fused_operation_desc {
557     std::shared_ptr<fuse_params> op_params;
558     size_t dep_idx_start;
559     size_t dep_size;
560     MultiDataTensor tensors;
561     DataTensor output_tensor;
562     size_t op_id;
563
564     // Helper functions for operation generation
565     KernelType GetType() const { return op_params->GetType(); }
566     template<typename T>
567     std::shared_ptr<T> GetOpParams() const {
568         auto p = std::dynamic_pointer_cast<T>(op_params);
569         if (!p)
570             throw std::runtime_error("Invalid dynamic cast of fused operation parameters");
571
572         return p;
573     }
574 };
575
576 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
577 // base_params
578 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
579 struct base_params : public Params {
580     virtual ~base_params() {}
581
582     std::vector<base_activation_params> activations;
583     std::vector<fused_operation_desc> fused_ops = {};
584     MultiDataTensor inputs;
585     DataTensor output;
586
587     std::string to_string() const override;
588     std::string to_cache_string_v2() const override;
589     ParamsKey GetParamsKey() const override;
590
591 protected:
592     explicit base_params(KernelType kt) : Params(kt, ""), inputs(1) {}
593 };
594
595 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
596 // Auto tuner parameters
597 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
598 class KernelRunnerInterface;
599 struct TuningParams {
600     TuningMode mode;
601     std::string cacheFilePath;
602     std::shared_ptr<KernelRunnerInterface> runner;
603
604     TuningParams() : mode(TuningMode::TUNING_DISABLED), cacheFilePath(""), runner(nullptr) {}
605 };
606
607 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
608 // optional_params
609 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
610 struct optional_params {
611     virtual ~optional_params() {}
612
613     KernelType GetType() const { return kType; }
614
615     std::vector<DataLayout> inputLayouts;
616     std::vector<DataLayout> outputLayouts;
617
618     bool meaningfulKernelsNames = false;  // use layer name instead of internal kernel name
619     bool allowStaticInputReordering =
620         true;  // allow kernel to provide a kernel which reorder static data like weights/bias/tables...
621     bool allowInputReordering =
622         false;  // allow kernel to ask graph compiler to reorder the input data before executing its
623     bool allowOutputReordering =
624         false;  // allow kernel to ask graph compiler to reorder the output data before executing the next kernel
625
626     TuningParams tuningParams;
627
628     virtual ParamsKey GetSupportedKey() const;
629
630 protected:
631     explicit optional_params(KernelType kt) : kType(kt) {}
632     KernelType kType;
633 };
634
635 }  // namespace kernel_selector