[IE CLDNN] Fixed gemm fusings with FP precision (#1490)
authorVladimir Paramuzov <vladimir.paramuzov@intel.com>
Mon, 27 Jul 2020 15:49:54 +0000 (18:49 +0300)
committerGitHub <noreply@github.com>
Mon, 27 Jul 2020 15:49:54 +0000 (18:49 +0300)
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gemm/gemm_kernel_ref.h
inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp

index 7441bc7..7ba6604 100644 (file)
@@ -30,7 +30,8 @@ protected:
     std::vector<FusedOpType> GetSupportedFusedOps() const override {
         return { FusedOpType::QUANTIZE,
                  FusedOpType::ACTIVATION,
-                 FusedOpType::SCALE };
+                 FusedOpType::SCALE,
+                 FusedOpType::ELTWISE };
     }
     bool Validate(const Params& params, const optional_params& options) const override;
     JitConstants GetJitConstants(const gemm_params& params) const override;
index 1915300..3edfa76 100644 (file)
@@ -306,20 +306,27 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
             auto in1_dt = node.get_dependency(1).get_output_layout().data_type;
             auto in0_fmt = node.get_dependency(0).get_output_layout().format;
             auto in1_fmt = node.get_dependency(1).get_output_layout().format;
-            if ((in0_dt == data_types::u8 || in0_dt == data_types::i8) &&
-                (in1_dt == data_types::u8 || in1_dt == data_types::i8) &&
-                in0_fmt == format::bfyx && in1_fmt == format::bfyx)
+
+            if (data_type_traits::is_floating_point(in0_dt) &&
+                data_type_traits::is_floating_point(in1_dt))
                 does_support_fusings = true;
 
-            if (node.inputs_count() == 3) {
-                auto in2_dt = node.get_dependency(2).get_output_layout().data_type;
-                auto in2_fmt = node.get_dependency(2).get_output_layout().format;
-                if ((in2_dt == data_types::u8 || in2_dt == data_types::i8) &&
-                    in2_fmt == format::bfyx)
+            if ((in0_dt == data_types::u8 || in0_dt == data_types::i8) &&
+                (in1_dt == data_types::u8 || in1_dt == data_types::i8) &&
+                in0_fmt == format::bfyx && in1_fmt == format::bfyx) {
+                if (node.inputs_count() == 3) {
+                    auto in2_dt = node.get_dependency(2).get_output_layout().data_type;
+                    auto in2_fmt = node.get_dependency(2).get_output_layout().format;
+                    if ((in2_dt == data_types::u8 || in2_dt == data_types::i8) &&
+                        in2_fmt == format::bfyx)
+                        does_support_fusings = true;
+                    else
+                        does_support_fusings = false;
+                } else {
                     does_support_fusings = true;
-                else
-                    does_support_fusings = false;
+                }
             }
+
             return does_support_fusings;
         };
 
@@ -524,12 +531,14 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
             bool can_fuse_parent1 = (parent1->is_type<convolution>() && conv_supports_fusings(parent1->as<convolution>())) ||
                                     (parent1->is_type<mvn>() && mvn_supports_fusings(parent1->as<mvn>())) ||
                                     (parent1->is_type<deconvolution>()) || (parent1->is_type<permute>()) ||
-                                    (parent1->is_type<depth_to_space>()) || (parent1->is_type<space_to_depth>()) || (parent1->is_type<gemm>());
+                                    (parent1->is_type<depth_to_space>()) || (parent1->is_type<space_to_depth>()) ||
+                                    (parent1->is_type<gemm>() && gemm_supports_fusings(parent1->as<gemm>()));
 
             bool can_fuse_parent2 = (parent2->is_type<convolution>() && conv_supports_fusings(parent2->as<convolution>())) ||
                                     (parent2->is_type<mvn>() && mvn_supports_fusings(parent2->as<mvn>())) ||
                                     (parent2->is_type<deconvolution>()) || (parent2->is_type<permute>()) ||
