[IE CLDNN] GEMM int8 optimization using MMAD macro (#635)
authorIlya Znamenskiy <ilya.znamenskiy@intel.com>
Fri, 5 Jun 2020 11:28:21 +0000 (14:28 +0300)
committerGitHub <noreply@github.com>
Fri, 5 Jun 2020 11:28:21 +0000 (14:28 +0300)
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.h
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.h [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_selector.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_mmad_int8.cl [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/mmad.cl
inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp

index 6155e1b..249e47f 100644 (file)
@@ -72,13 +72,6 @@ KernelsData GemmKernelBase::GetCommonKernelsData(const Params& params,
     auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
     auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
 
-    uint32_t fused_deps_total = 0;
-    for (auto& fused_dep : prim_params.fused_ops) {
-        for (int i = 0; i < static_cast<int>(fused_dep.dep_size); i++) {
-            fused_deps_total++;
-        }
-    }
-
     auto& kernel = k_data.kernels[0];
     FillCLKernelData(kernel,
                      run_info,
@@ -90,7 +83,7 @@ KernelsData GemmKernelBase::GetCommonKernelsData(const Params& params,
                      false,
                      false,
                      (uint32_t)prim_params.inputs.size(),
-                     fused_deps_total);
+                     GetFusedPrimitiveInputsCount(params));
 
     k_data.estimatedTime = estimated_time;
 
index 8b68ccc..d30d454 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2020 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@ public:
 
 protected:
     virtual JitConstants GetJitConstants(const gemm_params& params) const;
-    DispatchData SetDefault(const gemm_params& params) const;
+    virtual DispatchData SetDefault(const gemm_params& params) const;
     KernelsData GetCommonKernelsData(const Params& params, const optional_params&, float estimated_time) const;
     // Fused ops
     virtual JitConstants GetFusedPrimitivesJitConstants(const gemm_params& params, const DispatchData& kd) const;
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.cpp
new file mode 100644 (file)
index 0000000..62978d7
--- /dev/null
@@ -0,0 +1,204 @@
+/*
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+#include "gemm_kernel_mmad_int8.h"
+
+namespace kernel_selector {
+ParamsKey GemmKernelMMADint8::GetSupportedKey() const {
+    ParamsKey k;
+
+    k.EnableInputDataType(Datatype::INT8);
+    k.EnableInputDataType(Datatype::UINT8);
+    k.EnableInputDataType(Datatype::F32);
+    k.EnableOutputDataType(Datatype::F32);
+    k.EnableOutputDataType(Datatype::F16);
+    k.EnableOutputDataType(Datatype::INT8);
+    k.EnableOutputDataType(Datatype::UINT8);
+    k.EnableInputLayout(DataLayout::bfyx);
+    k.EnableOutputLayout(DataLayout::bfyx);
+    k.EnableInputLayout(DataLayout::bfzyx);
+    k.EnableOutputLayout(DataLayout::bfzyx);
+    k.EnableInputLayout(DataLayout::bfwzyx);
+    k.EnableOutputLayout(DataLayout::bfwzyx);
+
+    k.EnableBatching();
+    k.EnableDifferentTypes();
+    k.EnableTensorPitches();
+    k.EnableQuantization(QuantizationType::SYMMETRIC);
+
+    return k;
+}
+
+JitConstants GemmKernelMMADint8::GetJitConstants(const gemm_params& params) const {
+    JitConstants jit = Parent::GetJitConstants(params);
+    GemmTuningData td = SetTuningParams(params);
+
+    jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", td.simd_size));
+    jit.Merge(MakeTypeJitConstants(Datatype::INT32, "ACCUMULATOR"));
+    jit.Merge(MakeTypeJitConstants(Datatype::F32, "ACTIVATION"));
+    jit.Merge(MakeTypeJitConstants(params.inputs[0].GetDType() == Datatype::INT8 ? Datatype::INT32 : Datatype::UINT32, "PACKED_INPUT0"));
+    jit.Merge(MakeTypeJitConstants(params.inputs[1].GetDType() == Datatype::INT8 ? Datatype::INT32 : Datatype::UINT32, "PACKED_INPUT1"));
+    jit.AddConstant(MakeJitConstant("TILE_NUM", td.tile_num));
+    jit.AddConstant(MakeJitConstant("TILE_SIZE_M", td.simd_size * td.tile_num));
+    jit.AddConstant(MakeJitConstant("TILE_SIZE_N", td.simd_size));
+    jit.AddConstant(MakeJitConstant("TILE_SIZE_K", td.simd_size * td.pack_size));
+    jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_M", td.size_m % (td.simd_size * td.tile_num)));
+    jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_N", td.size_n % td.simd_size));
+    jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_K", td.size_k % (td.simd_size * td.pack_size)));
+
+    if (!params.fused_ops.empty()) {
+        auto input_dt = GetActivationType(params);
+        FusedOpsConfiguration conf = { "", {"b", "f", "output_y", "output_x"}, "dequantized", input_dt, 1 };
+        conf.SetLoopAxes({ Tensor::DataChannelName::Y }, true);
+        jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
+    }
+
+    return jit;
+}
+
+GemmKernelBase::DispatchData GemmKernelMMADint8::SetDefault(const gemm_params& params) const {
+    const auto& output = params.output;
+    auto total_batches = output.LogicalSize() / (output.X().v * output.Y().v);
+
+    DispatchData kd;
+    GemmTuningData td = SetTuningParams(params);
+
+    std::vector<size_t> global = { Align(output.X().v, td.simd_size),
+                                   Align(output.Y().v, td.simd_size * td.tile_num) / (td.simd_size * td.tile_num),
+                                   total_batches };
+
+    std::vector<size_t> local = { td.simd_size, 1, 1 };
+
+    kd.gws0 = global[0];
+    kd.gws1 = global[1];
+    kd.gws2 = global[2];
+
+    kd.lws0 = local[0];
+    kd.lws1 = local[1];
+    kd.lws2 = local[2];
+
+    return kd;
+}
+
+GemmKernelMMADint8::GemmTuningData GemmKernelMMADint8::InitGemmTuningData(const gemm_params& params) const {
+    GemmTuningData tuning_data;
+
+    tuning_data.size_m = params.output.Y().v;
+    tuning_data.size_n = params.output.X().v;
+    tuning_data.size_k = params.transpose_input0 ? params.inputs[0].Y().v : params.inputs[0].X().v;
+
+    return tuning_data;
+}
+
+inline size_t GemmKernelMMADint8::GetMmadOperationsNumber(const GemmTuningData& tuning_data) const {
+    return tuning_data.size_m * tuning_data.size_n * tuning_data.size_k;
+}
+
+bool GemmKernelMMADint8::HasLeftovers(const GemmTuningData& tuning_data, int tile_size) const {
+    if (tile_size == 32) {
+        return tuning_data.size_m % 32 || tuning_data.size_n % 16 || tuning_data.size_k % 64;
+    } else if (tile_size == 16) {
+        return tuning_data.size_m % 16 || tuning_data.size_n % 16 || tuning_data.size_k % 64;
+    } else if (tile_size == 8) {
+        return tuning_data.size_m % 8 || tuning_data.size_n % 8 || tuning_data.size_k % 32;
+    } else {
+        return true;
+    }
+}
+
+GemmKernelMMADint8::GemmTuningData GemmKernelMMADint8::SetTuningParams(const gemm_params& params) const {
+    GemmTuningData tuning_data = InitGemmTuningData(params);
+    auto mmad_operations_number = GetMmadOperationsNumber(tuning_data);
+
+    bool leftovers_simd16x2 = HasLeftovers(tuning_data, 16*2);
+    bool leftovers_simd16 = HasLeftovers(tuning_data, 16);
+    bool leftovers_simd8 = HasLeftovers(tuning_data, 8);
+
+    bool small_matrices = mmad_operations_number <= 128 * 128 * 128;
+    bool average_matrices = mmad_operations_number <= 448 * 448 * 448;
+    bool very_big_matrices = mmad_operations_number >= 1024 * 1024 * 1024;
+    bool no_input2 = params.inputs.size() == 3 ? false : true;
+
+    size_t simd_size = 16;
+    size_t tile_num = 1;
+
+    if (!leftovers_simd16x2 && very_big_matrices && no_input2)
+        { simd_size = 16; tile_num = 2; }
+    else if (leftovers_simd16 && !leftovers_simd8)
+        { simd_size = 8; }
+    else if ((!params.transpose_input0 && !params.transpose_input1) && average_matrices)
+        { simd_size = 8; }
+    else if (small_matrices)
+        { simd_size = 8; }
+    else
+        { simd_size = 16; }
+
+    tuning_data.simd_size = simd_size;
+    tuning_data.tile_num = tile_num;
+
+    return tuning_data;
+}
+
+KernelsData GemmKernelMMADint8::GetKernelsData(const Params& params, const optional_params& options) const {
+    if (!Validate(params, options)) {
+        return KernelsData();
+    }
+
+    const auto& prim_params = static_cast<const gemm_params&>(params);
+
+    auto run_info = GemmKernelMMADint8::SetDefault(prim_params);
+    KernelData k_data = KernelData::Default<gemm_params>(params);
+
+    auto cldnn_jit = GetJitConstants(prim_params);
+    auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
+    auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
+
+    auto& kernel = k_data.kernels[0];
+    FillCLKernelData(kernel,
+                     run_info,
+                     params.engineInfo,
+                     kernelName,
+                     jit,
+                     entry_point,
+                     DEFAULT,
+                     false,
+                     false,
+                     (uint32_t)prim_params.inputs.size(),
+                     GetFusedPrimitiveInputsCount(params));
+
+    GemmTuningData tuning_data = InitGemmTuningData(prim_params);
+    auto mmad_operations_number = GetMmadOperationsNumber(tuning_data);
+
+    k_data.estimatedTime = mmad_operations_number < 4096 ? DONT_USE_IF_HAVE_SOMETHING_ELSE : FORCE_PRIORITY_3;
+
+    return {k_data};
+}
+
+bool GemmKernelMMADint8::Validate(const Params& params, const optional_params& options) const {
+    if (!Parent::Validate(params, options))
+        return false;
+
+    const auto& gmm_params = static_cast<const gemm_params&>(params);
+    auto input0_type = gmm_params.inputs[0].GetDType();
+    auto input1_type = gmm_params.inputs[1].GetDType();
+
+    if ((input0_type != Datatype::UINT8 && input0_type != Datatype::INT8) ||
+        (input1_type != Datatype::UINT8 && input1_type != Datatype::INT8))
+        return false;
+
+    return true;
+}
+}  // namespace kernel_selector
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.h
new file mode 100644 (file)
index 0000000..f7ff633
--- /dev/null
@@ -0,0 +1,55 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "gemm_kernel_base.h"
+#include <vector>
+
+namespace kernel_selector {
+class GemmKernelMMADint8 : public GemmKernelBase {
+public:
+    using Parent = GemmKernelBase;
+    using DispatchData = CommonDispatchData;
+    struct GemmTuningData {
+        size_t size_m;
+        size_t size_n;
+        size_t size_k;
+
+        size_t simd_size = 16;
+        size_t tile_num = 1;
+        size_t pack_size = 4;
+    };
+
+    GemmKernelMMADint8() : GemmKernelBase("gemm_mmad_int8") {}
+
+    KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
+    ParamsKey GetSupportedKey() const override;
+
+protected:
+    std::vector<FusedOpType> GetSupportedFusedOps() const override {
+        return { FusedOpType::QUANTIZE,
+                 FusedOpType::ACTIVATION,
+                 FusedOpType::SCALE,
+                 FusedOpType::ELTWISE };
+    }
+    bool Validate(const Params& params, const optional_params& options) const override;
+    JitConstants GetJitConstants(const gemm_params& params) const override;
+    DispatchData SetDefault(const gemm_params& params) const override;
+    GemmTuningData InitGemmTuningData(const gemm_params& params) const;
+    GemmTuningData SetTuningParams(const gemm_params& params) const;
+    size_t GetMmadOperationsNumber(const GemmTuningData& tuning_data) const;
+    bool HasLeftovers(const GemmTuningData& tuning_data, int tile_size) const;
+};
+}  // namespace kernel_selector
index 0c44671..043bf4c 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2020 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
 
 #include "gemm_kernel_selector.h"
 #include "gemm_kernel_ref.h"
+#include "gemm_kernel_mmad_int8.h"
 
 namespace kernel_selector {
-gemm_kernel_selector::gemm_kernel_selector() { Attach<GemmKernelRef>(); }
+gemm_kernel_selector::gemm_kernel_selector() {
+    Attach<GemmKernelRef>();
+    Attach<GemmKernelMMADint8>();
+}
 
 KernelsData gemm_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
     return GetNaiveBestKernel(params, options, KernelType::GEMM);
 }
-}  // namespace kernel_selector
\ No newline at end of file
+}  // namespace kernel_selector
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_mmad_int8.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_mmad_int8.cl
new file mode 100644 (file)
index 0000000..a52174d
--- /dev/null
@@ -0,0 +1,518 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "include/include_all.cl"
+#include "include/mmad.cl"
+
+#define PACK_SIZE               4
+
+#define AS_TYPE(type, val)      CAT(as_, type)(val)
+#define ACCUMULATOR_TYPE_VEC    CAT(ACCUMULATOR_TYPE, SUB_GROUP_SIZE)
+#define ACTIVATION_TYPE_VEC     CAT(ACTIVATION_TYPE, SUB_GROUP_SIZE)
+#define PACKED_INPUT0_TYPE_VEC  CAT(PACKED_INPUT0_TYPE, SUB_GROUP_SIZE)
+#define PACKED_INPUT1_TYPE_VEC  CAT(PACKED_INPUT1_TYPE, SUB_GROUP_SIZE)
+#define BLOCK_READ(ptr)         intel_sub_group_block_read((const __global uint*)(ptr))
+#define BLOCK_SHUFFLE           intel_sub_group_shuffle
+
+#if SUB_GROUP_SIZE == 8
+#define MMAD                    MMAD_8x8
+#else // SUB_GROUP_SIZE == 8
+#define MMAD                    MMAD_16x16
+#define TILE_SIZE_M_DIV         (TILE_SIZE_M / 2)
+#endif // SUB_GROUP_SIZE == 8
+
+inline uint FUNC(get_input0_batch_offset)(uint b, uint f, uint w, uint z) {
+#if INPUT0_SIMPLE
+    return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, 0, 0);
+#else // INPUT0_SIMPLE
+#   error gemm_mmad_int8.cl : Unsupported input 0 format
+#endif // INPUT0_SIMPLE
+}
+
+inline uint FUNC(get_input1_batch_offset)(uint b, uint f, uint w, uint z) {
+#if INPUT1_SIMPLE
+    return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, 0, 0);
+#else // INPUT1_SIMPLE
+#   error gemm_mmad_int8.cl : Unsupported input 1 format
+#endif // INPUT1_SIMPLE
+}
+
+#ifdef INPUT2_TYPE
+inline uint FUNC(get_input2_batch_offset)(uint b, uint f, uint w, uint z) {
+#if INPUT2_SIMPLE
+    return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, 0, 0);
+#else // INPUT2_SIMPLE
+#   error gemm_mmad_int8.cl : Unsupported input 2 format
+#endif // INPUT2_SIMPLE
+}
+#endif // INPUT2_TYPE
+
+inline uint FUNC(get_output_batch_offset)(uint b, uint f, uint w, uint z) {
+#if OUTPUT_SIMPLE
+    return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, 0, 0);
+#else // OUTPUT_SIMPLE
+#   error gemm_mmad_int8.cl : Unsupported output format
+#endif // OUTPUT_SIMPLE
+}
+
+inline uint FUNC(get_common_input1_offset)(uint batch_offset_input1, uint k, uint i, uint output_x_tile, uint lid) {
+#if !TRANSPOSE_INPUT1
+    return batch_offset_input1 + (k * TILE_SIZE_K + i * PACK_SIZE) * INPUT1_SIZE_X + output_x_tile * TILE_SIZE_N;
+#else
+    return batch_offset_input1 + (output_x_tile * TILE_SIZE_N + lid) * INPUT1_SIZE_X + k * TILE_SIZE_K + i * PACK_SIZE;
+#endif
+}
+
+inline uint FUNC(get_current_input1_offset)(uint common_input1_offset, uint i, uint lid) {
+#if !TRANSPOSE_INPUT1
+    return common_input1_offset + INPUT1_SIZE_X * i + lid;
+#else
+    return common_input1_offset + i;
+#endif
+}
+
+inline uint FUNC(get_common_input0_offset)(uint batch_offset_input0, uint k, uint i, uint output_y_tile, uint lid) {
+#if !TRANSPOSE_INPUT0
+    return batch_offset_input0 + (output_y_tile * TILE_SIZE_M + i) * INPUT0_SIZE_X + k * TILE_SIZE_K;
+#else
+    return batch_offset_input0 + (k * TILE_SIZE_K + lid * PACK_SIZE) * INPUT0_SIZE_X + output_y_tile * TILE_SIZE_M + i;
+#endif
+}
+
+inline uint FUNC(get_current_input0_offset)(uint common_input0_offset, uint i, uint lid) {
+#if !TRANSPOSE_INPUT0
+    return common_input0_offset + lid * PACK_SIZE + i;
+#else
+    return common_input0_offset + INPUT0_SIZE_X * i;
+#endif
+}
+
+__attribute__((reqd_work_group_size(SUB_GROUP_SIZE, 1, 1)))
+__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
+KERNEL(gemm_mmad_int8)(
+    const __global INPUT0_TYPE* input0,
+    const __global INPUT1_TYPE* input1,
+#ifdef INPUT2_TYPE
+    const __global INPUT2_TYPE* input2,
+#endif // INPUT2_TYPE
+    __global OUTPUT_TYPE* output
+#if HAS_FUSED_OPS_DECLS
+    , FUSED_OPS_DECLS
+#endif // HAS_FUSED_OPS_DECLS
+    )
+
+// ***************************************************************************************** //
+// Kernel with leftovers for all sizes of input matrices and all transposition combinations. //
+// ***************************************************************************************** //
+
+#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+{
+    const uint output_x = (uint)get_global_id(0);
+    const uint output_x_tile = output_x / TILE_SIZE_N;
+    const uint output_y_tile = (uint)get_global_id(1);
+#if HAS_FUSED_OPS
+    uint output_y = output_y_tile * TILE_SIZE_M;
+#endif // HAS_FUSED_OPS
+    uint batch = get_global_id(2);
+    const uint lid = (uint)get_local_id(0);
+
+    const uint z = batch % OUTPUT_SIZE_Z;
+    batch /= OUTPUT_SIZE_Z;
+    const uint w = batch % OUTPUT_SIZE_W;
+    batch /= OUTPUT_SIZE_W;
+    const uint f = batch % OUTPUT_FEATURE_NUM;
+    batch /= OUTPUT_FEATURE_NUM;
+    const uint b = batch % OUTPUT_BATCH_NUM;
+
+    const uint batch_offset_input0 = FUNC_CALL(get_input0_batch_offset)(b, f, w, z);
+    const uint batch_offset_input1 = FUNC_CALL(get_input1_batch_offset)(b, f, w, z);
+#ifdef INPUT2_TYPE
+    const uint batch_offset_input2 = FUNC_CALL(get_input2_batch_offset)(b, f, w, z);
+#endif // INPUT2_TYPE
+    const uint batch_offset_output = FUNC_CALL(get_output_batch_offset)(b, f, w, z);
+
+    PACKED_INPUT0_TYPE_VEC tile_input0;
+    PACKED_INPUT1_TYPE_VEC tile_input1;
+#ifdef INPUT2_TYPE
+    ACTIVATION_TYPE_VEC tile_input2;
+#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+    for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+        if (output_y_tile * TILE_SIZE_M + i >= OUTPUT_SIZE_Y) continue;
+        if (output_x_tile * TILE_SIZE_N + lid >= OUTPUT_SIZE_X) continue;
+
+        tile_input2[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X +
+                                                   output_x_tile * TILE_SIZE_N + lid]);
+    }
+#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+    for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+        tile_input2[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X +
+                                                   output_x_tile * TILE_SIZE_N + lid]);
+    }
+#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+#endif // INPUT2_TYPE
+
+    ACCUMULATOR_TYPE_VEC tile_output = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO);
+
+#if !TRANSPOSE_INPUT0
+    const uint K_BLOCK_NUM = (INPUT0_SIZE_X - 1) / TILE_SIZE_K + 1;
+    const uint K_SIZE = INPUT0_SIZE_X;
+#else // !TRANSPOSE_INPUT0
+    const uint K_BLOCK_NUM = (INPUT0_SIZE_Y - 1) / TILE_SIZE_K + 1;
+    const uint K_SIZE = INPUT0_SIZE_Y;
+#endif // !TRANSPOSE_INPUT0
+
+    for (uint k = 0; k < K_BLOCK_NUM; k++) {
+        MAKE_VECTOR_TYPE(INPUT0_TYPE, PACK_SIZE) temp_input0[SUB_GROUP_SIZE];
+        MAKE_VECTOR_TYPE(INPUT1_TYPE, PACK_SIZE) temp_input1[SUB_GROUP_SIZE];
+
+        for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+            const uint common_input1_offset = FUNC_CALL(get_common_input1_offset)(batch_offset_input1, k, i, output_x_tile, lid);
+
+#if OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+            const uint cur_n = output_x_tile * TILE_SIZE_N + lid;
+            const uint cur_k = k * TILE_SIZE_K + i * PACK_SIZE;
+
+            temp_input1[i] = 0;
+
+            if (cur_n < OUTPUT_SIZE_X) {
+                if (cur_k + 3 < K_SIZE) {
+                    temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+                    temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)];
+                    temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)];
+                    temp_input1[i].s3 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 3, lid)];
+                } else if (cur_k + 2 < K_SIZE) {
+                    temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+                    temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)];
+                    temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)];
+                } else if (cur_k + 1 < K_SIZE) {
+                    temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+                    temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)];
+                } else if (cur_k < K_SIZE) {
+                    temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+                }
+            }
+#else // OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+            temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)];
+            temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)];
+            temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)];
+            temp_input1[i].s3 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 3, lid)];
+#endif // OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+
+            tile_input1[i] = AS_TYPE(PACKED_INPUT1_TYPE, temp_input1[i]);
+        }
+
+        for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+            const uint common_input0_offset = FUNC_CALL(get_common_input0_offset)(batch_offset_input0, k, i, output_y_tile, lid);
+
+#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K
+            const uint cur_m = output_y_tile * TILE_SIZE_M + i;
+            const uint cur_k = k * TILE_SIZE_K + lid * PACK_SIZE;
+
+            temp_input0[i] = 0;
+
+            if (cur_m < OUTPUT_SIZE_Y) {
+                if (cur_k + 3 < K_SIZE) {
+                    temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+                    temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)];
+                    temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)];
+                    temp_input0[i].s3 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 3, lid)];
+                } else if (cur_k + 2 < K_SIZE) {
+                    temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+                    temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)];
+                    temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)];
+                } else if (cur_k + 1 < K_SIZE) {
+                    temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+                    temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)];
+                } else if (cur_k < K_SIZE) {
+                    temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+                }
+            }
+
+            tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]);
+#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K
+
+#if !TRANSPOSE_INPUT0
+            tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset));
+#else // !TRANSPOSE_INPUT0
+            temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)];
+            temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)];
+            temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)];
+            temp_input0[i].s3 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 3, lid)];
+
+            tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]);
+#endif // !TRANSPOSE_INPUT0
+
+#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K
+        }
+
+        tile_output = MMAD(tile_input0, tile_input1, tile_output);
+    }
+
+#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+    FUSED_OPS_PRELOAD;
+#endif // HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+
+    for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+        if (output_y_tile * TILE_SIZE_M + i >= OUTPUT_SIZE_Y) continue;
+        if (output_x_tile * TILE_SIZE_N + lid >= OUTPUT_SIZE_X) continue;
+#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N
+
+        ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output[i]);
+        dequantized *= TO_ACTIVATION_TYPE(ALPHA);
+
+#ifdef INPUT2_TYPE
+        dequantized += TO_ACTIVATION_TYPE(BETA) * tile_input2[i];
+#endif // INPUT2_TYPE
+
+#if HAS_FUSED_OPS
+#if FUSED_OPS_CAN_USE_PRELOAD
+        FUSED_OPS_CALC;
+#else // FUSED_OPS_CAN_USE_PRELOAD
+        FUSED_OPS;
+#endif // FUSED_OPS_CAN_USE_PRELOAD
+        OUTPUT_TYPE res = FUSED_OPS_RESULT;
+        output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res;
+        output_y++;
+#else // HAS_FUSED_OPS
+        output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized;
+#endif // HAS_FUSED_OPS
+    }
+}
+
+// ******************************************************************************************************************************** //
+// Optimized kernel without leftovers (for tiling parameters M = 8, N = 8, K = 32; M = 16, N = 16, K = 64; M = 32, N = 16, K = 64). //
+// ******************************************************************************************************************************** //
+
+#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+{
+    const uint output_x = (uint)get_global_id(0);
+    const uint output_x_tile = output_x / TILE_SIZE_N;
+    const uint output_y_tile = (uint)get_global_id(1);
+#if HAS_FUSED_OPS
+    uint output_y = output_y_tile * TILE_SIZE_M;
+#endif // HAS_FUSED_OPS
+    uint batch = get_global_id(2);
+    const uint lid = (uint)get_local_id(0);
+
+    const uint z = batch % OUTPUT_SIZE_Z;
+    batch /= OUTPUT_SIZE_Z;
+    const uint w = batch % OUTPUT_SIZE_W;
+    batch /= OUTPUT_SIZE_W;
+    const uint f = batch % OUTPUT_FEATURE_NUM;
+    batch /= OUTPUT_FEATURE_NUM;
+    const uint b = batch % OUTPUT_BATCH_NUM;
+
+    const uint batch_offset_input0 = FUNC_CALL(get_input0_batch_offset)(b, f, w, z);
+    const uint batch_offset_input1 = FUNC_CALL(get_input1_batch_offset)(b, f, w, z);
+#ifdef INPUT2_TYPE
+    const uint batch_offset_input2 = FUNC_CALL(get_input2_batch_offset)(b, f, w, z);
+#endif // INPUT2_TYPE
+    const uint batch_offset_output = FUNC_CALL(get_output_batch_offset)(b, f, w, z);
+
+    PACKED_INPUT0_TYPE_VEC tile_input00;
+    PACKED_INPUT1_TYPE_VEC tile_input10;
+
+#ifdef INPUT2_TYPE
+    ACTIVATION_TYPE_VEC tile_input20;
+
+    for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+        tile_input20[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X +
+                                                    output_x_tile * TILE_SIZE_N + lid]);
+    }
+#endif // INPUT2_TYPE
+
+    ACCUMULATOR_TYPE_VEC tile_output00 = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO);
+#if TILE_NUM == 2
+    ACCUMULATOR_TYPE_VEC tile_output01 = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO);
+#endif // TILE_NUM == 2
+
+#if !TRANSPOSE_INPUT0
+    const uint K_BLOCK_NUM = INPUT0_SIZE_X / TILE_SIZE_K;
+#else // !TRANSPOSE_INPUT0
+    const uint K_BLOCK_NUM = INPUT0_SIZE_Y / TILE_SIZE_K;
+#endif // !TRANSPOSE_INPUT0
+
+    for (uint k = 0; k < K_BLOCK_NUM; k++) {
+#if !TRANSPOSE_INPUT1
+        MAKE_VECTOR_TYPE(INPUT1_TYPE, PACK_SIZE) temp_input1[SUB_GROUP_SIZE];
+        const uint common_input1_offset = batch_offset_input1 + k * TILE_SIZE_K * INPUT1_SIZE_X + output_x_tile * TILE_SIZE_N;
+
+        for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+            temp_input1[i].s0 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + lid];
+            temp_input1[i].s1 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + INPUT1_SIZE_X + lid];
+            temp_input1[i].s2 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + 2 * INPUT1_SIZE_X + lid];
+            temp_input1[i].s3 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + 3 * INPUT1_SIZE_X + lid];
+
+            tile_input10[i] = AS_TYPE(PACKED_INPUT1_TYPE, temp_input1[i]);
+        }
+#else // !TRANSPOSE_INPUT1
+        const uint common_input1_offset = batch_offset_input1 + output_x_tile * TILE_SIZE_N * INPUT1_SIZE_X + k * TILE_SIZE_K;
+
+        for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+            tile_input10[i] = AS_TYPE(PACKED_INPUT1_TYPE, BLOCK_READ(input1 + common_input1_offset  + i * INPUT1_SIZE_X));
+        }
+
+        PACKED_INPUT1_TYPE_VEC tile_input1_col0 = BLOCK_SHUFFLE(tile_input10, 0);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col1 = BLOCK_SHUFFLE(tile_input10, 1);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col2 = BLOCK_SHUFFLE(tile_input10, 2);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col3 = BLOCK_SHUFFLE(tile_input10, 3);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col4 = BLOCK_SHUFFLE(tile_input10, 4);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col5 = BLOCK_SHUFFLE(tile_input10, 5);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col6 = BLOCK_SHUFFLE(tile_input10, 6);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col7 = BLOCK_SHUFFLE(tile_input10, 7);
+#if SUB_GROUP_SIZE == 16
+        PACKED_INPUT1_TYPE_VEC tile_input1_col8 = BLOCK_SHUFFLE(tile_input10, 8);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col9 = BLOCK_SHUFFLE(tile_input10, 9);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col10 = BLOCK_SHUFFLE(tile_input10, 10);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col11 = BLOCK_SHUFFLE(tile_input10, 11);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col12 = BLOCK_SHUFFLE(tile_input10, 12);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col13 = BLOCK_SHUFFLE(tile_input10, 13);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col14 = BLOCK_SHUFFLE(tile_input10, 14);
+        PACKED_INPUT1_TYPE_VEC tile_input1_col15 = BLOCK_SHUFFLE(tile_input10, 15);
+#endif // SUB_GROUP_SIZE == 16
+
+        tile_input10.s0 = tile_input1_col0[lid];
+        tile_input10.s1 = tile_input1_col1[lid];
+        tile_input10.s2 = tile_input1_col2[lid];
+        tile_input10.s3 = tile_input1_col3[lid];
+        tile_input10.s4 = tile_input1_col4[lid];
+        tile_input10.s5 = tile_input1_col5[lid];
+        tile_input10.s6 = tile_input1_col6[lid];
+        tile_input10.s7 = tile_input1_col7[lid];
+#if SUB_GROUP_SIZE == 16
+        tile_input10.s8 = tile_input1_col8[lid];
+        tile_input10.s9 = tile_input1_col9[lid];
+        tile_input10.sa = tile_input1_col10[lid];
+        tile_input10.sb = tile_input1_col11[lid];
+        tile_input10.sc = tile_input1_col12[lid];
+        tile_input10.sd = tile_input1_col13[lid];
+        tile_input10.se = tile_input1_col14[lid];
+        tile_input10.sf = tile_input1_col15[lid];
+#endif // SUB_GROUP_SIZE == 16
+
+#endif // !TRANSPOSE_INPUT1
+
+#if !TRANSPOSE_INPUT0
+        const uint common_input0_offset = batch_offset_input0 + output_y_tile * TILE_SIZE_M * INPUT0_SIZE_X + k * TILE_SIZE_K;
+
+        for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+            tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset + i * INPUT0_SIZE_X));
+        }
+
+        tile_output00 = MMAD(tile_input00, tile_input10, tile_output00);
+
+#if TILE_NUM == 2
+        for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+            tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset + (TILE_SIZE_M_DIV + i) * INPUT0_SIZE_X));
+        }
+
+        tile_output01 = MMAD(tile_input00, tile_input10, tile_output01);
+#endif // TILE_NUM == 2
+
+#else // !TRANSPOSE_INPUT0
+        MAKE_VECTOR_TYPE(INPUT0_TYPE, PACK_SIZE) temp_input0[SUB_GROUP_SIZE];
+        const uint common_input0_offset = batch_offset_input0 + (k * TILE_SIZE_K + lid * PACK_SIZE) * INPUT0_SIZE_X + output_y_tile * TILE_SIZE_M;
+
+        for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+            temp_input0[i].s0 = input0[common_input0_offset + i];
+            temp_input0[i].s1 = input0[common_input0_offset + 1 * INPUT0_SIZE_X + i];
+            temp_input0[i].s2 = input0[common_input0_offset + 2 * INPUT0_SIZE_X + i];
+            temp_input0[i].s3 = input0[common_input0_offset + 3 * INPUT0_SIZE_X + i];
+
+            tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]);
+        }
+
+        tile_output00 = MMAD(tile_input00, tile_input10, tile_output00);
+
+#if TILE_NUM == 2
+        for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+            temp_input0[i].s0 = input0[common_input0_offset + TILE_SIZE_M_DIV + i];
+            temp_input0[i].s1 = input0[common_input0_offset + 1 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i];
+            temp_input0[i].s2 = input0[common_input0_offset + 2 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i];
+            temp_input0[i].s3 = input0[common_input0_offset + 3 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i];
+
+            tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]);
+        }
+
+        tile_output01 = MMAD(tile_input00, tile_input10, tile_output01);
+#endif // TILE_NUM == 2
+
+#endif // !TRANSPOSE_INPUT0
+    }
+
+#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+    FUSED_OPS_PRELOAD;
+#endif // HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+
+    for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+        ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output00[i]);
+        dequantized *= TO_ACTIVATION_TYPE(ALPHA);
+#ifdef INPUT2_TYPE
+        dequantized += TO_ACTIVATION_TYPE(BETA) * tile_input20[i];
+#endif // INPUT2_TYPE
+
+#if HAS_FUSED_OPS
+#if FUSED_OPS_CAN_USE_PRELOAD
+        FUSED_OPS_CALC;
+#else // FUSED_OPS_CAN_USE_PRELOAD
+        FUSED_OPS;
+#endif // FUSED_OPS_CAN_USE_PRELOAD
+
+        OUTPUT_TYPE res = FUSED_OPS_RESULT;
+        output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res;
+        output_y++;
+#else // HAS_FUSED_OPS
+        output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized;
+#endif // HAS_FUSED_OPS
+    }
+
+#if TILE_NUM == 2
+#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+    FUSED_OPS_PRELOAD;
+#endif
+
+    for (uint i = 0; i < SUB_GROUP_SIZE; i++) {
+        ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output01[i]);
+        dequantized *= TO_ACTIVATION_TYPE(ALPHA);
+
+#if HAS_FUSED_OPS
+#if FUSED_OPS_CAN_USE_PRELOAD
+        FUSED_OPS_CALC;
+#else // FUSED_OPS_CAN_USE_PRELOAD
+        FUSED_OPS;
+#endif // FUSED_OPS_CAN_USE_PRELOAD
+
+        OUTPUT_TYPE res = FUSED_OPS_RESULT;
+        output[batch_offset_output + (output_y_tile * TILE_SIZE_M + TILE_SIZE_M_DIV + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res;
+        output_y++;
+#else // HAS_FUSED_OPS
+        output[batch_offset_output + (output_y_tile * TILE_SIZE_M + TILE_SIZE_M_DIV + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized;
+#endif // HAS_FUSED_OPS
+    }
+#endif // TILE_NUM == 2
+
+}
+#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K
+
+#undef PACK_SIZE
+#undef AS_TYPE
+#undef ACCUMULATOR_TYPE_VEC
+#undef ACTIVATION_TYPE_VEC
+#undef PACKED_INPUT0_TYPE_VEC
+#undef PACKED_INPUT1_TYPE_VEC
+#undef BLOCK_READ
+#undef BLOCK_SHUFFLE
+#undef MMAD
+#undef TILE_SIZE_M_DIV
index 80fab34..3aab503 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2016-2020 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -55,6 +55,24 @@ uint8 FUNC(intel_sub_group_block_read_uint8)(const __local uint* p)
     return ret;
 }
 
+inline int FUNC(mmad_4)(char4 input, char4 weight, int acc) __attribute__((overloadable))
+{
+    acc += (input[0] * weight[0]);
+    acc += (input[1] * weight[1]);
+    acc += (input[2] * weight[2]);
+    acc += (input[3] * weight[3]);
+    return acc;
+}
+
+inline int FUNC(mmad_4)(char4 input, uchar4 weight, int acc) __attribute__((overloadable))
+{
+    acc += (input[0] * weight[0]);
+    acc += (input[1] * weight[1]);
+    acc += (input[2] * weight[2]);
+    acc += (input[3] * weight[3]);
+    return acc;
+}
+
 inline int FUNC(mmad_4)(uchar4 input, char4 weight, int acc) __attribute__((overloadable))
 {
     acc += (input[0] * weight[0]);
@@ -64,7 +82,7 @@ inline int FUNC(mmad_4)(uchar4 input, char4 weight, int acc) __attribute__((over
     return acc;
 }
 
-inline int FUNC(mmad_4)(char4 input, char4 weight, int acc) __attribute__((overloadable))
+inline int FUNC(mmad_4)(uchar4 input, uchar4 weight, int acc) __attribute__((overloadable))
 {
     acc += (input[0] * weight[0]);
     acc += (input[1] * weight[1]);
@@ -73,6 +91,34 @@ inline int FUNC(mmad_4)(char4 input, char4 weight, int acc) __attribute__((overl
     return acc;
 }
 
+inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable))
+{
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_char4(B_vectors[0]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_char4(B_vectors[1]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_char4(B_vectors[2]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_char4(B_vectors[3]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_char4(B_vectors[4]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_char4(B_vectors[5]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_char4(B_vectors[6]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_char4(B_vectors[7]), acc);
+
+    return acc;
+}
+
+inline int FUNC(mmad8)(int8 A_scalars, uint8 B_vectors, int acc) __attribute__((overloadable))
+{
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_uchar4(B_vectors[0]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_uchar4(B_vectors[1]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_uchar4(B_vectors[2]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_uchar4(B_vectors[3]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_uchar4(B_vectors[4]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_uchar4(B_vectors[5]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_uchar4(B_vectors[6]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_uchar4(B_vectors[7]), acc);
+
+    return acc;
+}
+
 inline int FUNC(mmad8)(uint8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable))
 {
     acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_char4(B_vectors[0]), acc);
@@ -86,7 +132,22 @@ inline int FUNC(mmad8)(uint8 A_scalars, int8 B_vectors, int acc) __attribute__((
 
     return acc;
 }
-inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable))
+
+inline int FUNC(mmad8)(uint8 A_scalars, uint8 B_vectors, int acc) __attribute__((overloadable))
+{
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_uchar4(B_vectors[0]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_uchar4(B_vectors[1]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_uchar4(B_vectors[2]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_uchar4(B_vectors[3]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_uchar4(B_vectors[4]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_uchar4(B_vectors[5]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_uchar4(B_vectors[6]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_uchar4(B_vectors[7]), acc);
+
+    return acc;
+}
+
+inline int FUNC(mmad16)(int16 A_scalars, int16 B_vectors, int acc) __attribute__((overloadable))
 {
     acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_char4(B_vectors[0]), acc);
     acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_char4(B_vectors[1]), acc);
@@ -96,10 +157,122 @@ inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc) __attribute__((o
     acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_char4(B_vectors[5]), acc);
     acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_char4(B_vectors[6]), acc);
     acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_char4(B_vectors[7]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[8]), as_char4(B_vectors[8]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[9]), as_char4(B_vectors[9]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[10]), as_char4(B_vectors[10]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[11]), as_char4(B_vectors[11]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[12]), as_char4(B_vectors[12]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[13]), as_char4(B_vectors[13]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[14]), as_char4(B_vectors[14]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[15]), as_char4(B_vectors[15]), acc);
 
     return acc;
 }
 
+inline int FUNC(mmad16)(int16 A_scalars, uint16 B_vectors, int acc) __attribute__((overloadable))
+{
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_uchar4(B_vectors[0]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_uchar4(B_vectors[1]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_uchar4(B_vectors[2]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_uchar4(B_vectors[3]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_uchar4(B_vectors[4]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_uchar4(B_vectors[5]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_uchar4(B_vectors[6]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_uchar4(B_vectors[7]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[8]), as_uchar4(B_vectors[8]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[9]), as_uchar4(B_vectors[9]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[10]), as_uchar4(B_vectors[10]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[11]), as_uchar4(B_vectors[11]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[12]), as_uchar4(B_vectors[12]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[13]), as_uchar4(B_vectors[13]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[14]), as_uchar4(B_vectors[14]), acc);
+    acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[15]), as_uchar4(B_vectors[15]), acc);
+
+    return acc;
+}
+
+inline int FUNC(mmad16)(uint16 A_scalars, int16 B_vectors, int acc) __attribute__((overloadable))
+{
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_char4(B_vectors[0]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_char4(B_vectors[1]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_char4(B_vectors[2]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_char4(B_vectors[3]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_char4(B_vectors[4]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_char4(B_vectors[5]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_char4(B_vectors[6]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_char4(B_vectors[7]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[8]), as_char4(B_vectors[8]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[9]), as_char4(B_vectors[9]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[10]), as_char4(B_vectors[10]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[11]), as_char4(B_vectors[11]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[12]), as_char4(B_vectors[12]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[13]), as_char4(B_vectors[13]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[14]), as_char4(B_vectors[14]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[15]), as_char4(B_vectors[15]), acc);
+
+    return acc;
+}
+
+inline int FUNC(mmad16)(uint16 A_scalars, uint16 B_vectors, int acc) __attribute__((overloadable))
+{
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_uchar4(B_vectors[0]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_uchar4(B_vectors[1]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_uchar4(B_vectors[2]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_uchar4(B_vectors[3]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_uchar4(B_vectors[4]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_uchar4(B_vectors[5]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_uchar4(B_vectors[6]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_uchar4(B_vectors[7]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[8]), as_uchar4(B_vectors[8]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[9]), as_uchar4(B_vectors[9]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[10]), as_uchar4(B_vectors[10]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[11]), as_uchar4(B_vectors[11]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[12]), as_uchar4(B_vectors[12]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[13]), as_uchar4(B_vectors[13]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[14]), as_uchar4(B_vectors[14]), acc);
+    acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[15]), as_uchar4(B_vectors[15]), acc);
+
+    return acc;
+}
+
+inline int4 FUNC(mmad4x8)(int4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable))
+{
+    int4 ret;
+    for(uint i = 0; i < 4; i++)
+    {
+        int8 A_scalars;
+        A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+        A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+        A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+        A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+        A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+        A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+        A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+        A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+        ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
+    }
+    return ret;
+}
+
+inline int4 FUNC(mmad4x8)(int4 A_vectors, uint8 B_vectors, int4 acc) __attribute__((overloadable))
+{
+    int4 ret;
+    for(uint i = 0; i < 4; i++)
+    {
+        int8 A_scalars;
+        A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+        A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+        A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+        A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+        A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+        A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+        A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+        A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+        ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
+    }
+    return ret;
+}
+
 inline int4 FUNC(mmad4x8)(uint4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable))
 {
     int4 ret;
@@ -119,11 +292,49 @@ inline int4 FUNC(mmad4x8)(uint4 A_vectors, int8 B_vectors, int4 acc) __attribute
     return ret;
 }
 
-inline int4 FUNC(mmad4x8)(int4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable))
+inline int4 FUNC(mmad4x8)(uint4 A_vectors, uint8 B_vectors, int4 acc) __attribute__((overloadable))
 {
     int4 ret;
     for(uint i = 0; i < 4; i++)
     {
+        uint8 A_scalars;
+        A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+        A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+        A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+        A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+        A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+        A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+        A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+        A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+        ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
+    }
+    return ret;
+}
+
+inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc) __attribute__((overloadable))
+{
+    int8 ret;
+    for(uint i = 0; i < 8; i++)
+    {
+        int8 A_scalars;
+        A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+        A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+        A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+        A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+        A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+        A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+        A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+        A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+        ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]);
+    }
+    return ret;
+}
+
+inline int8 FUNC(mmad8x8)(int8 A_vectors, uint8 B_vectors, int8 acc) __attribute__((overloadable))
+{
+    int8 ret;
+    for(uint i = 0; i < 8; i++)
+    {
         int8 A_scalars;
         A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
         A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
@@ -157,12 +368,12 @@ inline int8 FUNC(mmad8x8)(uint8 A_vectors, int8 B_vectors, int8 acc) __attribute
     return ret;
 }
 
-inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc) __attribute__((overloadable))
+inline int8 FUNC(mmad8x8)(uint8 A_vectors, uint8 B_vectors, int8 acc) __attribute__((overloadable))
 {
     int8 ret;
     for(uint i = 0; i < 8; i++)
     {
-        int8 A_scalars;
+        uint8 A_scalars;
         A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
         A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
         A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
@@ -176,7 +387,114 @@ inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc) __attribute_
     return ret;
 }
 
-// TODO: remove it when cl_intel_subgroups_char extension will work
+inline int16 FUNC(mmad16x16)(int16 A_vectors, int16 B_vectors, int16 acc) __attribute__((overloadable))
+{
+    int16 ret;
+    for(uint i = 0; i < 16; i++)
+    {
+        int16 A_scalars;
+        A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+        A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+        A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+        A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+        A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+        A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+        A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+        A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+        A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8);
+        A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9);
+        A_scalars.sa = sub_group_broadcast(A_vectors[i], 10);
+        A_scalars.sb = sub_group_broadcast(A_vectors[i], 11);
+        A_scalars.sc = sub_group_broadcast(A_vectors[i], 12);
+        A_scalars.sd = sub_group_broadcast(A_vectors[i], 13);
+        A_scalars.se = sub_group_broadcast(A_vectors[i], 14);
+        A_scalars.sf = sub_group_broadcast(A_vectors[i], 15);
+        ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]);
+    }
+    return ret;
+}
+
+inline int16 FUNC(mmad16x16)(int16 A_vectors, uint16 B_vectors, int16 acc) __attribute__((overloadable))
+{
+    int16 ret;
+    for(uint i = 0; i < 16; i++)
+    {
+        int16 A_scalars;
+        A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+        A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+        A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+        A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+        A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+        A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+        A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+        A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+        A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8);
+        A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9);
+        A_scalars.sa = sub_group_broadcast(A_vectors[i], 10);
+        A_scalars.sb = sub_group_broadcast(A_vectors[i], 11);
+        A_scalars.sc = sub_group_broadcast(A_vectors[i], 12);
+        A_scalars.sd = sub_group_broadcast(A_vectors[i], 13);
+        A_scalars.se = sub_group_broadcast(A_vectors[i], 14);
+        A_scalars.sf = sub_group_broadcast(A_vectors[i], 15);
+        ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]);
+    }
+    return ret;
+}
+
+inline int16 FUNC(mmad16x16)(uint16 A_vectors, int16 B_vectors, int16 acc) __attribute__((overloadable))
+{
+    int16 ret;
+    for(uint i = 0; i < 16; i++)
+    {
+        uint16 A_scalars;
+        A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+        A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+        A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+        A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+        A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+        A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+        A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+        A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+        A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8);
+        A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9);
+        A_scalars.sa = sub_group_broadcast(A_vectors[i], 10);
+        A_scalars.sb = sub_group_broadcast(A_vectors[i], 11);
+        A_scalars.sc = sub_group_broadcast(A_vectors[i], 12);
+        A_scalars.sd = sub_group_broadcast(A_vectors[i], 13);
+        A_scalars.se = sub_group_broadcast(A_vectors[i], 14);
+        A_scalars.sf = sub_group_broadcast(A_vectors[i], 15);
+        ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]);
+    }
+    return ret;
+}
+
+inline int16 FUNC(mmad16x16)(uint16 A_vectors, uint16 B_vectors, int16 acc) __attribute__((overloadable))
+{
+    int16 ret;
+    for(uint i = 0; i < 16; i++)
+    {
+        uint16 A_scalars;
+        A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0);
+        A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1);
+        A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2);
+        A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3);
+        A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4);
+        A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5);
+        A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6);
+        A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7);
+        A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8);
+        A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9);
+        A_scalars.sa = sub_group_broadcast(A_vectors[i], 10);
+        A_scalars.sb = sub_group_broadcast(A_vectors[i], 11);
+        A_scalars.sc = sub_group_broadcast(A_vectors[i], 12);
+        A_scalars.sd = sub_group_broadcast(A_vectors[i], 13);
+        A_scalars.se = sub_group_broadcast(A_vectors[i], 14);
+        A_scalars.sf = sub_group_broadcast(A_vectors[i], 15);
+        ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]);
+    }
+    return ret;
+}
+
 inline void FUNC(sub_group_block_write_uchar16)(__global uchar* outPtr, uchar16 v)
 {
 #ifdef cl_intel_subgroups_char
@@ -272,6 +590,7 @@ inline uchar8 FUNC(sub_group_block_read_uchar8)(const __global uchar* ptr)
     ret.s7 = ptr[idx]; idx += get_max_sub_group_size();
 
     return ret;
+
 #endif
 }
 
@@ -361,12 +680,10 @@ inline uchar FUNC(sub_group_block_read_uchar)(const __global uchar* ptr)
 #endif
 }
 
-//
-
-
 #define MMAD_8(A, B, C) FUNC_CALL(mmad8)(A, B, C)
 #define MMAD_4x8(A, B, C) FUNC_CALL(mmad4x8)(A, B, C)
 #define MMAD_8x8(A, B, C) FUNC_CALL(mmad8x8)(A, B, C)
+#define MMAD_16x16(A, B, C) FUNC_CALL(mmad16x16)(A, B, C)
 #define SLM_BLOCK_WRITE_4(A, B) (FUNC_CALL(intel_sub_group_block_write_4)(A, B))
 #define SLM_BLOCK_READ_4(A) (FUNC_CALL(intel_sub_group_block_read_uint4)(A))
 #define SLM_BLOCK_READ_8(A) (FUNC_CALL(intel_sub_group_block_read_uint8)(A))
index db98f15..a759875 100644 (file)
@@ -511,11 +511,13 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
 
             bool can_fuse_parent1 = (parent1->is_type<convolution>() && conv_supports_fusings(parent1->as<convolution>())) ||
                                     (parent1->is_type<mvn>() && mvn_supports_fusings(parent1->as<mvn>())) ||
-                                    (parent1->is_type<deconvolution>()) || (parent1->is_type<permute>());
+                                    (parent1->is_type<deconvolution>()) || (parent1->is_type<permute>()) ||
+                                    (parent1->is_type<gemm>());
 
             bool can_fuse_parent2 = (parent2->is_type<convolution>() && conv_supports_fusings(parent2->as<convolution>())) ||
                                     (parent2->is_type<mvn>() && mvn_supports_fusings(parent2->as<mvn>())) ||
-                                    (parent2->is_type<deconvolution>()) || (parent2->is_type<permute>());
+                                    (parent2->is_type<deconvolution>()) || (parent2->is_type<permute>()) ||
+                                    (parent2->is_type<gemm>());
 
             std::vector<bool> can_fuse_parents = { can_fuse_parent1, can_fuse_parent2 };
 
index 4e96c35..f6ef20f 100644 (file)
@@ -74,6 +74,7 @@ struct bc_test_params {
 
 struct gemm_test_params {
     std::vector<tensor> in_shapes;
+    tensor out_shape;
     tensor kernel;
     tensor pad;
     data_types data_type_in0;
@@ -360,6 +361,10 @@ public:
     layout get_per_channel_layout(gemm_test_params& p) {
         return layout{ p.default_type, p.default_format, tensor{1, p.in_shapes.at(0).feature[0], 1, 1} };
     }
+
+    layout get_output_layout(gemm_test_params& p) {
+        return layout{ p.default_type, p.input_format, p.out_shape };
+    }
 };
 
 // in_shape; out_shape; kernel; stride; pad; dilation; groups; data_type; input_format; weights_type; weights_format; default_type; default_format;
@@ -431,16 +436,19 @@ public:
 #define CASE_FC_U8S8_2 {2, 1, 3, 1}, {2, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
 #define CASE_FC_U8S8_3 {2, 32, 1, 1}, {2, 16, 1, 1}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
 
-#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
 
-#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, tensor{1}, tensor{0}, data_types::u8,  data_types::u8,  data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, tensor{1}, tensor{0}, data_types::u8,  data_types::u8,  data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 12, 16}}, tensor{1}, tensor{0}, data_types::u8,  data_types::u8,  data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
 
-#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, tensor{1}, tensor{0}, data_types::u8,  data_types::i8,  data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, tensor{1}, tensor{0}, data_types::i8,  data_types::u8,  data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
 
 #define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx
 
@@ -2270,6 +2278,35 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_i8,
                         gemm_test_params{ CASE_GEMM_2IN_S8U8_1, 3, 6 },
 }), );
 
+class gemm_int8_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {};
+TEST_P(gemm_int8_2in_act_scale_quantize_eltwise_i8, basic) {
+    auto p = GetParam();
+    create_topologies(input_layout("input0", get_input_layout(p, 0)),
+        input_layout("input1", get_input_layout(p, 1)),
+        data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
+        data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
+        data("out_lo", get_mem(get_single_element_layout(p), -127)),
+        data("out_hi", get_mem(get_single_element_layout(p), 127)),
+        data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
+        data("eltwise_data", get_mem(get_output_layout(p))),
+        gemm("gemm_prim", { "input0", "input1" }, data_types::f32),
+        activation("activation", "gemm_prim", activation_func::exp),
+        scale("scale", "activation", "scale_data"),
+        quantize("quantize", "scale", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
+        eltwise("sum", { "quantize", "eltwise_data"}, eltwise_mode::sum,  data_types::f32),
+        reorder("reorder_bfyx", "sum", p.default_format, data_types::f32)
+    );
+
+    tolerance = 1.0f;
+    execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_eltwise_i8,
+    ::testing::ValuesIn(std::vector<gemm_test_params>{
+                        gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 7 },
+                        gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 7 },
+}), );
+
 /* ----------------------------------------------------------------------------------------------------- */
 /* ---------------------------------------- Resample cases --------------------------------------------- */
 /* ----------------------------------------------------------------------------------------------------- */
index e184963..d1e1e41 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2020 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -3232,8 +3232,228 @@ TEST(gemm_gpu, basic_smarcink2) {
     auto output_ptr = output.pointer<float>();
 
     EXPECT_EQ(output_ptr.size(), (uint32_t)8);
-    for (uint32_t i = 0; i < out_data.size(); ++i) {         
+    for (uint32_t i = 0; i < out_data.size(); ++i) {
         EXPECT_FLOAT_EQ(output_ptr[i], out_data[i]);
     }
 }
 
+struct gemm_int8_test_params {
+    size_t m_size;
+    size_t n_size;
+    size_t k_size;
+    size_t b0_num;
+    size_t f0_num;
+    size_t b1_num;
+    size_t f1_num;
+    size_t b2_num;
+    size_t f2_num;
+    size_t b_out_num;
+    size_t f_out_num;
+    bool transpose_input0;
+    bool transpose_input1;
+    float alpha;
+    float beta;
+    std::string kernel_name;
+};
+
+#define CASE_GEMM_INT8_NN_TRANSPOSITION 64, 64, 64, 1, 2, 1, 2, 1, 2, 1, 2, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_NT_TRANSPOSITION 32, 64, 32, 2, 1, 2, 1, 2, 1, 2, 1, false, true, 1.7f, 1.3f
+#define CASE_GEMM_INT8_TN_TRANSPOSITION 128, 64, 32, 2, 2, 2, 2, 2, 2, 2, 2, true, false, 1.0f, 0.0f
+#define CASE_GEMM_INT8_TT_TRANSPOSITION 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.2f, 0.5f
+
+#define CASE_GEMM_INT8_BROADCAST_1 32, 32, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_BROADCAST_2 32, 32, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, false, 1.7f, 1.3f
+#define CASE_GEMM_INT8_BROADCAST_3 64, 32, 32, 1, 2, 2, 1, 1, 2, 2, 2, false, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_BROADCAST_4 32, 64, 32, 1, 1, 2, 2, 2, 2, 2, 2, false, false, 1.2f, 0.5f
+
+#define CASE_GEMM_INT8_LEFTOVERS_1 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_LEFTOVERS_2 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
+#define CASE_GEMM_INT8_LEFTOVERS_3 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_LEFTOVERS_4 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
+#define CASE_GEMM_INT8_LEFTOVERS_5 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_LEFTOVERS_6 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
+#define CASE_GEMM_INT8_LEFTOVERS_7 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_LEFTOVERS_8 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
+#define CASE_GEMM_INT8_LEFTOVERS_9 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_LEFTOVERS_10 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f
+#define CASE_GEMM_INT8_LEFTOVERS_11 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_LEFTOVERS_12 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f
+
+#define CASE_GEMM_INT8_COMBO_1 8, 8, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f
+#define CASE_GEMM_INT8_COMBO_2 16, 16, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, true, 1.7f, 0.0f
+#define CASE_GEMM_INT8_COMBO_3 11, 31, 21, 7, 15, 7, 15, 7, 15, 7, 15, true, false, 1.0f, 1.5f
+#define CASE_GEMM_INT8_COMBO_4 32, 32, 32, 3, 6, 3, 6, 3, 6, 3, 6, true, true, 1.2f, 4.0f
+
+template <typename T>
+class GemmInt8Test : public ::testing::TestWithParam<T> {
+public:
+
+    inline size_t getGemmIndex(size_t x, size_t y, size_t f, size_t b, size_t x_size, size_t y_size, size_t f_num, size_t b_num,
+                               size_t x_pitch, size_t y_pitch, size_t f_pitch, size_t b_pitch) {
+        return (x % x_size) * x_pitch + (y % y_size) * y_pitch + (f % f_num) * f_pitch + (b % b_num) * b_pitch;
+    }
+
+    void execute(T& p) {
+        const auto& engine = get_test_engine();
+
+        auto y0_size = p.m_size;
+        auto y0_pitch = p.k_size;
+        auto x0_size = p.k_size;
+        auto x0_pitch = 1;
+        auto f0_pitch = y0_size * x0_size;
+        auto b0_pitch = p.f0_num * f0_pitch;
+
+        auto y1_size = p.k_size;
+        auto y1_pitch = p.n_size;
+        auto x1_size = p.n_size;
+        auto x1_pitch = 1;
+        auto f1_pitch = y1_size * x1_size;
+        auto b1_pitch = p.f1_num * f1_pitch;
+
+        auto y2_size = p.m_size;
+        auto y2_pitch = p.n_size;
+        auto x2_size = p.n_size;
+        auto x2_pitch = 1;
+        auto f2_pitch = y2_size * x2_size;
+        auto b2_pitch = p.f2_num * f2_pitch;
+
+        auto y_out_size = p.m_size;
+        auto y_out_pitch = p.n_size;
+        auto x_out_size = p.n_size;
+        auto x_out_pitch = 1;
+        auto f_out_pitch = y_out_size * x_out_size;
+        auto b_out_pitch = p.f_out_num * f_out_pitch;
+
+        if (p.transpose_input0) {
+            y0_size = p.k_size;
+            y0_pitch = p.m_size;
+            x0_size = p.m_size;
+            x0_pitch = 1;
+        }
+
+        if (p.transpose_input1) {
+            y1_size = p.n_size;
+            y1_pitch = p.k_size;
+            x1_size = p.k_size;
+            x1_pitch = 1;
+        }
+
+        auto input0_size = tensor((int)p.b0_num, (int)p.f0_num, (int)x0_size, (int)y0_size);
+        auto input0_data = generate_random_4d<int8_t>(p.b0_num, p.f0_num, x0_size, y0_size, -128, 127, 1);
+        auto input0_data_bfyx = flatten_4d(format::bfyx, input0_data);
+        auto input0_mem = memory::allocate(engine, { data_types::i8, format::bfyx, input0_size });
+        set_values(input0_mem, input0_data_bfyx);
+
+        auto input1_size = tensor((int)p.b1_num, (int)p.f1_num, (int)x1_size, (int)y1_size);
+        auto input1_data = generate_random_4d<uint8_t>(p.b1_num, p.f1_num, x1_size, y1_size, 0, 255, 1);
+        auto input1_data_bfyx = flatten_4d(format::bfyx, input1_data);
+        auto input1_mem = memory::allocate(engine, { data_types::u8, format::bfyx, input1_size });
+        set_values(input1_mem, input1_data_bfyx);
+
+        auto input2_size = tensor((int)p.b2_num, (int)p.f2_num, (int)x2_size, (int)y2_size);
+        auto input2_data = generate_random_4d<float>(p.b2_num, p.f2_num, x2_size, y2_size, -10, 10);
+        auto input2_data_bfyx = flatten_4d(format::bfyx, input2_data);
+        auto input2_mem = memory::allocate(engine, { data_types::f32, format::bfyx, input2_size });
+        set_values(input2_mem, input2_data_bfyx);
+
+        std::vector<float> out_data(p.b_out_num * p.f_out_num * p.m_size * p.n_size);
+
+        for (size_t b = 0; b < p.b_out_num; ++b) {
+            for (size_t f = 0; f < p.f_out_num; ++f) {
+                for (size_t i = 0; i < p.m_size; ++i) {
+                    for (size_t j = 0; j < p.n_size; ++j) {
+                        size_t input2_data_index = getGemmIndex(j, i, f, b, x2_size, y2_size, p.f2_num, p.b2_num, x2_pitch, y2_pitch, f2_pitch, b2_pitch);
+                        size_t out_data_index = getGemmIndex(j, i, f, b, x_out_size, y_out_size, p.f_out_num, p.b_out_num,
+                                                          x_out_pitch, y_out_pitch, f_out_pitch, b_out_pitch);
+                        int32_t acc = 0;
+
+                        for (size_t k = 0; k < p.k_size; ++k) {
+                            size_t input0_data_index = getGemmIndex(k * (!p.transpose_input0) + i * p.transpose_input0, i * (!p.transpose_input0) +
+                            k * p.transpose_input0, f, b, x0_size, y0_size, p.f0_num, p.b0_num, x0_pitch, y0_pitch, f0_pitch, b0_pitch);
+                            size_t input1_data_index = getGemmIndex(j * (!p.transpose_input1) + k * p.transpose_input1, k * (!p.transpose_input1) +
+                            j * p.transpose_input1, f, b, x1_size, y1_size, p.f1_num, p.b1_num, x1_pitch, y1_pitch, f1_pitch, b1_pitch);
+
+                            acc += input0_data_bfyx[input0_data_index] * input1_data_bfyx[input1_data_index];
+                        }
+
+                        out_data[out_data_index] = (float)acc;
+                        out_data[out_data_index] *= p.alpha;
+                        out_data[out_data_index] += p.beta * input2_data_bfyx[input2_data_index];
+                    }
+                }
+            }
+        }
+
+        topology topology;
+        topology.add(input_layout("input0", input0_mem.get_layout()));
+        topology.add(input_layout("input1", input1_mem.get_layout()));
+        topology.add(input_layout("input2", input2_mem.get_layout()));
+        topology.add(gemm("output", { "input0", "input1", "input2" }, data_types::f32, p.transpose_input0, p.transpose_input1, p.alpha, p.beta));
+
+        build_options options;
+        implementation_desc gemm_int8_impl = { format::bfyx, p.kernel_name };
+        options.set_option(build_option::force_implementations({ {"output", gemm_int8_impl} }));
+
+        network network(engine, topology, options);
+        network.set_input_data("input0", input0_mem);
+        network.set_input_data("input1", input1_mem);
+        network.set_input_data("input2", input2_mem);
+        auto outputs = network.execute();
+
+        auto output = outputs.at("output").get_memory();
+        auto output_ptr = output.pointer<float>();
+
+        EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size));
+        for (size_t i = 0; i < out_data.size(); ++i) {
+            EXPECT_FLOAT_EQ(output_ptr[i], out_data[i]);
+        }
+    }
+};
+
+class gemm_int8_transposition_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
+TEST_P(gemm_int8_transposition_tests, basic) { auto p = GetParam(); execute(p); }
+
+INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_transposition_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
+                        gemm_int8_test_params{ CASE_GEMM_INT8_NN_TRANSPOSITION, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_NT_TRANSPOSITION, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_TN_TRANSPOSITION, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_TT_TRANSPOSITION, "gemm_mmad_int8" },
+}), );
+
+class gemm_int8_broadcast_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
+TEST_P(gemm_int8_broadcast_tests, basic) { auto p = GetParam(); execute(p); }
+
+INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_broadcast_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
+                        gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_1, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_2, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_3, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_4, "gemm_mmad_int8" },
+}), );
+
+class gemm_int8_leftovers_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
+TEST_P(gemm_int8_leftovers_tests, basic) { auto p = GetParam(); execute(p); }
+
+INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_leftovers_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_1, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_2, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_3, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_4, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_5, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_6, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_7, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_8, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_9, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_10, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_11, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_12, "gemm_mmad_int8" },
+}), );
+
+class gemm_int8_combo_tests : public ::GemmInt8Test<gemm_int8_test_params> {};
+TEST_P(gemm_int8_combo_tests, basic) { auto p = GetParam(); execute(p); }
+
+INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_combo_tests, ::testing::ValuesIn(std::vector <gemm_int8_test_params> {
+                        gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_1, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_2, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_3, "gemm_mmad_int8" },
+                        gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_4, "gemm_mmad_int8" },
+}), );