Automated g4 rollback of changelist 197118212
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 18 May 2018 17:42:50 +0000 (10:42 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 17:45:22 +0000 (10:45 -0700)
PiperOrigin-RevId: 197167501

tensorflow/stream_executor/cuda/cuda_dnn.cc

index d82d36c..7ace7fd 100644 (file)
@@ -53,8 +53,6 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
 
 namespace {
 
-static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher");
-
 // Converts (via narrowing) a type T value to a type U, and checks that the
 // value has no value change due to the conversion.
 template <typename WideT, typename NarrowT>
@@ -95,6 +93,7 @@ string ToString(cudnnStatus_t status) {
   }
 }
 
+#if CUDNN_VERSION >= 6000
 string ToString(libraryPropertyType type) {
   switch (type) {
     case MAJOR_VERSION:
@@ -108,6 +107,7 @@ string ToString(libraryPropertyType type) {
           "<unknown libraryPropertyType: ", static_cast<int>(type), ">");
   }
 }
+#endif
 
 template <typename T>
 cudnnDataType_t GetCudnnDataType();
@@ -213,8 +213,12 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
     case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
     case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
     case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
+#if CUDNN_VERSION >= 5000
     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
+#endif
+#if CUDNN_VERSION >= 5100
     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
+#endif
       return algo;
     default:
       LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
@@ -231,8 +235,12 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
+#if CUDNN_VERSION >= 5000
     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
+#endif
+#if CUDNN_VERSION >= 5100
     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
+#endif
       return algo;
     default:
       LOG(FATAL)
@@ -250,12 +258,11 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
+#if CUDNN_VERSION >= 5100
     // Based on cudnn.h, the following is not implemented.
     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
-      // Produces incorrect results for some shapes. Disabled for now, see
-      // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
-      // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
+#endif
       return algo;
     default:
       LOG(FATAL)
@@ -264,6 +271,7 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
   }
 }
 
+#if CUDNN_VERSION >= 6000
 port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
   cudnnStatus_t status = cudnnGetProperty(type, value);
   if (status != CUDNN_STATUS_SUCCESS) {
@@ -292,11 +300,19 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) {
     }
   }
 }
+#endif
 
 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
+#if CUDNN_VERSION >= 6000
   TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version));
   TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version));
   TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level));
+#else
+  size_t loaded_version = ::cudnnGetVersion();
+  version->major_version = loaded_version / 1000;
+  version->minor_version = (loaded_version / 100) % 10;
+  version->patch_level = loaded_version % 100;
+#endif
   return port::Status::OK();
 }
 
@@ -402,6 +418,7 @@ class ScopedTensorDescriptor {
                      << " to cudnn tensor descriptor: " << ToString(status);
         }
       } break;