-                                    (parent1->is_type<depth_to_space>()) || (parent1->is_type<space_to_depth>()) || (parent2->is_type<gemm>());
+                                    (parent2->is_type<depth_to_space>()) || (parent2->is_type<space_to_depth>()) ||
+                                    (parent2->is_type<gemm>() && gemm_supports_fusings(parent2->as<gemm>()));
 
             std::vector<bool> can_fuse_parents = { can_fuse_parent1, can_fuse_parent2 };
 
index b48e04d..f1a041b 100644 (file)
@@ -500,20 +500,6 @@ public:
 #define CASE_FC_U8S8_2 {2, 1, 3, 1}, {2, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
 #define CASE_FC_U8S8_3 {2, 32, 1, 1}, {2, 16, 1, 1}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
 
-#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-
-#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-
-#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-
-#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-
 #define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx
 
 /* ----------------------------------------------------------------------------------------------------- */
@@ -2376,8 +2362,32 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, fc_int8_scale_activation_quantize_i8,
         bc_test_params{CASE_FC_U8S8_3, 2, 5},
         }), );
 
-class gemm_int8_3in_quantize_i8 : public GemmFusingTest {};
-TEST_P(gemm_int8_3in_quantize_i8, basic) {
+
+/* ----------------------------------------------------------------------------------------------------- */
+/* ---------------------------------------- Gemm cases ------------------------------------------------- */
+/* ----------------------------------------------------------------------------------------------------- */
+#define CASE_GEMM_3IN_FP32_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_FP16_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_GEMM_2IN_FP32_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_FP16_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_GEMM_ELTWISE_2IN_FP32_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_FP16_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+
+class gemm_3in_quantize_i8 : public GemmFusingTest {};
+TEST_P(gemm_3in_quantize_i8, basic) {
     auto p = GetParam();
     create_topologies(input_layout("input0", get_input_layout(p, 0)),
         input_layout("input1", get_input_layout(p, 1)),
@@ -2391,23 +2401,25 @@ TEST_P(gemm_int8_3in_quantize_i8, basic) {
         reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
     );
 
-    tolerance = 1e-5f;
+    tolerance = 1.0f;
     execute(p);
 }
 
-INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_3in_quantize_i8,
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_3in_quantize_i8,
     ::testing::ValuesIn(std::vector<gemm_test_params>{
+                        gemm_test_params{ CASE_GEMM_3IN_FP32_1, 4, 5 },
+                        gemm_test_params{ CASE_GEMM_3IN_FP16_1, 4, 5 },
                         gemm_test_params{ CASE_GEMM_3IN_S8S8_1, 4, 5 },
                         gemm_test_params{ CASE_GEMM_3IN_S8S8_2, 4, 5 },
                         gemm_test_params{ CASE_GEMM_3IN_S8S8_3, 4, 5 },
 }), );
 
-class gemm_int8_2in_quantize_u8 : public GemmFusingTest {};
-TEST_P(gemm_int8_2in_quantize_u8, basic) {
+class gemm_2in_quantize_u8 : public GemmFusingTest {};
+TEST_P(gemm_2in_quantize_u8, basic) {
     auto p = GetParam();
     create_topologies(input_layout("input0", get_input_layout(p, 0)),
         input_layout("input1", get_input_layout(p, 1)),
-        data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
+        data("in_lo", get_mem(get_per_channel_layout(p), 0)),
         data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
         data("out_lo", get_mem(get_single_element_layout(p), 0)),
         data("out_hi", get_mem(get_single_element_layout(p), 255)),
@@ -2416,19 +2428,21 @@ TEST_P(gemm_int8_2in_quantize_u8, basic) {
         reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
     );
 
-    tolerance = 1e-5f;
+    tolerance = 1.0f;
     execute(p);
 }
 
-INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_quantize_u8,
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_quantize_u8,
     ::testing::ValuesIn(std::vector<gemm_test_params>{
+                        gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 },
+                        gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 4 },
                         gemm_test_params{ CASE_GEMM_2IN_U8U8_1, 3, 4 },
                         gemm_test_params{ CASE_GEMM_2IN_U8U8_2, 3, 4 },
                         gemm_test_params{ CASE_GEMM_2IN_U8U8_3, 3, 4 },
 }), );
 
