#define CASE_FC_U8S8_2 {2, 1, 3, 1}, {2, 4, 1, 1}, {4, 1, 3, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_U8S8_3 {2, 32, 1, 1}, {2, 16, 1, 1}, {16, 32, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
-
-#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-
-#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-
-#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
-
#define CASE_NORMALIZE_I8_1 {1, 2, 3, 3}, data_types::u8, format::bfyx, data_types::f32, format::bfyx
/* ----------------------------------------------------------------------------------------------------- */
bc_test_params{CASE_FC_U8S8_3, 2, 5},
}), );
-class gemm_int8_3in_quantize_i8 : public GemmFusingTest {};
-TEST_P(gemm_int8_3in_quantize_i8, basic) {
+
+/* ----------------------------------------------------------------------------------------------------- */
+/* ---------------------------------------- Gemm cases ------------------------------------------------- */
+/* ----------------------------------------------------------------------------------------------------- */
+#define CASE_GEMM_3IN_FP32_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_FP16_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_GEMM_3IN_S8S8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}, {1, 2, 256, 128}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_3IN_S8S8_3 {{1, 1, 8, 16}, {1, 1, 32, 8}, {1, 1, 32, 16}}, {1, 1, 32, 16}, tensor{1}, tensor{0}, data_types::i8, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_GEMM_2IN_FP32_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_FP16_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_GEMM_2IN_U8U8_1 {{1, 1, 2, 2}, {1, 1, 2, 2}}, {1, 1, 2, 2}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8U8_2 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_U8U8_3 {{1, 1, 16, 32}, {1, 1, 32, 16}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_GEMM_2IN_U8S8_1 {{1, 1, 4, 2}, {1, 1, 8, 4}}, {1, 1, 8, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_2IN_S8U8_1 {{1, 2, 64, 128}, {1, 2, 256, 64}}, {1, 2, 256, 128}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+
+#define CASE_GEMM_ELTWISE_2IN_FP32_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::f32, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_FP16_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_U8S8_1 {{1, 1, 4, 4}, {1, 1, 4, 4}}, {1, 1, 4, 4}, tensor{1}, tensor{0}, data_types::u8, data_types::i8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+#define CASE_GEMM_ELTWISE_2IN_S8U8_1 {{1, 1, 32, 32}, {1, 1, 32, 32}}, {1, 1, 32, 32}, tensor{1}, tensor{0}, data_types::i8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
+
+class gemm_3in_quantize_i8 : public GemmFusingTest {};
+TEST_P(gemm_3in_quantize_i8, basic) {
auto p = GetParam();
create_topologies(input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
- tolerance = 1e-5f;
+ tolerance = 1.0f;
execute(p);
}
-INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_3in_quantize_i8,
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_3in_quantize_i8,
::testing::ValuesIn(std::vector<gemm_test_params>{
+ gemm_test_params{ CASE_GEMM_3IN_FP32_1, 4, 5 },
+ gemm_test_params{ CASE_GEMM_3IN_FP16_1, 4, 5 },
gemm_test_params{ CASE_GEMM_3IN_S8S8_1, 4, 5 },
gemm_test_params{ CASE_GEMM_3IN_S8S8_2, 4, 5 },
gemm_test_params{ CASE_GEMM_3IN_S8S8_3, 4, 5 },
}), );
-class gemm_int8_2in_quantize_u8 : public GemmFusingTest {};
-TEST_P(gemm_int8_2in_quantize_u8, basic) {
+class gemm_2in_quantize_u8 : public GemmFusingTest {};
+TEST_P(gemm_2in_quantize_u8, basic) {
auto p = GetParam();
create_topologies(input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
- data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
+ data("in_lo", get_mem(get_per_channel_layout(p), 0)),
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), 0)),
data("out_hi", get_mem(get_single_element_layout(p), 255)),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
- tolerance = 1e-5f;
+ tolerance = 1.0f;
execute(p);
}
-INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_quantize_u8,
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_quantize_u8,
::testing::ValuesIn(std::vector<gemm_test_params>{
+ gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 4 },
+ gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 4 },
gemm_test_params{ CASE_GEMM_2IN_U8U8_1, 3, 4 },
gemm_test_params{ CASE_GEMM_2IN_U8U8_2, 3, 4 },
gemm_test_params{ CASE_GEMM_2IN_U8U8_3, 3, 4 },
}), );
-class gemm_int8_2in_act_scale_quantize_i8 : public GemmFusingTest {};
-TEST_P(gemm_int8_2in_act_scale_quantize_i8, basic) {
+class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {};
+TEST_P(gemm_2in_act_scale_quantize_i8, basic) {
auto p = GetParam();
create_topologies(input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
- tolerance = 1e-5f;
+ tolerance = 1.0f;
execute(p);
}
-INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_i8,
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_quantize_i8,
::testing::ValuesIn(std::vector<gemm_test_params>{
+ gemm_test_params{ CASE_GEMM_2IN_FP32_1, 3, 6 },
+ gemm_test_params{ CASE_GEMM_2IN_FP16_1, 3, 6 },
gemm_test_params{ CASE_GEMM_2IN_U8S8_1, 3, 6 },
gemm_test_params{ CASE_GEMM_2IN_S8U8_1, 3, 6 },
}), );
-class gemm_int8_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {};
-TEST_P(gemm_int8_2in_act_scale_quantize_eltwise_i8, basic) {
+class gemm_2in_act_scale_quantize_eltwise_i8 : public GemmFusingTest {};
+TEST_P(gemm_2in_act_scale_quantize_eltwise_i8, basic) {
auto p = GetParam();
create_topologies(input_layout("input0", get_input_layout(p, 0)),
input_layout("input1", get_input_layout(p, 1)),
execute(p);
}
-INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_eltwise_i8,
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_quantize_eltwise_i8,
::testing::ValuesIn(std::vector<gemm_test_params>{
+ gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 7 },
+ gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 7 },
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 7 },
gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 7 },
}), );
+class gemm_2in_act_scale_eltwise : public GemmFusingTest {};
+TEST_P(gemm_2in_act_scale_eltwise, basic) {
+ auto p = GetParam();
+ create_topologies(input_layout("input0", get_input_layout(p, 0)),
+ input_layout("input1", get_input_layout(p, 1)),
+ data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
+ data("eltwise_data", get_mem(get_output_layout(p))),
+ gemm("gemm_prim", { "input0", "input1" }, data_types::f32),
+ scale("scale", "gemm_prim", "scale_data"),
+ activation("activation", "scale", activation_func::negative),
+ eltwise("sum", { "activation", "eltwise_data"}, eltwise_mode::sum, data_types::f32),
+ reorder("reorder_bfyx", "sum", p.default_format, data_types::f32)
+ );
+
+ tolerance = 1e-4f;
+ execute(p);
+}
+
+TEST_P(gemm_2in_act_scale_eltwise, broadcast_eltwise) {
+ auto p = GetParam();
+ create_topologies(input_layout("input0", get_input_layout(p, 0)),
+ input_layout("input1", get_input_layout(p, 1)),
+ data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
+ data("eltwise_data", get_mem(get_single_element_layout(p))),
+ gemm("gemm_prim", { "input0", "input1" }, data_types::f32),
+ scale("scale", "gemm_prim", "scale_data"),
+ activation("activation", "scale", activation_func::negative),
+ eltwise("sum", { "activation", "eltwise_data"}, eltwise_mode::sum, data_types::f32),
+ reorder("reorder_bfyx", "sum", p.default_format, data_types::f32)
+ );
+
+ tolerance = 1e-4f;
+ execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_2in_act_scale_eltwise,
+ ::testing::ValuesIn(std::vector<gemm_test_params>{
+ gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP32_1, 3, 6 },
+ gemm_test_params{ CASE_GEMM_ELTWISE_2IN_FP16_1, 3, 6 },
+ gemm_test_params{ CASE_GEMM_ELTWISE_2IN_U8S8_1, 3, 6 },
+ gemm_test_params{ CASE_GEMM_ELTWISE_2IN_S8U8_1, 3, 6 },
+}), );
+
/* ----------------------------------------------------------------------------------------------------- */
/* ---------------------------------------- Resample cases --------------------------------------------- */
/* ----------------------------------------------------------------------------------------------------- */