+#if CUDNN_VERSION >= 6000
       case dnn::DataLayout::kBatchDepthYX4: {
         status = cudnnSetTensor4dDescriptor(
             handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type,
@@ -413,6 +430,7 @@ class ScopedTensorDescriptor {
                      << " to cudnn tensor descriptor: " << ToString(status);
         }
       } break;
+#endif
       default:
         LOG(FATAL) << "Unsupported tensor format "
                    << DataLayoutString(batch_descriptor.layout());
@@ -448,6 +466,7 @@ class ScopedFilterDescriptor {
                  << ToString(status);
     }
 
+#if CUDNN_VERSION >= 5000
     // TODO(b/23032134): Even if the filter layout is not supported,
     // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it
     // does not take layout as an input. Maybe force cuDNN by giving wrong
@@ -457,14 +476,17 @@ class ScopedFilterDescriptor {
       case dnn::FilterLayout::kOutputInputYX:
         format = CUDNN_TENSOR_NCHW;
         break;
+#if CUDNN_VERSION >= 6000
       case dnn::FilterLayout::kOutputInputYX4:
         format = CUDNN_TENSOR_NCHW_VECT_C;
         break;
+#endif
       default:
         LOG(FATAL) << "Unsupported filter format "
                    << FilterLayoutString(filter_descriptor.layout());
         break;
     }
+#endif
 
     std::vector<int> dims(2 + filter_descriptor.ndims());
     dims[0] = filter_descriptor.output_feature_map_count();
@@ -472,8 +494,11 @@ class ScopedFilterDescriptor {
     const auto& spatial_dims = filter_descriptor.input_filter_dims();
     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
 
-    status = cudnnSetFilterNdDescriptor(handle_, elem_type, format, dims.size(),
-                                        dims.data());
+    status = cudnnSetFilterNdDescriptor(handle_, elem_type,
+#if CUDNN_VERSION >= 5000
+                                        format,
+#endif
+                                        dims.size(), dims.data());
     if (status != CUDNN_STATUS_SUCCESS) {
       LOG(FATAL) << "could not set cudnn filter descriptor: "
                  << ToString(status);
@@ -667,8 +692,10 @@ class ScopedPoolingDescriptor {
         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
              ? CUDNN_POOLING_MAX
              : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
-        propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
-        shape.data(), padding.data(), strides.data());
+#if CUDNN_VERSION >= 5000
+        propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN,
+#endif
+        nd, shape.data(), padding.data(), strides.data());
     if (status != CUDNN_STATUS_SUCCESS) {
       LOG(FATAL) << "could not set cudnn pooling descriptor: "
                  << ToString(status);
@@ -744,6 +771,7 @@ class ScopedNormalizeDescriptor {
   SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
 };
 
+#if CUDNN_VERSION >= 5000
 // Turns a ActivationDescriptor structure into a cudnn activation
 // descriptor handle within a scope.
 class ScopedActivationDescriptor {
@@ -806,6 +834,7 @@ class ScopedActivationDescriptor {
 
   SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
 };
+#endif
 
 cudnnDataType_t ToCudnnDataType(
     dnn::DataType data_type,
@@ -815,14 +844,18 @@ cudnnDataType_t ToCudnnDataType(
     case dnn::DataType::kDouble:
     case dnn::DataType::kHalf:
       return static_cast<cudnnDataType_t>(data_type);
+#if CUDNN_VERSION >= 6000
     case dnn::DataType::kInt8:
       return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
                                                             : CUDNN_DATA_INT8;
+#endif
     default:
       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
   }
 }
 
+#if CUDNN_VERSION >= 5000
+
 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
   switch (input_mode) {
     case dnn::RnnInputMode::kRnnLinearSkip:
@@ -870,11 +903,15 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
   }
 }
 
+#endif  // CUDNN_VERSION
+
 template <typename Base>
 class MixinBase : public Base {};
 template <>
 class MixinBase<void> {};
 
+#if CUDNN_VERSION >= 5000
+
 #define CUDNN_RETURN_IF_FAIL(STATUS, ...)                                \
   if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) {              \
     string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
@@ -1005,7 +1042,9 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
         hidden_size_(hidden_size),
         input_size_(input_size),
         batch_size_(batch_size),
+#if CUDNN_VERSION >= 6000
         rnn_plan_(nullptr),
+#endif
         input_mode_(input_mode),
         direction_mode_(direction_mode),
         rnn_mode_(rnn_mode),
@@ -1023,6 +1062,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
     // Create the RNN handle
     cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_);
     CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
+#if CUDNN_VERSION >= 6000
     // TODO: allow the user to choose an algorithm.
     rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm());
     status = cudnnSetRNNDescriptor_v6(
@@ -1044,6 +1084,16 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
       status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_);
       CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan.");
     }
+#else
+    CHECK(algorithm_config_.is_default())
+        << "Non-default algorithm not supported for CUDA version < 6.0";
+    status = cudnnSetRNNDescriptor(
+        /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
+        /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(),
+        /*inputMode=*/input_mode, /*direction=*/direction_mode,
+        /*mode=*/rnn_mode, /*dataType=*/compute_type);
+    CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
+#endif
 
     // Create the params handle.
     cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this));
