1 // Copyright (c) 2019 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "contract_kernel_base.h"
18 #include "kernel_selector_utils.h"
21 namespace kernel_selector
23 JitConstants ContractKernelBase::GetJitConstants(const contract_params& params)
25 JitConstants jit = MakeBaseParamsJitConstants(params);
27 const size_t no_dim_flag = 6;
28 std::vector<size_t> output_dims(4, no_dim_flag);
30 for (int i = 3; i >= 0; --i)
32 if (std::find(params.reduction_axes.begin(), params.reduction_axes.end(), i) == params.reduction_axes.end())
33 output_dims.at(i) = out_dim--;
36 if (output_dims[3] != no_dim_flag)
38 MakeJitConstant("DIM_X", output_dims.at(3))
40 if (output_dims[2] != no_dim_flag)
42 MakeJitConstant("DIM_Y", output_dims.at(2))
44 if (output_dims[1] != no_dim_flag)
46 MakeJitConstant("DIM_F", output_dims.at(1))
48 if (output_dims[0] != no_dim_flag)
50 MakeJitConstant("DIM_B", output_dims.at(0))
54 MakeJitConstant("REDUCE_X", output_dims.at(3) == no_dim_flag),
55 MakeJitConstant("REDUCE_Y", output_dims.at(2) == no_dim_flag),
56 MakeJitConstant("REDUCE_F", output_dims.at(1) == no_dim_flag),
57 MakeJitConstant("REDUCE_B", output_dims.at(0) == no_dim_flag)
62 case ContractMode::SUM:
64 MakeJitConstant("REDUCE_SEED", "0"),
65 MakeJitConstant("REDUCE_OPERATION(a, b)", "a + b")
68 case ContractMode::PRODUCT:
70 MakeJitConstant("REDUCE_SEED", "1"),
71 MakeJitConstant("REDUCE_OPERATION(a, b)", "a * b")
74 case ContractMode::ALL:
76 MakeJitConstant("REDUCE_SEED", "1"),
77 MakeJitConstant("REDUCE_OPERATION(a, b)", "a && b")
80 case ContractMode::ANY:
82 MakeJitConstant("REDUCE_SEED", "0"),
83 MakeJitConstant("REDUCE_OPERATION(a, b)", "a || b")
86 case ContractMode::MAX:
88 MakeJitConstant("REDUCE_SEED", "UNIT_VAL_MIN"),
89 MakeJitConstant("REDUCE_OPERATION(a, b)", "UNIT_MAX_FUNC(a,b)")
97 ContractKernelBase::DispatchData ContractKernelBase::SetDefault(const contract_params& params)
99 const auto& output = params.output;
103 kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
105 std::vector<size_t> global{ output.Feature().v, output.Y().v, output.X().v };
106 const auto& local = GetOptimalLocalWorkGroupSizes(global);
119 KernelsData ContractKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const
121 assert(params.GetType() == KernelType::CONTRACT);
123 const auto& prim_params = static_cast<const contract_params&>(params); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
125 auto run_info = SetDefault(prim_params);
126 KernelData k_data = KernelData::Default<contract_params>(params);
128 auto cldnn_jit = GetJitConstants(prim_params);
129 auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
130 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
132 auto& kernel = k_data.kernels[0];
133 FillCLKernelData(kernel, run_info, params.engineInfo, kernelName, jit, entry_point);
134 k_data.estimatedTime = estimated_time;