const ConvolutionDimensionNumbers& dnums,
se::StreamExecutor* stream_exec) {
// Skip this check for cudnn7 and newer.
- se::port::StatusOr<std::tuple<int, int, int>> version =
+ auto version =
stream_exec->AsDnn()->GetVersion();
- if (version.ok() && std::get<0>(version.ValueOrDie()) >= 7) {
+ if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
ToString(status))};
}
-port::StatusOr<std::tuple<int, int, int>> CudnnSupport::GetVersion() {
+port::StatusOr<perftools::gputools::dnn::VersionInfo>
+CudnnSupport::GetVersion() {
CudnnVersion version;
TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
- return std::make_tuple(version.major_version, version.minor_version,
- version.patch_level);
+ return perftools::gputools::dnn::VersionInfo(
+ version.major_version, version.minor_version, version.patch_level);
}
// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope.
~CudnnSupport() override;
port::Status Init() override;
- port::StatusOr<std::tuple<int, int, int>> GetVersion() override;
+ port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
int num_layers, int hidden_size, int input_size,
string ElementwiseOperationString(ElementwiseOperation op);
+// A simple class representing the version of the backing library, to
+// workaround the "too perfect forwarding" issue in gcc6+ compilers.
+// See PR#16309 and issue #18402 for links discussing the issue.
+class VersionInfo {
+ public:
+ VersionInfo(int major = 0, int minor = 0, int patch = 0)
+ : major_(major), minor_(minor), patch_(patch) {}
+ int major_version() { return major_; }
+ int minor_version() { return minor_; }
+ int patch() { return patch_; }
+ private:
+ int major_;
+ int minor_;
+ int patch_;
+};
+
// Suite of operations typically used for implementing Deep/Convolutional Neural
// Nets. Note: A false return value of an operation indicates the
// implementation is not available.
virtual port::Status Init() = 0;
- // Gets the version of the backing library, as a {major, minor, patch} tuple.
- virtual port::StatusOr<std::tuple<int, int, int>> GetVersion() {
+ // Gets the version of the backing library, as a VersionInfo object.
+ virtual port::StatusOr<VersionInfo> GetVersion() {
return port::UnimplementedError(
"DnnSupport::GetVersion not implemented on this platform.");
}