1 // Copyright (c) 2016-2020 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "fully_connected_kernel_mmad.h"
16 #include "kernel_selector_utils.h"
18 namespace kernel_selector {
21 static const size_t sub_group_size = 8;
24 ParamsKey FullyConnectedKernelMMAD::GetSupportedKey() const {
26 k.EnableInputDataType(Datatype::INT8);
27 k.EnableInputDataType(Datatype::UINT8);
29 k.EnableOutputDataType(Datatype::INT8);
30 k.EnableOutputDataType(Datatype::UINT8);
31 k.EnableOutputDataType(Datatype::F32);
32 k.EnableOutputDataType(Datatype::F16);
34 k.EnableInputWeightsType(WeightsType::INT8);
36 k.EnableDifferentInputWeightsTypes();
37 k.EnableDifferentTypes();
39 k.EnableInputLayout(DataLayout::bfyx);
40 k.EnableInputLayout(DataLayout::b_fs_yx_fsv32);
41 k.EnableInputLayout(DataLayout::b_fs_zyx_fsv32);
42 k.EnableOutputLayout(DataLayout::bf);
44 k.EnableBiasPerOutput();
45 k.EnableBiasPerFeature();
46 k.EnableNonBiasTerm();
47 k.EnableTensorOffset();
48 k.EnableTensorPitches();
50 k.EnableQuantization(QuantizationType::SYMMETRIC);
54 bool FullyConnectedKernelMMAD::Validate(const Params& params, const optional_params& options) const {
55 if (!Parent::Validate(params, options))
58 auto fc_params = static_cast<const fully_connected_params&>(params);
59 auto input = fc_params.inputs[0];
60 if (input.GetLayout() == DataLayout::bfyx &&
61 (input.X().LogicalDimPadded() != 1 || input.Y().LogicalDimPadded() != 1 || input.Z().LogicalDimPadded() != 1)) {
68 FullyConnectedKernelMMAD::DispatchData FullyConnectedKernelMMAD::SetDefault(const fully_connected_params& params,
70 auto runInfo = Parent::SetDefault(params);
72 const auto& out = params.output;
74 std::vector<size_t> global = { Align(out.Feature().v, sub_group_size), out.Batch().v, 1 };
75 auto local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
77 runInfo.gws0 = global[0];
78 runInfo.gws1 = global[1];
79 runInfo.gws2 = global[2];
81 runInfo.lws0 = local[0];
82 runInfo.lws1 = local[1];
83 runInfo.lws2 = local[2];
88 JitConstants FullyConnectedKernelMMAD::GetJitConstants(const fully_connected_params& params,
89 const DispatchData& runInfo) const {
90 auto jit = Parent::GetJitConstants(params, runInfo);
92 auto& input = params.inputs[0];
93 auto& weights = params.weights;
95 jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", sub_group_size));
96 if (input.GetDims().size() == 5) {
97 jit.AddConstant(MakeJitConstant("FILTER_GET_OFFSET(f)", "GET_FILTER_OS_IS_YX_ISA8_OSV8_ISV4_INDEX(FILTER, f, 0, 0, 0)"));
99 jit.AddConstant(MakeJitConstant("FILTER_GET_OFFSET(f)", "GET_FILTER_OS_IS_ZYX_ISA8_OSV8_ISV4_INDEX(FILTER, f, 0, 0, 0, 0)"));
102 Datatype input_packed_type = Datatype::INT32;
103 Datatype filter_packed_type = Datatype::INT32;
105 if (input.GetDType() == Datatype::UINT8) {
106 input_packed_type = Datatype::UINT32;
107 } else if (input.GetDType() == Datatype::INT8) {
108 input_packed_type = Datatype::INT32;
111 if (weights.GetDType() == WeightsType::UINT8) {
112 filter_packed_type = Datatype::UINT32;
113 } else if (weights.GetDType() == WeightsType::INT8) {
114 filter_packed_type = Datatype::INT32;
117 jit.Merge(MakeTypeJitConstants(input_packed_type, "INPUT_PACKED"));
118 jit.Merge(MakeTypeJitConstants(filter_packed_type, "FILTER_PACKED"));
120 auto filter_spatial_size = weights.X().v * weights.Y().v * weights.Z().v;
121 int filter_spatial_pitch = 4 * 8 * 8;
123 jit.AddConstant(MakeJitConstant("FILTER_SPATIAL_SIZE", filter_spatial_size));
124 jit.AddConstant(MakeJitConstant("MMAD_FILTER_SPATIAL_PITCH", filter_spatial_pitch));
125 jit.AddConstant(MakeJitConstant("MMAD_FILTER_FBLOCK_PITCH", filter_spatial_size * filter_spatial_pitch));
127 size_t input_x_pitch = input.X().pitch;
128 size_t input_y_pitch = input.Y().pitch;
129 size_t input_z_pitch = input.Z().pitch;
131 if (input.GetLayout() == DataLayout::bfyx) {
132 jit.AddConstant(MakeJitConstant("MMAD_INPUT_FBLOCK_PITCH", 32));
133 } else if (input.GetLayout() == DataLayout::b_fs_yx_fsv32 || input.GetLayout() == DataLayout::b_fs_zyx_fsv32) {
137 jit.AddConstant(MakeJitConstant("MMAD_INPUT_FBLOCK_PITCH", input.Feature().pitch * 32));
140 if (input.GetLayout() == DataLayout::bfyx && input.Feature().v % 32 != 0) {
141 jit.AddConstant(MakeJitConstant("HAS_FEATURE_LEFTOVERS", true));
142 jit.AddConstant(MakeJitConstant("FEATURE_BLOCKS_COUNT", input.Feature().v / 32));
144 jit.AddConstant(MakeJitConstant("FEATURE_BLOCKS_COUNT", CeilDiv(input.Feature().v, 32)));
147 jit.AddConstant(MakeJitConstant("MMAD_INPUT_SPATIAL_PITCH", input_x_pitch));
148 jit.AddConstant(MakeJitConstant("MMAD_INPUT_X_PITCH", input_x_pitch));
149 jit.AddConstant(MakeJitConstant("MMAD_INPUT_Y_PITCH", input_y_pitch));
150 jit.AddConstant(MakeJitConstant("MMAD_INPUT_Z_PITCH", input_z_pitch));
152 bool split_spatial = input.X().pad.Total() != 0 || input.Y().pad.Total() != 0 || input.Z().pad.Total() != 0;
153 bool spatial_major = DataTensor::Channelndex(input.GetLayout(), Tensor::DataChannelName::X) <
154 DataTensor::Channelndex(input.GetLayout(), Tensor::DataChannelName::FEATURE);
156 jit.AddConstant(MakeJitConstant("SPLIT_SPATIAL", split_spatial));
157 jit.AddConstant(MakeJitConstant("SPATIAL_MAJOR", spatial_major));
159 if (!params.fused_ops.empty()) {
160 auto input_dt = GetActivationType(params);
161 FusedOpsConfiguration conf = { "", {"b", "f", "0", "0"}, "dequantized", input_dt, 1 };
162 jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
168 KernelsData FullyConnectedKernelMMAD::GetKernelsData(const Params& params, const optional_params& options) const {
169 auto fc_params = static_cast<const fully_connected_params&>(params);
170 auto& input = fc_params.inputs[0];
172 auto w_layout = WeightsLayout::os_is_yx_isa8_osv8_isv4;
173 if (input.GetDims().size() == 5) {
174 w_layout = WeightsLayout::os_is_zyx_isa8_osv8_isv4;
177 KernelsData res = {};
178 for (size_t i = 0; i < autoTuneOptions.size(); i++) {
179 KernelsData kd = GetTunedKernelsDataByIndex(params,
184 static_cast<int>(i));
186 res.emplace_back(kd[0]);
191 } // namespace kernel_selector