BenchmarkCache<miopenConvBwdDataAlgorithm_t> bwd_data_algos;
BenchmarkCache<miopenConvBwdWeightsAlgorithm_t> bwd_filter_algos;
+BenchmarkCache<size_t> fwd_wssizes;
+BenchmarkCache<size_t> bwd_data_wssizes;
+BenchmarkCache<size_t> bwd_filter_wssizes;
+
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(globalContext().lazyInitCUDA(), size);
static constexpr auto DEFAULT_ALGO = miopenConvolutionFwdAlgoGEMM;
static BenchmarkCache<algo_t>& cache() { return fwd_algos; }
+ static BenchmarkCache<size_t>& wsscache() { return fwd_wssizes; }
static perf_t findAlgorithm(const ConvolutionArgs& args) {
int perf_count;
static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdDataAlgoGEMM;
static BenchmarkCache<algo_t>& cache() { return bwd_data_algos; }
+ static BenchmarkCache<size_t>& wsscache() { return bwd_data_wssizes; }
static perf_t findAlgorithm(const ConvolutionArgs& args) {
int perf_count;
static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM;
static BenchmarkCache<algo_t>& cache() { return bwd_filter_algos; }
+ static BenchmarkCache<size_t>& wsscache() { return bwd_filter_wssizes; }
static perf_t findAlgorithm(const ConvolutionArgs& args) {
int perf_count;
void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
using search = algorithm_search<algo_t>;
auto& cache = search::cache();
+ auto& wsscache = search::wsscache();
if (cache.find(args.params, algo)) {
return;
*algo = reinterpret_cast<algo_t&>(perfResults);
cache.insert(args.params, *algo);
+ wsscache.insert(args.params, perfResults.memory);
THCCachingAllocator_emptyCache();
}
using search = algorithm_search<algo_t>;
size_t workspace_size;
- workspace_size = getWorkspaceSize(args, *algo);
+ search::wsscache().find(args.params, &workspace_size);
try {
return Workspace(workspace_size);
} catch (const std::exception& e) {
// switch to default algorithm and record it in the cache to prevent
// further OOM errors
*algo = search::DEFAULT_ALGO;
- search::cache().insert(args.params, *algo);
-
workspace_size = getWorkspaceSize(args, *algo);
+ search::cache().insert(args.params, *algo);
+ search::wsscache().insert(args.params, workspace_size);
return Workspace(workspace_size);
}
}