dnn::DataType output_type, float scale,
DeviceMemoryBase* output_data) {
mutex_lock lock{dnn_handle_mutex_};
+ cudnnStatus_t status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
+ }
+
float beta = 0.0f;
ScopedTensorDescriptor input_tensor_desc(
parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout()));
ScopedTensorDescriptor output_tensor_desc(
parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout()));
- cudnnStatus_t status = wrap::cudnnTransformTensor(
+ status = wrap::cudnnTransformTensor(
parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(),
input_data.opaque(), &beta, output_tensor_desc.handle(),
output_data->opaque());