2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "eltwise_kernel_vload8.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector {
22 ParamsKey EltwiseKernel_vload8::GetSupportedKey() const
25 k.EnableInputDataType(Datatype::F16);
26 k.EnableInputDataType(Datatype::F32);
27 k.EnableOutputDataType(Datatype::F16);
28 k.EnableOutputDataType(Datatype::F32);
29 k.EnableAllInputLayout();
30 k.EnableAllOutputLayout();
35 JitConstants EltwiseKernel_vload8::GetJitConstants(const eltwise_params& params) const
37 return GetJitConstantsCommon(params, true);
40 bool EltwiseKernel_vload8::Validate(const Params& params, const optional_params& o) const
42 if (!EltwiseKernelBase::Validate(params, o))
47 const auto& ewParams = static_cast<const eltwise_params&>(params);
49 for (size_t i = 0; i < ewParams.inputs.size(); i++)
51 if (ewParams.inputs[i].GetLayout() == DataLayout::fs_bs_yx_bsv4_fsv32)
54 if (ewParams.output.GetLayout() == DataLayout::fs_bs_yx_bsv4_fsv32)
57 const auto& output = ewParams.output;
58 const auto count = output.PhysicalSize();
60 const bool bSupportedCount = (count % 8) == 0;
62 bool bCheckSizes = true;
63 for (size_t i = 0; i < ewParams.inputs.size(); i++)
65 //allow only the same input sizes or scalars, without pitches
66 if (ewParams.inputs[i].PitchesDifferFromLogicalDims() ||
67 (!(ewParams.inputs[0] == ewParams.inputs[i] && ewParams.inputs[i] == ewParams.output) &&
68 ewParams.inputs[i].PhysicalSize() != 1))
72 //TODO: add support to this implementation when user requests input values updates
73 bool bCheckUpdateInput = true;
74 if (!ewParams.updateInputIds.empty())
75 bCheckUpdateInput = false;
77 //TODO: add support for reading from output buffer and using its values in computation
78 bool bCheckUseOutput = true;
79 for (size_t op = 0; op < ewParams.operations.size(); op++)
81 for (size_t input_idx = 0; input_idx < ewParams.operations[op].inputs.size(); input_idx++)
83 if (ewParams.operations[op].inputs[input_idx].mode == EltwiseInputMode::OUTPUT_BUFFER)
85 bCheckUseOutput = false;
91 if (!bCheckSizes || !bSupportedCount || !bCheckUpdateInput || !bCheckUseOutput)
99 KernelsData EltwiseKernel_vload8::GetKernelsData(const Params& params, const optional_params& options) const
101 if (!Validate(params, options))
106 KernelData kd = KernelData::Default<eltwise_params>(params);
107 eltwise_params& newParams = *static_cast<eltwise_params*>(kd.params.get());
111 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
115 auto cldnn_jit = GetJitConstants(newParams);
116 jit = CreateJit(kernelName, cldnn_jit, entry_point);
118 catch (const std::runtime_error&)
120 return KernelsData();
123 auto& kernel = kd.kernels[0];
124 kernel.workGroups.global = { std::max(newParams.inputs[0].LogicalSize()/8, (size_t)1), 1, 1 };
125 kernel.workGroups.local = GetOptimalLocalWorkGroupSizes(kernel.workGroups.global);
126 kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
127 kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(), false, false);
129 kd.estimatedTime = FORCE_PRIORITY_8;