Set the stream in TransformTensor.
authorGuangda Lai <laigd@google.com>
Sat, 24 Mar 2018 00:28:10 +0000 (17:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 25 Mar 2018 11:45:46 +0000 (04:45 -0700)
PiperOrigin-RevId: 190300166

tensorflow/stream_executor/cuda/cuda_dnn.cc

index 03e3e08..ab5e659 100644 (file)
@@ -3157,12 +3157,18 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
                                      dnn::DataType output_type, float scale,
                                      DeviceMemoryBase* output_data) {
   mutex_lock lock{dnn_handle_mutex_};
+  cudnnStatus_t status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+                                              AsCUDAStreamValue(stream));
+  if (status != CUDNN_STATUS_SUCCESS) {
+    LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
+  }
+
   float beta = 0.0f;
   ScopedTensorDescriptor input_tensor_desc(
       parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout()));
   ScopedTensorDescriptor output_tensor_desc(
       parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout()));
-  cudnnStatus_t status = wrap::cudnnTransformTensor(
+  status = wrap::cudnnTransformTensor(
       parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(),
       input_data.opaque(), &beta, output_tensor_desc.handle(),
       output_data->opaque());