INSTANTIATE_TEST_CASE_P(fusings_gpu, deconv_scale_actv_quant_u8_eltw_scale_actv_quant_i8,
::testing::ValuesIn(std::vector<deconv_test_params>{
- deconv_test_params{ CASE_DECONV_FP32_1, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_2, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_4, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_5, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_6, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_7, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_8, 6, 9 },
-
- deconv_test_params{ CASE_DECONV_FP16_1, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_2, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_4, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_5, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_6, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_7, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_8, 6, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_1, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_2, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_4, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_5, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_6, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_7, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_8, 4, 9 },
+
+ deconv_test_params{ CASE_DECONV_FP16_1, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_2, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_4, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_5, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_6, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_7, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_8, 4, 9 },
deconv_test_params{ CASE_DECONV_U8S8_1, 2, 9 },
deconv_test_params{ CASE_DECONV_U8S8_2, 2, 9 },
deconv_test_params{ CASE_DECONV_S8S8_7, 2, 9 },
deconv_test_params{ CASE_DECONV_S8S8_8, 2, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3D_1, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3D_2, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3D_3, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3D_4, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3D_5, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3D_6, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3D_7, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP32_3D_8, 6, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3D_1, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3D_2, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3D_3, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3D_4, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3D_5, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3D_6, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3D_7, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP32_3D_8, 4, 9 },
// deconv_test_params{ CASE_DECONV_FP32_3D_9, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3D_1, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3D_2, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3D_3, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3D_4, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3D_5, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3D_6, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3D_7, 6, 9 },
- deconv_test_params{ CASE_DECONV_FP16_3D_8, 6, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3D_1, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3D_2, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3D_3, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3D_4, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3D_5, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3D_6, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3D_7, 4, 9 },
+ deconv_test_params{ CASE_DECONV_FP16_3D_8, 4, 9 },
// deconv_test_params{ CASE_DECONV_FP16_3D_9, 6, 9 },
deconv_test_params{ CASE_DECONV_U8S8_3D_1, 2, 9 },
size_t expected_not_fused_primitives;
};
-#define CASE_ELTWISE_FP32_1 {2, 16, 4, 4}, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP32_2 {2, 16, 4, 4}, data_types::f32, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP32_3 {2, 16, 4, 4}, data_types::f32, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP16_1 {2, 16, 4, 4}, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP16_2 {2, 16, 4, 4}, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP16_3 {2, 16, 4, 4}, data_types::f16, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_1 {2, 16, 4, 4}, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_2 {2, 16, 4, 4}, data_types::i8, data_types::i8, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_3 {2, 16, 4, 4}, data_types::i8, data_types::i8, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_1 {2, 16, 4, 4}, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_2 {2, 16, 4, 4}, data_types::u8, data_types::u8, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_3 {2, 16, 4, 4}, data_types::u8, data_types::u8, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP32_FP16_1 {2, 16, 4, 4}, data_types::f32, data_types::f16, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP32_FP16_2 {2, 16, 4, 4}, data_types::f32, data_types::f16, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP32_FP16_3 {2, 16, 4, 4}, data_types::f32, data_types::f16, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP16_FP32_1 {2, 16, 4, 4}, data_types::f16, data_types::f32, format::bfyx, data_types::f16, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP16_FP32_2 {2, 16, 4, 4}, data_types::f16, data_types::f32, format::bfzyx, data_types::f16, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_FP16_FP32_3 {2, 16, 4, 4}, data_types::f16, data_types::f32, format::b_fs_yx_fsv16, data_types::f16, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_FP16_1 {2, 16, 4, 4}, data_types::i8, data_types::f16, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_FP16_2 {2, 16, 4, 4}, data_types::i8, data_types::f16, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_FP16_3 {2, 16, 4, 4}, data_types::i8, data_types::f16, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_FP32_1 {2, 16, 4, 4}, data_types::i8, data_types::f32, format::bfyx, data_types::f16, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_FP32_2 {2, 16, 4, 4}, data_types::i8, data_types::f32, format::bfzyx, data_types::f16, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_I8_FP32_3 {2, 16, 4, 4}, data_types::i8, data_types::f32, format::b_fs_yx_fsv16, data_types::f16, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_FP16_1 {2, 16, 4, 4}, data_types::u8, data_types::f16, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_FP16_2 {2, 16, 4, 4}, data_types::u8, data_types::f16, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_FP16_3 {2, 16, 4, 4}, data_types::u8, data_types::f16, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_FP32_1 {2, 16, 4, 4}, data_types::u8, data_types::f32, format::bfyx, data_types::f16, format::bfyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_FP32_2 {2, 16, 4, 4}, data_types::u8, data_types::f32, format::bfzyx, data_types::f16, format::bfzyx, eltwise_mode::sum, 3, 4
-#define CASE_ELTWISE_U8_FP32_3 {2, 16, 4, 4}, data_types::u8, data_types::f32, format::b_fs_yx_fsv16, data_types::f16, format::b_fs_yx_fsv16, eltwise_mode::sum, 3, 4
+#define CASE_ELTWISE_FP32_1 {2, 16, 4, 4}, data_types::f32, data_types::f32, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP32_2 {2, 16, 4, 4}, data_types::f32, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP32_3 {2, 16, 4, 4}, data_types::f32, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_FP32_4 {2, 16, 4, 4}, data_types::f32, data_types::f32, format::bfwzyx, data_types::f32, format::bfwzyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP16_1 {2, 16, 4, 4}, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP16_2 {2, 16, 4, 4}, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP16_3 {2, 16, 4, 4}, data_types::f16, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_I8_1 {2, 16, 4, 4}, data_types::i8, data_types::i8, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_I8_2 {2, 16, 4, 4}, data_types::i8, data_types::i8, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_I8_3 {2, 16, 4, 4}, data_types::i8, data_types::i8, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_U8_1 {2, 16, 4, 4}, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_U8_2 {2, 16, 4, 4}, data_types::u8, data_types::u8, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_U8_3 {2, 16, 4, 4}, data_types::u8, data_types::u8, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_FP32_FP16_1 {2, 16, 4, 4}, data_types::f32, data_types::f16, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP32_FP16_2 {2, 16, 4, 4}, data_types::f32, data_types::f16, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP32_FP16_3 {2, 16, 4, 4}, data_types::f32, data_types::f16, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_FP16_FP32_1 {2, 16, 4, 4}, data_types::f16, data_types::f32, format::bfyx, data_types::f16, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP16_FP32_2 {2, 16, 4, 4}, data_types::f16, data_types::f32, format::bfzyx, data_types::f16, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_FP16_FP32_3 {2, 16, 4, 4}, data_types::f16, data_types::f32, format::b_fs_yx_fsv16, data_types::f16, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_I8_FP16_1 {2, 16, 4, 4}, data_types::i8, data_types::f16, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_I8_FP16_2 {2, 16, 4, 4}, data_types::i8, data_types::f16, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_I8_FP16_3 {2, 16, 4, 4}, data_types::i8, data_types::f16, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_I8_FP32_1 {2, 16, 4, 4}, data_types::i8, data_types::f32, format::bfyx, data_types::f16, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_I8_FP32_2 {2, 16, 4, 4}, data_types::i8, data_types::f32, format::bfzyx, data_types::f16, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_I8_FP32_3 {2, 16, 4, 4}, data_types::i8, data_types::f32, format::b_fs_yx_fsv16, data_types::f16, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_U8_FP16_1 {2, 16, 4, 4}, data_types::u8, data_types::f16, format::bfyx, data_types::f32, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_U8_FP16_2 {2, 16, 4, 4}, data_types::u8, data_types::f16, format::bfzyx, data_types::f32, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_U8_FP16_3 {2, 16, 4, 4}, data_types::u8, data_types::f16, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16, eltwise_mode::sum
+#define CASE_ELTWISE_U8_FP32_1 {2, 16, 4, 4}, data_types::u8, data_types::f32, format::bfyx, data_types::f16, format::bfyx, eltwise_mode::sum
+#define CASE_ELTWISE_U8_FP32_2 {2, 16, 4, 4}, data_types::u8, data_types::f32, format::bfzyx, data_types::f16, format::bfzyx, eltwise_mode::sum
+#define CASE_ELTWISE_U8_FP32_3 {2, 16, 4, 4}, data_types::u8, data_types::f32, format::b_fs_yx_fsv16, data_types::f16, format::b_fs_yx_fsv16, eltwise_mode::sum
class EltwiseFusingTest : public ::BaseFusingTest<eltwise_test_params> {
INSTANTIATE_TEST_CASE_P(fusings_gpu,
eltwise_quantize,
::testing::ValuesIn(std::vector<eltwise_test_params>{
- eltwise_test_params{CASE_ELTWISE_FP16_1},
- eltwise_test_params{CASE_ELTWISE_FP16_2},
- eltwise_test_params{CASE_ELTWISE_FP16_3},
- eltwise_test_params{CASE_ELTWISE_FP32_1},
- eltwise_test_params{CASE_ELTWISE_FP32_2},
- eltwise_test_params{CASE_ELTWISE_FP32_3},
- eltwise_test_params{CASE_ELTWISE_FP32_FP16_1},
- eltwise_test_params{CASE_ELTWISE_FP32_FP16_2},
- eltwise_test_params{CASE_ELTWISE_FP32_FP16_3},
- eltwise_test_params{CASE_ELTWISE_FP16_FP32_1},
- eltwise_test_params{CASE_ELTWISE_FP16_FP32_2},
- eltwise_test_params{CASE_ELTWISE_FP16_FP32_3},
- eltwise_test_params{CASE_ELTWISE_I8_FP32_1},
- eltwise_test_params{CASE_ELTWISE_I8_FP32_2},
- eltwise_test_params{CASE_ELTWISE_I8_FP32_3},
- eltwise_test_params{CASE_ELTWISE_U8_FP32_1},
- eltwise_test_params{CASE_ELTWISE_U8_FP32_2},
- eltwise_test_params{CASE_ELTWISE_U8_FP32_3},
- eltwise_test_params{CASE_ELTWISE_I8_FP16_1},
- eltwise_test_params{CASE_ELTWISE_I8_FP16_2},
- eltwise_test_params{CASE_ELTWISE_I8_FP16_3},
- eltwise_test_params{CASE_ELTWISE_U8_FP16_1},
- eltwise_test_params{CASE_ELTWISE_U8_FP16_2},
- eltwise_test_params{CASE_ELTWISE_U8_FP16_3},
+ eltwise_test_params{CASE_ELTWISE_FP16_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_I8_FP32_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_I8_FP32_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_I8_FP32_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_U8_FP32_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_U8_FP32_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_U8_FP32_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_I8_FP16_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_I8_FP16_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_I8_FP16_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_U8_FP16_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_U8_FP16_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_U8_FP16_3, 3, 4},
+ }), );
+
+class eltwise_fp32_fused_prims : public EltwiseFusingTest {};
+TEST_P(eltwise_fp32_fused_prims, scale_activation) {
+ auto p = GetParam();
+ create_topologies(input_layout("input", get_input_layout(p)),
+ input_layout("input2", get_input_layout2(p)),
+ data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ eltwise("eltwise", {"input", "input2"}, p.mode, p.default_type),
+ scale("scale", "eltwise", "scale_data"),
+ activation("activation", "scale", activation_func::abs),
+ reorder("out", "activation", p.default_format, data_types::f32));
+
+ tolerance = 1e-5f;
+ execute(p);
+}
+
+TEST_P(eltwise_fp32_fused_prims, eltwise_activation) {
+ auto p = GetParam();
+ create_topologies(input_layout("input", get_input_layout(p)),
+ input_layout("input2", get_input_layout2(p)),
+ data("eltwise_data", get_mem(get_input_layout2(p), -10, 10)),
+ eltwise("eltwise1", {"input", "input2"}, p.mode, p.default_type),
+ eltwise("eltwise2", {"eltwise1", "eltwise_data"}, eltwise_mode::prod, p.default_type),
+ activation("activation", "eltwise2", activation_func::abs),
+ reorder("out", "activation", p.default_format, data_types::f32));
+
+ tolerance = 1e-5f;
+ execute(p);
+}
+
+TEST_P(eltwise_fp32_fused_prims, eltwise_activation_with_broadcast) {
+ auto p = GetParam();
+ create_topologies(input_layout("input", get_input_layout(p)),
+ input_layout("input2", get_input_layout2(p)),
+ data("eltwise_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ eltwise("eltwise1", {"input", "input2"}, p.mode, p.default_type),
+ eltwise("eltwise2", {"eltwise1", "eltwise_data"}, eltwise_mode::prod, p.default_type),
+ activation("activation", "eltwise2", activation_func::abs),
+ reorder("out", "activation", p.default_format, data_types::f32));
+
+ tolerance = 1e-5f;
+ execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu,
+ eltwise_fp32_fused_prims,
+ ::testing::ValuesIn(std::vector<eltwise_test_params>{
+ eltwise_test_params{CASE_ELTWISE_FP16_1, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP16_2, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP16_3, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP32_1, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP32_2, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP32_3, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_1, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_2, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_3, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_1, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_2, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_3, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_I8_FP32_1, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_I8_FP32_2, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_I8_FP32_3, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_U8_FP32_1, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_U8_FP32_2, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_U8_FP32_3, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_I8_FP16_1, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_I8_FP16_2, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_I8_FP16_3, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_U8_FP16_1, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_U8_FP16_2, 3, 5},
+ eltwise_test_params{CASE_ELTWISE_U8_FP16_3, 3, 5},
+ }), );
+
+class eltwise_fp32_scale : public EltwiseFusingTest {};
+TEST_P(eltwise_fp32_scale, 6d) {
+ auto p = GetParam();
+ create_topologies(input_layout("input", get_input_layout(p)),
+ input_layout("input2", get_input_layout2(p)),
+ data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ eltwise("eltwise", {"input", "input2"}, p.mode, p.default_type),
+ scale("scale", "eltwise", "scale_data"),
+ reorder("out", "scale", p.default_format, data_types::f32));
+
+ tolerance = 1e-5f;
+ execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu,
+ eltwise_fp32_scale,
+ ::testing::ValuesIn(std::vector<eltwise_test_params>{
+ eltwise_test_params{CASE_ELTWISE_FP32_4, 3, 4},
+ }), );
+
+/* ----------------------------------------------------------------------------------------------------- */
+/* ---------------------------------------- Scale cases ------------------------------------------------ */
+/* ----------------------------------------------------------------------------------------------------- */
+struct scale_test_params {
+ tensor input_size;
+ data_types input_type;
+ format input_format;
+ data_types default_type;
+ format default_format;
+ size_t expected_fused_primitives;
+ size_t expected_not_fused_primitives;
+};
+
+// Scale uses the same kernel as eltwise primitive, so the kernel is well covered by the eltwise tests above
+// So here we can just check that fused scale kernel is constructed correctly (inputs are set correctly, fused precision is propagated, etc)
+// and fusing conditions in the graph are correct
+#define CASE_SCALE_FP32_1 {2, 16, 4, 4}, data_types::f32, format::bfyx, data_types::f32, format::bfyx
+#define CASE_SCALE_FP32_2 {2, 16, 4, 4}, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_SCALE_FP32_3 {2, 16, 4, 4}, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::b_fs_yx_fsv16
+
+class ScaleFusingTest : public ::BaseFusingTest<scale_test_params> {
+public:
+ void execute(scale_test_params& p) {
+ auto input_prim = get_mem(get_input_layout(p));
+
+ network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
+ network network_fused(this->engine, this->topology_fused, bo_fused);
+
+ network_fused.set_input_data("input", input_prim);
+ network_not_fused.set_input_data("input", input_prim);
+
+ compare(network_not_fused, network_fused, p);
+ }
+
+ layout get_input_layout(scale_test_params& p) { return layout{p.input_type, p.input_format, p.input_size}; }
+
+ layout get_per_channel_layout(scale_test_params& p) {
+ return layout{p.default_type, p.default_format, tensor{1, p.input_size.feature[0], 1, 1}};
+ }
+};
+
+class scale_basic : public ScaleFusingTest {};
+TEST_P(scale_basic, no_bias_act_eltwise) {
+ auto p = GetParam();
+ create_topologies(input_layout("input", get_input_layout(p)),
+ data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ scale("scale", "input", "scale_data"),
+ activation("activation", "scale", activation_func::negative),
+ data("eltwise_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ eltwise("eltwise", {"activation", "eltwise_data"}, eltwise_mode::prod, p.default_type),
+ reorder("out", "eltwise", p.default_format, data_types::f32));
+
+ tolerance = 1e-5f;
+ execute(p);
+}
+
+TEST_P(scale_basic, bias_act_eltwise) {
+ auto p = GetParam();
+ create_topologies(input_layout("input", get_input_layout(p)),
+ data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ data("bias_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ scale("scale", "input", "scale_data", "bias_data"),
+ activation("activation", "scale", activation_func::negative),
+ data("eltwise_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ eltwise("eltwise", {"activation", "eltwise_data"}, eltwise_mode::prod, p.default_type),
+ reorder("out", "eltwise", p.default_format, data_types::f32));
+
+ tolerance = 1e-5f;
+ execute(p);
+}
+
+TEST_P(scale_basic, bias_act_scale) {
+ auto p = GetParam();
+ create_topologies(input_layout("input", get_input_layout(p)),
+ data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ data("bias_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ scale("scale", "input", "scale_data", "bias_data"),
+ activation("activation", "scale", activation_func::negative),
+ data("scale_data2", get_mem(get_per_channel_layout(p), -10, 10)),
+ scale("scale2", "activation", "scale_data2"),
+ reorder("out", "scale2", p.default_format, data_types::f32));
+
+ tolerance = 1e-5f;
+ execute(p);
+}
+
+TEST_P(scale_basic, bias_act_quantize) {
+ auto p = GetParam();
+ create_topologies(input_layout("input", get_input_layout(p)),
+ data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ data("bias_data", get_mem(get_per_channel_layout(p), -10, 10)),
+ scale("scale", "input", "scale_data", "bias_data"),
+ activation("activation", "scale", activation_func::negative),
+ data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
+ data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
+ data("out_lo", get_mem(get_single_element_layout(p), -128)),
+ data("out_hi", get_mem(get_single_element_layout(p), 127)),
+ quantize("quantize", "activation", "in_lo", "in_hi", "out_lo", "out_hi", 256, data_types::i8),
+ reorder("out", "quantize", p.default_format, data_types::f32));
+
+ tolerance = 1.f;
+ execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu,
+ scale_basic,
+ ::testing::ValuesIn(std::vector<scale_test_params>{
+ scale_test_params{CASE_SCALE_FP32_1, 2, 4},
+ scale_test_params{CASE_SCALE_FP32_2, 2, 4},
+ scale_test_params{CASE_SCALE_FP32_3, 2, 4},
}), );
class eltwise_no_pitches_same_dims_quantize : public EltwiseFusingTest {};
INSTANTIATE_TEST_CASE_P(fusings_gpu,
eltwise_no_pitches_same_dims_quantize,
::testing::ValuesIn(std::vector<eltwise_test_params>{
- eltwise_test_params{CASE_ELTWISE_FP16_1},
- eltwise_test_params{CASE_ELTWISE_FP16_2},
- eltwise_test_params{CASE_ELTWISE_FP16_3},
- eltwise_test_params{CASE_ELTWISE_FP32_1},
- eltwise_test_params{CASE_ELTWISE_FP32_2},
- eltwise_test_params{CASE_ELTWISE_FP32_3},
+ eltwise_test_params{CASE_ELTWISE_FP16_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_3, 3, 4},
}), );
class eltwise_activation : public EltwiseFusingTest {};
INSTANTIATE_TEST_CASE_P(fusings_gpu,
eltwise_activation,
::testing::ValuesIn(std::vector<eltwise_test_params>{
- eltwise_test_params{CASE_ELTWISE_FP16_1},
- eltwise_test_params{CASE_ELTWISE_FP16_2},
- eltwise_test_params{CASE_ELTWISE_FP16_3},
- eltwise_test_params{CASE_ELTWISE_FP32_1},
- eltwise_test_params{CASE_ELTWISE_FP32_2},
- eltwise_test_params{CASE_ELTWISE_FP32_3},
- eltwise_test_params{CASE_ELTWISE_FP32_FP16_1},
- eltwise_test_params{CASE_ELTWISE_FP32_FP16_2},
- eltwise_test_params{CASE_ELTWISE_FP32_FP16_3},
- eltwise_test_params{CASE_ELTWISE_FP16_FP32_1},
- eltwise_test_params{CASE_ELTWISE_FP16_FP32_2},
- eltwise_test_params{CASE_ELTWISE_FP16_FP32_3}
+ eltwise_test_params{CASE_ELTWISE_FP16_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP32_FP16_3, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_1, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_2, 3, 4},
+ eltwise_test_params{CASE_ELTWISE_FP16_FP32_3, 3, 4}
}), );
/* ----------------------------------------------------------------------------------------------------- */