2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
19 #include "common_kernel_base.h"
21 namespace kernel_selector
23 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
25 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
26 struct eltwise_params : public base_params
28 eltwise_params() : base_params(KernelType::ELTWISE) {}
32 EltwiseInputMode mode = EltwiseInputMode::INPUT_BUFFER;
33 uint32_t index = 0; // for inputs results;
34 uint32_t tmpIndex = 0; // for temp results;
37 static InputType Buffer(uint32_t index)
39 eltwise_params::InputType input;
40 input.mode = EltwiseInputMode::INPUT_BUFFER;
45 static InputType UnorderedAccessBuffer(uint32_t index, uint32_t tmpIndex)
47 eltwise_params::InputType input;
48 input.mode = EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER;
50 input.tmpIndex = tmpIndex;
54 static InputType Intermediate(uint32_t tmpIndex)
56 eltwise_params::InputType input;
57 input.mode = EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX;
58 input.tmpIndex = tmpIndex;
62 static InputType Scalar(float s)
64 eltwise_params::InputType input;
65 input.mode = EltwiseInputMode::SCALAR;
70 static InputType OutBuffer()
72 eltwise_params::InputType output;
73 output.mode = EltwiseInputMode::OUTPUT_BUFFER;
80 std::vector<InputType> inputs;
84 struct UpdateInputData
90 std::vector<eltwise_params::Node> operations;
91 std::vector<float> coefficients;
92 std::vector<UpdateInputData> updateInputIds;
93 std::vector<uSize> stride;
95 bool layoutBased = false;
96 bool int8_quantization = false;
97 bool output_calibration = false;
98 float output_quantization_factor = 1.0f;
99 bool broadcast = false;
101 MultiDataTensor output_calibration_factors;
102 virtual ParamsKey GetParamsKey() const;
105 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
106 // eltwise_optional_params
107 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
108 struct eltwise_optional_params : optional_params
110 eltwise_optional_params() : optional_params(KernelType::ELTWISE) {}
113 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
115 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
116 class EltwiseKernelBase : public common_kernel_base
119 using common_kernel_base::common_kernel_base;
120 virtual ~EltwiseKernelBase() {}
122 using DispatchData = CommonDispatchData;
123 JitConstants GetJitConstantsCommon(const eltwise_params& params, bool useVload8) const;
126 virtual bool Validate(const Params& p, const optional_params& o) const override;
127 virtual JitConstants GetJitConstants(const eltwise_params& params) const;
128 virtual DispatchData SetDefault(const eltwise_params& params) const;
129 KernelsData GetCommonKernelsData(const Params& params, const optional_params& options) const;