Add debugging checks for setting cuda stream, so it will check fail if the
authorGuangda Lai <laigd@google.com>
Fri, 13 Apr 2018 17:02:25 +0000 (10:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 17:05:23 +0000 (10:05 -0700)
stream is not set or set to a wrong one when running cudnn methods that
conceptually require a stream.

Also add missing cudnnSetStream()s for DoRnnForwardImpl() and
DoRnnBackwardImpl().

Implementation details:
1. a current_cudnn_stream_ member is added which will be set in cudnnSetStream()
2. a different macro is used to wrap cudnn methods that require a stream in
   order to verify whether the provided stream is same as current_cudnn_stream_,
   and the program will check fail if not

PiperOrigin-RevId: 192783913

tensorflow/stream_executor/cuda/cuda_dnn.cc
tensorflow/stream_executor/cuda/cuda_dnn.h

index 1dc7f99..4a6b2bf 100644 (file)
@@ -169,11 +169,34 @@ static port::ThreadPool* GetCudaThreadpool() {
     }                                                              \
   } __name;
 
+#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM(__name)        \
+  struct WrapperShim__##__name {                                         \
+    template <typename... Args>                                          \
+    cudnnStatus_t operator()(CudnnSupport* dnn, Stream* s, Args... args) \
+        SHARED_LOCKS_REQUIRED(dnn->dnn_handle_mutex_) {                  \
+      CHECK_NOTNULL(s);                                                  \
+      CHECK_EQ(s, dnn->GetCurrentDnnStream())                            \
+          << "Stream is not set correctly!";                             \
+      cuda::ScopedActivateExecutorContext sac{dnn->GetParentExecutor()}; \
+      cudnnStatus_t retval = ::__name(args...);                          \
+      return retval;                                                     \
+    }                                                                    \
+  } __name;
+
+// Handles cudnnSetStream differently in order to add debug information.
+struct WrapperShim__cudnnSetStream {
+  cudnnStatus_t operator()(CudnnSupport* dnn, Stream* stream,
+                           cudnnHandle_t handle)
+      EXCLUSIVE_LOCKS_REQUIRED(dnn->dnn_handle_mutex_) {
+    dnn->SetCurrentDnnStream(stream);
+    cuda::ScopedActivateExecutorContext sac{dnn->GetParentExecutor()};
+    cudnnStatus_t retval = ::cudnnSetStream(handle, AsCUDAStreamValue(stream));
+    return retval;
+  }
+} cudnnSetStream;
+
 // clang-format off
 #define CUDNN_DNN_ROUTINE_EACH(__macro)                   \
-  __macro(cudnnBatchNormalizationBackward)                \
-  __macro(cudnnBatchNormalizationForwardInference)        \
-  __macro(cudnnBatchNormalizationForwardTraining)         \
   __macro(cudnnGetConvolutionNdForwardOutputDim)          \
   __macro(cudnnGetConvolutionForwardAlgorithm)            \
   __macro(cudnnCreateTensorDescriptor)                    \
