From 48f5f524b8b38bcfe36fb20459436abd8c188a1a Mon Sep 17 00:00:00 2001 From: Vladimir Paramuzov Date: Mon, 27 Jul 2020 18:49:54 +0300 Subject: [PATCH] [IE CLDNN] Fixed gemm fusings with FP precision (#1490) --- .../core/actual_kernels/gemm/gemm_kernel_ref.h | 3 +- .../graph_optimizer/prepare_primitive_fusing.cpp | 33 ++++-- .../clDNN/tests/test_cases/fusings_gpu_test.cpp | 121 ++++++++++++++++----- 3 files changed, 114 insertions(+), 43 deletions(-) diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_ref.h index 7441bc7..7ba6604 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_ref.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_ref.h @@ -30,7 +30,8 @@ protected: std::vector GetSupportedFusedOps() const override { return { FusedOpType::QUANTIZE, FusedOpType::ACTIVATION, - FusedOpType::SCALE }; + FusedOpType::SCALE, + FusedOpType::ELTWISE }; } bool Validate(const Params& params, const optional_params& options) const override; JitConstants GetJitConstants(const gemm_params& params) const override; diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp index 1915300..3edfa76 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp @@ -306,20 +306,27 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { auto in1_dt = node.get_dependency(1).get_output_layout().data_type; auto in0_fmt = node.get_dependency(0).get_output_layout().format; auto in1_fmt = node.get_dependency(1).get_output_layout().format; - if ((in0_dt == data_types::u8 || in0_dt == data_types::i8) && - (in1_dt == data_types::u8 || in1_dt == data_types::i8) && - in0_fmt == format::bfyx && in1_fmt == format::bfyx) + + if (data_type_traits::is_floating_point(in0_dt) && + data_type_traits::is_floating_point(in1_dt)) does_support_fusings = true; - if (node.inputs_count() == 3) { - auto in2_dt = node.get_dependency(2).get_output_layout().data_type; - auto in2_fmt = node.get_dependency(2).get_output_layout().format; - if ((in2_dt == data_types::u8 || in2_dt == data_types::i8) && - in2_fmt == format::bfyx) + if ((in0_dt == data_types::u8 || in0_dt == data_types::i8) && + (in1_dt == data_types::u8 || in1_dt == data_types::i8) && + in0_fmt == format::bfyx && in1_fmt == format::bfyx) { + if (node.inputs_count() == 3) { + auto in2_dt = node.get_dependency(2).get_output_layout().data_type; + auto in2_fmt = node.get_dependency(2).get_output_layout().format; + if ((in2_dt == data_types::u8 || in2_dt == data_types::i8) && + in2_fmt == format::bfyx) + does_support_fusings = true; + else + does_support_fusings = false; + } else { does_support_fusings = true; - else - does_support_fusings = false; + } } + return does_support_fusings; }; @@ -524,12 +531,14 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { bool can_fuse_parent1 = (parent1->is_type() && conv_supports_fusings(parent1->as())) || (parent1->is_type() && mvn_supports_fusings(parent1->as())) || (parent1->is_type()) || (parent1->is_type()) || - (parent1->is_type()) || (parent1->is_type()) || (parent1->is_type()); + (parent1->is_type()) || (parent1->is_type()) || + (parent1->is_type() && gemm_supports_fusings(parent1->as())); bool can_fuse_parent2 = (parent2->is_type() && conv_supports_fusings(parent2->as())) || (parent2->is_type() && mvn_supports_fusings(parent2->as())) || (parent2->is_type()) || (parent2->is_type()) || - (parent1->is_type()) || (parent1->is_type()) || (parent2->is_type()); + (parent2->is_type()) || (parent2->is_type()) || + (parent2->is_type() && gemm_supports_fusings(parent2->as())); std::vector can_fuse_parents = { can_fuse_parent1, can_fuse_parent2 }; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index b48e04d..f1a041b 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -500,20 +500,6 @@ public: #define CASE_FC_U8S8_2 {2, 1, 3, 1}, {2, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx #define CASE_FC_U8S8_3 {2, 32, 1, 1}, {2, 16, 1, 1}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx -#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx - -#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx - -#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx - -#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx -#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx - #define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx /* ----------------------------------------------------------------------------------------------------- */ @@ -2376,8 +2362,32 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, fc_int8_scale_activation_quantize_i8, bc_test_params{CASE_FC_U8S8_3, 2, 5}, }), ); -class gemm_int8_3in_quantize_i8 : public GemmFusingTest {}; -TEST_P(gemm_int8_3in_quantize_i8, basic) { + +/* ----------------------------------------------------------------------------------------------------- */ +/* ---------------------------------------- Gemm cases ------------------------------------------------- */ +/* ----------------------------------------------------------------------------------------------------- */ +#define CASE_GEMM_3IN_FP32_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_3IN_FP16_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx + +#define CASE_GEMM_2IN_FP32_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_2IN_FP16_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx + +#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx + +#define CASE_GEMM_ELTWISE_2IN_FP32_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_ELTWISE_2IN_FP16_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx +#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx + +class gemm_3in_quantize_i8 : public GemmFusingTest {}; +TEST_P(gemm_3in_quantize_i8, basic) { auto p = GetParam(); create_topologies(input_layout("input0", get_input_layout(p, 0)), input_layout("input1", get_input_layout(p, 1)), @@ -2391,23 +2401,25 @@ TEST_P(gemm_int8_3in_quantize_i8, basic) { reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) ); - tolerance = 1e-5f; + tolerance = 1.0f; execute(p); } -INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_3in_quantize_i8, +INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_3in_quantize_i8, ::testing::ValuesIn(std::vector{ + gemm_test_params{ CASE_GEMM_3IN_FP32_1, 4, 5 }, + gemm_test_params{ CASE_GEMM_3IN_FP16_1, 4, 5 }, gemm_test_params{ CASE_GEMM_3IN_S8S8_1, 4, 5 }, gemm_test_params{ CASE_GEMM_3IN_S8S8_2, 4, 5 }, gemm_test_params{ CASE_GEMM_3IN_S8S8_3, 4, 5 }, }), ); -class gemm_int8_2in_quantize_u8 : public GemmFusingTest {}; -TEST_P(gemm_int8_2in_quantize_u8, basic) { +class gemm_2in_quantize_u8 : public GemmFusingTest {}; +TEST_P(gemm_2in_quantize_u8, basic) { auto p = GetParam(); create_topologies(input_layout("input0", get_input_layout(p, 0)), input_layout("input1", get_input_layout(p, 1)), - data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_lo", get_mem(get_per_channel_layout(p), 0)), data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), data("out_lo", get_mem(get_single_element_layout(p), 0)), data("out_hi", get_mem(get_single_element_layout(p), 255)), @@ -2416,19 +2428,21 @@ TEST_P(gemm_int8_2in_quantize_u8, basic) { reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) ); - tolerance = 1e-5f; + tolerance = 1.0f; execute(p); } -INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_quantize_u8, +INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_quantize_u8, ::testing::ValuesIn(std::vector{ + gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 }, + gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 4 }, gemm_test_params{ CASE_GEMM_2IN_U8U8_1, 3, 4 }, gemm_test_params{ CASE_GEMM_2IN_U8U8_2, 3, 4 }, gemm_test_params{ CASE_GEMM_2IN_U8U8_3, 3, 4 }, }), ); -class gemm_int8_2in_act_scale_quantize_i8 : public GemmFusingTest {}; -TEST_P(gemm_int8_2in_act_scale_quantize_i8, basic) { +class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {}; +TEST_P(gemm_2in_act_scale_quantize_i8, basic) { auto p = GetParam(); create_topologies(input_layout("input0", get_input_layout(p, 0)), input_layout("input1", get_input_layout(p, 1)), @@ -2444,18 +2458,20 @@ TEST_P(gemm_int8_2in_act_scale_quantize_i8, basic) { reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) ); - tolerance = 1e-5f; + tolerance = 1.0f; execute(p); } -INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_i8, +INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_quantize_i8, ::testing::ValuesIn(std::vector{ + gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 6 }, + gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 6 }, gemm_test_params{ CASE_GEMM_2IN_U8S8_1, 3, 6 }, gemm_test_params{ CASE_GEMM_2IN_S8U8_1, 3, 6 }, }), ); -class gemm_int8_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {}; -TEST_P(gemm_int8_2in_act_scale_quantize_eltwise_i8, basic) { +class gemm_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {}; +TEST_P(gemm_2in_act_scale_quantize_eltwise_i8, basic) { auto p = GetParam(); create_topologies(input_layout("input0", get_input_layout(p, 0)), input_layout("input1", get_input_layout(p, 1)), @@ -2477,12 +2493,57 @@ TEST_P(gemm_int8_2in_act_scale_quantize_eltwise_i8, basic) { execute(p); } -INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_eltwise_i8, +INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_quantize_eltwise_i8, ::testing::ValuesIn(std::vector{ + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 7 }, + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 7 }, gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 7 }, gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 7 }, }), ); +class gemm_2in_act_scale_eltwise : public GemmFusingTest {}; +TEST_P(gemm_2in_act_scale_eltwise, basic) { + auto p = GetParam(); + create_topologies(input_layout("input0", get_input_layout(p, 0)), + input_layout("input1", get_input_layout(p, 1)), + data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)), + data("eltwise_data", get_mem(get_output_layout(p))), + gemm("gemm_prim", { "input0", "input1" }, data_types::f32), + scale("scale", "gemm_prim", "scale_data"), + activation("activation", "scale", activation_func::negative), + eltwise("sum", { "activation", "eltwise_data"}, eltwise_mode::sum, data_types::f32), + reorder("reorder_bfyx", "sum", p.default_format, data_types::f32) + ); + + tolerance = 1e-4f; + execute(p); +} + +TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) { + auto p = GetParam(); + create_topologies(input_layout("input0", get_input_layout(p, 0)), + input_layout("input1", get_input_layout(p, 1)), + data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)), + data("eltwise_data", get_mem(get_single_element_layout(p))), + gemm("gemm_prim", { "input0", "input1" }, data_types::f32), + scale("scale", "gemm_prim", "scale_data"), + activation("activation", "scale", activation_func::negative), + eltwise("sum", { "activation", "eltwise_data"}, eltwise_mode::sum, data_types::f32), + reorder("reorder_bfyx", "sum", p.default_format, data_types::f32) + ); + + tolerance = 1e-4f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_eltwise, + ::testing::ValuesIn(std::vector{ + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 6 }, + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 6 }, + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 6 }, + gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 6 }, +}), ); + /* ----------------------------------------------------------------------------------------------------- */ /* ---------------------------------------- Resample cases --------------------------------------------- */ /* ----------------------------------------------------------------------------------------------------- */ -- 2.7.4