Add mlgo to graph examples
authorSiCong Li <sicong.li@arm.com>
Wed, 3 Feb 2021 12:17:35 +0000 (12:17 +0000)
committerSiCong Li <sicong.li@arm.com>
Tue, 9 Feb 2021 16:02:44 +0000 (16:02 +0000)
Resolves COMPMID-3847

Change-Id: I99f73bfc8eda66e8ce1dd1f2a18be76e9d826569
Signed-off-by: SiCong Li <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5033
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>

31 files changed:
arm_compute/graph/Types.h
arm_compute/graph/backends/CL/CLDeviceBackend.h
examples/graph_alexnet.cpp
examples/graph_deepspeech_v0_4_1.cpp
examples/graph_edsr.cpp
examples/graph_googlenet.cpp
examples/graph_inception_resnet_v1.cpp
examples/graph_inception_resnet_v2.cpp
examples/graph_inception_v3.cpp
examples/graph_inception_v4.cpp
examples/graph_lenet.cpp
examples/graph_mnist.cpp
examples/graph_mobilenet.cpp
examples/graph_mobilenet_v2.cpp
examples/graph_resnet12.cpp
examples/graph_resnet50.cpp
examples/graph_resnet_v2_50.cpp
examples/graph_resnext50.cpp
examples/graph_shufflenet.cpp
examples/graph_squeezenet.cpp
examples/graph_squeezenet_v1_1.cpp
examples/graph_srcnn955.cpp
examples/graph_ssd_mobilenet.cpp
examples/graph_vgg16.cpp
examples/graph_vgg19.cpp
examples/graph_vgg_vdsr.cpp
examples/graph_yolov3.cpp
src/graph/backends/CL/CLDeviceBackend.cpp
tests/benchmark_examples/RunExample.cpp
utils/CommonGraphOptions.cpp
utils/CommonGraphOptions.h

index c5d3d17a9bbe3c13ae6952d0538f4212c4452faf..b891c1772fcc77f73419154c01491ae108586889 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -87,6 +87,7 @@ struct GraphConfig
     CLTunerMode tuner_mode{ CLTunerMode::EXHAUSTIVE }; /**< Tuner mode to be used by the CL tuner */
     int         num_threads{ -1 };                     /**< Number of threads to use (thread capable backends), if 0 the backend will auto-initialize, if -1 the backend will stay as it is. */
     std::string tuner_file{ "acl_tuner.csv" };         /**< File to load/store tuning values from */
+    std::string mlgo_file{ "heuristics.mlgo" };        /**< Filename to load MLGO heuristics from */
 };
 
 /**< Device target types */
index a8ee25d7e225f7b8982df6debc1240672c757367..82c0eacd113a0c025e195af1bc3e7efd6fa96fef 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,6 +27,7 @@
 #include "arm_compute/graph/IDeviceBackend.h"
 
 #include "arm_compute/runtime/CL/CLBufferAllocator.h"
+#include "arm_compute/runtime/CL/CLGEMMHeuristicsHandle.h"
 #include "arm_compute/runtime/CL/CLTuner.h"
 
 namespace arm_compute
@@ -70,10 +71,11 @@ public:
     std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
 
 private:
-    int                                _context_count; /**< Counts how many contexts are currently using the backend */
-    CLTuner                            _tuner;         /**< CL kernel tuner */
-    std::unique_ptr<CLBufferAllocator> _allocator;     /**< CL buffer affinity allocator */
-    std::string                        _tuner_file;    /**< Filename to load/store the tuner's values from */
+    int                                _context_count;   /**< Counts how many contexts are currently using the backend */
+    CLTuner                            _tuner;           /**< CL kernel tuner */
+    CLGEMMHeuristicsHandle             _gemm_heuristics; /**< GEMM heuristics */
+    std::unique_ptr<CLBufferAllocator> _allocator;       /**< CL buffer affinity allocator */
+    std::string                        _tuner_file;      /**< Filename to load/store the tuner's values from */
 };
 } // namespace backends
 } // namespace graph
index ce398be6cf9f8c13f8448508cd98c917979a123d..7f4e75aaf87aae32aee3b3befa7dae6294680d83 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -159,6 +159,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         // Load the precompiled kernels from a file into the kernel library, in this way the next time they are needed
         // compilation won't be required.
index 4a8a8b15a93d0be434fa3db9ee95ba8b3086ff8c..a5658625c71f4e8dbb8cab40c58da2348bee0f34 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -211,6 +211,7 @@ public:
         config.num_threads      = common_params.threads;
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index 77783d97ed1f8b84f9d0e6f62653a52e4d565951..0e41f1215561730c88f2b4353dd48bee05bd684b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020 Arm Limited.
+ * Copyright (c) 2020-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -75,6 +75,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         context.set_config(config);
 