@@ -190,16 +213,25 @@ static port::ThreadPool* GetCudaThreadpool() {
   __macro(cudnnDestroyConvolutionDescriptor)              \
   __macro(cudnnCreate)                                    \
   __macro(cudnnDestroy)                                   \
-  __macro(cudnnSetStream)                                 \
-  __macro(cudnnActivationForward)                         \
-  __macro(cudnnConvolutionForward)                        \
-  __macro(cudnnConvolutionBackwardBias)                   \
   __macro(cudnnGetConvolutionForwardWorkspaceSize)        \
-  __macro(cudnnTransformTensor)                           \
   __macro(cudnnSetConvolutionNdDescriptor)                \
   __macro(cudnnSetTensor4dDescriptor)                     \
   __macro(cudnnSetTensorNdDescriptor)                     \
-  __macro(cudnnSetFilterNdDescriptor)                     \
+  __macro(cudnnSetFilterNdDescriptor)
+
+// clang-format on
+CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
+#undef CUDNN_DNN_ROUTINE_EACH
+
+// clang-format off
+#define CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(__macro)       \
+  __macro(cudnnBatchNormalizationBackward)                \
+  __macro(cudnnBatchNormalizationForwardInference)        \
+  __macro(cudnnBatchNormalizationForwardTraining)         \
+  __macro(cudnnActivationForward)                         \
+  __macro(cudnnConvolutionForward)                        \
+  __macro(cudnnConvolutionBackwardBias)                   \
+  __macro(cudnnTransformTensor)                           \
   __macro(cudnnPoolingForward)                            \
   __macro(cudnnPoolingBackward)                           \
   __macro(cudnnLRNCrossChannelForward)                    \
@@ -207,9 +239,11 @@ static port::ThreadPool* GetCudaThreadpool() {
   __macro(cudnnAddTensor)                                 \
   __macro(cudnnConvolutionBackwardData)                   \
   __macro(cudnnConvolutionBackwardFilter)
-// clang-format on
 
-CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
+// clang-format on
+CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(
+    PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM)
+#undef CUDNN_DNN_ROUTINE_EACH_WITH_STREAM
 
 // APIs available after R3:
 #if CUDNN_VERSION >= 3000
@@ -225,14 +259,15 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
 // APIs in R3 but not in R5
 // clang-format off
 #if CUDNN_VERSION >= 3000 && CUDNN_VERSION < 5000
-#define CUDNN_DNN_ROUTINE_EACH_R3(__macro)                    \
+#define CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(__macro)        \
   __macro(cudnnAddTensor_v3)                                  \
   __macro(cudnnConvolutionBackwardData_v3)                    \
   __macro(cudnnConvolutionBackwardFilter_v3)
 // clang-format on
 
-CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
-#undef CUDNN_DNN_ROUTINE_EACH_R3
+CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(
+    PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM)
+#undef CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM
 #endif
 
 // APIs in R5
@@ -254,29 +289,44 @@ CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
   __macro(cudnnGetRNNTrainingReserveSize)                     \
   __macro(cudnnGetRNNLinLayerMatrixParams)                    \
   __macro(cudnnGetRNNLinLayerBiasParams)                      \
-  __macro(cudnnRNNForwardInference)                           \
-  __macro(cudnnRNNForwardTraining)                            \
-  __macro(cudnnRNNBackwardData)                               \
-  __macro(cudnnRNNBackwardWeights)                            \
   __macro(cudnnSetRNNDescriptor)                              \
   __macro(cudnnGetFilterNdDescriptor)
 
 // clang-format on
-
 CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
 #undef CUDNN_DNN_ROUTINE_EACH_R5
+
+// clang-format off
+#define CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(__macro)        \
+  __macro(cudnnRNNForwardInference)                           \
+  __macro(cudnnRNNForwardTraining)                            \
+  __macro(cudnnRNNBackwardData)                               \
+  __macro(cudnnRNNBackwardWeights)
+
+// clang-format on
+CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(
+    PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM)
+#undef CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM
 #endif
 
 // APIs in R6
 // clang-format off
 #if CUDNN_VERSION >= 6000
 #define CUDNN_DNN_ROUTINE_EACH_R6(__macro)                    \
-  __macro(cudnnConvolutionBiasActivationForward)              \
   __macro(cudnnSetRNNDescriptor_v6)
 
 // clang-format on
 CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
 #undef CUDNN_DNN_ROUTINE_EACH_R6
+
+// clang-format off
+#define CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(__macro)        \
+  __macro(cudnnConvolutionBiasActivationForward)
+
+// clang-format on
+CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
+    PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM)
+#undef CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM
 #endif
 
 // APIs in R7
@@ -291,8 +341,6 @@ CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
 #undef CUDNN_DNN_ROUTINE_EACH_R7
 #endif
 
-#undef CUDNN_DNN_ROUTINE_EACH
-
 }  // namespace wrap
 
 namespace {
@@ -419,7 +467,7 @@ port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
 }  // namespace
 
 CudnnSupport::CudnnSupport(CUDAExecutor* parent)
