From 4d3ddc168410f2461c0dfe37a7e1bd9dbda0f41e Mon Sep 17 00:00:00 2001 From: Ilya Znamenskiy Date: Fri, 5 Jun 2020 14:28:21 +0300 Subject: [PATCH] [IE CLDNN] GEMM int8 optimization using MMAD macro (#635) --- .../core/actual_kernels/gemm/gemm_kernel_base.cpp | 9 +- .../core/actual_kernels/gemm/gemm_kernel_base.h | 4 +- .../actual_kernels/gemm/gemm_kernel_mmad_int8.cpp | 204 ++++++++ .../actual_kernels/gemm/gemm_kernel_mmad_int8.h | 55 +++ .../actual_kernels/gemm/gemm_kernel_selector.cpp | 10 +- .../core/cl_kernels/gemm_mmad_int8.cl | 518 +++++++++++++++++++++ .../core/cl_kernels/include/mmad.cl | 337 +++++++++++++- .../graph_optimizer/prepare_primitive_fusing.cpp | 6 +- .../clDNN/tests/test_cases/fusings_gpu_test.cpp | 53 ++- .../clDNN/tests/test_cases/gemm_gpu_test.cpp | 224 ++++++++- 10 files changed, 1385 insertions(+), 35 deletions(-) create mode 100644 inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.cpp create mode 100644 inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.h create mode 100644 inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_mmad_int8.cl diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.cpp index 6155e1b..249e47f 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.cpp @@ -72,13 +72,6 @@ KernelsData GemmKernelBase::GetCommonKernelsData(const Params& params, auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options); auto jit = CreateJit(kernelName, cldnn_jit, entry_point); - uint32_t fused_deps_total = 0; - for (auto& fused_dep : prim_params.fused_ops) { - for (int i = 0; i < static_cast(fused_dep.dep_size); i++) { - fused_deps_total++; - } - } - auto& kernel = k_data.kernels[0]; FillCLKernelData(kernel, run_info, @@ -90,7 +83,7 @@ KernelsData GemmKernelBase::GetCommonKernelsData(const Params& params, false, false, (uint32_t)prim_params.inputs.size(), - fused_deps_total); + GetFusedPrimitiveInputsCount(params)); k_data.estimatedTime = estimated_time; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.h index 8b68ccc..d30d454 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_base.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018 Intel Corporation +// Copyright (c) 2018-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ public: protected: virtual JitConstants GetJitConstants(const gemm_params& params) const; - DispatchData SetDefault(const gemm_params& params) const; + virtual DispatchData SetDefault(const gemm_params& params) const; KernelsData GetCommonKernelsData(const Params& params, const optional_params&, float estimated_time) const; // Fused ops virtual JitConstants GetFusedPrimitivesJitConstants(const gemm_params& params, const DispatchData& kd) const; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.cpp new file mode 100644 index 0000000..62978d7 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.cpp @@ -0,0 +1,204 @@ +/* +// Copyright (c) 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "gemm_kernel_mmad_int8.h" + +namespace kernel_selector { +ParamsKey GemmKernelMMADint8::GetSupportedKey() const { + ParamsKey k; + + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); + k.EnableInputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + k.EnableInputLayout(DataLayout::bfzyx); + k.EnableOutputLayout(DataLayout::bfzyx); + k.EnableInputLayout(DataLayout::bfwzyx); + k.EnableOutputLayout(DataLayout::bfwzyx); + + k.EnableBatching(); + k.EnableDifferentTypes(); + k.EnableTensorPitches(); + k.EnableQuantization(QuantizationType::SYMMETRIC); + + return k; +} + +JitConstants GemmKernelMMADint8::GetJitConstants(const gemm_params& params) const { + JitConstants jit = Parent::GetJitConstants(params); + GemmTuningData td = SetTuningParams(params); + + jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", td.simd_size)); + jit.Merge(MakeTypeJitConstants(Datatype::INT32, "ACCUMULATOR")); + jit.Merge(MakeTypeJitConstants(Datatype::F32, "ACTIVATION")); + jit.Merge(MakeTypeJitConstants(params.inputs[0].GetDType() == Datatype::INT8 ? Datatype::INT32 : Datatype::UINT32, "PACKED_INPUT0")); + jit.Merge(MakeTypeJitConstants(params.inputs[1].GetDType() == Datatype::INT8 ? Datatype::INT32 : Datatype::UINT32, "PACKED_INPUT1")); + jit.AddConstant(MakeJitConstant("TILE_NUM", td.tile_num)); + jit.AddConstant(MakeJitConstant("TILE_SIZE_M", td.simd_size * td.tile_num)); + jit.AddConstant(MakeJitConstant("TILE_SIZE_N", td.simd_size)); + jit.AddConstant(MakeJitConstant("TILE_SIZE_K", td.simd_size * td.pack_size)); + jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_M", td.size_m % (td.simd_size * td.tile_num))); + jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_N", td.size_n % td.simd_size)); + jit.AddConstant(MakeJitConstant("OUTPUT_LEFTOVERS_K", td.size_k % (td.simd_size * td.pack_size))); + + if (!params.fused_ops.empty()) { + auto input_dt = GetActivationType(params); + FusedOpsConfiguration conf = { "", {"b", "f", "output_y", "output_x"}, "dequantized", input_dt, 1 }; + conf.SetLoopAxes({ Tensor::DataChannelName::Y }, true); + jit.Merge(MakeFusedOpsJitConstants(params, { conf })); + } + + return jit; +} + +GemmKernelBase::DispatchData GemmKernelMMADint8::SetDefault(const gemm_params& params) const { + const auto& output = params.output; + auto total_batches = output.LogicalSize() / (output.X().v * output.Y().v); + + DispatchData kd; + GemmTuningData td = SetTuningParams(params); + + std::vector global = { Align(output.X().v, td.simd_size), + Align(output.Y().v, td.simd_size * td.tile_num) / (td.simd_size * td.tile_num), + total_batches }; + + std::vector local = { td.simd_size, 1, 1 }; + + kd.gws0 = global[0]; + kd.gws1 = global[1]; + kd.gws2 = global[2]; + + kd.lws0 = local[0]; + kd.lws1 = local[1]; + kd.lws2 = local[2]; + + return kd; +} + +GemmKernelMMADint8::GemmTuningData GemmKernelMMADint8::InitGemmTuningData(const gemm_params& params) const { + GemmTuningData tuning_data; + + tuning_data.size_m = params.output.Y().v; + tuning_data.size_n = params.output.X().v; + tuning_data.size_k = params.transpose_input0 ? params.inputs[0].Y().v : params.inputs[0].X().v; + + return tuning_data; +} + +inline size_t GemmKernelMMADint8::GetMmadOperationsNumber(const GemmTuningData& tuning_data) const { + return tuning_data.size_m * tuning_data.size_n * tuning_data.size_k; +} + +bool GemmKernelMMADint8::HasLeftovers(const GemmTuningData& tuning_data, int tile_size) const { + if (tile_size == 32) { + return tuning_data.size_m % 32 || tuning_data.size_n % 16 || tuning_data.size_k % 64; + } else if (tile_size == 16) { + return tuning_data.size_m % 16 || tuning_data.size_n % 16 || tuning_data.size_k % 64; + } else if (tile_size == 8) { + return tuning_data.size_m % 8 || tuning_data.size_n % 8 || tuning_data.size_k % 32; + } else { + return true; + } +} + +GemmKernelMMADint8::GemmTuningData GemmKernelMMADint8::SetTuningParams(const gemm_params& params) const { + GemmTuningData tuning_data = InitGemmTuningData(params); + auto mmad_operations_number = GetMmadOperationsNumber(tuning_data); + + bool leftovers_simd16x2 = HasLeftovers(tuning_data, 16*2); + bool leftovers_simd16 = HasLeftovers(tuning_data, 16); + bool leftovers_simd8 = HasLeftovers(tuning_data, 8); + + bool small_matrices = mmad_operations_number <= 128 * 128 * 128; + bool average_matrices = mmad_operations_number <= 448 * 448 * 448; + bool very_big_matrices = mmad_operations_number >= 1024 * 1024 * 1024; + bool no_input2 = params.inputs.size() == 3 ? false : true; + + size_t simd_size = 16; + size_t tile_num = 1; + + if (!leftovers_simd16x2 && very_big_matrices && no_input2) + { simd_size = 16; tile_num = 2; } + else if (leftovers_simd16 && !leftovers_simd8) + { simd_size = 8; } + else if ((!params.transpose_input0 && !params.transpose_input1) && average_matrices) + { simd_size = 8; } + else if (small_matrices) + { simd_size = 8; } + else + { simd_size = 16; } + + tuning_data.simd_size = simd_size; + tuning_data.tile_num = tile_num; + + return tuning_data; +} + +KernelsData GemmKernelMMADint8::GetKernelsData(const Params& params, const optional_params& options) const { + if (!Validate(params, options)) { + return KernelsData(); + } + + const auto& prim_params = static_cast(params); + + auto run_info = GemmKernelMMADint8::SetDefault(prim_params); + KernelData k_data = KernelData::Default(params); + + auto cldnn_jit = GetJitConstants(prim_params); + auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options); + auto jit = CreateJit(kernelName, cldnn_jit, entry_point); + + auto& kernel = k_data.kernels[0]; + FillCLKernelData(kernel, + run_info, + params.engineInfo, + kernelName, + jit, + entry_point, + DEFAULT, + false, + false, + (uint32_t)prim_params.inputs.size(), + GetFusedPrimitiveInputsCount(params)); + + GemmTuningData tuning_data = InitGemmTuningData(prim_params); + auto mmad_operations_number = GetMmadOperationsNumber(tuning_data); + + k_data.estimatedTime = mmad_operations_number < 4096 ? DONT_USE_IF_HAVE_SOMETHING_ELSE : FORCE_PRIORITY_3; + + return {k_data}; +} + +bool GemmKernelMMADint8::Validate(const Params& params, const optional_params& options) const { + if (!Parent::Validate(params, options)) + return false; + + const auto& gmm_params = static_cast(params); + auto input0_type = gmm_params.inputs[0].GetDType(); + auto input1_type = gmm_params.inputs[1].GetDType(); + + if ((input0_type != Datatype::UINT8 && input0_type != Datatype::INT8) || + (input1_type != Datatype::UINT8 && input1_type != Datatype::INT8)) + return false; + + return true; +} +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.h new file mode 100644 index 0000000..f7ff633 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_mmad_int8.h @@ -0,0 +1,55 @@ +// Copyright (c) 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "gemm_kernel_base.h" +#include + +namespace kernel_selector { +class GemmKernelMMADint8 : public GemmKernelBase { +public: + using Parent = GemmKernelBase; + using DispatchData = CommonDispatchData; + struct GemmTuningData { + size_t size_m; + size_t size_n; + size_t size_k; + + size_t simd_size = 16; + size_t tile_num = 1; + size_t pack_size = 4; + }; + + GemmKernelMMADint8() : GemmKernelBase("gemm_mmad_int8") {} + + KernelsData GetKernelsData(const Params& params, const optional_params& options) const override; + ParamsKey GetSupportedKey() const override; + +protected: + std::vector GetSupportedFusedOps() const override { + return { FusedOpType::QUANTIZE, + FusedOpType::ACTIVATION, + FusedOpType::SCALE, + FusedOpType::ELTWISE }; + } + bool Validate(const Params& params, const optional_params& options) const override; + JitConstants GetJitConstants(const gemm_params& params) const override; + DispatchData SetDefault(const gemm_params& params) const override; + GemmTuningData InitGemmTuningData(const gemm_params& params) const; + GemmTuningData SetTuningParams(const gemm_params& params) const; + size_t GetMmadOperationsNumber(const GemmTuningData& tuning_data) const; + bool HasLeftovers(const GemmTuningData& tuning_data, int tile_size) const; +}; +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_selector.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_selector.cpp index 0c44671..043bf4c 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_selector.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_selector.cpp @@ -1,5 +1,5 @@ /* -// Copyright (c) 2018 Intel Corporation +// Copyright (c) 2018-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,11 +16,15 @@ #include "gemm_kernel_selector.h" #include "gemm_kernel_ref.h" +#include "gemm_kernel_mmad_int8.h" namespace kernel_selector { -gemm_kernel_selector::gemm_kernel_selector() { Attach(); } +gemm_kernel_selector::gemm_kernel_selector() { + Attach(); + Attach(); +} KernelsData gemm_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const { return GetNaiveBestKernel(params, options, KernelType::GEMM); } -} // namespace kernel_selector \ No newline at end of file +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_mmad_int8.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_mmad_int8.cl new file mode 100644 index 0000000..a52174d --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gemm_mmad_int8.cl @@ -0,0 +1,518 @@ +// Copyright (c) 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "include/include_all.cl" +#include "include/mmad.cl" + +#define PACK_SIZE 4 + +#define AS_TYPE(type, val) CAT(as_, type)(val) +#define ACCUMULATOR_TYPE_VEC CAT(ACCUMULATOR_TYPE, SUB_GROUP_SIZE) +#define ACTIVATION_TYPE_VEC CAT(ACTIVATION_TYPE, SUB_GROUP_SIZE) +#define PACKED_INPUT0_TYPE_VEC CAT(PACKED_INPUT0_TYPE, SUB_GROUP_SIZE) +#define PACKED_INPUT1_TYPE_VEC CAT(PACKED_INPUT1_TYPE, SUB_GROUP_SIZE) +#define BLOCK_READ(ptr) intel_sub_group_block_read((const __global uint*)(ptr)) +#define BLOCK_SHUFFLE intel_sub_group_shuffle + +#if SUB_GROUP_SIZE == 8 +#define MMAD MMAD_8x8 +#else // SUB_GROUP_SIZE == 8 +#define MMAD MMAD_16x16 +#define TILE_SIZE_M_DIV (TILE_SIZE_M / 2) +#endif // SUB_GROUP_SIZE == 8 + +inline uint FUNC(get_input0_batch_offset)(uint b, uint f, uint w, uint z) { +#if INPUT0_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, 0, 0); +#else // INPUT0_SIMPLE +# error gemm_mmad_int8.cl : Unsupported input 0 format +#endif // INPUT0_SIMPLE +} + +inline uint FUNC(get_input1_batch_offset)(uint b, uint f, uint w, uint z) { +#if INPUT1_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, 0, 0); +#else // INPUT1_SIMPLE +# error gemm_mmad_int8.cl : Unsupported input 1 format +#endif // INPUT1_SIMPLE +} + +#ifdef INPUT2_TYPE +inline uint FUNC(get_input2_batch_offset)(uint b, uint f, uint w, uint z) { +#if INPUT2_SIMPLE + return GET_DATA_INDEX_6D_SAFE(INPUT2, b, f, w, z, 0, 0); +#else // INPUT2_SIMPLE +# error gemm_mmad_int8.cl : Unsupported input 2 format +#endif // INPUT2_SIMPLE +} +#endif // INPUT2_TYPE + +inline uint FUNC(get_output_batch_offset)(uint b, uint f, uint w, uint z) { +#if OUTPUT_SIMPLE + return GET_DATA_INDEX_6D(OUTPUT, b, f, w, z, 0, 0); +#else // OUTPUT_SIMPLE +# error gemm_mmad_int8.cl : Unsupported output format +#endif // OUTPUT_SIMPLE +} + +inline uint FUNC(get_common_input1_offset)(uint batch_offset_input1, uint k, uint i, uint output_x_tile, uint lid) { +#if !TRANSPOSE_INPUT1 + return batch_offset_input1 + (k * TILE_SIZE_K + i * PACK_SIZE) * INPUT1_SIZE_X + output_x_tile * TILE_SIZE_N; +#else + return batch_offset_input1 + (output_x_tile * TILE_SIZE_N + lid) * INPUT1_SIZE_X + k * TILE_SIZE_K + i * PACK_SIZE; +#endif +} + +inline uint FUNC(get_current_input1_offset)(uint common_input1_offset, uint i, uint lid) { +#if !TRANSPOSE_INPUT1 + return common_input1_offset + INPUT1_SIZE_X * i + lid; +#else + return common_input1_offset + i; +#endif +} + +inline uint FUNC(get_common_input0_offset)(uint batch_offset_input0, uint k, uint i, uint output_y_tile, uint lid) { +#if !TRANSPOSE_INPUT0 + return batch_offset_input0 + (output_y_tile * TILE_SIZE_M + i) * INPUT0_SIZE_X + k * TILE_SIZE_K; +#else + return batch_offset_input0 + (k * TILE_SIZE_K + lid * PACK_SIZE) * INPUT0_SIZE_X + output_y_tile * TILE_SIZE_M + i; +#endif +} + +inline uint FUNC(get_current_input0_offset)(uint common_input0_offset, uint i, uint lid) { +#if !TRANSPOSE_INPUT0 + return common_input0_offset + lid * PACK_SIZE + i; +#else + return common_input0_offset + INPUT0_SIZE_X * i; +#endif +} + +__attribute__((reqd_work_group_size(SUB_GROUP_SIZE, 1, 1))) +__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) +KERNEL(gemm_mmad_int8)( + const __global INPUT0_TYPE* input0, + const __global INPUT1_TYPE* input1, +#ifdef INPUT2_TYPE + const __global INPUT2_TYPE* input2, +#endif // INPUT2_TYPE + __global OUTPUT_TYPE* output +#if HAS_FUSED_OPS_DECLS + , FUSED_OPS_DECLS +#endif // HAS_FUSED_OPS_DECLS + ) + +// ***************************************************************************************** // +// Kernel with leftovers for all sizes of input matrices and all transposition combinations. // +// ***************************************************************************************** // + +#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K +{ + const uint output_x = (uint)get_global_id(0); + const uint output_x_tile = output_x / TILE_SIZE_N; + const uint output_y_tile = (uint)get_global_id(1); +#if HAS_FUSED_OPS + uint output_y = output_y_tile * TILE_SIZE_M; +#endif // HAS_FUSED_OPS + uint batch = get_global_id(2); + const uint lid = (uint)get_local_id(0); + + const uint z = batch % OUTPUT_SIZE_Z; + batch /= OUTPUT_SIZE_Z; + const uint w = batch % OUTPUT_SIZE_W; + batch /= OUTPUT_SIZE_W; + const uint f = batch % OUTPUT_FEATURE_NUM; + batch /= OUTPUT_FEATURE_NUM; + const uint b = batch % OUTPUT_BATCH_NUM; + + const uint batch_offset_input0 = FUNC_CALL(get_input0_batch_offset)(b, f, w, z); + const uint batch_offset_input1 = FUNC_CALL(get_input1_batch_offset)(b, f, w, z); +#ifdef INPUT2_TYPE + const uint batch_offset_input2 = FUNC_CALL(get_input2_batch_offset)(b, f, w, z); +#endif // INPUT2_TYPE + const uint batch_offset_output = FUNC_CALL(get_output_batch_offset)(b, f, w, z); + + PACKED_INPUT0_TYPE_VEC tile_input0; + PACKED_INPUT1_TYPE_VEC tile_input1; +#ifdef INPUT2_TYPE + ACTIVATION_TYPE_VEC tile_input2; +#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + if (output_y_tile * TILE_SIZE_M + i >= OUTPUT_SIZE_Y) continue; + if (output_x_tile * TILE_SIZE_N + lid >= OUTPUT_SIZE_X) continue; + + tile_input2[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + + output_x_tile * TILE_SIZE_N + lid]); + } +#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + tile_input2[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + + output_x_tile * TILE_SIZE_N + lid]); + } +#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N +#endif // INPUT2_TYPE + + ACCUMULATOR_TYPE_VEC tile_output = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO); + +#if !TRANSPOSE_INPUT0 + const uint K_BLOCK_NUM = (INPUT0_SIZE_X - 1) / TILE_SIZE_K + 1; + const uint K_SIZE = INPUT0_SIZE_X; +#else // !TRANSPOSE_INPUT0 + const uint K_BLOCK_NUM = (INPUT0_SIZE_Y - 1) / TILE_SIZE_K + 1; + const uint K_SIZE = INPUT0_SIZE_Y; +#endif // !TRANSPOSE_INPUT0 + + for (uint k = 0; k < K_BLOCK_NUM; k++) { + MAKE_VECTOR_TYPE(INPUT0_TYPE, PACK_SIZE) temp_input0[SUB_GROUP_SIZE]; + MAKE_VECTOR_TYPE(INPUT1_TYPE, PACK_SIZE) temp_input1[SUB_GROUP_SIZE]; + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + const uint common_input1_offset = FUNC_CALL(get_common_input1_offset)(batch_offset_input1, k, i, output_x_tile, lid); + +#if OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K + const uint cur_n = output_x_tile * TILE_SIZE_N + lid; + const uint cur_k = k * TILE_SIZE_K + i * PACK_SIZE; + + temp_input1[i] = 0; + + if (cur_n < OUTPUT_SIZE_X) { + if (cur_k + 3 < K_SIZE) { + temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)]; + temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)]; + temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)]; + temp_input1[i].s3 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 3, lid)]; + } else if (cur_k + 2 < K_SIZE) { + temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)]; + temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)]; + temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)]; + } else if (cur_k + 1 < K_SIZE) { + temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)]; + temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)]; + } else if (cur_k < K_SIZE) { + temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)]; + } + } +#else // OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K + temp_input1[i].s0 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 0, lid)]; + temp_input1[i].s1 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 1, lid)]; + temp_input1[i].s2 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 2, lid)]; + temp_input1[i].s3 = input1[FUNC_CALL(get_current_input1_offset)(common_input1_offset, 3, lid)]; +#endif // OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K + + tile_input1[i] = AS_TYPE(PACKED_INPUT1_TYPE, temp_input1[i]); + } + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + const uint common_input0_offset = FUNC_CALL(get_common_input0_offset)(batch_offset_input0, k, i, output_y_tile, lid); + +#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K + const uint cur_m = output_y_tile * TILE_SIZE_M + i; + const uint cur_k = k * TILE_SIZE_K + lid * PACK_SIZE; + + temp_input0[i] = 0; + + if (cur_m < OUTPUT_SIZE_Y) { + if (cur_k + 3 < K_SIZE) { + temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)]; + temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)]; + temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)]; + temp_input0[i].s3 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 3, lid)]; + } else if (cur_k + 2 < K_SIZE) { + temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)]; + temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)]; + temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)]; + } else if (cur_k + 1 < K_SIZE) { + temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)]; + temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)]; + } else if (cur_k < K_SIZE) { + temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)]; + } + } + + tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]); +#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K + +#if !TRANSPOSE_INPUT0 + tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset)); +#else // !TRANSPOSE_INPUT0 + temp_input0[i].s0 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 0, lid)]; + temp_input0[i].s1 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 1, lid)]; + temp_input0[i].s2 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 2, lid)]; + temp_input0[i].s3 = input0[FUNC_CALL(get_current_input0_offset)(common_input0_offset, 3, lid)]; + + tile_input0[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]); +#endif // !TRANSPOSE_INPUT0 + +#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_K + } + + tile_output = MMAD(tile_input0, tile_input1, tile_output); + } + +#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS_PRELOAD; +#endif // HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { +#if OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N + if (output_y_tile * TILE_SIZE_M + i >= OUTPUT_SIZE_Y) continue; + if (output_x_tile * TILE_SIZE_N + lid >= OUTPUT_SIZE_X) continue; +#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N + + ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output[i]); + dequantized *= TO_ACTIVATION_TYPE(ALPHA); + +#ifdef INPUT2_TYPE + dequantized += TO_ACTIVATION_TYPE(BETA) * tile_input2[i]; +#endif // INPUT2_TYPE + +#if HAS_FUSED_OPS +#if FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS_CALC; +#else // FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS; +#endif // FUSED_OPS_CAN_USE_PRELOAD + OUTPUT_TYPE res = FUSED_OPS_RESULT; + output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res; + output_y++; +#else // HAS_FUSED_OPS + output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized; +#endif // HAS_FUSED_OPS + } +} + +// ******************************************************************************************************************************** // +// Optimized kernel without leftovers (for tiling parameters M = 8, N = 8, K = 32; M = 16, N = 16, K = 64; M = 32, N = 16, K = 64). // +// ******************************************************************************************************************************** // + +#else // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K +{ + const uint output_x = (uint)get_global_id(0); + const uint output_x_tile = output_x / TILE_SIZE_N; + const uint output_y_tile = (uint)get_global_id(1); +#if HAS_FUSED_OPS + uint output_y = output_y_tile * TILE_SIZE_M; +#endif // HAS_FUSED_OPS + uint batch = get_global_id(2); + const uint lid = (uint)get_local_id(0); + + const uint z = batch % OUTPUT_SIZE_Z; + batch /= OUTPUT_SIZE_Z; + const uint w = batch % OUTPUT_SIZE_W; + batch /= OUTPUT_SIZE_W; + const uint f = batch % OUTPUT_FEATURE_NUM; + batch /= OUTPUT_FEATURE_NUM; + const uint b = batch % OUTPUT_BATCH_NUM; + + const uint batch_offset_input0 = FUNC_CALL(get_input0_batch_offset)(b, f, w, z); + const uint batch_offset_input1 = FUNC_CALL(get_input1_batch_offset)(b, f, w, z); +#ifdef INPUT2_TYPE + const uint batch_offset_input2 = FUNC_CALL(get_input2_batch_offset)(b, f, w, z); +#endif // INPUT2_TYPE + const uint batch_offset_output = FUNC_CALL(get_output_batch_offset)(b, f, w, z); + + PACKED_INPUT0_TYPE_VEC tile_input00; + PACKED_INPUT1_TYPE_VEC tile_input10; + +#ifdef INPUT2_TYPE + ACTIVATION_TYPE_VEC tile_input20; + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + tile_input20[i] = TO_ACTIVATION_TYPE(input2[batch_offset_input2 + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + + output_x_tile * TILE_SIZE_N + lid]); + } +#endif // INPUT2_TYPE + + ACCUMULATOR_TYPE_VEC tile_output00 = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO); +#if TILE_NUM == 2 + ACCUMULATOR_TYPE_VEC tile_output01 = (ACCUMULATOR_TYPE_VEC)(ACCUMULATOR_VAL_ZERO); +#endif // TILE_NUM == 2 + +#if !TRANSPOSE_INPUT0 + const uint K_BLOCK_NUM = INPUT0_SIZE_X / TILE_SIZE_K; +#else // !TRANSPOSE_INPUT0 + const uint K_BLOCK_NUM = INPUT0_SIZE_Y / TILE_SIZE_K; +#endif // !TRANSPOSE_INPUT0 + + for (uint k = 0; k < K_BLOCK_NUM; k++) { +#if !TRANSPOSE_INPUT1 + MAKE_VECTOR_TYPE(INPUT1_TYPE, PACK_SIZE) temp_input1[SUB_GROUP_SIZE]; + const uint common_input1_offset = batch_offset_input1 + k * TILE_SIZE_K * INPUT1_SIZE_X + output_x_tile * TILE_SIZE_N; + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + temp_input1[i].s0 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + lid]; + temp_input1[i].s1 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + INPUT1_SIZE_X + lid]; + temp_input1[i].s2 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + 2 * INPUT1_SIZE_X + lid]; + temp_input1[i].s3 = input1[common_input1_offset + i * PACK_SIZE * INPUT1_SIZE_X + 3 * INPUT1_SIZE_X + lid]; + + tile_input10[i] = AS_TYPE(PACKED_INPUT1_TYPE, temp_input1[i]); + } +#else // !TRANSPOSE_INPUT1 + const uint common_input1_offset = batch_offset_input1 + output_x_tile * TILE_SIZE_N * INPUT1_SIZE_X + k * TILE_SIZE_K; + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + tile_input10[i] = AS_TYPE(PACKED_INPUT1_TYPE, BLOCK_READ(input1 + common_input1_offset + i * INPUT1_SIZE_X)); + } + + PACKED_INPUT1_TYPE_VEC tile_input1_col0 = BLOCK_SHUFFLE(tile_input10, 0); + PACKED_INPUT1_TYPE_VEC tile_input1_col1 = BLOCK_SHUFFLE(tile_input10, 1); + PACKED_INPUT1_TYPE_VEC tile_input1_col2 = BLOCK_SHUFFLE(tile_input10, 2); + PACKED_INPUT1_TYPE_VEC tile_input1_col3 = BLOCK_SHUFFLE(tile_input10, 3); + PACKED_INPUT1_TYPE_VEC tile_input1_col4 = BLOCK_SHUFFLE(tile_input10, 4); + PACKED_INPUT1_TYPE_VEC tile_input1_col5 = BLOCK_SHUFFLE(tile_input10, 5); + PACKED_INPUT1_TYPE_VEC tile_input1_col6 = BLOCK_SHUFFLE(tile_input10, 6); + PACKED_INPUT1_TYPE_VEC tile_input1_col7 = BLOCK_SHUFFLE(tile_input10, 7); +#if SUB_GROUP_SIZE == 16 + PACKED_INPUT1_TYPE_VEC tile_input1_col8 = BLOCK_SHUFFLE(tile_input10, 8); + PACKED_INPUT1_TYPE_VEC tile_input1_col9 = BLOCK_SHUFFLE(tile_input10, 9); + PACKED_INPUT1_TYPE_VEC tile_input1_col10 = BLOCK_SHUFFLE(tile_input10, 10); + PACKED_INPUT1_TYPE_VEC tile_input1_col11 = BLOCK_SHUFFLE(tile_input10, 11); + PACKED_INPUT1_TYPE_VEC tile_input1_col12 = BLOCK_SHUFFLE(tile_input10, 12); + PACKED_INPUT1_TYPE_VEC tile_input1_col13 = BLOCK_SHUFFLE(tile_input10, 13); + PACKED_INPUT1_TYPE_VEC tile_input1_col14 = BLOCK_SHUFFLE(tile_input10, 14); + PACKED_INPUT1_TYPE_VEC tile_input1_col15 = BLOCK_SHUFFLE(tile_input10, 15); +#endif // SUB_GROUP_SIZE == 16 + + tile_input10.s0 = tile_input1_col0[lid]; + tile_input10.s1 = tile_input1_col1[lid]; + tile_input10.s2 = tile_input1_col2[lid]; + tile_input10.s3 = tile_input1_col3[lid]; + tile_input10.s4 = tile_input1_col4[lid]; + tile_input10.s5 = tile_input1_col5[lid]; + tile_input10.s6 = tile_input1_col6[lid]; + tile_input10.s7 = tile_input1_col7[lid]; +#if SUB_GROUP_SIZE == 16 + tile_input10.s8 = tile_input1_col8[lid]; + tile_input10.s9 = tile_input1_col9[lid]; + tile_input10.sa = tile_input1_col10[lid]; + tile_input10.sb = tile_input1_col11[lid]; + tile_input10.sc = tile_input1_col12[lid]; + tile_input10.sd = tile_input1_col13[lid]; + tile_input10.se = tile_input1_col14[lid]; + tile_input10.sf = tile_input1_col15[lid]; +#endif // SUB_GROUP_SIZE == 16 + +#endif // !TRANSPOSE_INPUT1 + +#if !TRANSPOSE_INPUT0 + const uint common_input0_offset = batch_offset_input0 + output_y_tile * TILE_SIZE_M * INPUT0_SIZE_X + k * TILE_SIZE_K; + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset + i * INPUT0_SIZE_X)); + } + + tile_output00 = MMAD(tile_input00, tile_input10, tile_output00); + +#if TILE_NUM == 2 + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, BLOCK_READ(input0 + common_input0_offset + (TILE_SIZE_M_DIV + i) * INPUT0_SIZE_X)); + } + + tile_output01 = MMAD(tile_input00, tile_input10, tile_output01); +#endif // TILE_NUM == 2 + +#else // !TRANSPOSE_INPUT0 + MAKE_VECTOR_TYPE(INPUT0_TYPE, PACK_SIZE) temp_input0[SUB_GROUP_SIZE]; + const uint common_input0_offset = batch_offset_input0 + (k * TILE_SIZE_K + lid * PACK_SIZE) * INPUT0_SIZE_X + output_y_tile * TILE_SIZE_M; + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + temp_input0[i].s0 = input0[common_input0_offset + i]; + temp_input0[i].s1 = input0[common_input0_offset + 1 * INPUT0_SIZE_X + i]; + temp_input0[i].s2 = input0[common_input0_offset + 2 * INPUT0_SIZE_X + i]; + temp_input0[i].s3 = input0[common_input0_offset + 3 * INPUT0_SIZE_X + i]; + + tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]); + } + + tile_output00 = MMAD(tile_input00, tile_input10, tile_output00); + +#if TILE_NUM == 2 + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + temp_input0[i].s0 = input0[common_input0_offset + TILE_SIZE_M_DIV + i]; + temp_input0[i].s1 = input0[common_input0_offset + 1 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i]; + temp_input0[i].s2 = input0[common_input0_offset + 2 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i]; + temp_input0[i].s3 = input0[common_input0_offset + 3 * INPUT0_SIZE_X + TILE_SIZE_M_DIV + i]; + + tile_input00[i] = AS_TYPE(PACKED_INPUT0_TYPE, temp_input0[i]); + } + + tile_output01 = MMAD(tile_input00, tile_input10, tile_output01); +#endif // TILE_NUM == 2 + +#endif // !TRANSPOSE_INPUT0 + } + +#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS_PRELOAD; +#endif // HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output00[i]); + dequantized *= TO_ACTIVATION_TYPE(ALPHA); +#ifdef INPUT2_TYPE + dequantized += TO_ACTIVATION_TYPE(BETA) * tile_input20[i]; +#endif // INPUT2_TYPE + +#if HAS_FUSED_OPS +#if FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS_CALC; +#else // FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS; +#endif // FUSED_OPS_CAN_USE_PRELOAD + + OUTPUT_TYPE res = FUSED_OPS_RESULT; + output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res; + output_y++; +#else // HAS_FUSED_OPS + output[batch_offset_output + (output_y_tile * TILE_SIZE_M + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized; +#endif // HAS_FUSED_OPS + } + +#if TILE_NUM == 2 +#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS_PRELOAD; +#endif + + for (uint i = 0; i < SUB_GROUP_SIZE; i++) { + ACTIVATION_TYPE dequantized = TO_ACTIVATION_TYPE(tile_output01[i]); + dequantized *= TO_ACTIVATION_TYPE(ALPHA); + +#if HAS_FUSED_OPS +#if FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS_CALC; +#else // FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS; +#endif // FUSED_OPS_CAN_USE_PRELOAD + + OUTPUT_TYPE res = FUSED_OPS_RESULT; + output[batch_offset_output + (output_y_tile * TILE_SIZE_M + TILE_SIZE_M_DIV + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = res; + output_y++; +#else // HAS_FUSED_OPS + output[batch_offset_output + (output_y_tile * TILE_SIZE_M + TILE_SIZE_M_DIV + i) * OUTPUT_SIZE_X + output_x_tile * TILE_SIZE_N + lid] = dequantized; +#endif // HAS_FUSED_OPS + } +#endif // TILE_NUM == 2 + +} +#endif // OUTPUT_LEFTOVERS_M || OUTPUT_LEFTOVERS_N || OUTPUT_LEFTOVERS_K + +#undef PACK_SIZE +#undef AS_TYPE +#undef ACCUMULATOR_TYPE_VEC +#undef ACTIVATION_TYPE_VEC +#undef PACKED_INPUT0_TYPE_VEC +#undef PACKED_INPUT1_TYPE_VEC +#undef BLOCK_READ +#undef BLOCK_SHUFFLE +#undef MMAD +#undef TILE_SIZE_M_DIV diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/mmad.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/mmad.cl index 80fab34..3aab503 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/mmad.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/mmad.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2016 Intel Corporation +// Copyright (c) 2016-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -55,6 +55,24 @@ uint8 FUNC(intel_sub_group_block_read_uint8)(const __local uint* p) return ret; } +inline int FUNC(mmad_4)(char4 input, char4 weight, int acc) __attribute__((overloadable)) +{ + acc += (input[0] * weight[0]); + acc += (input[1] * weight[1]); + acc += (input[2] * weight[2]); + acc += (input[3] * weight[3]); + return acc; +} + +inline int FUNC(mmad_4)(char4 input, uchar4 weight, int acc) __attribute__((overloadable)) +{ + acc += (input[0] * weight[0]); + acc += (input[1] * weight[1]); + acc += (input[2] * weight[2]); + acc += (input[3] * weight[3]); + return acc; +} + inline int FUNC(mmad_4)(uchar4 input, char4 weight, int acc) __attribute__((overloadable)) { acc += (input[0] * weight[0]); @@ -64,7 +82,7 @@ inline int FUNC(mmad_4)(uchar4 input, char4 weight, int acc) __attribute__((over return acc; } -inline int FUNC(mmad_4)(char4 input, char4 weight, int acc) __attribute__((overloadable)) +inline int FUNC(mmad_4)(uchar4 input, uchar4 weight, int acc) __attribute__((overloadable)) { acc += (input[0] * weight[0]); acc += (input[1] * weight[1]); @@ -73,6 +91,34 @@ inline int FUNC(mmad_4)(char4 input, char4 weight, int acc) __attribute__((overl return acc; } +inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable)) +{ + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_char4(B_vectors[0]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_char4(B_vectors[1]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_char4(B_vectors[2]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_char4(B_vectors[3]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_char4(B_vectors[4]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_char4(B_vectors[5]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_char4(B_vectors[6]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_char4(B_vectors[7]), acc); + + return acc; +} + +inline int FUNC(mmad8)(int8 A_scalars, uint8 B_vectors, int acc) __attribute__((overloadable)) +{ + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_uchar4(B_vectors[0]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_uchar4(B_vectors[1]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_uchar4(B_vectors[2]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_uchar4(B_vectors[3]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_uchar4(B_vectors[4]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_uchar4(B_vectors[5]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_uchar4(B_vectors[6]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_uchar4(B_vectors[7]), acc); + + return acc; +} + inline int FUNC(mmad8)(uint8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable)) { acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_char4(B_vectors[0]), acc); @@ -86,7 +132,22 @@ inline int FUNC(mmad8)(uint8 A_scalars, int8 B_vectors, int acc) __attribute__(( return acc; } -inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc) __attribute__((overloadable)) + +inline int FUNC(mmad8)(uint8 A_scalars, uint8 B_vectors, int acc) __attribute__((overloadable)) +{ + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_uchar4(B_vectors[0]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_uchar4(B_vectors[1]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_uchar4(B_vectors[2]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_uchar4(B_vectors[3]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_uchar4(B_vectors[4]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_uchar4(B_vectors[5]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_uchar4(B_vectors[6]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_uchar4(B_vectors[7]), acc); + + return acc; +} + +inline int FUNC(mmad16)(int16 A_scalars, int16 B_vectors, int acc) __attribute__((overloadable)) { acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_char4(B_vectors[0]), acc); acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_char4(B_vectors[1]), acc); @@ -96,10 +157,122 @@ inline int FUNC(mmad8)(int8 A_scalars, int8 B_vectors, int acc) __attribute__((o acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_char4(B_vectors[5]), acc); acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_char4(B_vectors[6]), acc); acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_char4(B_vectors[7]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[8]), as_char4(B_vectors[8]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[9]), as_char4(B_vectors[9]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[10]), as_char4(B_vectors[10]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[11]), as_char4(B_vectors[11]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[12]), as_char4(B_vectors[12]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[13]), as_char4(B_vectors[13]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[14]), as_char4(B_vectors[14]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[15]), as_char4(B_vectors[15]), acc); return acc; } +inline int FUNC(mmad16)(int16 A_scalars, uint16 B_vectors, int acc) __attribute__((overloadable)) +{ + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[0]), as_uchar4(B_vectors[0]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[1]), as_uchar4(B_vectors[1]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[2]), as_uchar4(B_vectors[2]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[3]), as_uchar4(B_vectors[3]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[4]), as_uchar4(B_vectors[4]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[5]), as_uchar4(B_vectors[5]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[6]), as_uchar4(B_vectors[6]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[7]), as_uchar4(B_vectors[7]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[8]), as_uchar4(B_vectors[8]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[9]), as_uchar4(B_vectors[9]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[10]), as_uchar4(B_vectors[10]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[11]), as_uchar4(B_vectors[11]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[12]), as_uchar4(B_vectors[12]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[13]), as_uchar4(B_vectors[13]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[14]), as_uchar4(B_vectors[14]), acc); + acc = FUNC_CALL(mmad_4)(as_char4(A_scalars[15]), as_uchar4(B_vectors[15]), acc); + + return acc; +} + +inline int FUNC(mmad16)(uint16 A_scalars, int16 B_vectors, int acc) __attribute__((overloadable)) +{ + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_char4(B_vectors[0]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_char4(B_vectors[1]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_char4(B_vectors[2]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_char4(B_vectors[3]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_char4(B_vectors[4]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_char4(B_vectors[5]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_char4(B_vectors[6]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_char4(B_vectors[7]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[8]), as_char4(B_vectors[8]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[9]), as_char4(B_vectors[9]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[10]), as_char4(B_vectors[10]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[11]), as_char4(B_vectors[11]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[12]), as_char4(B_vectors[12]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[13]), as_char4(B_vectors[13]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[14]), as_char4(B_vectors[14]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[15]), as_char4(B_vectors[15]), acc); + + return acc; +} + +inline int FUNC(mmad16)(uint16 A_scalars, uint16 B_vectors, int acc) __attribute__((overloadable)) +{ + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[0]), as_uchar4(B_vectors[0]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[1]), as_uchar4(B_vectors[1]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[2]), as_uchar4(B_vectors[2]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[3]), as_uchar4(B_vectors[3]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[4]), as_uchar4(B_vectors[4]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[5]), as_uchar4(B_vectors[5]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[6]), as_uchar4(B_vectors[6]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[7]), as_uchar4(B_vectors[7]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[8]), as_uchar4(B_vectors[8]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[9]), as_uchar4(B_vectors[9]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[10]), as_uchar4(B_vectors[10]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[11]), as_uchar4(B_vectors[11]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[12]), as_uchar4(B_vectors[12]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[13]), as_uchar4(B_vectors[13]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[14]), as_uchar4(B_vectors[14]), acc); + acc = FUNC_CALL(mmad_4)(as_uchar4(A_scalars[15]), as_uchar4(B_vectors[15]), acc); + + return acc; +} + +inline int4 FUNC(mmad4x8)(int4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable)) +{ + int4 ret; + for(uint i = 0; i < 4; i++) + { + int8 A_scalars; + A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); + A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); + A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); + A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3); + A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4); + A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5); + A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6); + A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7); + ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]); + } + return ret; +} + +inline int4 FUNC(mmad4x8)(int4 A_vectors, uint8 B_vectors, int4 acc) __attribute__((overloadable)) +{ + int4 ret; + for(uint i = 0; i < 4; i++) + { + int8 A_scalars; + A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); + A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); + A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); + A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3); + A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4); + A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5); + A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6); + A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7); + ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]); + } + return ret; +} + inline int4 FUNC(mmad4x8)(uint4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable)) { int4 ret; @@ -119,11 +292,49 @@ inline int4 FUNC(mmad4x8)(uint4 A_vectors, int8 B_vectors, int4 acc) __attribute return ret; } -inline int4 FUNC(mmad4x8)(int4 A_vectors, int8 B_vectors, int4 acc) __attribute__((overloadable)) +inline int4 FUNC(mmad4x8)(uint4 A_vectors, uint8 B_vectors, int4 acc) __attribute__((overloadable)) { int4 ret; for(uint i = 0; i < 4; i++) { + uint8 A_scalars; + A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); + A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); + A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); + A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3); + A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4); + A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5); + A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6); + A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7); + ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]); + } + return ret; +} + +inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc) __attribute__((overloadable)) +{ + int8 ret; + for(uint i = 0; i < 8; i++) + { + int8 A_scalars; + A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); + A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); + A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); + A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3); + A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4); + A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5); + A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6); + A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7); + ret[i] = FUNC_CALL(mmad8)(A_scalars, B_vectors, acc[i]); + } + return ret; +} + +inline int8 FUNC(mmad8x8)(int8 A_vectors, uint8 B_vectors, int8 acc) __attribute__((overloadable)) +{ + int8 ret; + for(uint i = 0; i < 8; i++) + { int8 A_scalars; A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); @@ -157,12 +368,12 @@ inline int8 FUNC(mmad8x8)(uint8 A_vectors, int8 B_vectors, int8 acc) __attribute return ret; } -inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc) __attribute__((overloadable)) +inline int8 FUNC(mmad8x8)(uint8 A_vectors, uint8 B_vectors, int8 acc) __attribute__((overloadable)) { int8 ret; for(uint i = 0; i < 8; i++) { - int8 A_scalars; + uint8 A_scalars; A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); @@ -176,7 +387,114 @@ inline int8 FUNC(mmad8x8)(int8 A_vectors, int8 B_vectors, int8 acc) __attribute_ return ret; } -// TODO: remove it when cl_intel_subgroups_char extension will work +inline int16 FUNC(mmad16x16)(int16 A_vectors, int16 B_vectors, int16 acc) __attribute__((overloadable)) +{ + int16 ret; + for(uint i = 0; i < 16; i++) + { + int16 A_scalars; + A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); + A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); + A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); + A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3); + A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4); + A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5); + A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6); + A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7); + A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8); + A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9); + A_scalars.sa = sub_group_broadcast(A_vectors[i], 10); + A_scalars.sb = sub_group_broadcast(A_vectors[i], 11); + A_scalars.sc = sub_group_broadcast(A_vectors[i], 12); + A_scalars.sd = sub_group_broadcast(A_vectors[i], 13); + A_scalars.se = sub_group_broadcast(A_vectors[i], 14); + A_scalars.sf = sub_group_broadcast(A_vectors[i], 15); + ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]); + } + return ret; +} + +inline int16 FUNC(mmad16x16)(int16 A_vectors, uint16 B_vectors, int16 acc) __attribute__((overloadable)) +{ + int16 ret; + for(uint i = 0; i < 16; i++) + { + int16 A_scalars; + A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); + A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); + A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); + A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3); + A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4); + A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5); + A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6); + A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7); + A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8); + A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9); + A_scalars.sa = sub_group_broadcast(A_vectors[i], 10); + A_scalars.sb = sub_group_broadcast(A_vectors[i], 11); + A_scalars.sc = sub_group_broadcast(A_vectors[i], 12); + A_scalars.sd = sub_group_broadcast(A_vectors[i], 13); + A_scalars.se = sub_group_broadcast(A_vectors[i], 14); + A_scalars.sf = sub_group_broadcast(A_vectors[i], 15); + ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]); + } + return ret; +} + +inline int16 FUNC(mmad16x16)(uint16 A_vectors, int16 B_vectors, int16 acc) __attribute__((overloadable)) +{ + int16 ret; + for(uint i = 0; i < 16; i++) + { + uint16 A_scalars; + A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); + A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); + A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); + A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3); + A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4); + A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5); + A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6); + A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7); + A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8); + A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9); + A_scalars.sa = sub_group_broadcast(A_vectors[i], 10); + A_scalars.sb = sub_group_broadcast(A_vectors[i], 11); + A_scalars.sc = sub_group_broadcast(A_vectors[i], 12); + A_scalars.sd = sub_group_broadcast(A_vectors[i], 13); + A_scalars.se = sub_group_broadcast(A_vectors[i], 14); + A_scalars.sf = sub_group_broadcast(A_vectors[i], 15); + ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]); + } + return ret; +} + +inline int16 FUNC(mmad16x16)(uint16 A_vectors, uint16 B_vectors, int16 acc) __attribute__((overloadable)) +{ + int16 ret; + for(uint i = 0; i < 16; i++) + { + uint16 A_scalars; + A_scalars.s0 = sub_group_broadcast(A_vectors[i], 0); + A_scalars.s1 = sub_group_broadcast(A_vectors[i], 1); + A_scalars.s2 = sub_group_broadcast(A_vectors[i], 2); + A_scalars.s3 = sub_group_broadcast(A_vectors[i], 3); + A_scalars.s4 = sub_group_broadcast(A_vectors[i], 4); + A_scalars.s5 = sub_group_broadcast(A_vectors[i], 5); + A_scalars.s6 = sub_group_broadcast(A_vectors[i], 6); + A_scalars.s7 = sub_group_broadcast(A_vectors[i], 7); + A_scalars.s8 = sub_group_broadcast(A_vectors[i], 8); + A_scalars.s9 = sub_group_broadcast(A_vectors[i], 9); + A_scalars.sa = sub_group_broadcast(A_vectors[i], 10); + A_scalars.sb = sub_group_broadcast(A_vectors[i], 11); + A_scalars.sc = sub_group_broadcast(A_vectors[i], 12); + A_scalars.sd = sub_group_broadcast(A_vectors[i], 13); + A_scalars.se = sub_group_broadcast(A_vectors[i], 14); + A_scalars.sf = sub_group_broadcast(A_vectors[i], 15); + ret[i] = FUNC_CALL(mmad16)(A_scalars, B_vectors, acc[i]); + } + return ret; +} + inline void FUNC(sub_group_block_write_uchar16)(__global uchar* outPtr, uchar16 v) { #ifdef cl_intel_subgroups_char @@ -272,6 +590,7 @@ inline uchar8 FUNC(sub_group_block_read_uchar8)(const __global uchar* ptr) ret.s7 = ptr[idx]; idx += get_max_sub_group_size(); return ret; + #endif } @@ -361,12 +680,10 @@ inline uchar FUNC(sub_group_block_read_uchar)(const __global uchar* ptr) #endif } -// - - #define MMAD_8(A, B, C) FUNC_CALL(mmad8)(A, B, C) #define MMAD_4x8(A, B, C) FUNC_CALL(mmad4x8)(A, B, C) #define MMAD_8x8(A, B, C) FUNC_CALL(mmad8x8)(A, B, C) +#define MMAD_16x16(A, B, C) FUNC_CALL(mmad16x16)(A, B, C) #define SLM_BLOCK_WRITE_4(A, B) (FUNC_CALL(intel_sub_group_block_write_4)(A, B)) #define SLM_BLOCK_READ_4(A) (FUNC_CALL(intel_sub_group_block_read_uint4)(A)) #define SLM_BLOCK_READ_8(A) (FUNC_CALL(intel_sub_group_block_read_uint8)(A)) diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp index db98f15..a759875 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp @@ -511,11 +511,13 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { bool can_fuse_parent1 = (parent1->is_type() && conv_supports_fusings(parent1->as())) || (parent1->is_type() && mvn_supports_fusings(parent1->as())) || - (parent1->is_type()) || (parent1->is_type()); + (parent1->is_type()) || (parent1->is_type()) || + (parent1->is_type()); bool can_fuse_parent2 = (parent2->is_type() && conv_supports_fusings(parent2->as())) || (parent2->is_type() && mvn_supports_fusings(parent2->as())) || - (parent2->is_type()) || (parent2->is_type()); + (parent2->is_type()) || (parent2->is_type()) || + (parent2->is_type()); std::vector can_fuse_parents = { can_fuse_parent1, can_fuse_parent2 }; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 4e96c35..f6ef20f 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -74,6 +74,7 @@ struct bc_test_params { struct gemm_test_params { std::vector in_shapes; + tensor out_shape; tensor kernel; tensor pad; data_types data_type_in0; @@ -360,6 +361,10 @@ public: layout get_per_channel_layout(gemm_test_params& p) { return layout{ p.default_type, p.default_format, tensor{1, p.in_shapes.at(0).feature[0], 1, 1} }; } + + layout get_output_layout(gemm_test_params& p) { + return layout{ p.default_type, p.input_format, p.out_shape }; + } }; // in_shape; out_shape; kernel; stride; pad; dilation; groups; data_type; input_format; weights_type; weights_format; default_type; default_format; @@ -431,16 +436,19 @@ public: #define CASE_FC_U8S8_2 {2, 1, 3, 1}, {2, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx #define CASE_FC_U8S8_3 {2, 32, 1, 1}, {2, 16, 1, 1}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx -#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx + +#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 12, 16}}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx #define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx @@ -2270,6 +2278,35 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_i8, gemm_test_params{ CASE_GEMM_2IN_S8U8_1, 3, 6 }, }), ); +class gemm_int8_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {}; +TEST_P(gemm_int8_2in_act_scale_quantize_eltwise_i8, basic) { + auto p = GetParam(); + create_topologies(input_layout("input0", get_input_layout(p, 0)), + input_layout("input1", get_input_layout(p, 1)), + data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_lo", get_mem(get_single_element_layout(p), -127)), + data("out_hi", get_mem(get_single_element_layout(p), 127)), + data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)), + data("eltwise_data", get_mem(get_output_layout(p))), + gemm("gemm_prim", { "input0", "input1" }, data_types::f32), + activation("activation", "gemm_prim", activation_func::exp), + scale("scale", "activation", "scale_data"), + quantize("quantize", "scale", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), + eltwise("sum", { "quantize", "eltwise_data"}, eltwise_mode::sum, data_types::f32), + reorder("reorder_bfyx", "sum", p.default_format, data_types::f32) + ); + + tolerance = 1.0f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_eltwise_i8, + ::testing::ValuesIn(std::vector{ + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 7 }, + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 7 }, +}), ); + /* ----------------------------------------------------------------------------------------------------- */ /* ---------------------------------------- Resample cases --------------------------------------------- */ /* ----------------------------------------------------------------------------------------------------- */ diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp index e184963..d1e1e41 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gemm_gpu_test.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018 Intel Corporation +// Copyright (c) 2018-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -3232,8 +3232,228 @@ TEST(gemm_gpu, basic_smarcink2) { auto output_ptr = output.pointer(); EXPECT_EQ(output_ptr.size(), (uint32_t)8); - for (uint32_t i = 0; i < out_data.size(); ++i) { + for (uint32_t i = 0; i < out_data.size(); ++i) { EXPECT_FLOAT_EQ(output_ptr[i], out_data[i]); } } +struct gemm_int8_test_params { + size_t m_size; + size_t n_size; + size_t k_size; + size_t b0_num; + size_t f0_num; + size_t b1_num; + size_t f1_num; + size_t b2_num; + size_t f2_num; + size_t b_out_num; + size_t f_out_num; + bool transpose_input0; + bool transpose_input1; + float alpha; + float beta; + std::string kernel_name; +}; + +#define CASE_GEMM_INT8_NN_TRANSPOSITION 64, 64, 64, 1, 2, 1, 2, 1, 2, 1, 2, false, false, 1.5f, 2.0f +#define CASE_GEMM_INT8_NT_TRANSPOSITION 32, 64, 32, 2, 1, 2, 1, 2, 1, 2, 1, false, true, 1.7f, 1.3f +#define CASE_GEMM_INT8_TN_TRANSPOSITION 128, 64, 32, 2, 2, 2, 2, 2, 2, 2, 2, true, false, 1.0f, 0.0f +#define CASE_GEMM_INT8_TT_TRANSPOSITION 32, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.2f, 0.5f + +#define CASE_GEMM_INT8_BROADCAST_1 32, 32, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f +#define CASE_GEMM_INT8_BROADCAST_2 32, 32, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, false, 1.7f, 1.3f +#define CASE_GEMM_INT8_BROADCAST_3 64, 32, 32, 1, 2, 2, 1, 1, 2, 2, 2, false, false, 1.0f, 1.5f +#define CASE_GEMM_INT8_BROADCAST_4 32, 64, 32, 1, 1, 2, 2, 2, 2, 2, 2, false, false, 1.2f, 0.5f + +#define CASE_GEMM_INT8_LEFTOVERS_1 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f +#define CASE_GEMM_INT8_LEFTOVERS_2 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f +#define CASE_GEMM_INT8_LEFTOVERS_3 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f +#define CASE_GEMM_INT8_LEFTOVERS_4 13, 32, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f +#define CASE_GEMM_INT8_LEFTOVERS_5 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f +#define CASE_GEMM_INT8_LEFTOVERS_6 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f +#define CASE_GEMM_INT8_LEFTOVERS_7 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f +#define CASE_GEMM_INT8_LEFTOVERS_8 32, 13, 32, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f +#define CASE_GEMM_INT8_LEFTOVERS_9 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, false, 1.5f, 2.0f +#define CASE_GEMM_INT8_LEFTOVERS_10 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, false, true, 1.6f, 1.0f +#define CASE_GEMM_INT8_LEFTOVERS_11 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, false, 1.0f, 1.5f +#define CASE_GEMM_INT8_LEFTOVERS_12 32, 32, 13, 1, 1, 1, 1, 1, 1, 1, 1, true, true, 1.7f, 1.3f + +#define CASE_GEMM_INT8_COMBO_1 8, 8, 32, 1, 2, 1, 1, 1, 1, 1, 2, false, false, 1.5f, 2.0f +#define CASE_GEMM_INT8_COMBO_2 16, 16, 64, 2, 1, 1, 1, 1, 1, 2, 1, false, true, 1.7f, 0.0f +#define CASE_GEMM_INT8_COMBO_3 11, 31, 21, 7, 15, 7, 15, 7, 15, 7, 15, true, false, 1.0f, 1.5f +#define CASE_GEMM_INT8_COMBO_4 32, 32, 32, 3, 6, 3, 6, 3, 6, 3, 6, true, true, 1.2f, 4.0f + +template +class GemmInt8Test : public ::testing::TestWithParam { +public: + + inline size_t getGemmIndex(size_t x, size_t y, size_t f, size_t b, size_t x_size, size_t y_size, size_t f_num, size_t b_num, + size_t x_pitch, size_t y_pitch, size_t f_pitch, size_t b_pitch) { + return (x % x_size) * x_pitch + (y % y_size) * y_pitch + (f % f_num) * f_pitch + (b % b_num) * b_pitch; + } + + void execute(T& p) { + const auto& engine = get_test_engine(); + + auto y0_size = p.m_size; + auto y0_pitch = p.k_size; + auto x0_size = p.k_size; + auto x0_pitch = 1; + auto f0_pitch = y0_size * x0_size; + auto b0_pitch = p.f0_num * f0_pitch; + + auto y1_size = p.k_size; + auto y1_pitch = p.n_size; + auto x1_size = p.n_size; + auto x1_pitch = 1; + auto f1_pitch = y1_size * x1_size; + auto b1_pitch = p.f1_num * f1_pitch; + + auto y2_size = p.m_size; + auto y2_pitch = p.n_size; + auto x2_size = p.n_size; + auto x2_pitch = 1; + auto f2_pitch = y2_size * x2_size; + auto b2_pitch = p.f2_num * f2_pitch; + + auto y_out_size = p.m_size; + auto y_out_pitch = p.n_size; + auto x_out_size = p.n_size; + auto x_out_pitch = 1; + auto f_out_pitch = y_out_size * x_out_size; + auto b_out_pitch = p.f_out_num * f_out_pitch; + + if (p.transpose_input0) { + y0_size = p.k_size; + y0_pitch = p.m_size; + x0_size = p.m_size; + x0_pitch = 1; + } + + if (p.transpose_input1) { + y1_size = p.n_size; + y1_pitch = p.k_size; + x1_size = p.k_size; + x1_pitch = 1; + } + + auto input0_size = tensor((int)p.b0_num, (int)p.f0_num, (int)x0_size, (int)y0_size); + auto input0_data = generate_random_4d(p.b0_num, p.f0_num, x0_size, y0_size, -128, 127, 1); + auto input0_data_bfyx = flatten_4d(format::bfyx, input0_data); + auto input0_mem = memory::allocate(engine, { data_types::i8, format::bfyx, input0_size }); + set_values(input0_mem, input0_data_bfyx); + + auto input1_size = tensor((int)p.b1_num, (int)p.f1_num, (int)x1_size, (int)y1_size); + auto input1_data = generate_random_4d(p.b1_num, p.f1_num, x1_size, y1_size, 0, 255, 1); + auto input1_data_bfyx = flatten_4d(format::bfyx, input1_data); + auto input1_mem = memory::allocate(engine, { data_types::u8, format::bfyx, input1_size }); + set_values(input1_mem, input1_data_bfyx); + + auto input2_size = tensor((int)p.b2_num, (int)p.f2_num, (int)x2_size, (int)y2_size); + auto input2_data = generate_random_4d(p.b2_num, p.f2_num, x2_size, y2_size, -10, 10); + auto input2_data_bfyx = flatten_4d(format::bfyx, input2_data); + auto input2_mem = memory::allocate(engine, { data_types::f32, format::bfyx, input2_size }); + set_values(input2_mem, input2_data_bfyx); + + std::vector out_data(p.b_out_num * p.f_out_num * p.m_size * p.n_size); + + for (size_t b = 0; b < p.b_out_num; ++b) { + for (size_t f = 0; f < p.f_out_num; ++f) { + for (size_t i = 0; i < p.m_size; ++i) { + for (size_t j = 0; j < p.n_size; ++j) { + size_t input2_data_index = getGemmIndex(j, i, f, b, x2_size, y2_size, p.f2_num, p.b2_num, x2_pitch, y2_pitch, f2_pitch, b2_pitch); + size_t out_data_index = getGemmIndex(j, i, f, b, x_out_size, y_out_size, p.f_out_num, p.b_out_num, + x_out_pitch, y_out_pitch, f_out_pitch, b_out_pitch); + int32_t acc = 0; + + for (size_t k = 0; k < p.k_size; ++k) { + size_t input0_data_index = getGemmIndex(k * (!p.transpose_input0) + i * p.transpose_input0, i * (!p.transpose_input0) + + k * p.transpose_input0, f, b, x0_size, y0_size, p.f0_num, p.b0_num, x0_pitch, y0_pitch, f0_pitch, b0_pitch); + size_t input1_data_index = getGemmIndex(j * (!p.transpose_input1) + k * p.transpose_input1, k * (!p.transpose_input1) + + j * p.transpose_input1, f, b, x1_size, y1_size, p.f1_num, p.b1_num, x1_pitch, y1_pitch, f1_pitch, b1_pitch); + + acc += input0_data_bfyx[input0_data_index] * input1_data_bfyx[input1_data_index]; + } + + out_data[out_data_index] = (float)acc; + out_data[out_data_index] *= p.alpha; + out_data[out_data_index] += p.beta * input2_data_bfyx[input2_data_index]; + } + } + } + } + + topology topology; + topology.add(input_layout("input0", input0_mem.get_layout())); + topology.add(input_layout("input1", input1_mem.get_layout())); + topology.add(input_layout("input2", input2_mem.get_layout())); + topology.add(gemm("output", { "input0", "input1", "input2" }, data_types::f32, p.transpose_input0, p.transpose_input1, p.alpha, p.beta)); + + build_options options; + implementation_desc gemm_int8_impl = { format::bfyx, p.kernel_name }; + options.set_option(build_option::force_implementations({ {"output", gemm_int8_impl} })); + + network network(engine, topology, options); + network.set_input_data("input0", input0_mem); + network.set_input_data("input1", input1_mem); + network.set_input_data("input2", input2_mem); + auto outputs = network.execute(); + + auto output = outputs.at("output").get_memory(); + auto output_ptr = output.pointer(); + + EXPECT_EQ(output_ptr.size(), (size_t)(p.b_out_num * p.f_out_num * p.m_size * p.n_size)); + for (size_t i = 0; i < out_data.size(); ++i) { + EXPECT_FLOAT_EQ(output_ptr[i], out_data[i]); + } + } +}; + +class gemm_int8_transposition_tests : public ::GemmInt8Test {}; +TEST_P(gemm_int8_transposition_tests, basic) { auto p = GetParam(); execute(p); } + +INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_transposition_tests, ::testing::ValuesIn(std::vector { + gemm_int8_test_params{ CASE_GEMM_INT8_NN_TRANSPOSITION, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_NT_TRANSPOSITION, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_TN_TRANSPOSITION, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_TT_TRANSPOSITION, "gemm_mmad_int8" }, +}), ); + +class gemm_int8_broadcast_tests : public ::GemmInt8Test {}; +TEST_P(gemm_int8_broadcast_tests, basic) { auto p = GetParam(); execute(p); } + +INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_broadcast_tests, ::testing::ValuesIn(std::vector { + gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_1, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_2, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_3, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_BROADCAST_4, "gemm_mmad_int8" }, +}), ); + +class gemm_int8_leftovers_tests : public ::GemmInt8Test {}; +TEST_P(gemm_int8_leftovers_tests, basic) { auto p = GetParam(); execute(p); } + +INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_leftovers_tests, ::testing::ValuesIn(std::vector { + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_1, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_2, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_3, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_4, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_5, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_6, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_7, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_8, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_9, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_10, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_11, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_LEFTOVERS_12, "gemm_mmad_int8" }, +}), ); + +class gemm_int8_combo_tests : public ::GemmInt8Test {}; +TEST_P(gemm_int8_combo_tests, basic) { auto p = GetParam(); execute(p); } + +INSTANTIATE_TEST_CASE_P(gemm_gpu, gemm_int8_combo_tests, ::testing::ValuesIn(std::vector { + gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_1, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_2, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_3, "gemm_mmad_int8" }, + gemm_int8_test_params{ CASE_GEMM_INT8_COMBO_4, "gemm_mmad_int8" }, +}), ); -- 2.7.4