From e489b600f388ae345387881a85368af3cd373ba2 Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Thu, 12 Apr 2018 18:07:50 -0700 Subject: [PATCH] Replace tuple for version info with a class in DnnSupportr::GetVersion() (#18434) * Replace tuple for version info with a class * Removed clang-format modifications on non-edited code * Update dnn.h Update the comment as per request of reviewer --- .../gpu/cudnn_convolution_algorithm_picker.cc | 4 ++-- tensorflow/stream_executor/cuda/cuda_dnn.cc | 7 ++++--- tensorflow/stream_executor/cuda/cuda_dnn.h | 2 +- tensorflow/stream_executor/dnn.h | 20 ++++++++++++++++++-- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index d6b457a..1eccfe8 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -99,9 +99,9 @@ bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, const ConvolutionDimensionNumbers& dnums, se::StreamExecutor* stream_exec) { // Skip this check for cudnn7 and newer. - se::port::StatusOr> version = + auto version = stream_exec->AsDnn()->GetVersion(); - if (version.ok() && std::get<0>(version.ValueOrDie()) >= 7) { + if (version.ok() && version.ValueOrDie().major_version() >= 7) { return true; } diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 1dc7f99..a11b644 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -477,11 +477,12 @@ port::Status CudnnSupport::Init() { ToString(status))}; } -port::StatusOr> CudnnSupport::GetVersion() { +port::StatusOr +CudnnSupport::GetVersion() { CudnnVersion version; TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version)); - return std::make_tuple(version.major_version, version.minor_version, - version.patch_level); + return perftools::gputools::dnn::VersionInfo( + version.major_version, version.minor_version, version.patch_level); } // Turns a BatchDescriptor structure into a cudnn tensor handle within a scope. diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 0e5368a..09d248f 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -46,7 +46,7 @@ class CudnnSupport : public dnn::DnnSupport { ~CudnnSupport() override; port::Status Init() override; - port::StatusOr> GetVersion() override; + port::StatusOr GetVersion() override; port::StatusOr> createRnnDescriptor( int num_layers, int hidden_size, int input_size, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 3c47d2c..47dcd80 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -876,6 +876,22 @@ enum class ElementwiseOperation { kAdd, kMultiply }; string ElementwiseOperationString(ElementwiseOperation op); +// A simple class representing the version of the backing library, to +// workaround the "too perfect forwarding" issue in gcc6+ compilers. +// See PR#16309 and issue #18402 for links discussing the issue. +class VersionInfo { + public: + VersionInfo(int major = 0, int minor = 0, int patch = 0) + : major_(major), minor_(minor), patch_(patch) {} + int major_version() { return major_; } + int minor_version() { return minor_; } + int patch() { return patch_; } + private: + int major_; + int minor_; + int patch_; +}; + // Suite of operations typically used for implementing Deep/Convolutional Neural // Nets. Note: A false return value of an operation indicates the // implementation is not available. @@ -886,8 +902,8 @@ class DnnSupport { virtual port::Status Init() = 0; - // Gets the version of the backing library, as a {major, minor, patch} tuple. - virtual port::StatusOr> GetVersion() { + // Gets the version of the backing library, as a VersionInfo object. + virtual port::StatusOr GetVersion() { return port::UnimplementedError( "DnnSupport::GetVersion not implemented on this platform."); } -- 2.7.4