-class gemm_int8_2in_act_scale_quantize_i8 : public GemmFusingTest {};
-TEST_P(gemm_int8_2in_act_scale_quantize_i8, basic) {
+class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {};
+TEST_P(gemm_2in_act_scale_quantize_i8, basic) {
     auto p = GetParam();
     create_topologies(input_layout("input0", get_input_layout(p, 0)),
         input_layout("input1", get_input_layout(p, 1)),
@@ -2444,18 +2458,20 @@ TEST_P(gemm_int8_2in_act_scale_quantize_i8, basic) {
         reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
     );
 
-    tolerance = 1e-5f;
+    tolerance = 1.0f;
     execute(p);
 }
 
-INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_i8,
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_quantize_i8,
     ::testing::ValuesIn(std::vector<gemm_test_params>{
+                        gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 6 },
+                        gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 6 },
                         gemm_test_params{ CASE_GEMM_2IN_U8S8_1, 3, 6 },
                         gemm_test_params{ CASE_GEMM_2IN_S8U8_1, 3, 6 },
 }), );
 
-class gemm_int8_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {};
-TEST_P(gemm_int8_2in_act_scale_quantize_eltwise_i8, basic) {
+class gemm_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {};
+TEST_P(gemm_2in_act_scale_quantize_eltwise_i8, basic) {
     auto p = GetParam();
     create_topologies(input_layout("input0", get_input_layout(p, 0)),
         input_layout("input1", get_input_layout(p, 1)),
@@ -2477,12 +2493,57 @@ TEST_P(gemm_int8_2in_act_scale_quantize_eltwise_i8, basic) {
     execute(p);
 }
 
-INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_eltwise_i8,
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_quantize_eltwise_i8,
     ::testing::ValuesIn(std::vector<gemm_test_params>{
+                        gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 7 },
+                        gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 7 },
                         gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 7 },
                         gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 7 },
 }), );
 
+class gemm_2in_act_scale_eltwise : public GemmFusingTest {};
+TEST_P(gemm_2in_act_scale_eltwise, basic) {
+    auto p = GetParam();
+    create_topologies(input_layout("input0", get_input_layout(p, 0)),
+        input_layout("input1", get_input_layout(p, 1)),
+        data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
+        data("eltwise_data", get_mem(get_output_layout(p))),
+        gemm("gemm_prim", { "input0", "input1" }, data_types::f32),
+        scale("scale", "gemm_prim", "scale_data"),
+        activation("activation", "scale", activation_func::negative),
+        eltwise("sum", { "activation", "eltwise_data"}, eltwise_mode::sum,  data_types::f32),
+        reorder("reorder_bfyx", "sum", p.default_format, data_types::f32)
+    );
+
+    tolerance = 1e-4f;
+    execute(p);
+}
+
+TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
+    auto p = GetParam();
+    create_topologies(input_layout("input0", get_input_layout(p, 0)),
+        input_layout("input1", get_input_layout(p, 1)),
+        data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
+        data("eltwise_data", get_mem(get_single_element_layout(p))),
+        gemm("gemm_prim", { "input0", "input1" }, data_types::f32),
+        scale("scale", "gemm_prim", "scale_data"),
+        activation("activation", "scale", activation_func::negative),
+        eltwise("sum", { "activation", "eltwise_data"}, eltwise_mode::sum,  data_types::f32),
+        reorder("reorder_bfyx", "sum", p.default_format, data_types::f32)
+    );
+
+    tolerance = 1e-4f;
+    execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_eltwise,
+    ::testing::ValuesIn(std::vector<gemm_test_params>{
+                        gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 6 },
+                        gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 6 },
+                        gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 6 },
+                        gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 6 },
+}), );
+
 /* ----------------------------------------------------------------------------------------------------- */
 /* ---------------------------------------- Resample cases --------------------------------------------- */
 /* ----------------------------------------------------------------------------------------------------- */