-    : parent_(parent), dnn_handle_(nullptr) {}
+    : parent_(parent), dnn_handle_(nullptr), current_dnn_stream_(nullptr) {}
 
 CudnnSupport::~CudnnSupport() {
   auto status = wrap::cudnnDestroy(parent_, ToHandle(dnn_handle_));
@@ -1660,6 +1708,12 @@ bool CudnnSupport::DoRnnForwardImpl(
 
   // check params size
   mutex_lock lock{dnn_handle_mutex_};
+  auto set_stream_status =
+      wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
+  if (set_stream_status != CUDNN_STATUS_SUCCESS) {
+    LOG(FATAL) << "failed to set stream for cudnn handle: "
+               << ToString(set_stream_status);
+  }
 
   if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc,
                              input_desc)) {
@@ -1720,7 +1774,7 @@ bool CudnnSupport::DoRnnForwardImpl(
   cudnnStatus_t status;
   if (!is_training) {
     status = wrap::cudnnRNNForwardInference(
-        parent_, ToHandle(dnn_handle_) /*handle*/,
+        this, stream, ToHandle(dnn_handle_) /*handle*/,
         rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
         input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
         input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
@@ -1733,7 +1787,7 @@ bool CudnnSupport::DoRnnForwardImpl(
         workspace.size() /*workSpaceSizeInBytes*/);
   } else {
     status = wrap::cudnnRNNForwardTraining(
-        parent_, ToHandle(dnn_handle_) /*handle*/,
+        this, stream, ToHandle(dnn_handle_) /*handle*/,
         rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
         input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
         input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
@@ -1810,6 +1864,12 @@ bool CudnnSupport::DoRnnBackwardImpl(
 
   // check params size
   mutex_lock lock{dnn_handle_mutex_};
+  auto set_stream_status =
+      wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
+  if (set_stream_status != CUDNN_STATUS_SUCCESS) {
+    LOG(FATAL) << "failed to set stream for cudnn handle: "
+               << ToString(set_stream_status);
+  }
 
   if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc,
                              input_desc)) {
@@ -1841,10 +1901,11 @@ bool CudnnSupport::DoRnnBackwardImpl(
   }
   // make the backward data call
   cudnnStatus_t status = wrap::cudnnRNNBackwardData(
-      parent_, ToHandle(dnn_handle_) /*handle*/, rnn_desc.handle() /*rnnDesc*/,
-      model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
-      output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
-      output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
+      this, stream, ToHandle(dnn_handle_) /*handle*/,
+      rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
+      output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/,
+      output_desc.handles() /*dyDesc*/, output_backprop_data.opaque() /*dy*/,
+      output_h_desc.handle() /*dhyDesc*/,
       output_h_backprop_data.opaque() /*dhy*/,
       output_c_desc.handle() /*dcyDesc*/,
       output_c_backprop_data.opaque() /*dcy*/,
@@ -1873,7 +1934,7 @@ bool CudnnSupport::DoRnnBackwardImpl(
     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
     // make the backward weight call
     status = wrap::cudnnRNNBackwardWeights(
-        parent_, ToHandle(dnn_handle_) /*handle*/,
+        this, stream, ToHandle(dnn_handle_) /*handle*/,
         rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
         input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
         input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
@@ -2517,8 +2578,7 @@ bool CudnnSupport::DoConvolveImpl(
                                    GetConvComputeType<T>()};
 
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   }
@@ -2668,7 +2728,7 @@ bool CudnnSupport::DoConvolveImpl(
     }
   }
   status = wrap::cudnnConvolutionForward(
-      parent_, ToHandle(dnn_handle_),
+      this, stream, ToHandle(dnn_handle_),
       /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
       /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
       /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
@@ -2737,8 +2797,7 @@ bool CudnnSupport::DoFusedConvolveImpl(
       static_cast<cudnnDataType_t>(cudnn_compute_type)};
 
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   CHECK(status == CUDNN_STATUS_SUCCESS)
       << "failed to set stream for cudnn handle: " << ToString(status);
 
@@ -2804,7 +2863,7 @@ bool CudnnSupport::DoFusedConvolveImpl(
           << "\noutput_data->opaque() = " << output_data->opaque();
 
   status = wrap::cudnnConvolutionBiasActivationForward(
-      parent_, ToHandle(dnn_handle_), /*alpha1=*/&conv_input_scale,
+      this, stream, ToHandle(dnn_handle_), /*alpha1=*/&conv_input_scale,
       /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
       /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
       /*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(),
@@ -3009,8 +3068,7 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
     bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
     std::function<void()> inv_var_to_var) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -3046,7 +3104,7 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
     }
 
     status = wrap::cudnnBatchNormalizationForwardTraining(
-        parent_, ToHandle(dnn_handle_), mode, &one, &zero,
+        this, stream, ToHandle(dnn_handle_), mode, &one, &zero,
         x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(),
         scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), 1.0,
         batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
@@ -3063,7 +3121,7 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
     const void* maybe_inv_var = estimated_variance.opaque();
 #endif
     status = wrap::cudnnBatchNormalizationForwardInference(
-        parent_, ToHandle(dnn_handle_), mode, &one, &zero,
+        this, stream, ToHandle(dnn_handle_), mode, &one, &zero,
         x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(),
         scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(),
         estimated_mean.opaque(), maybe_inv_var, epsilon);
@@ -3114,8 +3172,7 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
     DeviceMemory<U>* offset_backprop) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -3136,7 +3193,7 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
   float zero = 0.0;
 
   status = wrap::cudnnBatchNormalizationBackward(
-      parent_, ToHandle(dnn_handle_), mode, &one, &zero, &one, &zero,
+      this, stream, ToHandle(dnn_handle_), mode, &one, &zero, &one, &zero,
       x_descriptor.handle(), x.opaque(), x_descriptor.handle(),
       y_backprop.opaque(), x_descriptor.handle(), x_backprop->opaque(),
       scale_offset_descriptor.handle(), scale.opaque(),
@@ -3326,7 +3383,7 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
   float alpha = 1.0f;
   float beta = 0.0f;
   auto status = wrap::cudnnTransformTensor(
-      parent_, ToHandle(dnn_handle_), &alpha, orig_out_back_nd.handle(),
+      this, stream, ToHandle(dnn_handle_), &alpha, orig_out_back_nd.handle(),
       backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(),
       (*transform_scratch)->mutable_device_memory()->opaque());
 
@@ -3345,8 +3402,7 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
                                      dnn::DataType output_type, float scale,
                                      DeviceMemoryBase* output_data) {
   mutex_lock lock{dnn_handle_mutex_};
-  cudnnStatus_t status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                              AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   }
@@ -3357,7 +3413,7 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
   ScopedTensorDescriptor output_tensor_desc(
       parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout()));
   status = wrap::cudnnTransformTensor(
-      parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(),
+      this, stream, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(),
       input_data.opaque(), &beta, output_tensor_desc.handle(),
       output_data->opaque());
   if (status != CUDNN_STATUS_SUCCESS) {
@@ -3384,8 +3440,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   }
@@ -3554,7 +3609,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
 #else
   status = wrap::cudnnConvolutionBackwardData_v3(
 #endif
-      parent_, ToHandle(dnn_handle_),
+      this, stream, ToHandle(dnn_handle_),
       /*alpha=*/alpha,
       /*filterDesc=*/filter.handle(),
       /*filterData=*/filter_data.opaque(),
@@ -3655,8 +3710,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   }
@@ -3826,7 +3880,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
 #else
   status = wrap::cudnnConvolutionBackwardFilter_v3(
 #endif
-      parent_, ToHandle(dnn_handle_), /*alpha=*/alpha,
+      this, stream, ToHandle(dnn_handle_), /*alpha=*/alpha,
       /*srcDesc=*/input_nd.handle(),
       /*srcData=*/input_data.opaque(),
       /*diffDesc=*/out_back_nd.handle(),
@@ -3922,8 +3976,7 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl(
     const dnn::BatchDescriptor& bias_descriptor,
     DeviceMemory<T>* backward_bias_data) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   }
@@ -3938,7 +3991,7 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl(
   float beta = 0.0;
 
   status = wrap::cudnnConvolutionBackwardBias(
-      parent_, ToHandle(dnn_handle_), &alpha, input_nd.handle(),
+      this, stream, ToHandle(dnn_handle_), &alpha, input_nd.handle(),
       input_data.opaque(), &beta, bias_nd.handle(),
       backward_bias_data->opaque());
   if (status != CUDNN_STATUS_SUCCESS) {
@@ -4143,8 +4196,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
   }
 
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4158,7 +4210,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
 #else
   status = wrap::cudnnAddTensor_v3(
 #endif
-      parent_, ToHandle(dnn_handle_), &alpha, bias_descriptor.handle(),
+      this, stream, ToHandle(dnn_handle_), &alpha, bias_descriptor.handle(),
       biases.opaque(), &beta, input_descriptor.handle(), output_data->opaque());
 
   if (status != CUDNN_STATUS_SUCCESS) {
@@ -4176,8 +4228,7 @@ bool CudnnSupport::DoActivate(Stream* stream,
                               DeviceMemory<float>* output_data,
                               uint64 options) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4221,7 +4272,7 @@ bool CudnnSupport::DoActivate(Stream* stream,
   // Beta is the output scaling factor.
   float beta = 0.0;
   status = wrap::cudnnActivationForward(
-      parent_, ToHandle(dnn_handle_),
+      this, stream, ToHandle(dnn_handle_),
 #if CUDNN_VERSION >= 5000
       activation_desc.handle(),
 #else
@@ -4245,8 +4296,7 @@ bool CudnnSupport::DoPoolForward(
     const dnn::BatchDescriptor& output_dimensions,
     DeviceMemory<double>* output_data) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4262,7 +4312,7 @@ bool CudnnSupport::DoPoolForward(
                                    CUDNN_DATA_DOUBLE};
   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   status = wrap::cudnnPoolingForward(
-      parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+      this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
       src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
       output_data->opaque());
   if (status != CUDNN_STATUS_SUCCESS) {
@@ -4280,8 +4330,7 @@ bool CudnnSupport::DoPoolForward(
     const dnn::BatchDescriptor& output_dimensions,
     DeviceMemory<float>* output_data) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4297,7 +4346,7 @@ bool CudnnSupport::DoPoolForward(
                                    CUDNN_DATA_FLOAT};
   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   status = wrap::cudnnPoolingForward(
-      parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+      this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
       src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
       output_data->opaque());
   if (status != CUDNN_STATUS_SUCCESS) {
@@ -4315,8 +4364,7 @@ bool CudnnSupport::DoPoolForward(
     const dnn::BatchDescriptor& output_dimensions,
     DeviceMemory<Eigen::half>* output_data) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4331,7 +4379,7 @@ bool CudnnSupport::DoPoolForward(
   ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   status = wrap::cudnnPoolingForward(
-      parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+      this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
       src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
       output_data->opaque());
   if (status != CUDNN_STATUS_SUCCESS) {
@@ -4351,8 +4399,7 @@ bool CudnnSupport::DoPoolBackward(
     const DeviceMemory<double>& input_diff_data,
     DeviceMemory<double>* output_diff_data) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4368,7 +4415,7 @@ bool CudnnSupport::DoPoolBackward(
                                    CUDNN_DATA_DOUBLE};
   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   status = wrap::cudnnPoolingBackward(
-      parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+      this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
       dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
       input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
       src_desc.handle(), output_diff_data->opaque());
@@ -4389,8 +4436,7 @@ bool CudnnSupport::DoPoolBackward(
     const DeviceMemory<float>& input_diff_data,
     DeviceMemory<float>* output_diff_data) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4406,7 +4452,7 @@ bool CudnnSupport::DoPoolBackward(
                                    CUDNN_DATA_FLOAT};
   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   status = wrap::cudnnPoolingBackward(
-      parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+      this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
       dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
       input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
       src_desc.handle(), output_diff_data->opaque());
@@ -4427,8 +4473,7 @@ bool CudnnSupport::DoPoolBackward(
     const DeviceMemory<Eigen::half>& input_diff_data,
     DeviceMemory<Eigen::half>* output_diff_data) {
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4443,7 +4488,7 @@ bool CudnnSupport::DoPoolBackward(
   ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
   ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
   status = wrap::cudnnPoolingBackward(
-      parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+      this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
       dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
       input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
       src_desc.handle(), output_diff_data->opaque());
@@ -4478,8 +4523,7 @@ bool CudnnSupport::DoNormalizeWithDimensions(
 
   // Launch the normalization.
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4494,7 +4538,7 @@ bool CudnnSupport::DoNormalizeWithDimensions(
   float beta = 0.0f;
 
   status = wrap::cudnnLRNCrossChannelForward(
-      parent_, ToHandle(dnn_handle_), normalize.handle(),
+      this, stream, ToHandle(dnn_handle_), normalize.handle(),
       CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(), input_data.opaque(),
       &beta, dims.handle(), output_data->opaque());
   if (status != CUDNN_STATUS_SUCCESS) {
@@ -4521,8 +4565,7 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
   }
 
   mutex_lock lock{dnn_handle_mutex_};
-  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
-                                     AsCUDAStreamValue(stream));
+  auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
   if (status != CUDNN_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
     return false;
@@ -4535,7 +4578,7 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
   float beta = 0.0f;
 
   status = wrap::cudnnLRNCrossChannelBackward(
-      parent_, ToHandle(dnn_handle_), normalize.handle(),
+      this, stream, ToHandle(dnn_handle_), normalize.handle(),
       CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(),
       normalized_data.opaque(), dims.handle(),
       normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
index 0e5368a..7518b23 100644 (file)
@@ -625,10 +625,27 @@ class CudnnSupport : public dnn::DnnSupport {
                          dnn::DataType output_type, float scale,
                          DeviceMemoryBase* output_data) override;
 
- private:
-  // Guards the enqueueing of DNN operations via the dnn_handle_ below.
+  const Stream* GetCurrentDnnStream() const
+      SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) {
+    return current_dnn_stream_;
+  }
+
+  void SetCurrentDnnStream(Stream* stream)
+      EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_) {
+    current_dnn_stream_ = stream;
+  }
+
+  CUDAExecutor* GetParentExecutor() { return parent_; }
+
+  // Guards the enqueueing of DNN operations via the dnn_handle_ below, and
+  // access to current_dnn_stream_.
+  //
+  // This is a public member because we need to add thread safty annotations in
+  // the cudnn wrapper functions in the cc file, which need to access this
+  // mutex (the annotations require C++ permission checks).
   mutex dnn_handle_mutex_;
 
+ private:
   CUDAExecutor* parent_;  // Parent executor object. Not owned.
 
   // cudnn library handle. cudnnHandle_t type is not present in this header to
@@ -636,6 +653,9 @@ class CudnnSupport : public dnn::DnnSupport {
   // single cuda_dnn translation unit.
   void* dnn_handle_ GUARDED_BY(dnn_handle_mutex_);
 
+  // The current cudnn stream that is set by cudnnSetStream().
+  Stream* current_dnn_stream_ GUARDED_BY(dnn_handle_mutex_);
+
   // NOTE(keveman): Temporary data layout transformation until cuDNN supports
   // kBatchYXDepth for backward pass. This function allocates temporary memory,
   // lays out the source data into the temporary but in the kBatchDepthXY