COMPMID-3959: Update Mali-G52 heuristic for CLGEMM - F32
authorGian Marco Iodice <gianmarco.iodice@arm.com>
Tue, 10 Nov 2020 10:41:37 +0000 (10:41 +0000)
committerGian Marco Iodice <gianmarco.iodice@arm.com>
Fri, 13 Nov 2020 11:52:32 +0000 (11:52 +0000)
- Add heuristic in CLGEMMKernelSelection
- Add heuristic in CLGEMMReshapedRHSOnly
- Add heuristic in CLGEMMReshaped

Change-Id: Ibaa13398f7a5976418a0ab1b6696ace09cc480fa
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4366
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>

src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h
src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h
src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp
src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h

index c1ca187a70c153b6feec65a54809fab3eaedf53b..70992974a3a88e289c4636fccc6b2e746d9f5980 100644 (file)
@@ -60,6 +60,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }
     };
 
+    // Configurations for Mali-G52
+    static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
+    {
+        { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f32 },
+        { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 },
+        { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
+        { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
+        { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 },
+        { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
+    };
+
     // Configurations for Mali-G7x
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
     {
@@ -153,6 +164,105 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
     }
 }
 
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
+    const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+    const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
+    const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
+
+    GEMMLHSMatrixInfo lhs_info_buf;
+    GEMMRHSMatrixInfo rhs_info_buf;
+    GEMMLHSMatrixInfo lhs_info_img;
+    GEMMRHSMatrixInfo rhs_info_img;
+
+    if(workload <= 274.4000f)
+    {
+        if(r_nk <= 0.7461f)
+        {
+            if(r_mn <= 21.1667f)
+            {
+                return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false);
+            }
+            else
+            {
+                std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+                std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+
+                return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                           std::make_pair(lhs_info_buf, rhs_info_buf),
+                                           n, k, b, DataType::F32);
+            }
+        }
+        else
+        {
+            std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+            std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+
+            return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                       std::make_pair(lhs_info_buf, rhs_info_buf),
+                                       n, k, b, DataType::F32);
+        }
+    }
+    else
+    {
+        if(r_mk <= 17.3926f)
+        {
+            if(workload <= 542.4000f)
+            {
+                std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+                std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+
+                return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                           std::make_pair(lhs_info_buf, rhs_info_buf),
+                                           n, k, b, DataType::F32);
+            }
+            else
+            {
+                std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
+                std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
+
+                return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                           std::make_pair(lhs_info_buf, rhs_info_buf),
+                                           n, k, b, DataType::F32);
+            }
+        }
+        else
+        {
+            if(r_nk <= 0.5463f)
+            {
+                if(workload <= 11767.6001f)
+                {
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+                    std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F32);
+                }
+                else
+                {
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
+                    std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F32);
+                }
+            }
+            else
+            {
+                std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+                std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+
+                return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                           std::make_pair(lhs_info_buf, rhs_info_buf),
+                                           n, k, b, DataType::F32);
+            }
+        }
+    }
+}
+
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
 {
     ARM_COMPUTE_UNUSED(k);
index e3b62ced6a754124b4251afd75722020320b6ac0..6c67b70962545bc1387510019db57eac164622fe 100644 (file)
@@ -45,6 +45,7 @@ public:
 
 private:
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+    std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
index 3105db6693553f98acbc4770e96434f097c6912a..188ba4d9c513d411435d9f13faa03fcb88f5893a 100644 (file)
@@ -61,6 +61,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }
     };
 
+    // Configurations for Mali-G52
+    static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
+    {
+        { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32 },
+        { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 },
+        { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
+        { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
+        { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
+        { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }
+    };
+
     // Configurations for Mali-G76
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
     {
@@ -94,6 +105,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
             {
                 ARM_COMPUTE_ERROR("Not supported data type");
             }
+        case GPUTarget::G52:
+            if(gemm_configs_G52.find(data_type) != gemm_configs_G52.end())
+            {
+                return (this->*gemm_configs_G52[data_type])(m, n, k, b);
+            }
+            else
+            {
+                ARM_COMPUTE_ERROR("Not supported data type");
+            }
         case GPUTarget::G51:
             if(gemm_configs_G51.find(data_type) != gemm_configs_G51.end())
             {
@@ -201,6 +221,50 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
     }
 }
 
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+    const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
+
+    GEMMLHSMatrixInfo lhs_info_buf;
+    GEMMRHSMatrixInfo rhs_info_buf;
+    GEMMLHSMatrixInfo lhs_info_img;
+    GEMMRHSMatrixInfo rhs_info_img;
+
+    if(m == 1)
+    {
+        if(r_nk <= 0.4664f)
+        {
+            return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
+        }
+        else
+        {
+            std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
+            std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
+
+            return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                       std::make_pair(lhs_info_buf, rhs_info_buf),
+                                       n, k, b, DataType::F32);
+        }
+    }
+    else
+    {
+        if(workload <= 274.4000f)
+        {
+            return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
+        }
+        else
+        {
+            std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
+            std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
+
+            return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                       std::make_pair(lhs_info_buf, rhs_info_buf),
+                                       n, k, b, DataType::F32);
+        }
+    }
+}
+
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
 {
     ARM_COMPUTE_UNUSED(k);
index 618dbd9923111fb2bc758927a84b08fe56187176..3dfd96a822538fcdb83f09db559c4eb3a8641dc6 100644 (file)
@@ -46,6 +46,7 @@ public:
 private:
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+    std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
index 7c6efe3f114835df8a9bac80d66c5eee5cb0abcb..c77746a04472c12986610a917c9a8fb616f11de2 100644 (file)
@@ -68,6 +68,17 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS
         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMKernelSelectionBifrost::default_q8 }
     };
 
