From: Guangda Lai Date: Sat, 24 Mar 2018 00:28:10 +0000 (-0700) Subject: Set the stream in TransformTensor. X-Git-Tag: tflite-v0.1.7~106^2^2~19 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d40c53dd2cb7c0e3ec20ca56f5c3c95038820900;p=platform%2Fupstream%2Ftensorflow.git Set the stream in TransformTensor. PiperOrigin-RevId: 190300166 --- diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 03e3e08..ab5e659 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -3157,12 +3157,18 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) { mutex_lock lock{dnn_handle_mutex_}; + cudnnStatus_t status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_), + AsCUDAStreamValue(stream)); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); + } + float beta = 0.0f; ScopedTensorDescriptor input_tensor_desc( parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout())); ScopedTensorDescriptor output_tensor_desc( parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout())); - cudnnStatus_t status = wrap::cudnnTransformTensor( + status = wrap::cudnnTransformTensor( parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(), input_data.opaque(), &beta, output_tensor_desc.handle(), output_data->opaque());