auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
- uint32_t fused_deps_total = 0;
- for (auto& fused_dep : prim_params.fused_ops) {
- for (int i = 0; i < static_cast<int>(fused_dep.dep_size); i++) {
- fused_deps_total++;
- }
- }
-
auto& kernel = k_data.kernels[0];
FillCLKernelData(kernel,
run_info,
false,
false,
(uint32_t)prim_params.inputs.size(),
- fused_deps_total);
+ GetFusedPrimitiveInputsCount(params));
k_data.estimatedTime = estimated_time;
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
protected:
virtual JitConstants GetJitConstants(const gemm_params& params) const;
- DispatchData SetDefault(const gemm_params& params) const;
+ virtual DispatchData SetDefault(const gemm_params& params) const;
KernelsData GetCommonKernelsData(const Params& params, const optional_params&, float estimated_time) const;
// Fused ops
virtual JitConstants GetFusedPrimitivesJitConstants(const gemm_params& params, const DispatchData& kd) const;
--- /dev/null
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+#include "gemm_kernel_mmad_int8.h"
+
+namespace kernel_selector {
+ParamsKey GemmKernelMMADint8::GetSupportedKey() const {
+ ParamsKey k;
+
+ k.EnableInputDataType(Datatype::INT8);
+ k.EnableInputDataType(Datatype::UINT8);
+ k.EnableInputDataType(Datatype::F32);
+ k.EnableOutputDataType(Datatype::F32);
+ k.EnableOutputDataType(Datatype::F16);
+ k.EnableOutputDataType(Datatype::INT8);
+ k.EnableOutputDataType(Datatype::UINT8);
+ k.EnableInputLayout(DataLayout::bfyx);
+ k.EnableOutputLayout(DataLayout::bfyx);
+ k.EnableInputLayout(DataLayout::bfzyx);
+ k.EnableOutputLayout(DataLayout::bfzyx);
+ k.EnableInputLayout(DataLayout::bfwzyx);
+ k.EnableOutputLayout(DataLayout::bfwzyx);
+
+ k.EnableBatching();
+ k.EnableDifferentTypes();
+ k.EnableTensorPitches();
+ k.EnableQuantization(QuantizationType::SYMMETRIC);
+
+ return k;
+}
+
+JitConstants GemmKernelMMADint8::GetJitConstants(const gemm_params& params) const {
+ JitConstants jit = Parent::GetJitConstants(params);
+ GemmTuningData td = SetTuningParams(params);
+
+ jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", td.simd_size));
+ jit.Merge(MakeTypeJitConstants(Datatype::INT32, "ACCUMULATOR"));
+ jit.Merge(MakeTypeJitConstants(Datatype::F32, "ACTIVATION"));
+ jit.Merge(MakeTypeJitConstants(params.inputs[0].GetDType() == Datatype::INT8 ? Datatype::INT32 : Datatype::UINT32, "PACKED_INPUT0"));
+ jit.Merge(MakeTypeJitConstants(params.inputs[1].GetDType() == Datatype::INT8 ? Datatype::INT32 : Datatype::UINT32, "PACKED_INPUT1"));
+ jit.AddConstant(MakeJitConstant("TILE_NUM", td.tile_num));
+ jit.AddConstant(MakeJitConstant("TILE_SIZE_M", td.simd_size * td.tile_num));
+ jit.AddConstant(MakeJitConstant("TILE_SIZE_N", td.simd_size));
+ jit.AddConstant(MakeJitConstant("TILE_SIZE_K", td.simd_size * td.pack_size));
+ jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_M", td.size_m % (td.simd_size * td.tile_num)));
+ jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_N", td.size_n % td.simd_size));
+ jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_K", td.size_k % (td.simd_size * td.pack_size)));
+
+ if (!params.fused_ops.empty()) {
+ auto input_dt = GetActivationType(params);
+ FusedOpsConfiguration conf = { "", {"b", "f", "output_y", "output_x"}, "dequantized", input_dt, 1 };
+ conf.SetLoopAxes({ Tensor::DataChannelName::Y }, true);
+ jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
+ }
+
+ return jit;
+}
+
+GemmKernelBase::DispatchData GemmKernelMMADint8::SetDefault(const gemm_params& params) const {
+ const auto& output = params.output;
+ auto total_batches = output.LogicalSize() / (output.X().v * output.Y().v);
+
+ DispatchData kd;
+ GemmTuningData td = SetTuningParams(params);
+
+ std::vector<size_t> global = { Align(output.X().v, td.simd_size),
+ Align(output.Y().v, td.simd_size * td.tile_num) / (td.simd_size * td.tile_num),
+ total_batches };
+
+ std::vector<size_t> local = { td.simd_size, 1, 1 };
+
+ kd.gws0 = global[0];
+ kd.gws1 = global[1];
+ kd.gws2 = global[2];
+
+ kd.lws0 = local[0];
+ kd.lws1 = local[1];
+ kd.lws2 = local[2];
+
+ return kd;
+}
+
+GemmKernelMMADint8::GemmTuningData GemmKernelMMADint8::InitGemmTuningData(const gemm_params& params) const {
+ GemmTuningData tuning_data;
+
+ tuning_data.size_m = params.output.Y().v;
+ tuning_data.size_n = params.output.X().v;
+ tuning_data.size_k = params.transpose_input0 ? params.inputs[0].Y().v : params.inputs[0].X().v;
+
+ return tuning_data;
+}
+
+inline size_t GemmKernelMMADint8::GetMmadOperationsNumber(const GemmTuningData& tuning_data) const {
+ return tuning_data.size_m * tuning_data.size_n * tuning_data.size_k;
+}
+
+bool GemmKernelMMADint8::HasLeftovers(const GemmTuningData& tuning_data, int tile_size) const {
+ if (tile_size == 32) {
+ return tuning_data.size_m % 32 || tuning_data.size_n % 16 || tuning_data.size_k % 64;
+ } else if (tile_size == 16) {
+ return tuning_data.size_m % 16 || tuning_data.size_n % 16 || tuning_data.size_k % 64;
+ } else if (tile_size == 8) {
+ return tuning_data.size_m % 8 || tuning_data.size_n % 8 || tuning_data.size_k % 32;
+ } else {
+ return true;
+ }
+}
+
+GemmKernelMMADint8::GemmTuningData GemmKernelMMADint8::SetTuningParams(const gemm_params& params) const {
+ GemmTuningData tuning_data = InitGemmTuningData(params);
+ auto mmad_operations_number = GetMmadOperationsNumber(tuning_data);
+
+ bool leftovers_simd16x2 = HasLeftovers(tuning_data, 16*2);
+ bool leftovers_simd16 = HasLeftovers(tuning_data, 16);
+ bool leftovers_simd8 = HasLeftovers(tuning_data, 8);
+
+ bool small_matrices = mmad_operations_number <= 128 * 128 * 128;
+ bool average_matrices = mmad_operations_number <= 448 * 448 * 448;
+ bool very_big_matrices = mmad_operations_number >= 1024 * 1024 * 1024;
+ bool no_input2 = params.inputs.size() == 3 ? false : true;
+
+ size_t simd_size = 16;
+ size_t tile_num = 1;
+
+ if (!leftovers_simd16x2 && very_big_matrices && no_input2)
+ { simd_size = 16; tile_num = 2; }
+ else if (leftovers_simd16 && !leftovers_simd8)
+ { simd_size = 8; }
+ else if ((!params.transpose_input0 && !params.transpose_input1) && average_matrices)
+ { simd_size = 8; }
+ else if (small_matrices)
+ { simd_size = 8; }
+ else
+ { simd_size = 16; }
+
+ tuning_data.simd_size = simd_size;
+ tuning_data.tile_num = tile_num;
+
+ return tuning_data;
+}
+
+KernelsData GemmKernelMMADint8::GetKernelsData(const Params& params, const optional_params& options) const {
+ if (!Validate(params, options)) {
+ return KernelsData();
+ }
+
+ const auto& prim_params = static_cast<const gemm_params&>(params);
+
+ auto run_info = GemmKernelMMADint8::SetDefault(prim_params);
+ KernelData k_data = KernelData::Default<gemm_params>(params);
+
+ auto cldnn_jit = GetJitConstants(prim_params);
+ auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
+ auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
+
+ auto& kernel = k_data.kernels[0];
+ FillCLKernelData(kernel,
+ run_info,
+ params.engineInfo,
+ kernelName,
+ jit,
+ entry_point,
+ DEFAULT,
+ false,
+ false,
+ (uint32_t)prim_params.inputs.size(),
+ GetFusedPrimitiveInputsCount(params));
+
+ GemmTuningData tuning_data = InitGemmTuningData(prim_params);
+ auto mmad_operations_number = GetMmadOperationsNumber(tuning_data);
+
+ k_data.estimatedTime = mmad_operations_number < 4096 ? DONT_USE_IF_HAVE_SOMETHING_ELSE : FORCE_PRIORITY_3;
+
+ return {k_data};
+}
+
+bool GemmKernelMMADint8::Validate(const Params& params, const optional_params& options) const {
+ if (!Parent::Validate(params, options))
+ return false;
+
+ const auto& gmm_params = static_cast<const gemm_params&>(params);
+ auto input0_type = gmm_params.inputs[0].GetDType();
+ auto input1_type = gmm_params.inputs[1].GetDType();
+
+ if ((input0_type != Datatype::UINT8 && input0_type != Datatype::INT8) ||
+ (input1_type != Datatype::UINT8 && input1_type != Datatype::INT8))
+ return false;
+
+ return true;
+}
+} // namespace kernel_selector
--- /dev/null
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "gemm_kernel_base.h"
+#include <vector>
+
+namespace kernel_selector {
+class GemmKernelMMADint8 : public GemmKernelBase {
+public:
+ using Parent = GemmKernelBase;
+ using DispatchData = CommonDispatchData;
+ struct GemmTuningData {
+ size_t size_m;
+ size_t size_n;
+ size_t size_k;
+
+ size_t simd_size = 16;
+ size_t tile_num = 1;
+ size_t pack_size = 4;
+ };
+
+ GemmKernelMMADint8() : GemmKernelBase("gemm_mmad_int8") {}
+
+ KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
+ ParamsKey GetSupportedKey() const override;
+
+protected:
+ std::vector<FusedOpType> GetSupportedFusedOps() const override {
+ return { FusedOpType::QUANTIZE,
+ FusedOpType::ACTIVATION,
+ FusedOpType::SCALE,
+ FusedOpType::ELTWISE };
+ }
+ bool Validate(const Params& params, const optional_params& options) const override;
+ JitConstants GetJitConstants(const gemm_params& params) const override;
+ DispatchData SetDefault(const gemm_params& params) const override;
+ GemmTuningData InitGemmTuningData(const gemm_params& params) const;
+ GemmTuningData SetTuningParams(const gemm_params& params) const;
+ size_t GetMmadOperationsNumber(const GemmTuningData& tuning_data) const;
+ bool HasLeftovers(const GemmTuningData& tuning_data, int tile_size) const;
+};
+} // namespace kernel_selector
/*
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
#include "gemm_kernel_selector.h"
#include "gemm_kernel_ref.h"
+#include "gemm_kernel_mmad_int8.h"
namespace kernel_selector {
-gemm_kernel_selector::gemm_kernel_selector() { Attach<GemmKernelRef>(); }
+gemm_kernel_selector::gemm_kernel_selector() {
+ Attach<GemmKernelRef>();
+ Attach<GemmKernelMMADint8>();
+}
KernelsData gemm_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
return GetNaiveBestKernel(params, options, KernelType::GEMM);
}
-} // namespace kernel_selector
\ No newline at end of file
+} // namespace kernel_selector
--- /dev/null
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "include/include_all.cl"
+#include "include/mmad.cl"
+
+#define PACK_SIZE 4
+
+#define AS_TYPE(type, val) CAT(as_, type)(val)
+#define ACCUMULATOR_TYPE_VEC CAT(ACCUMULATOR_TYPE, SUB_GROUP_SIZE)
+#define ACTIVATION_TYPE_VEC CAT(ACTIVATION_TYPE, SUB_GROUP_SIZE)
+#define PACKED_INPUT0_TYPE_VEC CAT(PACKED_INPUT0_TYPE, SUB_GROUP_SIZE)
+#define PACKED_INPUT1_TYPE_VEC CAT(PACKED_INPUT1_TYPE, SUB_GROUP_SIZE)
+#define BLOCK_READ(ptr) intel_sub_group_block_read((const __global uint*)(ptr))
+#define BLOCK_SHUFFLE intel_sub_group_shuffle
+
+#if SUB_GROUP_SIZE == 8
+#define MMAD MMAD_8x8
+#else // SUB_GROUP_SIZE == 8
+#define MMAD MMAD_16x16
+#define TILE_SIZE_M_DIV (TILE_SIZE_M / 2)
+#endif // SUB_GROUP_SIZE == 8
+
+inline uint FUNC(get_input0_batch_offset)(uint b, uint f, uint w, uint z) {
+#if INPUT0_SIMPLE
+ return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, 0, 0);
+#else // INPUT0_SIMPLE
+# error gemm_mmad_int8.cl : Unsupported input 0 format
+#endif // INPUT0_SIMPLE
+}
+
+inline uint FUNC(get_input1_batch_offset)(uint b, uint f, uint w, uint z) {
+#if INPUT1_SIMPLE
+ return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, 0, 0);
+#else // INPUT1_SIMPLE
+# error gemm_mmad_int8.cl : Unsupported input 1 format
+#endif // INPUT1_SIMPLE
+}
+
+#ifdef INPUT2_TYPE
+inline uint FUNC(get_input2_batch_offset)(uint b, uint f, uint w, uint z) {
+#if INPUT2_SIMPLE
+ return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, 0, 0);
+#else // INPUT2_SIMPLE
+# error gemm_mmad_int8.cl : Unsupported input 2 format
+#endif // INPUT2_SIMPLE
+}
+#endif // INPUT2_TYPE
+
+inline uint FUNC(get_output_batch_offset)(uint b, uint f, uint w, uint z) {
+#if OUTPUT_SIMPLE
+ return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, 0, 0);
+#else // OUTPUT_SIMPLE
+# error gemm_mmad_int8.cl : Unsupported output format
+#endif // OUTPUT_SIMPLE
+}
+
+inline uint FUNC(get_common_input1_offset)(uint batch_offset_input1, uint k, uint i, uint output_x_tile, uint lid) {
+#if !TRANSPOSE_INPUT1
+ return batch_offset_input1 + (k * TILE_SIZE_K + i * PACK_SIZE) * INPUT1_SIZE_X + output_x_tile * TILE_SIZE_N;
+#else
+ return batch_offset_input1 + (output_x_tile * TILE_SIZE_N + lid) * INPUT1_SIZE_X + k * TILE_SIZE_K + i * PACK_SIZE;
+#endif
+}
+
+inline uint FUNC(get_current_input1_offset)(uint common_input1_offset, uint i, uint lid) {
+#if !TRANSPOSE_INPUT1
+ return common_input1_offset + INPUT1_SIZE_X * i + lid;
+#else
+ return common_input1_offset + i;
+#endif
+}
+
+inline uint FUNC(get_common_input0_offset)(uint batch_offset_input0, uint k, uint i, uint output_y_tile, uint lid) {
+#if !TRANSPOSE_INPUT0
+ return batch_offset_input0 + (output_y_tile * TILE_SIZE_M + i) * INPUT0_SIZE_X + k * TILE_SIZE_K;
+#else
+ return batch_offset_input0 + (k * TILE_SIZE_K + lid * PACK_SIZE) * INPUT0_SIZE_X + output_y_tile * TILE_SIZE_M + i;
+#endif
+}
+
+inline uint FUNC(get_current_input0_offset)(uint common_input0_offset, uint i, uint lid) {
+#if !TRANSPOSE_INPUT0
+ return common_input0_offset + lid * PACK_SIZE + i;
+#else
+ return common_input0_offset + INPUT0_SIZE_X * i;
+#endif
+}
+
+__attribute__((reqd_work_group_size(SUB_GROUP_SIZE, 1, 1)))
+__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
+KERNEL(gemm_mmad_int8)(
+ const __global INPUT0_TYPE* input0,
+ const __global INPUT1_TYPE* input1,
+#ifdef INPUT2_TYPE
+ const __global INPUT2_TYPE* input2,
+#endif // INPUT2_TYPE
+ __global OUTPUT_TYPE* output
+#if HAS_FUSED_OPS_DECLS
+ , FUSED_OPS_DECLS
+#endif // HAS_FUSED_OPS_DECLS
+ )
+
+// ***************************************************************************************** //
+// Kernel with leftovers for all sizes of input matrices and all transposition combinations. //
+// ***************************************************************************************** //
+
+#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+{
+ const uint output_x = (uint)get_global_id(0);
+ const uint output_x_tile = output_x / TILE_SIZE_N;
+ const uint output_y_tile = (uint)get_global_id(1);
+#if HAS_FUSED_OPS
+ uint output_y = output_y_tile * TILE_SIZE_M;
+#endif // HAS_FUSED_OPS
+ uint batch = get_global_id(2);
+ const uint lid = (uint)get_local_id(0);
+
+ const uint z = batch % OUTPUT_SIZE_Z;
+ batch /= OUTPUT_SIZE_Z;
+ const uint w = batch % OUTPUT_SIZE_W;
+ batch /= OUTPUT_SIZE_W;
+ const uint f = batch % OUTPUT_FEATURE_NUM;
+ batch /= OUTPUT_FEATURE_NUM;
+ const uint b = batch % OUTPUT_BATCH_NUM;
+
+ const uint batch_offset_input0 = FUNC_CALL(get_input0_batch_offset)(b, f, w, z);
+ const uint batch_offset_input1 = FUNC_CALL(get_input1_batch_offset)(b, f, w, z);
+#ifdef INPUT2_TYPE
+ const uint batch_offset_input2 = FUNC_CALL(get_input2_batch_offset)(b, f, w, z);
+#endif // INPUT2_TYPE
+ const uint batch_offset_output = FUNC_CALL(get_output_batch_offset)(b, f, w, z);
+
+ PACKED_INPUT0_TYPE_VEC tile_input0;
+ PACKED_INPUT1_TYPE_VEC tile_input1;
+#ifdef INPUT2_TYPE
+ ACTIVATION_TYPE_VEC tile_input2;
+#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ if (output_y_tile * TILE_SIZE_M + i >= OUTPUT_SIZE_Y) continue;
+ if (output_x_tile * TILE_SIZE_N + lid >= OUTPUT_SIZE_X) continue;
+
+ tile_input2[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X +
+ output_x_tile * TILE_SIZE_N + lid]);
+ }
+#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ tile_input2[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X +
+ output_x_tile * TILE_SIZE_N + lid]);
+ }
+#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+#endif // INPUT2_TYPE
+
+ ACCUMULATOR_TYPE_VEC tile_output = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO);
+
+#if !TRANSPOSE_INPUT0
+ const uint K_BLOCK_NUM = (INPUT0_SIZE_X - 1) / TILE_SIZE_K + 1;
+ const uint K_SIZE = INPUT0_SIZE_X;
+#else // !TRANSPOSE_INPUT0
+ const uint K_BLOCK_NUM = (INPUT0_SIZE_Y - 1) / TILE_SIZE_K + 1;
+ const uint K_SIZE = INPUT0_SIZE_Y;
+#endif // !TRANSPOSE_INPUT0
+
+ for (uint k = 0; k < K_BLOCK_NUM; k++) {
+ MAKE_VECTOR_TYPE(INPUT0_TYPE, PACK_SIZE) temp_input0[SUB_GROUP_SIZE];
+ MAKE_VECTOR_TYPE(INPUT1_TYPE, PACK_SIZE) temp_input1[SUB_GROUP_SIZE];
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ const uint common_input1_offset = FUNC_CALL(get_common_input1_offset)(batch_offset_input1, k, i, output_x_tile, lid);
+
+#if OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+ const uint cur_n = output_x_tile * TILE_SIZE_N + lid;
+ const uint cur_k = k * TILE_SIZE_K + i * PACK_SIZE;
+
+ temp_input1[i] = 0;
+
+ if (cur_n < OUTPUT_SIZE_X) {
+ if (cur_k + 3 < K_SIZE) {
+ temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+ temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)];
+ temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)];
+ temp_input1[i].s3 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 3, lid)];
+ } else if (cur_k + 2 < K_SIZE) {
+ temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+ temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)];
+ temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)];
+ } else if (cur_k + 1 < K_SIZE) {
+ temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+ temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)];
+ } else if (cur_k < K_SIZE) {
+ temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+ }
+ }
+#else // OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+ temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+ temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)];
+ temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)];
+ temp_input1[i].s3 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 3, lid)];
+#endif // OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+
+ tile_input1[i] = AS_TYPE(PACKED_INPUT1_TYPE, temp_input1[i]);
+ }
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ const uint common_input0_offset = FUNC_CALL(get_common_input0_offset)(batch_offset_input0, k, i, output_y_tile, lid);
+
+#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K
+ const uint cur_m = output_y_tile * TILE_SIZE_M + i;
+ const uint cur_k = k * TILE_SIZE_K + lid * PACK_SIZE;
+
+ temp_input0[i] = 0;
+
+ if (cur_m < OUTPUT_SIZE_Y) {
+ if (cur_k + 3 < K_SIZE) {
+ temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+ temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)];
+ temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)];
+ temp_input0[i].s3 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 3, lid)];
+ } else if (cur_k + 2 < K_SIZE) {
+ temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+ temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)];
+ temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)];
+ } else if (cur_k + 1 < K_SIZE) {
+ temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+ temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)];
+ } else if (cur_k < K_SIZE) {
+ temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+ }
+ }
+
+ tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]);
+#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K
+
+#if !TRANSPOSE_INPUT0
+ tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset));
+#else // !TRANSPOSE_INPUT0
+ temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+ temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)];
+ temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)];
+ temp_input0[i].s3 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 3, lid)];
+
+ tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]);
+#endif // !TRANSPOSE_INPUT0
+
+#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K
+ }
+
+ tile_output = MMAD(tile_input0, tile_input1, tile_output);
+ }
+
+#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS_PRELOAD;
+#endif // HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+ if (output_y_tile * TILE_SIZE_M + i >= OUTPUT_SIZE_Y) continue;
+ if (output_x_tile * TILE_SIZE_N + lid >= OUTPUT_SIZE_X) continue;
+#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+
+ ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output[i]);
+ dequantized *= TO_ACTIVATION_TYPE(ALPHA);
+
+#ifdef INPUT2_TYPE
+ dequantized += TO_ACTIVATION_TYPE(BETA) * tile_input2[i];
+#endif // INPUT2_TYPE
+
+#if HAS_FUSED_OPS
+#if FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS_CALC;
+#else // FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS;
+#endif // FUSED_OPS_CAN_USE_PRELOAD
+ OUTPUT_TYPE res = FUSED_OPS_RESULT;
+ output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res;
+ output_y++;
+#else // HAS_FUSED_OPS
+ output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized;
+#endif // HAS_FUSED_OPS
+ }
+}
+
+// ******************************************************************************************************************************** //
+// Optimized kernel without leftovers (for tiling parameters M = 8, N = 8, K = 32; M = 16, N = 16, K = 64; M = 32, N = 16, K = 64). //
+// ******************************************************************************************************************************** //
+
+#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+{
+ const uint output_x = (uint)get_global_id(0);
+ const uint output_x_tile = output_x / TILE_SIZE_N;
+ const uint output_y_tile = (uint)get_global_id(1);
+#if HAS_FUSED_OPS
+ uint output_y = output_y_tile * TILE_SIZE_M;
+#endif // HAS_FUSED_OPS
+ uint batch = get_global_id(2);
+ const uint lid = (uint)get_local_id(0);
+
+ const uint z = batch % OUTPUT_SIZE_Z;
+ batch /= OUTPUT_SIZE_Z;
+ const uint w = batch % OUTPUT_SIZE_W;
+ batch /= OUTPUT_SIZE_W;
+ const uint f = batch % OUTPUT_FEATURE_NUM;
+ batch /= OUTPUT_FEATURE_NUM;
+ const uint b = batch % OUTPUT_BATCH_NUM;
+
+ const uint batch_offset_input0 = FUNC_CALL(get_input0_batch_offset)(b, f, w, z);
+ const uint batch_offset_input1 = FUNC_CALL(get_input1_batch_offset)(b, f, w, z);
+#ifdef INPUT2_TYPE
+ const uint batch_offset_input2 = FUNC_CALL(get_input2_batch_offset)(b, f, w, z);
+#endif // INPUT2_TYPE
+ const uint batch_offset_output = FUNC_CALL(get_output_batch_offset)(b, f, w, z);
+
+ PACKED_INPUT0_TYPE_VEC tile_input00;
+ PACKED_INPUT1_TYPE_VEC tile_input10;
+
+#ifdef INPUT2_TYPE
+ ACTIVATION_TYPE_VEC tile_input20;
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ tile_input20[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X +
+ output_x_tile * TILE_SIZE_N + lid]);
+ }
+#endif // INPUT2_TYPE
+
+ ACCUMULATOR_TYPE_VEC tile_output00 = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO);
+#if TILE_NUM == 2
+ ACCUMULATOR_TYPE_VEC tile_output01 = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO);
+#endif // TILE_NUM == 2
+
+#if !TRANSPOSE_INPUT0
+ const uint K_BLOCK_NUM = INPUT0_SIZE_X / TILE_SIZE_K;
+#else // !TRANSPOSE_INPUT0
+ const uint K_BLOCK_NUM = INPUT0_SIZE_Y / TILE_SIZE_K;
+#endif // !TRANSPOSE_INPUT0
+
+ for (uint k = 0; k < K_BLOCK_NUM; k++) {
+#if !TRANSPOSE_INPUT1
+ MAKE_VECTOR_TYPE(INPUT1_TYPE, PACK_SIZE) temp_input1[SUB_GROUP_SIZE];
+ const uint common_input1_offset = batch_offset_input1 + k * TILE_SIZE_K * INPUT1_SIZE_X + output_x_tile * TILE_SIZE_N;
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ temp_input1[i].s0 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + lid];
+ temp_input1[i].s1 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + INPUT1_SIZE_X + lid];
+ temp_input1[i].s2 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + 2 * INPUT1_SIZE_X + lid];
+ temp_input1[i].s3 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + 3 * INPUT1_SIZE_X + lid];
+
+ tile_input10[i] = AS_TYPE(PACKED_INPUT1_TYPE, temp_input1[i]);
+ }
+#else // !TRANSPOSE_INPUT1
+ const uint common_input1_offset = batch_offset_input1 + output_x_tile * TILE_SIZE_N * INPUT1_SIZE_X + k * TILE_SIZE_K;
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ tile_input10[i] = AS_TYPE(PACKED_INPUT1_TYPE, BLOCK_READ(input1 + common_input1_offset + i * INPUT1_SIZE_X));
+ }
+
+ PACKED_INPUT1_TYPE_VEC tile_input1_col0 = BLOCK_SHUFFLE(tile_input10, 0);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col1 = BLOCK_SHUFFLE(tile_input10, 1);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col2 = BLOCK_SHUFFLE(tile_input10, 2);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col3 = BLOCK_SHUFFLE(tile_input10, 3);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col4 = BLOCK_SHUFFLE(tile_input10, 4);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col5 = BLOCK_SHUFFLE(tile_input10, 5);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col6 = BLOCK_SHUFFLE(tile_input10, 6);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col7 = BLOCK_SHUFFLE(tile_input10, 7);
+#if SUB_GROUP_SIZE == 16
+ PACKED_INPUT1_TYPE_VEC tile_input1_col8 = BLOCK_SHUFFLE(tile_input10, 8);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col9 = BLOCK_SHUFFLE(tile_input10, 9);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col10 = BLOCK_SHUFFLE(tile_input10, 10);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col11 = BLOCK_SHUFFLE(tile_input10, 11);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col12 = BLOCK_SHUFFLE(tile_input10, 12);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col13 = BLOCK_SHUFFLE(tile_input10, 13);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col14 = BLOCK_SHUFFLE(tile_input10, 14);
+ PACKED_INPUT1_TYPE_VEC tile_input1_col15 = BLOCK_SHUFFLE(tile_input10, 15);
+#endif // SUB_GROUP_SIZE == 16
+
+ tile_input10.s0 = tile_input1_col0[lid];
+ tile_input10.s1 = tile_input1_col1[lid];
+ tile_input10.s2 = tile_input1_col2[lid];
+ tile_input10.s3 = tile_input1_col3[lid];
+ tile_input10.s4 = tile_input1_col4[lid];
+ tile_input10.s5 = tile_input1_col5[lid];
+ tile_input10.s6 = tile_input1_col6[lid];
+ tile_input10.s7 = tile_input1_col7[lid];
+#if SUB_GROUP_SIZE == 16
+ tile_input10.s8 = tile_input1_col8[lid];
+ tile_input10.s9 = tile_input1_col9[lid];
+ tile_input10.sa = tile_input1_col10[lid];
+ tile_input10.sb = tile_input1_col11[lid];
+ tile_input10.sc = tile_input1_col12[lid];
+ tile_input10.sd = tile_input1_col13[lid];
+ tile_input10.se = tile_input1_col14[lid];
+ tile_input10.sf = tile_input1_col15[lid];
+#endif // SUB_GROUP_SIZE == 16
+
+#endif // !TRANSPOSE_INPUT1
+
+#if !TRANSPOSE_INPUT0
+ const uint common_input0_offset = batch_offset_input0 + output_y_tile * TILE_SIZE_M * INPUT0_SIZE_X + k * TILE_SIZE_K;
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset + i * INPUT0_SIZE_X));
+ }
+
+ tile_output00 = MMAD(tile_input00, tile_input10, tile_output00);
+
+#if TILE_NUM == 2
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset + (TILE_SIZE_M_DIV + i) * INPUT0_SIZE_X));
+ }
+
+ tile_output01 = MMAD(tile_input00, tile_input10, tile_output01);
+#endif // TILE_NUM == 2
+
+#else // !TRANSPOSE_INPUT0
+ MAKE_VECTOR_TYPE(INPUT0_TYPE, PACK_SIZE) temp_input0[SUB_GROUP_SIZE];
+ const uint common_input0_offset = batch_offset_input0 + (k * TILE_SIZE_K + lid * PACK_SIZE) * INPUT0_SIZE_X + output_y_tile * TILE_SIZE_M;
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ temp_input0[i].s0 = input0[common_input0_offset + i];
+ temp_input0[i].s1 = input0[common_input0_offset + 1 * INPUT0_SIZE_X + i];
+ temp_input0[i].s2 = input0[common_input0_offset + 2 * INPUT0_SIZE_X + i];
+ temp_input0[i].s3 = input0[common_input0_offset + 3 * INPUT0_SIZE_X + i];
+
+ tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]);
+ }
+
+ tile_output00 = MMAD(tile_input00, tile_input10, tile_output00);
+
+#if TILE_NUM == 2
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ temp_input0[i].s0 = input0[common_input0_offset + TILE_SIZE_M_DIV + i];
+ temp_input0[i].s1 = input0[common_input0_offset + 1 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i];
+ temp_input0[i].s2 = input0[common_input0_offset + 2 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i];
+ temp_input0[i].s3 = input0[common_input0_offset + 3 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i];
+
+ tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]);
+ }
+
+ tile_output01 = MMAD(tile_input00, tile_input10, tile_output01);
+#endif // TILE_NUM == 2
+
+#endif // !TRANSPOSE_INPUT0
+ }
+
+#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS_PRELOAD;
+#endif // HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output00[i]);
+ dequantized *= TO_ACTIVATION_TYPE(ALPHA);
+#ifdef INPUT2_TYPE
+ dequantized += TO_ACTIVATION_TYPE(BETA) * tile_input20[i];
+#endif // INPUT2_TYPE
+
+#if HAS_FUSED_OPS
+#if FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS_CALC;
+#else // FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS;
+#endif // FUSED_OPS_CAN_USE_PRELOAD
+
+ OUTPUT_TYPE res = FUSED_OPS_RESULT;
+ output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res;
+ output_y++;
+#else // HAS_FUSED_OPS
+ output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized;
+#endif // HAS_FUSED_OPS
+ }
+
+#if TILE_NUM == 2
+#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS_PRELOAD;
+#endif
+
+ for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+ ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output01[i]);
+ dequantized *= TO_ACTIVATION_TYPE(ALPHA);
+
+#if HAS_FUSED_OPS
+#if FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS_CALC;
+#else // FUSED_OPS_CAN_USE_PRELOAD
+ FUSED_OPS;
+#endif // FUSED_OPS_CAN_USE_PRELOAD
+
+ OUTPUT_TYPE res = FUSED_OPS_RESULT;
+ output[batch_offset_output + (output_y_tile * TILE_SIZE_M + TILE_SIZE_M_DIV + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res;
+ output_y++;
+#else // HAS_FUSED_OPS
+ output[batch_offset_output + (output_y_tile * TILE_SIZE_M + TILE_SIZE_M_DIV + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized;
+#endif // HAS_FUSED_OPS
+ }
+#endif // TILE_NUM == 2
+
+}
+#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+
+#undef PACK_SIZE
+#undef AS_TYPE
+#undef ACCUMULATOR_TYPE_VEC
+#undef ACTIVATION_TYPE_VEC
+#undef PACKED_INPUT0_TYPE_VEC
+#undef PACKED_INPUT1_TYPE_VEC
+#undef BLOCK_READ
+#undef BLOCK_SHUFFLE
+#undef MMAD
+#undef TILE_SIZE_M_DIV
/*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2016-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
return ret;
}
+inline int FUNC(mmad_4)(char4 input, char4 weight, int acc) __attribute__((overloadable))
+{
+ acc += (input[0] * weight[0]);
+ acc += (input[1] * weight[1]);
+ acc += (input[2] * weight[2]);
+ acc += (input[3] * weight[3]);
+ return acc;
+}
+
+inline int FUNC(mmad_4)(char4 input, uchar4 weight, int acc) __attribute__((overloadable))
+{
+ acc += (input[0] * weight[0]);
+ acc += (input[1] * weight[1]);
+ acc += (input[2] * weight[2]);
+ acc += (input[3] * weight[3]);
+ return acc;
+}
+
inline int FUNC(mmad_4)(uchar4 input, char4 weight, int acc) __attribute__((overloadable))
{
acc += (input[0] * weight[0]);
return acc;
}
-inline int FUNC(mmad_4)(char4 input, char4 weight, int acc) __attribute__((overloadable))
+inline int FUNC(mmad_4)(uchar4 input, uchar4 weight, int acc) __attribute__((overloadable))
{
acc += (input[0] * weight[0]);
acc += (input[1] * weight[1]);
return acc;
}
+inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable))
+{
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_char4(B_vectors[0]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_char4(B_vectors[1]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_char4(B_vectors[2]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_char4(B_vectors[3]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_char4(B_vectors[4]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_char4(B_vectors[5]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_char4(B_vectors[6]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_char4(B_vectors[7]), acc);
+
+ return acc;
+}
+
+inline int FUNC(mmad8)(int8 A_scalars, uint8 B_vectors, int acc) __attribute__((overloadable))
+{
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_uchar4(B_vectors[0]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_uchar4(B_vectors[1]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_uchar4(B_vectors[2]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_uchar4(B_vectors[3]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_uchar4(B_vectors[4]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_uchar4(B_vectors[5]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_uchar4(B_vectors[6]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_uchar4(B_vectors[7]), acc);
+
+ return acc;
+}
+
inline int FUNC(mmad8)(uint8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable))
{
acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_char4(B_vectors[0]), acc);
return acc;
}
-inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable))
+
+inline int FUNC(mmad8)(uint8 A_scalars, uint8 B_vectors, int acc) __attribute__((overloadable))
+{
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_uchar4(B_vectors[0]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_uchar4(B_vectors[1]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_uchar4(B_vectors[2]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_uchar4(B_vectors[3]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_uchar4(B_vectors[4]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_uchar4(B_vectors[5]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_uchar4(B_vectors[6]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_uchar4(B_vectors[7]), acc);
+
+ return acc;
+}
+
+inline int FUNC(mmad16)(int16 A_scalars, int16 B_vectors, int acc) __attribute__((overloadable))
{
acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_char4(B_vectors[0]), acc);
acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_char4(B_vectors[1]), acc);
acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_char4(B_vectors[5]), acc);
acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_char4(B_vectors[6]), acc);
acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_char4(B_vectors[7]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[8]), as_char4(B_vectors[8]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[9]), as_char4(B_vectors[9]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[10]), as_char4(B_vectors[10]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[11]), as_char4(B_vectors[11]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[12]), as_char4(B_vectors[12]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[13]), as_char4(B_vectors[13]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[14]), as_char4(B_vectors[14]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[15]), as_char4(B_vectors[15]), acc);
return acc;
}
+inline int FUNC(mmad16)(int16 A_scalars, uint16 B_vectors, int acc) __attribute__((overloadable))
+{
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_uchar4(B_vectors[0]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_uchar4(B_vectors[1]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_uchar4(B_vectors[2]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_uchar4(B_vectors[3]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_uchar4(B_vectors[4]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_uchar4(B_vectors[5]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_uchar4(B_vectors[6]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_uchar4(B_vectors[7]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[8]), as_uchar4(B_vectors[8]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[9]), as_uchar4(B_vectors[9]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[10]), as_uchar4(B_vectors[10]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[11]), as_uchar4(B_vectors[11]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[12]), as_uchar4(B_vectors[12]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[13]), as_uchar4(B_vectors[13]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[14]), as_uchar4(B_vectors[14]), acc);
+ acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[15]), as_uchar4(B_vectors[15]), acc);
+
+ return acc;
+}
+
+inline int FUNC(mmad16)(uint16 A_scalars, int16 B_vectors, int acc) __attribute__((overloadable))
+{
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_char4(B_vectors[0]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_char4(B_vectors[1]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_char4(B_vectors[2]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_char4(B_vectors[3]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_char4(B_vectors[4]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_char4(B_vectors[5]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_char4(B_vectors[6]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_char4(B_vectors[7]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[8]), as_char4(B_vectors[8]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[9]), as_char4(B_vectors[9]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[10]), as_char4(B_vectors[10]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[11]), as_char4(B_vectors[11]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[12]), as_char4(B_vectors[12]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[13]), as_char4(B_vectors[13]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[14]), as_char4(B_vectors[14]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[15]), as_char4(B_vectors[15]), acc);
+
+ return acc;
+}
+
+inline int FUNC(mmad16)(uint16 A_scalars, uint16 B_vectors, int acc) __attribute__((overloadable))
+{
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_uchar4(B_vectors[0]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_uchar4(B_vectors[1]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_uchar4(B_vectors[2]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_uchar4(B_vectors[3]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_uchar4(B_vectors[4]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_uchar4(B_vectors[5]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_uchar4(B_vectors[6]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_uchar4(B_vectors[7]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[8]), as_uchar4(B_vectors[8]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[9]), as_uchar4(B_vectors[9]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[10]), as_uchar4(B_vectors[10]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[11]), as_uchar4(B_vectors[11]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[12]), as_uchar4(B_vectors[12]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[13]), as_uchar4(B_vectors[13]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[14]), as_uchar4(B_vectors[14]), acc);
+ acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[15]), as_uchar4(B_vectors[15]), acc);
+
+ return acc;
+}
+
+inline int4 FUNC(mmad4x8)(int4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable))
+{
+ int4 ret;
+ for(uint i = 0; i < 4; i++)
+ {
+ int8 A_scalars;
+ A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+ A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+ A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+ A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+ A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+ A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+ A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+ A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+ ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
+ }
+ return ret;
+}
+
+inline int4 FUNC(mmad4x8)(int4 A_vectors, uint8 B_vectors, int4 acc) __attribute__((overloadable))
+{
+ int4 ret;
+ for(uint i = 0; i < 4; i++)
+ {
+ int8 A_scalars;
+ A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+ A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+ A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+ A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+ A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+ A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+ A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+ A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+ ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
+ }
+ return ret;
+}
+
inline int4 FUNC(mmad4x8)(uint4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable))
{
int4 ret;
return ret;
}
-inline int4 FUNC(mmad4x8)(int4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable))
+inline int4 FUNC(mmad4x8)(uint4 A_vectors, uint8 B_vectors, int4 acc) __attribute__((overloadable))
{
int4 ret;
for(uint i = 0; i < 4; i++)
{
+ uint8 A_scalars;
+ A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+ A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+ A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+ A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+ A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+ A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+ A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+ A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+ ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
+ }
+ return ret;
+}
+
+inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc) __attribute__((overloadable))
+{
+ int8 ret;
+ for(uint i = 0; i < 8; i++)
+ {
+ int8 A_scalars;
+ A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+ A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+ A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+ A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+ A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+ A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+ A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+ A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+ ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
+ }
+ return ret;
+}
+
+inline int8 FUNC(mmad8x8)(int8 A_vectors, uint8 B_vectors, int8 acc) __attribute__((overloadable))
+{
+ int8 ret;
+ for(uint i = 0; i < 8; i++)
+ {
int8 A_scalars;
A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
return ret;
}
-inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc) __attribute__((overloadable))
+inline int8 FUNC(mmad8x8)(uint8 A_vectors, uint8 B_vectors, int8 acc) __attribute__((overloadable))
{
int8 ret;
for(uint i = 0; i < 8; i++)
{
- int8 A_scalars;
+ uint8 A_scalars;
A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
return ret;
}
-// TODO: remove it when cl_intel_subgroups_char extension will work
+inline int16 FUNC(mmad16x16)(int16 A_vectors, int16 B_vectors, int16 acc) __attribute__((overloadable))
+{
+ int16 ret;
+ for(uint i = 0; i < 16; i++)
+ {
+ int16 A_scalars;
+ A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+ A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+ A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+ A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+ A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+ A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+ A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+ A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+ A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8);
+ A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9);
+ A_scalars.sa = sub_group_broadcast(A_vectors[i], 10);
+ A_scalars.sb = sub_group_broadcast(A_vectors[i], 11);
+ A_scalars.sc = sub_group_broadcast(A_vectors[i], 12);
+ A_scalars.sd = sub_group_broadcast(A_vectors[i], 13);
+ A_scalars.se = sub_group_broadcast(A_vectors[i], 14);
+ A_scalars.sf = sub_group_broadcast(A_vectors[i], 15);
+ ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]);
+ }
+ return ret;
+}
+
+inline int16 FUNC(mmad16x16)(int16 A_vectors, uint16 B_vectors, int16 acc) __attribute__((overloadable))
+{
+ int16 ret;
+ for(uint i = 0; i < 16; i++)
+ {
+ int16 A_scalars;
+ A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+ A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+ A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+ A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+ A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+ A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+ A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+ A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+ A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8);
+ A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9);
+ A_scalars.sa = sub_group_broadcast(A_vectors[i], 10);
+ A_scalars.sb = sub_group_broadcast(A_vectors[i], 11);
+ A_scalars.sc = sub_group_broadcast(A_vectors[i], 12);
+ A_scalars.sd = sub_group_broadcast(A_vectors[i], 13);
+ A_scalars.se = sub_group_broadcast(A_vectors[i], 14);
+ A_scalars.sf = sub_group_broadcast(A_vectors[i], 15);
+ ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]);
+ }
+ return ret;
+}
+
+inline int16 FUNC(mmad16x16)(uint16 A_vectors, int16 B_vectors, int16 acc) __attribute__((overloadable))
+{
+ int16 ret;
+ for(uint i = 0; i < 16; i++)
+ {
+ uint16 A_scalars;
+ A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+ A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+ A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+ A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+ A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+ A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+ A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+ A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+ A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8);
+ A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9);
+ A_scalars.sa = sub_group_broadcast(A_vectors[i], 10);
+ A_scalars.sb = sub_group_broadcast(A_vectors[i], 11);
+ A_scalars.sc = sub_group_broadcast(A_vectors[i], 12);
+ A_scalars.sd = sub_group_broadcast(A_vectors[i], 13);
+ A_scalars.se = sub_group_broadcast(A_vectors[i], 14);
+ A_scalars.sf = sub_group_broadcast(A_vectors[i], 15);
+ ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]);
+ }
+ return ret;
+}
+
+inline int16 FUNC(mmad16x16)(uint16 A_vectors, uint16 B_vectors, int16 acc) __attribute__((overloadable))
+{
+ int16 ret;
+ for(uint i = 0; i < 16; i++)
+ {
+ uint16 A_scalars;
+ A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+ A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+ A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+ A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+ A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+ A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+ A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+ A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+ A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8);
+ A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9);
+ A_scalars.sa = sub_group_broadcast(A_vectors[i], 10);
+ A_scalars.sb = sub_group_broadcast(A_vectors[i], 11);
+ A_scalars.sc = sub_group_broadcast(A_vectors[i], 12);
+ A_scalars.sd = sub_group_broadcast(A_vectors[i], 13);
+ A_scalars.se = sub_group_broadcast(A_vectors[i], 14);
+ A_scalars.sf = sub_group_broadcast(A_vectors[i], 15);
+ ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]);
+ }
+ return ret;
+}
+
inline void FUNC(sub_group_block_write_uchar16)(__global uchar* outPtr, uchar16 v)
{
#ifdef cl_intel_subgroups_char
ret.s7 = ptr[idx]; idx += get_max_sub_group_size();
return ret;
+
#endif
}
#endif
}
-//
-
-
#define MMAD_8(A, B, C) FUNC_CALL(mmad8)(A, B, C)
#define MMAD_4x8(A, B, C) FUNC_CALL(mmad4x8)(A, B, C)
#define MMAD_8x8(A, B, C) FUNC_CALL(mmad8x8)(A, B, C)
+#define MMAD_16x16(A, B, C) FUNC_CALL(mmad16x16)(A, B, C)
#define SLM_BLOCK_WRITE_4(A, B) (FUNC_CALL(intel_sub_group_block_write_4)(A, B))
#define SLM_BLOCK_READ_4(A) (FUNC_CALL(intel_sub_group_block_read_uint4)(A))
#define SLM_BLOCK_READ_8(A) (FUNC_CALL(intel_sub_group_block_read_uint8)(A))
bool can_fuse_parent1 = (parent1->is_type<convolution>() && conv_supports_fusings(parent1->as<convolution>())) ||
(parent1->is_type<mvn>() && mvn_supports_fusings(parent1->as<mvn>())) ||
- (parent1->is_type<deconvolution>()) || (parent1->is_type<permute>());
+ (parent1->is_type<deconvolution>()) || (parent1->is_type<permute>()) ||
+ (parent1->is_type<gemm>());
bool can_fuse_parent2 = (parent2->is_type<convolution>() && conv_supports_fusings(parent2->as<convolution>())) ||
(parent2->is_type<mvn>() && mvn_supports_fusings(parent2->as<mvn>())) ||
- (parent2->is_type<deconvolution>()) || (parent2->is_type<permute>());
+ (parent2->is_type<deconvolution>()) || (parent2->is_type<permute>()) ||
+ (parent2->is_type<gemm>());
std::vector<bool> can_fuse_parents = { can_fuse_parent1, can_fuse_parent2 };
struct gemm_test_params {
std::vector<tensor> in_shapes;
+ tensor out_shape;
tensor kernel;
tensor pad;
data_types data_type_in0;
layout get_per_channel_layout(gemm_test_params& p) {
return layout{ p.default_type, p.default_format, tensor{1, p.in_shapes.at(0).feature[0], 1, 1} };
}
+
+ layout get_output_layout(gemm_test_params& p) {
+ return layout{ p.default_type, p.input_format, p.out_shape };
+ }
};
// in_shape; out_shape; kernel; stride; pad; dilation; groups; data_type; input_format; weights_type; weights_format; default_type; default_format;
#define CASE_FC_U8S8_2 {2, 1, 3, 1}, {2, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_U8S8_3 {2, 32, 1, 1}, {2, 16, 1, 1}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 12, 16}}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx
gemm_test_params{ CASE_GEMM_2IN_S8U8_1, 3, 6 },
}), );
+class gemm_int8_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {};
+TEST_P(gemm_int8_2in_act_scale_quantize_eltwise_i8, basic) {
+ auto p = GetParam();
+ create_topologies(input_layout("input0", get_input_layout(p, 0)),
+ input_layout("input1", get_input_layout(p, 1)),
+ data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
+ data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
+ data("out_lo", get_mem(get_single_element_layout(p), -127)),
+ data("out_hi", get_mem(get_single_element_layout(p), 127)),
+ data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
+ data("eltwise_data", get_mem(get_output_layout(p))),
+ gemm("gemm_prim", { "input0", "input1" }, data_types::f32),
+ activation("activation", "gemm_prim", activation_func::exp),
+ scale("scale", "activation", "scale_data"),
+ quantize("quantize", "scale", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
+ eltwise("sum", { "quantize", "eltwise_data"}, eltwise_mode::sum, data_types::f32),
+ reorder("reorder_bfyx", "sum", p.default_format, data_types::f32)
+ );
+
+ tolerance = 1.0f;
+ execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_eltwise_i8,
+ ::testing::ValuesIn(std::vector<gemm_test_params>{
+ gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 7 },
+ gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 7 },
+}), );
+
/* ----------------------------------------------------------------------------------------------------- */
/* ---------------------------------------- Resample cases --------------------------------------------- */
/* ----------------------------------------------------------------------------------------------------- */
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
auto output_ptr = output.pointer<float>();
EXPECT_EQ(output_ptr.size(), (uint32_t)8);
- for (uint32_t i = 0; i < out_data.size(); ++i) {
+ for (uint32_t i = 0; i < out_data.size(); ++i) {
EXPECT_FLOAT_EQ(output_ptr[i], out_data[i]);
}
}
+struct gemm_int8_test_params {
+ size_t m_size;
+ size_t n_size;
+ size_t k_size;
+ size_t b0_num;
+ size_t f0_num;
+ size_t b1_num;
+ size_t f1_num;
+ size_t b2_num;
+ size_t f2_num;
+ size_t b_out_num;
+ size_t f_out_num;
+ bool transpose_input0;
+ bool transpose_input1;
+ float alpha;
+ float beta;
+ std::string kernel_name;
+};
+
+#define CASE_GEMM_INT8_NN_TRANSPOSITION 64, 64, 64, 1, 2, 1, 2, 1, 2, 1, 2, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_NT_TRANSPOSITION 32, 64, 32, 2, 1, 2, 1, 2, 1, 2, 1, false, true, 1.7f, 1.3f
+#define CASE_GEMM_INT8_TN_TRANSPOSITION 128, 64, 32, 2, 2, 2, 2, 2, 2, 2, 2, true, false, 1.0f, 0.0f
+#define CASE_GEMM_INT8_TT_TRANSPOSITION 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.2f, 0.5f
+
+#define CASE_GEMM_INT8_BROADCAST_1 32, 32, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_BROADCAST_2 32, 32, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, false, 1.7f, 1.3f
+#define CASE_GEMM_INT8_BROADCAST_3 64, 32, 32, 1, 2, 2, 1, 1, 2, 2, 2, false, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_BROADCAST_4 32, 64, 32, 1, 1, 2, 2, 2, 2, 2, 2, false, false, 1.2f, 0.5f
+
+#define CASE_GEMM_INT8_LEFTOVERS_1 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_LEFTOVERS_2 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
+#define CASE_GEMM_INT8_LEFTOVERS_3 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_LEFTOVERS_4 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
+#define CASE_GEMM_INT8_LEFTOVERS_5 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_LEFTOVERS_6 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
+#define CASE_GEMM_INT8_LEFTOVERS_7 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_LEFTOVERS_8 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
+#define CASE_GEMM_INT8_LEFTOVERS_9 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_LEFTOVERS_10 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
+#define CASE_GEMM_INT8_LEFTOVERS_11 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_LEFTOVERS_12 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
+
+#define CASE_GEMM_INT8_COMBO_1 8, 8, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_COMBO_2 16, 16, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, true, 1.7f, 0.0f
+#define CASE_GEMM_INT8_COMBO_3 11, 31, 21, 7, 15, 7, 15, 7, 15, 7, 15, true, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_COMBO_4 32, 32, 32, 3, 6, 3, 6, 3, 6, 3, 6, true, true, 1.2f, 4.0f
+
+template <typename T>
+class GemmInt8Test : public ::testing::TestWithParam<T> {
+public:
+
+ inline size_t getGemmIndex(size_t x, size_t y, size_t f, size_t b, size_t x_size, size_t y_size, size_t f_num, size_t b_num,
+ size_t x_pitch, size_t y_pitch, size_t f_pitch, size_t b_pitch) {
+ return (x % x_size) * x_pitch + (y % y_size) * y_pitch + (f % f_num) * f_pitch + (b % b_num) * b_pitch;
+ }
+
+ void execute(T& p) {
+ const auto& engine = get_test_engine();
+
+ auto y0_size = p.m_size;
+ auto y0_pitch = p.k_size;
+ auto x0_size = p.k_size;
+ auto x0_pitch = 1;
+ auto f0_pitch = y0_size * x0_size;
+ auto b0_pitch = p.f0_num * f0_pitch;
+
+ auto y1_size = p.k_size;
+ auto y1_pitch = p.n_size;
+ auto x1_size = p.n_size;
+ auto x1_pitch = 1;
+ auto f1_pitch = y1_size * x1_size;
+ auto b1_pitch = p.f1_num * f1_pitch;
+
+ auto y2_size = p.m_size;
+ auto y2_pitch = p.n_size;
+ auto x2_size = p.n_size;
+ auto x2_pitch = 1;
+ auto f2_pitch = y2_size * x2_size;
+ auto b2_pitch = p.f2_num * f2_pitch;
+
+ auto y_out_size = p.m_size;
+ auto y_out_pitch = p.n_size;
+ auto x_out_size = p.n_size;
+ auto x_out_pitch = 1;
+ auto f_out_pitch = y_out_size * x_out_size;
+ auto b_out_pitch = p.f_out_num * f_out_pitch;
+
+ if (p.transpose_input0) {
+ y0_size = p.k_size;
+ y0_pitch = p.m_size;
+ x0_size = p.m_size;
+ x0_pitch = 1;
+ }
+
+ if (p.transpose_input1) {
+ y1_size = p.n_size;
+ y1_pitch = p.k_size;
+ x1_size = p.k_size;
+ x1_pitch = 1;
+ }
+
+ auto input0_size = tensor((int)p.b0_num, (int)p.f0_num, (int)x0_size, (int)y0_size);
+ auto input0_data = generate_random_4d<int8_t>(p.b0_num, p.f0_num, x0_size, y0_size, -128, 127, 1);
+ auto input0_data_bfyx = flatten_4d(format::bfyx, input0_data);
+ auto input0_mem = memory::allocate(engine, { data_types::i8, format::bfyx, input0_size });
+ set_values(input0_mem, input0_data_bfyx);
+
+ auto input1_size = tensor((int)p.b1_num, (int)p.f1_num, (int)x1_size, (int)y1_size);
+ auto input1_data = generate_random_4d<uint8_t>(p.b1_num, p.f1_num, x1_size, y1_size, 0, 255, 1);
+ auto input1_data_bfyx = flatten_4d(format::bfyx, input1_data);
+ auto input1_mem = memory::allocate(engine, { data_types::u8, format::bfyx, input1_size });
+ set_values(input1_mem, input1_data_bfyx);
+
+ auto input2_size = tensor((int)p.b2_num, (int)p.f2_num, (int)x2_size, (int)y2_size);
+ auto input2_data = generate_random_4d<float>(p.b2_num, p.f2_num, x2_size, y2_size, -10, 10);
+ auto input2_data_bfyx = flatten_4d(format::bfyx, input2_data);
+ auto input2_mem = memory::allocate(engine, { data_types::f32, format::bfyx, input2_size });
+ set_values(input2_mem, input2_data_bfyx);
+
+ std::vector<float> out_data(p.b_out_num * p.f_out_num * p.m_size * p.n_size);
+
+ for (size_t b = 0; b < p.b_out_num; ++b) {
+ for (size_t f = 0; f < p.f_out_num; ++f) {
+ for (size_t i = 0; i < p.m_size; ++i) {
+ for (size_t j = 0; j < p.n_size; ++j) {
+ size_t input2_data_index = getGemmIndex(j, i, f, b, x2_size, y2_size, p.f2_num, p.b2_num, x2_pitch, y2_pitch, f2_pitch, b2_pitch);
+ size_t out_data_index = getGemmIndex(j, i, f, b, x_out_size, y_out_size, p.f_out_num, p.b_out_num,
+ x_out_pitch, y_out_pitch, f_out_pitch, b_out_pitch);
+ int32_t acc = 0;
+
+ for (size_t k = 0; k < p.k_size; ++k) {
+ size_t input0_data_index = getGemmIndex(k * (!p.transpose_input0) + i * p.transpose_input0, i * (!p.transpose_input0) +
+ k * p.transpose_input0, f, b, x0_size, y0_size, p.f0_num, p.b0_num, x0_pitch, y0_pitch, f0_pitch, b0_pitch);
+ size_t input1_data_index = getGemmIndex(j * (!p.transpose_input1) + k * p.transpose_input1, k * (!p.transpose_input1) +
+ j * p.transpose_input1, f, b, x1_size, y1_size, p.f1_num, p.b1_num, x1_pitch, y1_pitch, f1_pitch, b1_pitch);
+
+ acc += input0_data_bfyx[input0_data_index] * input1_data_bfyx[input1_data_index];
+ }
+
+ out_data[out_data_index] = (float)acc;
+ out_data[out_data_index] *= p.alpha;
+ out_data[out_data_index] += p.beta * input2_data_bfyx[input2_data_index];
+ }
+ }
+ }
+ }
+
+ topology topology;
+ topology.add(input_layout("input0", input0_mem.get_layout()));
+ topology.add(input_layout("input1", input1_mem.get_layout()));
+ topology.add(input_layout("input2", input2_mem.get_layout()));
+ topology.add(gemm("output", { "input0", "input1", "input2" }, data_types::f32, p.transpose_input0, p.transpose_input1, p.alpha, p.beta));
+
+ build_options options;
+ implementation_desc gemm_int8_impl = { format::bfyx, p.kernel_name };
+ options.set_option(build_option::force_implementations({ {"output", gemm_int8_impl} }));
+
+ network network(engine, topology, options);
+ network.set_input_data("input0", input0_mem);
+ network.set_input_data("input1", input1_mem);
+ network.set_input_data("input2", input2_mem);
+ auto outputs = network.execute();
+
+ auto output = outputs.at("output").get_memory();
+ auto output_ptr = output.pointer<float>();
+
+ EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size));
+ for (size_t i = 0; i < out_data.size(); ++i) {
+ EXPECT_FLOAT_EQ(output_ptr[i], out_data[i]);
+ }
+ }
+};
+
+class gemm_int8_transposition_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
+TEST_P(gemm_int8_transposition_tests, basic) { auto p = GetParam(); execute(p); }
+
+INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_transposition_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
+ gemm_int8_test_params{ CASE_GEMM_INT8_NN_TRANSPOSITION, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_NT_TRANSPOSITION, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_TN_TRANSPOSITION, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_TT_TRANSPOSITION, "gemm_mmad_int8" },
+}), );
+
+class gemm_int8_broadcast_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
+TEST_P(gemm_int8_broadcast_tests, basic) { auto p = GetParam(); execute(p); }
+
+INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_broadcast_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
+ gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_1, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_2, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_3, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_4, "gemm_mmad_int8" },
+}), );
+
+class gemm_int8_leftovers_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
+TEST_P(gemm_int8_leftovers_tests, basic) { auto p = GetParam(); execute(p); }
+
+INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_leftovers_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_1, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_2, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_3, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_4, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_5, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_6, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_7, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_8, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_9, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_10, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_11, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_12, "gemm_mmad_int8" },
+}), );
+
+class gemm_int8_combo_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
+TEST_P(gemm_int8_combo_tests, basic) { auto p = GetParam(); execute(p); }
+
+INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_combo_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
+ gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_1, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_2, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_3, "gemm_mmad_int8" },
+ gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_4, "gemm_mmad_int8" },
+}), );