@@ -1056,10 +1106,12 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
   ~CudnnRnnDescriptor() override {
     if (rnn_desc_) {
       cudnnStatus_t status;
+#if CUDNN_VERSION >= 6000
       if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) {
         status = cudnnDestroyPersistentRNNPlan(rnn_plan_);
         CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan.");
       }
+#endif
       status = cudnnDestroyRNNDescriptor(rnn_desc_);
       CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
     }
@@ -1120,8 +1172,10 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
   // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
   // algorithm.
   int batch_size_;
+#if CUDNN_VERSION >= 6000
   cudnnRNNAlgo_t rnn_algo_;
   cudnnPersistentRNNPlan_t rnn_plan_;
+#endif
   cudnnRNNInputMode_t input_mode_;
   cudnnDirectionMode_t direction_mode_;
   cudnnRNNMode_t rnn_mode_;
@@ -1752,6 +1806,8 @@ bool CudnnSupport::DoRnnBackwardImpl(
   return true;
 }
 
+#endif  // CUDNN_VERSION
+
 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
 CudnnSupport::createRnnDescriptor(
     int num_layers, int hidden_size, int input_size, int batch_size,
@@ -1759,6 +1815,7 @@ CudnnSupport::createRnnDescriptor(
     dnn::RnnMode rnn_mode, dnn::DataType data_type,
     const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
     ScratchAllocator* state_allocator) {
+#if CUDNN_VERSION >= 5000
   // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
   // not enqueueing anything into a stream, we pass in the null stream.
   auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
@@ -1773,12 +1830,20 @@ CudnnSupport::createRnnDescriptor(
   }
   return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
       std::move(rnn_desc));
+#else
+  string error_msg =
+      port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ",
+                   "Current Cudnn version: ", CUDNN_VERSION, ". ");
+  LOG(ERROR) << error_msg;
+  return port::Status(port::error::UNIMPLEMENTED, error_msg);
+#endif  // CUDNN_VERSION
 }
 
 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
 CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
                                                 int data_size,
                                                 dnn::DataType data_type) {
+#if CUDNN_VERSION >= 5000
   std::unique_ptr<CudnnRnnSequenceTensorDescriptor> seq_desc(
       new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size,
                                            data_size,
@@ -1788,12 +1853,20 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
   }
   return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
       std::move(seq_desc));
+#else
+  string error_msg = port::StrCat(
+      "createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ",
+      "Current Cudnn version: ", CUDNN_VERSION, ". ");
+  LOG(ERROR) << error_msg;
+  return port::Status(port::error::UNIMPLEMENTED, error_msg);
+#endif  // CUDNN_VERSION
 }
 
 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
                                              int data_size,
                                              dnn::DataType data_type) {
+#if CUDNN_VERSION >= 5000
   std::unique_ptr<CudnnRnnStateTensorDescriptor> state_desc(
       new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
                                         data_size, ToCudnnDataType(data_type)));
@@ -1802,6 +1875,13 @@ CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
   }
   return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
       std::move(state_desc));
