2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "strided_slice_kernel_ref.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector
22 ParamsKey StridedSliceKernelRef::GetSupportedKey() const
25 k.EnableInputDataType(Datatype::F16);
26 k.EnableInputDataType(Datatype::F32);
27 k.EnableOutputDataType(Datatype::F16);
28 k.EnableOutputDataType(Datatype::F32);
29 k.EnableAllInputLayout();
30 k.EnableAllOutputLayout();
31 k.EnableTensorOffset();
32 k.EnableTensorPitches();
37 CommonDispatchData StridedSliceKernelRef::SetDefault(const strided_slice_params& params, const optional_params&) const
39 CommonDispatchData runInfo;
40 std::vector<size_t> gws;
42 // If the new_axis_mask is set, then begin, end, and stride are ignored
43 // and a new length 1 dimension is adding. Input data just copying to output
44 // TODO: remove data copying in case where only shape size changing
45 if (params.new_axis_mask.size() != 0)
46 gws = { params.inputs[0].Batch().v, params.inputs[0].Feature().v, params.inputs[0].Y().v * params.inputs[0].X().v };
48 gws = { params.output.Batch().v, params.output.Feature().v, params.output.Y().v * params.output.X().v };
50 auto lws = GetOptimalLocalWorkGroupSizes(gws);
52 runInfo.gws0 = gws[0];
53 runInfo.gws1 = gws[1];
54 runInfo.gws2 = gws[2];
56 runInfo.lws0 = lws[0];
57 runInfo.lws1 = lws[1];
58 runInfo.lws2 = lws[2];
63 JitConstants StridedSliceKernelRef::GetJitConstants(const strided_slice_params& params) const
65 JitConstants jit = MakeBaseParamsJitConstants(params);
67 auto makeJitConstForParam = [](JitConstants& jit, const std::string name, const std::vector<int32_t> vec) {
68 jit.AddConstant(MakeJitConstant(name + "_SIZES", vec));
69 jit.AddConstant(MakeJitConstant(name + "_BATCH", vec[0]));
70 jit.AddConstant(MakeJitConstant(name + "_FEATURE", vec[1]));
71 jit.AddConstant(MakeJitConstant(name + "_Y", vec[2]));
72 jit.AddConstant(MakeJitConstant(name + "_X", vec[3]));
75 makeJitConstForParam(jit, "SLICE_BEGIN", params.striding_params[0]);
76 makeJitConstForParam(jit, "SLICE_END", params.striding_params[1]);
77 makeJitConstForParam(jit, "SLICE_STEPS", params.striding_params[2]);
79 jit.AddConstant(MakeJitConstant("NEW_AXIS_MODE", std::find(params.new_axis_mask.begin(), params.new_axis_mask.end(), 1) != params.new_axis_mask.end()));
84 KernelsData StridedSliceKernelRef::GetKernelsData(const Params& params, const optional_params& options) const
86 KernelData kd = KernelData::Default<strided_slice_params>(params);
87 strided_slice_params& newParams = *static_cast<strided_slice_params*>(kd.params.get());
89 assert(params.GetType() == KernelType::STRIDED_SLICE);
91 auto runInfo = SetDefault(newParams, options);
92 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
93 auto cldnn_jit = GetJitConstants(newParams);
94 std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
96 auto& kernel = kd.kernels[0];
98 FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point);
100 kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;