+    // Mali-G52 configurations
+    static std::map<DataType, FunctionExecutorPtr> gemm_g52_configs =
+    {
+        { DataType::F32, &CLGEMMKernelSelectionBifrost::g52_f32 },
+        { DataType::F16, &CLGEMMKernelSelectionBifrost::default_f16 },
+        { DataType::QASYMM8, &CLGEMMKernelSelectionBifrost::default_q8 },
+        { DataType::QASYMM8_SIGNED, &CLGEMMKernelSelectionBifrost::default_q8 },
+        { DataType::QSYMM8, &CLGEMMKernelSelectionBifrost::default_q8 },
+        { DataType::QSYMM8_PER_CHANNEL, &CLGEMMKernelSelectionBifrost::default_q8 }
+    };
+
     // Mali-G76 configurations
     static std::map<DataType, FunctionExecutorPtr> gemm_g76_configs =
     {
@@ -95,6 +106,12 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS
                 return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
             }
             ARM_COMPUTE_ERROR("Not supported data type");
+        case GPUTarget::G52:
+            if(gemm_g52_configs.find(data_type) != gemm_g52_configs.end())
+            {
+                return (this->*gemm_g52_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
+            }
+            ARM_COMPUTE_ERROR("Not supported data type");
         default:
             if(gemm_default_configs.find(data_type) != gemm_default_configs.end())
             {
@@ -237,6 +254,133 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned
     }
 }
 
+CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
+{
+    ARM_COMPUTE_UNUSED(b);
+
+    if (!is_rhs_constant)
+    {
+        return CLGEMMKernelType::NATIVE_V1;
+    }
+
+    if (m == 1)
+    {
+        return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+    }
+
+    const float r_mn  = static_cast<float>(m) / static_cast<float>(n);
+    const float r_mk  = static_cast<float>(m) / static_cast<float>(k);
+    const float r_nk  = static_cast<float>(n) / static_cast<float>(k);
+    const float r_mnk = static_cast<float>(m) / (static_cast<float>(n) * static_cast<float>(k));
+
+    if(r_mn <= 1.5469f)
+    {
+        if(r_mk <= 0.8766f)
+        {
+            if(r_mk <= 0.0211f)
+            {
+                if(r_mnk <= 77.5833f)
+                {
+                    return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                }
+                else
+                {
+                    return CLGEMMKernelType::RESHAPED;
+                }
+            }
+            else
+            {
+                if(r_nk <= 0.0832f)
+                {
+                    return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                }
+                else
+                {
+                    return CLGEMMKernelType::RESHAPED;
+                }
+            }
+        }
+        else
+        {
+            if(r_mnk <= 193.0000f)
+            {
+                if(r_mn <= 0.9948f)
+                {
+                    if(r_mk <= 2.5453f)
+                    {
+                        return CLGEMMKernelType::RESHAPED;
+                    }
+                    else
+                    {
+                        return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                    }
+                }
+                else
+                {
+                    return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                }
+            }
+            else
+            {
+                return CLGEMMKernelType::RESHAPED;
+            }
+        }
+    }
+    else
+    {
+        if(r_mn <= 17.7370f)
+        {
+            if(r_mnk <= 1391.2875f)
+            {
+                if(r_mk <= 2.9724f)
+                {
+                    return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                }
+                else
+                {
+                    if(r_mnk <= 470.0000f)
+                    {
+                        return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                    }
+                    else
+                    {
+                        return CLGEMMKernelType::RESHAPED;
+                    }
+                }
+            }
+            else
+            {
+                if(r_nk <= 0.1381f)
+                {
+                    if(r_mnk <= 9040.5000f)
+                    {
+                        return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                    }
+                    else
+                    {
+                        return CLGEMMKernelType::RESHAPED;
+                    }
+                }
+                else
+                {
+                    if(r_mn <= 5.6790f)
+                    {
+                        return CLGEMMKernelType::RESHAPED;
+                    }
+                    else
+                    {
+                        return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+                    }
+                }
+            }
+        }
+        else
+        {
+            return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+        }
+    }
+}
+
 CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
 {
     ARM_COMPUTE_UNUSED(b);
index e3cc8e4a27962f903dec39b5055da04b8b9f7a74..fbafc531f5fda106dcac6b1d34b38da58e2ab698 100644 (file)
@@ -44,6 +44,7 @@ public:
     CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams &params) override;
 
 private:
+    CLGEMMKernelType g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
     CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
     CLGEMMKernelType g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
     CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);