case dnn::FilterLayout::kOutputInputYX:
format = CUDNN_TENSOR_NCHW;
break;
+ case dnn::FilterLayout::kOutputYXInput:
+ format = CUDNN_TENSOR_NHWC;
+ break;
case dnn::FilterLayout::kOutputInputYX4:
format = CUDNN_TENSOR_NCHW_VECT_C;
break;
output_profile_result);
}
-namespace {
-// NOTE(keveman): Temporary data layout transformation until cuDNN supports
-// kBatchYXDepth for backward pass. This function allocates temporary memory,
-// lays out the source data into the temporary but in the kBatchDepthXY
-// layout, and returns the temporary memory. The caller is responsible for
-// deallocating the temporary. Since the allocation is done using Stream's
-// AllocateTemporaryMemory, a later BlockHostUntilDone could be used for
-// deallocation.
-//
-// transform_scratch is populated with a legitimate temporary allocation iff
-// the original output data needs to be transformed.
-template <class T>
-DeviceMemory<T> MaybeTransformLayout(
- Stream* stream, const CudnnHandle& cudnn,
- dnn::BatchDescriptor* output_descriptor,
- DeviceMemory<T> backward_output_data,
- std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch) {
- if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) {
- return backward_output_data;
- }
- CHECK(output_descriptor->layout() == dnn::DataLayout::kBatchYXDepth);
- *transform_scratch =
- stream->AllocateTemporaryArray<T>(backward_output_data.ElementCount())
- .ConsumeValueOrDie();
- dnn::BatchDescriptor transformed_output_descriptor;
- transformed_output_descriptor.CloneFrom(*output_descriptor);
- transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX);
- cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
- ScopedTensorDescriptor orig_out_back_nd(*output_descriptor, cudnn_type);
- ScopedTensorDescriptor transformed_out_back_nd(transformed_output_descriptor,
- cudnn_type);
-
- float alpha = 1.0f;
- float beta = 0.0f;
- auto status = cudnnTransformTensor(
- cudnn.handle(), &alpha, orig_out_back_nd.handle(),
- backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(),
- (*transform_scratch)->mutable_device_memory()->opaque());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "Failed to transform the data layout.";
- }
- output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX);
- return (*transform_scratch)->device_memory();
-}
-} // namespace
-
bool CudnnSupport::DoTransformTensor(Stream* stream,
const dnn::BatchDescriptor& input_desc,
dnn::DataType input_type,
bool CudnnSupport::DoConvolveBackwardDataImpl(
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
- const dnn::BatchDescriptor& output_descriptor_in,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<T> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& input_descriptor,
auto cudnn = cudnn_->GetHandle(parent_, stream);
- // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
- dnn::BatchDescriptor output_descriptor;
- output_descriptor.CloneFrom(output_descriptor_in);
- std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
- backward_output_data =
- MaybeTransformLayout(stream, cudnn, &output_descriptor,
- backward_output_data, &transform_scratch);
-
ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
ScopedTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
bool CudnnSupport::DoConvolveBackwardFilterImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
- const dnn::BatchDescriptor& output_descriptor_in,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<T> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
auto cudnn = cudnn_->GetHandle(parent_, stream);
- // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
- dnn::BatchDescriptor output_descriptor;
- output_descriptor.CloneFrom(output_descriptor_in);
- std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
- backward_output_data =
- MaybeTransformLayout(stream, cudnn, &output_descriptor,
- backward_output_data, &transform_scratch);
-
ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);