Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / eltwise / eltwise_kernel_base.h
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #pragma once
18
19 #include "common_kernel_base.h"
20
21 namespace kernel_selector
22 {
23     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
24     // eltwise_params
25     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
26     struct eltwise_params : public base_params
27     {
28         eltwise_params() : base_params(KernelType::ELTWISE) {}
29
30         struct InputType
31         {
32             EltwiseInputMode mode = EltwiseInputMode::INPUT_BUFFER;
33             uint32_t         index = 0;    // for inputs results;
34             uint32_t         tmpIndex = 0;    // for temp results;
35             float            scalar = 0.f;
36
37             static InputType Buffer(uint32_t index)
38             {
39                 eltwise_params::InputType input;
40                 input.mode = EltwiseInputMode::INPUT_BUFFER;
41                 input.index = index;
42                 return input;
43             }
44
45             static InputType UnorderedAccessBuffer(uint32_t index, uint32_t tmpIndex)
46             {
47                 eltwise_params::InputType input;
48                 input.mode = EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER;
49                 input.index = index;
50                 input.tmpIndex = tmpIndex;
51                 return input;
52             }
53
54             static InputType Intermediate(uint32_t tmpIndex)
55             {
56                 eltwise_params::InputType input;
57                 input.mode = EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX;
58                 input.tmpIndex = tmpIndex;
59                 return input;
60             }
61
62             static InputType Scalar(float s)
63             {
64                 eltwise_params::InputType input;
65                 input.mode = EltwiseInputMode::SCALAR;
66                 input.scalar = s;
67                 return input;
68             }
69
70             static InputType OutBuffer()
71             {
72                 eltwise_params::InputType output;
73                 output.mode = EltwiseInputMode::OUTPUT_BUFFER;
74                 return output;
75             }
76         };
77
78         struct Node
79         {
80             std::vector<InputType> inputs;
81             EltwiseMode mode;
82         };
83
84         struct UpdateInputData
85         {
86             uint32_t inputId;
87             uint32_t tmpId;
88         };
89
90         std::vector<eltwise_params::Node> operations;
91         std::vector<float> coefficients;
92         std::vector<UpdateInputData> updateInputIds;
93         std::vector<uSize> stride;
94
95         bool  layoutBased = false;
96         bool  int8_quantization = false;
97         bool  output_calibration = false;
98         float output_quantization_factor = 1.0f;
99         bool  broadcast = false;
100
101         MultiDataTensor output_calibration_factors;
102         virtual ParamsKey GetParamsKey() const;
103     };
104
105     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
106     // eltwise_optional_params
107     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
108     struct eltwise_optional_params : optional_params
109     {
110         eltwise_optional_params() : optional_params(KernelType::ELTWISE) {}
111     };
112
113     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
114     // EltwiseKernelBase
115     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
116     class EltwiseKernelBase : public common_kernel_base
117     {
118     public:
119         using common_kernel_base::common_kernel_base;
120         virtual ~EltwiseKernelBase() {}
121
122         using DispatchData = CommonDispatchData;
123         JitConstants GetJitConstantsCommon(const eltwise_params& params, bool useVload8) const;
124
125     protected:
126         virtual bool Validate(const Params& p, const optional_params& o) const override;
127         virtual JitConstants GetJitConstants(const eltwise_params& params) const;
128         virtual DispatchData SetDefault(const eltwise_params& params) const;
129         KernelsData GetCommonKernelsData(const Params& params, const optional_params& options) const;
130     };
131 }