+#else
+  string error_msg = port::StrCat(
+      "createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ",
+      "Current Cudnn version: ", CUDNN_VERSION, ". ");
+  LOG(ERROR) << error_msg;
+  return port::Status(port::error::UNIMPLEMENTED, error_msg);
+#endif  // CUDNN_VERSION
 }
 
 bool CudnnSupport::DoRnnForward(
@@ -1822,6 +1902,7 @@ bool CudnnSupport::DoRnnForward(
     ScratchAllocator* reserve_space_allocator,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
   const CudnnRnnDescriptor& cudnn_rnn_desc =
       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -1843,6 +1924,9 @@ bool CudnnSupport::DoRnnForward(
       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
       output_c_data, is_training, reserve_space_allocator, workspace_allocator,
       output_profile_result);
+#else
+  return false;
+#endif  // CUDNN_VERSION
 }
 
 bool CudnnSupport::DoRnnForward(
@@ -1862,6 +1946,7 @@ bool CudnnSupport::DoRnnForward(
     ScratchAllocator* reserve_space_allocator,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
   const CudnnRnnDescriptor& cudnn_rnn_desc =
       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -1883,6 +1968,9 @@ bool CudnnSupport::DoRnnForward(
       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
       output_c_data, is_training, reserve_space_allocator, workspace_allocator,
       output_profile_result);
+#else
+  return false;
+#endif  // CUDNN_VERSION
 }
 
 bool CudnnSupport::DoRnnForward(
@@ -1903,6 +1991,7 @@ bool CudnnSupport::DoRnnForward(
     ScratchAllocator* reserve_space_allocator,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
   const CudnnRnnDescriptor& cudnn_rnn_desc =
       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -1924,6 +2013,9 @@ bool CudnnSupport::DoRnnForward(
       output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
       output_c_data, is_training, reserve_space_allocator, workspace_allocator,
       output_profile_result);
+#else
+  return false;
+#endif  // CUDNN_VERSION
 }
 
 bool CudnnSupport::DoRnnBackward(
@@ -1951,6 +2043,7 @@ bool CudnnSupport::DoRnnBackward(
     DeviceMemory<uint8>* reserve_space_data,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
   const CudnnRnnDescriptor& cudnn_rnn_desc =
       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -1974,6 +2067,9 @@ bool CudnnSupport::DoRnnBackward(
       output_c_backprop_data, input_backprop_data, input_h_backprop_data,
       input_c_backprop_data, params_backprop_data, reserve_space_data,
       workspace_allocator, output_profile_result);
+#else
+  return false;
+#endif  // CUDNN_VERSION
 }
 
 bool CudnnSupport::DoRnnBackward(
@@ -2000,6 +2096,7 @@ bool CudnnSupport::DoRnnBackward(
     DeviceMemory<uint8>* reserve_space_data,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
   const CudnnRnnDescriptor& cudnn_rnn_desc =
       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -2023,6 +2120,9 @@ bool CudnnSupport::DoRnnBackward(
       output_c_backprop_data, input_backprop_data, input_h_backprop_data,
       input_c_backprop_data, params_backprop_data, reserve_space_data,
       workspace_allocator, output_profile_result);
+#else
+  return false;
+#endif  // CUDNN_VERSION
 }
 
 bool CudnnSupport::DoRnnBackward(
@@ -2050,6 +2150,7 @@ bool CudnnSupport::DoRnnBackward(
     DeviceMemory<uint8>* reserve_space_data,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
   const CudnnRnnDescriptor& cudnn_rnn_desc =
       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -2073,6 +2174,9 @@ bool CudnnSupport::DoRnnBackward(
       output_c_backprop_data, input_backprop_data, input_h_backprop_data,
       input_c_backprop_data, params_backprop_data, reserve_space_data,
       workspace_allocator, output_profile_result);
+#else
+  return false;
+#endif  // CUDNN_VERSION
 }
 
 namespace {
@@ -2207,12 +2311,16 @@ class CudnnEnvVar {
 };
 
 // A helper struct to decide whether to enable the FFT_TILING algorithms for
-// forward convolution. It is disabled for cuDNN < 7 due to memory corruption
-// caused by some shapes with this algorithm. Users can explicitly enable the
-// algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
+// forward convolution. Before cudnn v5.1 it works fine but since cudnn v5.1
+// it is turned off due to memory corruption caused by some shapes with this
+// algorithm.
+// Before NVIDIA fixes the memory corruption bug, users can explicitly
+// enable the algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
 struct FftTilingForward {
   static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
-  static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000;
+  // TODO(yangzihao): turn the default to True when the memory corruption bug
+  // is fixed.
+  static constexpr bool kDefaultFlag = CUDNN_VERSION < 5100;
 };
 
 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
@@ -2221,9 +2329,10 @@ struct FftTilingForward {
 // https://github.com/tensorflow/tensorflow/pull/4901
 struct WinogradNonfused {
   static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
-  // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
-  // we have a workaround.
-  static constexpr bool kDefaultFlag = true;
+  // NVIDIA has fixed winograd nonfused bug for cudnn v>=7.
+  // For cudnn v>=5.1, we have a workaround and for any lower version, we
+  // disable it by default.
+  static constexpr bool kDefaultFlag = CUDNN_VERSION >= 5100;
 };
 
 // A helper struct to decide whether to use FP32 as the internal compute type
@@ -2512,6 +2621,11 @@ bool CudnnSupport::DoFusedConvolveImpl(
     DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION < 6000
+  LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
+                "supported for cuDNN version >= 6";
+  return false;
+#else
   ScopedTensorDescriptor conv_input_nd(
       conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
   ScopedTensorDescriptor output_nd(
@@ -2618,27 +2732,32 @@ bool CudnnSupport::DoFusedConvolveImpl(
   }
 
   return true;
+#endif  // CUDNN_VERSION < 6000
 }
 
 bool CudnnSupport::GetConvolveAlgorithms(
     bool with_winograd_nonfused, int cc_major, int cc_minor,
     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
-      // clang-format off
+    // clang-format off
     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
     CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
     CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
     CUDNN_CONVOLUTION_FWD_ALGO_FFT,
+#if CUDNN_VERSION >= 5000
     CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
-      // clang-format on
+#endif
+    // clang-format on
   };
   if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
     algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
   }
+#if CUDNN_VERSION >= 5100
   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
     algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
   }
+#endif
 
   out_algorithms->clear();
   for (auto i : algo_types) {
@@ -2653,11 +2772,13 @@ bool CudnnSupport::GetConvolveAlgorithms(
 bool CudnnSupport::GetRnnAlgorithms(
     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
-      // clang-format off
+  // clang-format off
+#if CUDNN_VERSION >= 6000
     CUDNN_RNN_ALGO_STANDARD,
     CUDNN_RNN_ALGO_PERSIST_STATIC,
     CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
-      // clang-format on
+#endif
+    // clang-format on
   };
 
   out_algorithms->clear();
@@ -2676,17 +2797,21 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
     bool with_winograd_nonfused, int cc_major, int cc_minor,
     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
-      // clang-format off
+    // clang-format off
     CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
     CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
+#if CUDNN_VERSION >= 5000
     CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
-      // clang-format on
+#endif
+    // clang-format on
   };
+#if CUDNN_VERSION >= 5100
   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
   }
+#endif
 
   out_algorithms->clear();
   for (auto i : algo_types) {
@@ -2709,15 +2834,13 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
       // Based on cudnn.h, the following is not implemented.
       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
-
-      // Produces incorrect results for some shapes. Disabled for now, see
-      // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
-      // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
       // clang-format on
   };
+#if CUDNN_VERSION >= 5100
   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
   }
+#endif
 
   out_algorithms->clear();
   for (auto i : algo_types) {
@@ -2816,8 +2939,17 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
         scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque,
         batch_var_opaque, epsilon, saved_mean->opaque(),
         saved_inv_var->opaque());
+#if CUDNN_VERSION < 5000
+    CHECK(inv_var_to_var);
+    inv_var_to_var();
+#endif
   } else {
+#if CUDNN_VERSION < 5000
+    CHECK(var_to_inv_var);
+    const void* maybe_inv_var = var_to_inv_var().opaque();
+#else
     const void* maybe_inv_var = estimated_variance.opaque();
+#endif
     status = cudnnBatchNormalizationForwardInference(
         cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
@@ -3027,6 +3159,11 @@ bool CudnnSupport::DoFusedConvolve(
     DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION < 6000
+  LOG(WARNING) << "cudnnConvolutionBiasActivationForward() is only "
+                  "supported for cuDNN version >= 6";
+  return false;
+#else
   int cc_major, cc_minor;
   stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
                                                                    &cc_minor);
@@ -3042,6 +3179,7 @@ bool CudnnSupport::DoFusedConvolve(
       side_input_scale, bias_descriptor, biases, activation_mode,
       output_descriptor, output_data, scratch_allocator, algorithm_config,
       output_profile_result);
+#endif
 }
 
 namespace {
@@ -3290,8 +3428,13 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
     timer->Start(AsCUDAStream(stream));
   }
 
+#if CUDNN_VERSION >= 5000
   auto status =
       cudnnConvolutionBackwardData(cudnn.handle(),
+#else
+  auto status =
+      cudnnConvolutionBackwardData_v3(cudnn.handle(),
+#endif
                                    /*alpha=*/alpha,
                                    /*wDesc=*/filter.handle(),
                                    /*w=*/filter_data.opaque(),
@@ -3554,8 +3697,13 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
     timer->Start(AsCUDAStream(stream));
   }
 
+#if CUDNN_VERSION >= 5000
   auto status = cudnnConvolutionBackwardFilter(
       cudnn.handle(),
+#else
+  auto status = cudnnConvolutionBackwardFilter_v3(
+      cudnn.handle(),
+#endif
       /*alpha=*/alpha,
       /*srcDesc=*/input_nd.handle(),
       /*srcData=*/input_data.opaque(),
@@ -3868,7 +4016,11 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
 
   auto cudnn = cudnn_->GetHandle(parent_, stream);
 
+#if CUDNN_VERSION >= 5000
   auto status = cudnnAddTensor(
+#else
+  auto status = cudnnAddTensor_v3(
+#endif
       cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta,
       input_descriptor.handle(), output_data->opaque());
 
@@ -3886,8 +4038,37 @@ bool CudnnSupport::DoActivate(Stream* stream,
                               const DeviceMemory<float>& input_data,
                               DeviceMemory<float>* output_data,
                               uint64 options) {
+#if CUDNN_VERSION >= 5000
   ScopedActivationDescriptor activation_desc(
       activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
+#else
+  cudnnActivationMode_t mode;
+  switch (activation_mode) {
+    case dnn::ActivationMode::kRelu6:
+      // TODO(leary) should probably do a post-pass to clip at 6?
+      LOG(WARNING) << "user requested Relu6, but providing Relu instead";
+      mode = CUDNN_ACTIVATION_RELU;
+      break;
+    case dnn::ActivationMode::kReluX:
+      // TODO(broune) should probably do a post-pass to clip at X?
+      LOG(WARNING) << "user requested ReluX, but providing Relu instead";
+      mode = CUDNN_ACTIVATION_RELU;
+      break;
+    case dnn::ActivationMode::kRelu:
+      mode = CUDNN_ACTIVATION_RELU;
+      break;
+    case dnn::ActivationMode::kSigmoid:
+      mode = CUDNN_ACTIVATION_SIGMOID;
+      break;
+    case dnn::ActivationMode::kTanh:
+      mode = CUDNN_ACTIVATION_TANH;
+      break;
+    default:
+      LOG(ERROR) << "unrecognized activation mode: "
+                 << static_cast<int>(activation_mode);
+      return false;
+  }
+#endif
 
   ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
   // Alpha is the input scaling factor.
@@ -3896,9 +4077,15 @@ bool CudnnSupport::DoActivate(Stream* stream,
   float beta = 0.0;
 
   auto cudnn = cudnn_->GetHandle(parent_, stream);
-  auto status = cudnnActivationForward(
-      cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
-      input_data.opaque(), &beta, input_nd.handle(), output_data->opaque());
+  auto status =
+      cudnnActivationForward(cudnn.handle(),
+#if CUDNN_VERSION >= 5000
+                             activation_desc.handle(),
+#else
+                             mode,
+#endif
+                             &alpha, input_nd.handle(), input_data.opaque(),
+                             &beta, input_nd.handle(), output_data->opaque());
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "stream " << stream
                << " could not enqueue activation: " << ToString(status);