Cache workspace size in the BenchmarkCache. (#15742)
authormwootton <michael.wootton@amd.com>
Tue, 8 Jan 2019 20:23:30 +0000 (12:23 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 8 Jan 2019 21:10:15 +0000 (13:10 -0800)
Summary:
Cache the workspace size information for MIOpen for a given configuration as opposed to inquiring it every time. This reduces overhead significantly as inquiring the workspace size forces a full read of the performance database in MIOpen and this database has grown significantly in recent releases. This caching gets us back to ideal performance.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15742

Differential Revision: D13598932

Pulled By: bddppq

fbshipit-source-id: 4e65d247b71dec828293cf0562aac3fbd4fad83a

aten/src/ATen/native/miopen/Conv_miopen.cpp

index 1635161..c04e100 100644 (file)
@@ -336,6 +336,10 @@ BenchmarkCache<miopenConvFwdAlgorithm_t> fwd_algos;
 BenchmarkCache<miopenConvBwdDataAlgorithm_t> bwd_data_algos;
 BenchmarkCache<miopenConvBwdWeightsAlgorithm_t> bwd_filter_algos;
 
+BenchmarkCache<size_t> fwd_wssizes;
+BenchmarkCache<size_t> bwd_data_wssizes;
+BenchmarkCache<size_t> bwd_filter_wssizes;
+
 struct Workspace {
   Workspace(size_t size) : size(size), data(NULL) {
     data = THCudaMalloc(globalContext().lazyInitCUDA(), size);
@@ -409,6 +413,7 @@ struct algorithm_search<miopenConvFwdAlgorithm_t> {
 
   static constexpr auto DEFAULT_ALGO = miopenConvolutionFwdAlgoGEMM;
   static BenchmarkCache<algo_t>& cache() { return fwd_algos; }
+  static BenchmarkCache<size_t>& wsscache() { return fwd_wssizes; }
 
   static perf_t findAlgorithm(const ConvolutionArgs& args) {
     int perf_count;
@@ -438,6 +443,7 @@ struct algorithm_search<miopenConvBwdDataAlgorithm_t> {
 
   static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdDataAlgoGEMM;
   static BenchmarkCache<algo_t>& cache() { return bwd_data_algos; }
+  static BenchmarkCache<size_t>& wsscache() { return bwd_data_wssizes; }
 
   static perf_t findAlgorithm(const ConvolutionArgs& args) {
     int perf_count;
@@ -467,6 +473,7 @@ struct algorithm_search<miopenConvBwdWeightsAlgorithm_t> {
 
   static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM;
   static BenchmarkCache<algo_t>& cache() { return bwd_filter_algos; }
+  static BenchmarkCache<size_t>& wsscache() { return bwd_filter_wssizes; }
 
   static perf_t findAlgorithm(const ConvolutionArgs& args) {
     int perf_count;
@@ -493,6 +500,7 @@ template<typename algo_t>
 void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
   using search = algorithm_search<algo_t>;
   auto& cache = search::cache();
+  auto& wsscache = search::wsscache();
 
   if (cache.find(args.params, algo)) {
     return;
@@ -512,6 +520,7 @@ void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
   *algo = reinterpret_cast<algo_t&>(perfResults);
 
   cache.insert(args.params, *algo);
+  wsscache.insert(args.params, perfResults.memory);
 
   THCCachingAllocator_emptyCache();
 }
@@ -526,7 +535,7 @@ Workspace chooseAlgorithm(
 
   using search = algorithm_search<algo_t>;
   size_t workspace_size;
-  workspace_size = getWorkspaceSize(args, *algo);
+  search::wsscache().find(args.params, &workspace_size);
   try {
     return Workspace(workspace_size);
   } catch (const std::exception& e) {
@@ -535,9 +544,9 @@ Workspace chooseAlgorithm(
     // switch to default algorithm and record it in the cache to prevent
     // further OOM errors
     *algo = search::DEFAULT_ALGO;
-    search::cache().insert(args.params, *algo);
-
     workspace_size = getWorkspaceSize(args, *algo);
+    search::cache().insert(args.params, *algo);
+    search::wsscache().insert(args.params, workspace_size);
     return Workspace(workspace_size);
   }
 }