COMPMID-4018: Fix heuristic fallback for CLGEMMReshapedRHSOnly for
authorGian Marco Iodice <gianmarco.iodice@arm.com>
Mon, 23 Nov 2020 16:10:27 +0000 (16:10 +0000)
committerPablo Marquez Tello <pablo.tello@arm.com>
Mon, 23 Nov 2020 19:27:24 +0000 (19:27 +0000)
Mali-G52

- Missing fallback in case of export to cl_image

Change-Id: I5bb3013fd1350628f16e4709c4bb31999fece22d
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4531
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
(cherry picked from commit 8919a1a849e425aefcd09c5db5f6f9f2e403d4e9)
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4170
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp

index 59a2a82edf717a5b660f9761e5b86ccfb1bf5d29..46eeff3524c9a58a69083f54c155dbe0bdccf74c 100644 (file)
@@ -269,13 +269,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
 
     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
 
-    if(workload <= 232.8000f)
+    if(workload <= 323.4000f)
     {
-        return configure_lhs_rhs_info(m, n, 2, 4, 4, 4, 4, true, true, true, false, false);
+        return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false);
     }
     else
     {
-        return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, true, true, true, false, false);
+        return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2, true, true, true, false, false);
     }
 }
 
index a2c1ed2c8e80e3b749ee0028c1a15d33dc5a04bd..d5b76d8eafdd2191cc096395f56ad0be609938c9 100644 (file)
@@ -322,8 +322,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
     const float r_mk = static_cast<float>(m) / static_cast<float>(k);
     const float r_nk = static_cast<float>(n) / static_cast<float>(k);
 
+    GEMMLHSMatrixInfo lhs_info_buf;
+    GEMMRHSMatrixInfo rhs_info_buf;
+    GEMMLHSMatrixInfo lhs_info_img;
+    GEMMRHSMatrixInfo rhs_info_img;
+
     if(m == 1)
     {
+        std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
+
         if(r_mk <= 0.0026f)
         {
             if(r_nk <= 0.4664f)
@@ -332,7 +339,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
             }
             else
             {
-                return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+                std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+                return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                           std::make_pair(lhs_info_buf, rhs_info_buf),
+                                           n, k, b, DataType::F16);
             }
         }
         else
@@ -343,12 +353,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
             }
             else
             {
-                return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+                std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+                return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                           std::make_pair(lhs_info_buf, rhs_info_buf),
+                                           n, k, b, DataType::F16);
             }
         }
     }
     else
     {
+        std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
+
         if(workload <= 362.6000f)
         {
             return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
@@ -359,7 +374,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
             {
                 if(workload <= 708.8000f)
                 {
-                    return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
                 }
                 else
                 {
@@ -374,7 +392,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
                 }
                 else
                 {
-                    return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
                 }
             }
         }
index 46d07fffba3810eeecbb563e264719c19383b8a7..0bda38e5e98378917a9a1bb50ba9615e4e94c24d 100644 (file)
@@ -445,8 +445,6 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f16(unsigned int m, unsigned
 
 CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
 {
-    ARM_COMPUTE_UNUSED(b);
-
     if (!is_rhs_constant)
     {
         return CLGEMMKernelType::NATIVE_V1;
@@ -457,26 +455,25 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f16(unsigned int m, unsigned
         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
     }
 
-    const float r_mn  = static_cast<float>(m) / static_cast<float>(n);
-    const float r_mk  = static_cast<float>(m) / static_cast<float>(k);
-    const float r_nk  = static_cast<float>(n) / static_cast<float>(k);
-    const float r_mnk = static_cast<float>(m) / (static_cast<float>(n) * static_cast<float>(k));
-
-    if(r_mn <= 22.9200f)
+    if(n <= 127.0000f)
     {
-        if(r_mk <= 0.0157f)
+        if(n <= 63.5000f)
         {
             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
         }
         else
         {
-            if(r_mnk <= 7809.3750f)
+            if(m <= 3616.0000f)
             {
-                if(r_mnk <= 101.7937f)
+                if(b <= 18.5000f)
                 {
-                    if(r_mn <= 0.4594f)
+                    if(m <= 2970.5000f)
                     {
-                        if(r_mk <= 0.0557f)
+                        return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                    }
+                    else
+                    {
+                        if(k <= 104.0000f)
                         {
                             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
                         }
@@ -485,80 +482,76 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f16(unsigned int m, unsigned
                             return CLGEMMKernelType::RESHAPED;
                         }
                     }
-                    else
-                    {
-                        return CLGEMMKernelType::RESHAPED_ONLY_RHS;
-                    }
                 }
                 else
                 {
-                    if(r_nk <= 0.4396f)
+                    return CLGEMMKernelType::RESHAPED;
+                }
+            }
+            else
+            {
+                return CLGEMMKernelType::RESHAPED;
+            }
+        }
+    }
+    else
+    {
+        if(m <= 12.5000f)
+        {
+            return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+        }
+        else
+        {
+            if(k <= 104.0000f)
+            {
+                if(b <= 18.5000f)
+                {
+                    if(m <= 490.0000f)
                     {
-                        if(r_mn <= 1.5182f)
+                        if(n <= 272.0000f)
                         {
-                            if(r_mnk <= 1709.9167f)
-                            {
-                                return CLGEMMKernelType::RESHAPED;
-                            }
-                            else
-                            {
-                                return CLGEMMKernelType::RESHAPED_ONLY_RHS;
-                            }
+                            return CLGEMMKernelType::RESHAPED_ONLY_RHS;
                         }
                         else
                         {
-                            if(r_mnk <= 1330.6000f)
-                            {
-                                return CLGEMMKernelType::RESHAPED_ONLY_RHS;
-                            }
-                            else
-                            {
-                                return CLGEMMKernelType::RESHAPED;
-                            }
+                            return CLGEMMKernelType::RESHAPED;
                         }
                     }
                     else
                     {
-                        if(r_mn <= 2.5896f)
+                        return CLGEMMKernelType::RESHAPED;
+                    }
+                }
+                else
+                {
+                    return CLGEMMKernelType::RESHAPED;
+                }
+            }
+            else
+            {
+                if(m <= 226.0000f)
+                {
+                    if(n <= 140.0000f)
+                    {
+                        if(m <= 179.5000f)
                         {
                             return CLGEMMKernelType::RESHAPED;
                         }
                         else
                         {
-                            if(r_mnk <= 326.6667f)
-                            {
-                                return CLGEMMKernelType::RESHAPED_ONLY_RHS;
-                            }
-                            else
-                            {
-                                return CLGEMMKernelType::RESHAPED;
-                            }
+                            return CLGEMMKernelType::RESHAPED_ONLY_RHS;
                         }
                     }
+                    else
+                    {
+                        return CLGEMMKernelType::RESHAPED;
+                    }
+                }
+                else
+                {
+                    return CLGEMMKernelType::RESHAPED;
                 }
             }
-            else
-            {
-                return CLGEMMKernelType::RESHAPED_ONLY_RHS;
-            }
-        }
-    }
-    else
-    {
-        if(r_mn <= 86.7578f)
-        {
-            if(r_mnk <= 11231.6406f)
-            {
-                return CLGEMMKernelType::RESHAPED_ONLY_RHS;
-            }
-            else
-            {
-                return CLGEMMKernelType::RESHAPED;
-            }
-        }
-        else
-        {
-            return CLGEMMKernelType::RESHAPED_ONLY_RHS;
         }
     }
 }