index 0a53355611ddef1f8089fe1cc0d9211f79c4562f..7555d805c1905285eb09d4810746371fdead42af 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -130,6 +130,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index 7a55733a209c5d76889e750b44448b1d916b4f94..6ae5b5dc774da4bbf61ec6c01aa75fb90b4eb9b2 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -215,6 +215,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index 60236d078038200a4e1693e0a055633eb044e22a..ae37ee507d12948c7fa1326cfb661b697d673aea 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -196,6 +196,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index 5cacbcb6e19709fb12bac3233c139d06e32fd8c1..8ceeb5c68ef6718873f92d238bcc7c07b0a08fbf 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -201,6 +201,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index db2a31047e590c0c7122a5e1a9b83a1cbfd2e023..cafa5c9f1023adc1dbc00cbe09fd849ae452b609 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -156,6 +156,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         // Load the precompiled kernels from a file into the kernel library, in this way the next time they are needed
index e5783078f1b43a1aa6af359e8efa635262b7f64a..6560a980cc861bbdc327d4668b66b84a506d001c 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -111,6 +111,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index 85ab0ab97282625c247e813e9da27c630867970c..4ef96cc596bb3a36ec5789b4306300f08481eecd 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -140,6 +140,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index b73f7a2abdb423f0e8f9b608331c6afdc9a93c85..09b6e6e097d8e9f847526ed65520bf5b735e39f1 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -100,6 +100,8 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index fa16c94645c7116147a417813b9e26c6e9e527f3..b1b33be2f54dbf49b93dac80e84edc7f379185e5 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -91,6 +91,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index ebd2e5dd16ec629fbd5e2a22293abec8dc29f354..8818cf742ad913b8d081363b844b05d7261b4464 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -136,6 +136,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index 47d258ede706c8634cc542e5e75bf1f834e63c12..b585284c60dfd49536860153004ac270388f1171 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -115,6 +115,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index 921fb145d6d6c41a491f4ecd89e2e563d253a543..472bf02b47ecf7d4b3c54de196fa71fb86104c73 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -118,6 +118,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index 1d9ed8dc890e3491e535cd5331f762e21ac27cbe..ec87e0b882f9521de4c31e8e75c77cd74753f165 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -102,6 +102,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index 300d0f15a1a987a9ff0aefded153218f1ac16522..f90f36149d03451cc3861ac6a840846c03a539ca 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -148,6 +148,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index 2e72c14763e3f56444de6a58e98420211fc003a7..3d32794e8d9fcd47de4b46e7163ed66263f06708 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -168,6 +168,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index 1708ac2f5afa4a3f361dac3587a14ad34af83c5e..6d4ffee994e702bd232401cc370efcba6496bcaf 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -168,6 +168,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index bcc3824c602b62f25aa1266d62474d79d19742bc..f4ffc021305cd8405eb93a410c842e7e99006d3b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -119,6 +119,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index f5af84f4d48b9048b26bd6bf69319a65c1d11f0b..c0859227ab14d57c35debadc305678151caf9ddb 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -97,6 +97,7 @@ public:
         config.num_threads = common_params.threads;
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index a4c5e6bbd22a1dacfd271a0ccaade1539bd014dd..83e663798bb88226e923d303a9a673f51b8b99e4 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -216,6 +216,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index c95fb0336858f3227ff345ff37410908474ace02..03f7e1606cb1f8a015c3c3e71fc39ec0f5c4669b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -227,6 +227,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index 3fa7dd13300cf067ed8040d0f62cd3717eeb158e..bdb898081dc14a94dc30cb45f0990d437934f202 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -140,6 +140,7 @@ public:
         config.use_tuner        = common_params.enable_tuner;
         config.tuner_mode       = common_params.tuner_mode;
         config.tuner_file       = common_params.tuner_file;
+        config.mlgo_file        = common_params.mlgo_file;
         config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
 
         graph.finalize(common_params.target, config);
index 54aaf201cb2e7598e1591df715c88a35d149dd8e..3c8ddbffd812b8d2b09ca8240aecbdb14d3513d5 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -401,6 +401,7 @@ public:
         config.use_tuner   = common_params.enable_tuner;
         config.tuner_mode  = common_params.tuner_mode;
         config.tuner_file  = common_params.tuner_file;
+        config.mlgo_file   = common_params.mlgo_file;
 
         graph.finalize(common_params.target, config);
 
index 50dd799ee19067ed0badbc4a1806d63922b96f68..f8e22ca7a071767daefcbbb0f1584951876cb876 100644 (file)
@@ -65,7 +65,7 @@ bool file_exists(const std::string &filename)
 static detail::BackendRegistrar<CLDeviceBackend> CLDeviceBackend_registrar(Target::CL);
 
 CLDeviceBackend::CLDeviceBackend()
-    : _context_count(0), _tuner(), _allocator(nullptr), _tuner_file()
+    : _context_count(0), _tuner(), _gemm_heuristics(), _allocator(nullptr), _tuner_file()
 {
 }
 
@@ -87,7 +87,7 @@ void CLDeviceBackend::set_kernel_tuning_mode(CLTunerMode tuning_mode)
 void CLDeviceBackend::initialize_backend()
 {
     // Setup Scheduler
-    CLScheduler::get().default_init(&_tuner);
+    CLScheduler::get().default_init(&_tuner, &_gemm_heuristics);
     // Create allocator with new context
     _allocator = std::make_unique<CLBufferAllocator>(nullptr /* legacy path for CLCoreRuntimeContext */);
 }
