COMPMID-3981: Fix missing fallback for export to cl_image
authorGian Marco Iodice <gianmarco.iodice@arm.com>
Wed, 18 Nov 2020 14:56:09 +0000 (14:56 +0000)
committerPablo Marquez Tello <pablo.tello@arm.com>
Thu, 19 Nov 2020 15:49:23 +0000 (15:49 +0000)
- Fix missing fallback in the CLGEMMReshaped heuristic on Mali-G77

Change-Id: I0a243c7ed153216966d0809a3b3348f030a845eb
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4463
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
(cherry picked from commit 8b84aaa4db093ac08efa96c2cbf766e800465529)
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4162
Reviewed-by: SiCong Li <sicong.li@arm.com>
src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp

index 3f82dcab00ba6af286b699169f545c23bad8ec56..4fd446f64752dfc51f31f9f8e006525b3df9739d 100644 (file)
@@ -95,6 +95,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
     const float r_mk = static_cast<float>(m) / static_cast<float>(k);
     const float r_nk = static_cast<float>(n) / static_cast<float>(k);
 
+    GEMMLHSMatrixInfo lhs_info_buf;
+    GEMMRHSMatrixInfo rhs_info_buf;
+    GEMMLHSMatrixInfo lhs_info_img;
+    GEMMRHSMatrixInfo rhs_info_img;
+
+    std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, false, false, true, false, false);
+
     if(r_mk <= 0.11824845522642136)
     {
         if(workload <= 880.0)
@@ -111,7 +118,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
                 }
                 else
                 {
-                    return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
                 }
             }
             else
@@ -135,7 +146,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
             {
                 if(r_mn <= 2.545312523841858)
                 {
-                    return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
                 }
                 else
                 {
@@ -146,11 +161,19 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
             {
                 if(workload <= 2881.199951171875)
                 {
-                    return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, false, false, true, false, true);
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, false, false, true, false, true);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
                 }
                 else
                 {
-                    return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
                 }
             }
         }
@@ -160,16 +183,28 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
             {
                 if(r_mn <= 6.010416746139526)
                 {
-                    return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
                 }
                 else
                 {
-                    return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, false, true, false, true);
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, false, true, false, true);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
                 }
             }
             else
             {
-                return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, false, true, false, true);
+                std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, false, true, false, true);
+
+                return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                            std::make_pair(lhs_info_buf, rhs_info_buf),
+                                            n, k, b, DataType::F16);
             }
         }
     }