@@ -123,6 +123,10 @@ void CLDeviceBackend::setup_backend_context(GraphContext &ctx)
     set_kernel_tuning(ctx.config().use_tuner);
     set_kernel_tuning_mode(ctx.config().tuner_mode);
 
+    // Attempt to load mlgo heuristics
+    ARM_COMPUTE_ERROR_ON(CLScheduler::get().gemm_heuristics() == nullptr);
+    CLScheduler::get().gemm_heuristics()->reload_from_file(ctx.config().mlgo_file);
+
     // Setup a management backend
     if(ctx.memory_management_ctx(Target::CL) == nullptr)
     {
index 925daaf1561d59b5d0c7f94c14818c19d00667f3..8adcd95ff6928759cf6929ff5ce17a1076978a04 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,6 +34,7 @@
 #include "utils/command_line/CommandLineParser.h"
 
 #ifdef ARM_COMPUTE_CL
+#include "arm_compute/runtime/CL/CLGEMMHeuristicsHandle.h"
 #include "arm_compute/runtime/CL/CLHelpers.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
 #endif /* ARM_COMPUTE_CL */
@@ -127,12 +128,13 @@ int run_example(int argc, char **argv, std::unique_ptr<Example> example)
     }
 
 #ifdef ARM_COMPUTE_CL
+    CLGEMMHeuristicsHandle gemm_h;
     if(opencl_is_available())
     {
         auto ctx_dev_err = create_opencl_context_and_device();
         ARM_COMPUTE_ERROR_ON_MSG(std::get<2>(ctx_dev_err) != CL_SUCCESS, "Failed to create OpenCL context");
         CLScheduler::get()
-        .default_init_with_context(std::get<1>(ctx_dev_err), std::get<0>(ctx_dev_err));
+        .default_init_with_context(std::get<1>(ctx_dev_err), std::get<0>(ctx_dev_err), nullptr, &gemm_h);
     }
 #endif /* ARM_COMPUTE_CL */
 
index d262ea86e930e2ea5f957563b3fecd6dbb73d210..44d66fa91b2719e14765188a93e0f5caef9be1b7 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -87,6 +87,7 @@ namespace utils
     os << "Cache enabled? : " << (common_params.enable_cl_cache ? true_str : false_str) << std::endl;
     os << "Tuner mode : " << common_params.tuner_mode << std::endl;
     os << "Tuner file : " << common_params.tuner_file << std::endl;
+    os << "MLGO file : " << common_params.mlgo_file << std::endl;
     os << "Fast math enabled? : " << (common_params.fast_math_hint == FastMathHint::Enabled ? true_str : false_str) << std::endl;
     if(!common_params.data_path.empty())
     {
@@ -129,7 +130,8 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser)
       validation_file(parser.add_option<SimpleOption<std::string>>("validation-file")),
       validation_path(parser.add_option<SimpleOption<std::string>>("validation-path")),
       validation_range(parser.add_option<SimpleOption<std::string>>("validation-range")),
-      tuner_file(parser.add_option<SimpleOption<std::string>>("tuner-file"))
+      tuner_file(parser.add_option<SimpleOption<std::string>>("tuner-file")),
+      mlgo_file(parser.add_option<SimpleOption<std::string>>("mlgo-file"))
 {
     std::set<arm_compute::graph::Target> supported_targets
     {
@@ -183,6 +185,7 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser)
     validation_path->set_help("Path to the validation data");
     validation_range->set_help("Range of the images to validate for (Format : start,end)");
     tuner_file->set_help("File to load/save CLTuner values");
+    mlgo_file->set_help("File to load MLGO heuristics");
 }
 
 CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options)
@@ -211,6 +214,7 @@ CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options)
     common_params.validation_range_start = validation_range.first;
     common_params.validation_range_end   = validation_range.second;
     common_params.tuner_file             = options.tuner_file->value();
+    common_params.mlgo_file              = options.mlgo_file->value();
 
     return common_params;
 }
index dac2e10b193d7418403e76e347a35c0d69e6b1a1..13cd653e46c96e3e07e40664ea91f044ae665240 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -108,6 +108,7 @@ struct CommonGraphParams
     std::string                      validation_file{};
     std::string                      validation_path{};
     std::string                      tuner_file{};
+    std::string                      mlgo_file{};
     unsigned int                     validation_range_start{ 0 };
     unsigned int                     validation_range_end{ std::numeric_limits<unsigned int>::max() };
 };
@@ -165,6 +166,7 @@ public:
     SimpleOption<std::string>              *validation_path;  /**< Validation data path */
     SimpleOption<std::string>              *validation_range; /**< Validation range */
     SimpleOption<std::string>              *tuner_file;       /**< File to load/store the tuner's values from */
+    SimpleOption<std::string>              *mlgo_file;        /**< File to load the MLGO heuristics from */
 };
 
 /** Consumes the common graph options and creates a structure containing any information