ROCm MIOpen NHWC Convolution support (#63617)
authorAswin John Mathews <Aswin.Mathews@amd.com>
Fri, 10 Sep 2021 15:05:21 +0000 (08:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 15:06:32 +0000 (08:06 -0700)
Summary:
- Added 2D-Convolution NHWC support
  - on ROCm 4.3, with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` flag
  - May need to force MIOpen to search for solutions ( see examples below for flags )

**PYTORCH_MIOPEN_SUGGEST_NHWC Environment Flag**
MIOpen does not officially support NHWC yet, although convolution support has been added to tip-of-tree of MIOpen. This flag is intended to be a short-lived flag to explicitly turn on NHWC support until ROCm officially supports NHWC and performance is verified.

**Examples**
1. Example usage 1 : Run test on ROCm4.3
`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_MIOPEN_SUGGEST_NHWC=1 MIOPEN_FIND_ENFORCE=4 MIOPEN_DEBUG_CONV_GEMM=0 MIOPEN_FIND_MODE=1 pytest test_nn.py -v -k "test_conv_cudnn_nhwc" `
2. Example usage 2: Run the following with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` on ROCm4.3.
```
#!/usr/bin/env python3
import torch
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
model = model.to(memory_format=torch.channels_last)
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)

# should print True for is_contiguous(channels_last), and strides must match NHWC format
print(input.is_contiguous(memory_format=torch.channels_last), input.shape, input.stride() )

out = model(input)

# should print True for is_contiguous(channels_last), and strides must match NHWC format
print("Contiguous channel last :", out.is_contiguous(memory_format=torch.channels_last), " out shape :",  out.shape, "out stride :", out.stride() )
```

See https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html for more examples.

cc jeffdaily sunway513 jithunnair-amd ROCmSupport

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63617

Reviewed By: saketh-are

Differential Revision: D30730800

Pulled By: ezyang

fbshipit-source-id: 61906a0f30be8299e6547d312ae6ac91cc7c3238

aten/src/ATen/miopen/Descriptors.cpp
aten/src/ATen/miopen/Descriptors.h
aten/src/ATen/native/ConvUtils.h
aten/src/ATen/native/Convolution.cpp
aten/src/ATen/native/miopen/Conv_miopen.cpp
c10/util/env.h
test/test_nn.py
torch/testing/_internal/common_device_type.py
torch/testing/_internal/common_utils.py

index 3887519..6911b1a 100644 (file)
@@ -90,7 +90,7 @@ std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
 
 void TensorDescriptor::print() { std::cout << *this; }
 
-void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
+void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) {
   auto dim = t.ndimension();
   if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
 #define _STR(X) #X
@@ -98,9 +98,9 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
     throw std::runtime_error("MIOpen supports only up to " STR(MIOPEN_DIM_MAX) " dimensions");
 #undef _STR
 #undef STR
-  if (!t.is_contiguous()) {
-    throw std::runtime_error("MIOpen filters (a.k.a. weights) must be contiguous");
-  }
+  TORCH_CHECK(t.is_contiguous(memory_format),
+      "MIOpen filters (a.k.a. weights) must be contiguous");
+
   int size[MIOPEN_DIM_MAX];
   int stride[MIOPEN_DIM_MAX];
   for (int i = 0; i < dim; ++i) {
@@ -109,9 +109,15 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
   for (int i = dim; i < pad; ++i) {
     size[i] = (int) 1;
   }
-  for (int i = dim - 1; i >=0; --i) {
-    stride[i] = (i == dim - 1) ? 1 : stride[i+1] * size[i+1];
+
+  for (int i = pad; i >= dim; --i ) {
+      stride[i] = 1;
   }
+  for (int i = dim-1 ; i >=0; --i ) {
+      // Pass-through
+      stride[i] = t.stride(i);
+  }
+
   dim = std::max(dim, pad);
   set(getDataType(t), (int) dim, size, stride);
 }
index 69304a3..a376b30 100644 (file)
@@ -18,20 +18,6 @@ inline int dataSize(miopenDataType_t dataType)
   }
 }
 
-// This function modifies 'stride' in place so that the stride for
-// dim i is the product of the sizes of dims i+1 to the end.
-static inline void fixSizeOneDimStride(int dim, const int *size, int *stride) {
-  int64_t z = 1;
-  for(int d = dim-1; d >= 0; d--)
-  {
-    if (size[d] == 1) {
-      stride[d] = z;
-    } else {
-      z *= size[d];
-    }
-  }
-}
-
 template <typename T, miopenStatus_t (*dtor)(T*)>
 struct DescriptorDeleter {
   void operator()(T* x) {
@@ -96,7 +82,6 @@ public:
 
 private:
   void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
-    fixSizeOneDimStride(dim, size, stride);
     MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
   }
 };
@@ -108,12 +93,15 @@ class FilterDescriptor
                       &miopenCreateTensorDescriptor,
                       &miopenDestroyTensorDescriptor>
 {
-public:
-  void set(const at::Tensor &t, int64_t pad = 0);
+ public:
+  void set(const at::Tensor &t, int64_t pad = 0) {
+    set(t, at::MemoryFormat::Contiguous, pad);
+  }
+
+  void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
 
 private:
   void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
-    fixSizeOneDimStride(dim, size, stride);
     MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
   }
 };
index 191c016..9f0abaf 100644 (file)
@@ -1,5 +1,6 @@
 #pragma once
 #include <ATen/detail/CUDAHooksInterface.h>
+#include <c10/util/env.h>
 
 namespace at { namespace native {
 
@@ -106,4 +107,33 @@ static inline bool cudnn_conv_use_channels_last(const at::Tensor& input, const a
   return can_use_cudnn_channels_last_2d || can_use_cudnn_channels_last_3d;
 }
 
+static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  // disable NHWC for float64 input.
+  if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
+      input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
+    return false;
+  }
+
+  bool can_use_miopen_channels_last_2d = false;
+#if defined(USE_ROCM) && (ROCM_VERSION >= 40300)
+  // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
+  // See #64427
+  static c10::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
+
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+
+  can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC &&  *PYTORCH_MIOPEN_SUGGEST_NHWC && (
+            ( (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
+            (weight_memory_format == at::MemoryFormat::ChannelsLast) )
+        );
+#endif
+
+  bool can_use_miopen_channels_last_3d = false;
+
+  return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
+}
+
 }} // namespace at::native
index 4a4d83f..78eb889 100644 (file)
@@ -838,9 +838,14 @@ at::Tensor _convolution(
     weight = view4d(weight);
   }
 
-  at::MemoryFormat cudnn_memory_format = at::MemoryFormat::Contiguous;
-  if (cudnn_conv_use_channels_last(input, weight)) {
-    cudnn_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
+  at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
+
+  if (detail::getCUDAHooks().compiledWithCuDNN() && cudnn_conv_use_channels_last(input, weight)) {
+    backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
+  }
+
+  if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
+    backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
   }
 
   Tensor output;
@@ -853,7 +858,7 @@ at::Tensor _convolution(
       auto dilation = params.dilation;
       if (params.use_cudnn_depthwise(input, weight)) {
         output = at::cudnn_convolution(
-            input.contiguous(cudnn_memory_format), weight,
+            input.contiguous(backend_memory_format), weight,
             padding, stride, dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
         if (bias.defined()) {
           output.add_(reshape_bias(input.dim(), bias));
@@ -861,7 +866,7 @@ at::Tensor _convolution(
 
       } else if (params.use_miopen(input, weight, bias.defined())){
         output = at::miopen_depthwise_convolution(
-            input.contiguous(), weight, bias,
+            input.contiguous(backend_memory_format), weight, bias,
             padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
       } else {
           if (input.ndimension() == 4) {
@@ -882,14 +887,14 @@ at::Tensor _convolution(
 
     if (params.transposed) {
       output = at::cudnn_convolution_transpose(
-          input.contiguous(cudnn_memory_format), weight,
+          input.contiguous(backend_memory_format), weight,
           params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
       if (bias.defined()) {
         output.add_(reshape_bias(input.dim(), bias));
       }
     } else {
       output = at::cudnn_convolution(
-          input.contiguous(cudnn_memory_format), weight,
+          input.contiguous(backend_memory_format), weight,
           params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
       if (bias.defined()) {
         output.add_(reshape_bias(input.dim(), bias));
@@ -905,11 +910,11 @@ at::Tensor _convolution(
 
     if (params.transposed) {
       output = at::miopen_convolution_transpose(
-          input.contiguous(), weight, bias,
+          input.contiguous(backend_memory_format), weight, bias,
           params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
     } else {
       output = at::miopen_convolution(
-          input.contiguous(), weight, bias,
+          input.contiguous(backend_memory_format), weight, bias,
           params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
     }
   } else if (params.use_mkldnn(input, weight)) {
index b6ffa91..39f214a 100644 (file)
@@ -524,7 +524,21 @@ void miopen_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const
   checkSize(c, bias, { output->size(output_channels_dim) });
 
   TensorDescriptor bdesc, odesc;
-  bdesc.set(bias->expand({1, bias->size(0)}), output->dim());
+
+  auto memory_format = output->suggest_memory_format();
+
+  std::vector<int64_t> shape( output->dim(), 1);
+  shape[output_channels_dim] = -1;
+  at::Tensor bias_contig =  bias->reshape(shape).contiguous(memory_format);
+  // Make sure that NC11 strides follow formula
+  bias_contig.resize_(bias_contig.sizes(), memory_format );
+
+  // TODO: Workaround since MIOpen does not support NHWC bias
+  // See #64426
+  output->add_( bias_contig );
+
+  /* MIOpen does not support NHWC bias; Activate once support is added.
+  bdesc.set( bias_contig );
   odesc.set(*output);
 
   auto handle = getMiopenHandle();
@@ -534,6 +548,7 @@ void miopen_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const
 
   MIOPEN_CHECK(miopenConvolutionForwardBias(handle, &one, bdesc.desc(), bias->data_ptr(),
                                      &zero, odesc.desc(), output->data_ptr()));
+  */
 }
 
 // see NOTE [ Convolution design ] in src/Aten/native/cudnn/Conv.cpp
@@ -566,7 +581,7 @@ void raw_miopen_convolution_forward_out(
   args.handle = getMiopenHandle();
   setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic);
   args.idesc.set(input);
-  args.wdesc.set(weight);
+  args.wdesc.set(weight, input.suggest_memory_format(), 0);
   args.odesc.set(output);
   args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
 
@@ -593,10 +608,19 @@ Tensor miopen_convolution_forward(
   checkAllSameType(c, {input, weight});
   checkAllSameGPU(c, {input, weight});
 
-  auto output_t = at::empty(
+  auto memory_format = at::MemoryFormat::Contiguous;
+  if (miopen_conv_use_channels_last(*input, *weight)) {
+    memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
+  }
+
+  auto output_t = at::native::empty_cuda(
                     conv_output_size(input->sizes(), weight->sizes(),
                                      padding, stride, dilation),
-                    input->options());
+                    /*dtype=*/input->scalar_type(),
+                    /*layout=*/c10::nullopt,
+                    /*device=*/kCUDA,
+                    /*pin_memory=*/c10::nullopt,
+                    /*memory_format=*/memory_format);
 
   if (output_t.numel() == 0) {
     return output_t;
@@ -607,10 +631,16 @@ Tensor miopen_convolution_forward(
   convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
 
   // See #4500
-  Tensor weight_contig = weight->contiguous();
+  Tensor weight_contig = weight->contiguous(memory_format);
+  // Make sure that NC11 strides follow formula
+  weight_contig.resize_(weight_contig.sizes(), memory_format);
+  Tensor input_contig = input->contiguous(memory_format);
+  input_contig.resize_(input_contig.sizes(), memory_format);
+
+
 
   raw_miopen_convolution_forward_out(
-      *output, *input, weight_contig,
+      *output, input_contig, weight_contig,
       padding, stride, dilation, groups, benchmark, deterministic);
 
   return *output;
@@ -650,7 +680,7 @@ void raw_miopen_depthwise_convolution_forward_out(
   args.handle = getMiopenHandle();
   setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic);
   args.idesc.set(input);
-  args.wdesc.set(weight);
+  args.wdesc.set(weight, input.suggest_memory_format(), 0);
   args.odesc.set(output);
   args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
 
@@ -677,18 +707,32 @@ Tensor miopen_depthwise_convolution_forward(
   checkAllSameType(c, {input, weight});
   checkAllSameGPU(c, {input, weight});
 
-  auto output_t = at::empty(
+  auto memory_format = at::MemoryFormat::Contiguous;
+  if (miopen_conv_use_channels_last(*input, *weight)) {
+    memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
+  }
+
+  auto output_t = at::native::empty_cuda(
                     conv_output_size(input->sizes(), weight->sizes(),
                                      padding, stride, dilation),
-                    input->options());
+                    /*dtype=*/input->scalar_type(),
+                    /*layout=*/c10::nullopt,
+                    /*device=*/kCUDA,
+                    /*pin_memory=*/c10::nullopt,
+                    /*memory_format=*/memory_format);
 
   TensorArg output{ output_t, "result", 0 };
   convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
 
-  Tensor weight_contig = weight->contiguous();
+  // See #4500
+  Tensor weight_contig = weight->contiguous(memory_format);
+  // Make sure that NC11 strides follow formula
+  weight_contig.resize_(weight_contig.sizes(), memory_format);
+  Tensor input_contig = input->contiguous(memory_format);
+  input_contig.resize_(input_contig.sizes(), memory_format);
 
   raw_miopen_depthwise_convolution_forward_out(
-      *output, *input, weight_contig,
+      *output, input_contig, weight_contig,
       padding, stride, dilation, groups, benchmark, deterministic);
 
   return *output;
@@ -768,7 +812,7 @@ void raw_miopen_convolution_backward_input_out(
   args.handle = getMiopenHandle();
   setConvolutionParams(&args.params, args.handle, grad_input, weight, padding, stride, dilation, groups, deterministic);
   args.idesc.set(grad_input);
-  args.wdesc.set(weight);
+  args.wdesc.set(weight, grad_output.suggest_memory_format(), 0);
   args.odesc.set(grad_output);
   args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
 
@@ -797,17 +841,33 @@ Tensor miopen_convolution_backward_input(
   checkAllSameType(c, {grad_output, weight});
   checkAllSameGPU(c, {grad_output, weight});
 
-  auto grad_input_t = at::empty(input_size, grad_output->options());
+  auto memory_format = at::MemoryFormat::Contiguous;
+  if (miopen_conv_use_channels_last(*grad_output, *weight)) {
+    memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
+  }
+
+  auto grad_input_t = at::native::empty_cuda(
+                    input_size,
+                    /*dtype=*/grad_output->scalar_type(),
+                    /*layout=*/c10::nullopt,
+                    /*device=*/kCUDA,
+                    /*pin_memory=*/c10::nullopt,
+                    /*memory_format=*/memory_format);
 
   // Avoid "grad_input" when this is being used as transposed convolution
   TensorArg grad_input{ grad_input_t, "result", 0 };
   convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
 
   // See #4500
-  Tensor weight_contig = weight->contiguous();
+  Tensor weight_contig = weight->contiguous(memory_format);
+  // Make sure that NC11 strides follow formula
+  weight_contig.resize_(weight_contig.sizes(), memory_format);
+
+  Tensor grad_output_contig = grad_output->contiguous(memory_format);
+  grad_output_contig.resize_(grad_output_contig.sizes(), memory_format);
 
   raw_miopen_convolution_backward_input_out(
-      *grad_input, *grad_output, weight_contig,
+      *grad_input, grad_output_contig, weight_contig,
       padding, stride, dilation, groups, benchmark, deterministic);
 
   return *grad_input;
@@ -853,7 +913,7 @@ void raw_miopen_depthwise_convolution_backward_input_out(
   args.handle = getMiopenHandle();
   setConvolutionParams(&args.params, args.handle, grad_input, weight, padding, stride, dilation, groups, deterministic);
   args.idesc.set(grad_input);
-  args.wdesc.set(weight);
+  args.wdesc.set(weight, grad_output.suggest_memory_format(), 0);
   args.odesc.set(grad_output);
   args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
 
@@ -880,15 +940,32 @@ Tensor miopen_depthwise_convolution_backward_input(
   checkAllSameType(c, {grad_output, weight});
   checkAllSameGPU(c, {grad_output, weight});
 
-  auto grad_input_t = at::empty(input_size, grad_output->options());
+  auto memory_format = at::MemoryFormat::Contiguous;
+  if (miopen_conv_use_channels_last(*grad_output, *weight)) {
+    memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
+  }
+
+  auto grad_input_t = at::native::empty_cuda(
+                    input_size,
+                    /*dtype=*/grad_output->scalar_type(),
+                    /*layout=*/c10::nullopt,
+                    /*device=*/kCUDA,
+                    /*pin_memory=*/c10::nullopt,
+                    /*memory_format=*/memory_format);
 
   TensorArg grad_input{ grad_input_t, "result", 0 };
   convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
 
-  Tensor weight_contig = weight->contiguous();
+  // See #4500
+  Tensor weight_contig = weight->contiguous(memory_format);
+  // Make sure that NC11 strides follow formula
+  weight_contig.resize_(weight_contig.sizes(), memory_format);
+
+  Tensor grad_output_contig = grad_output->contiguous(memory_format);
+  grad_output_contig.resize_(grad_output_contig.sizes(), memory_format);
 
   raw_miopen_depthwise_convolution_backward_input_out(
-      *grad_input, *grad_output, weight_contig,
+      *grad_input, grad_output_contig, weight_contig,
       padding, stride, dilation, groups, benchmark, deterministic);
 
   return *grad_input;
@@ -912,7 +989,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_convolution_backward(
     IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
     bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
 
-  Tensor grad_output = grad_output_t.contiguous();
+  Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
 
   Tensor grad_input, grad_weight, grad_bias;
   if (output_mask[0]) {
@@ -988,7 +1065,7 @@ void raw_miopen_convolution_backward_weight_out(
   args.handle = getMiopenHandle();
   setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic);
   args.idesc.set(input);
-  args.wdesc.set(grad_weight);
+  args.wdesc.set(grad_weight, input.suggest_memory_format(), 0);
   args.odesc.set(grad_output);
   args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
 
@@ -1016,15 +1093,29 @@ Tensor miopen_convolution_backward_weight(
   checkAllSameType(c, {grad_output, input});
   checkAllSameGPU(c, {grad_output, input});
 
-  auto grad_weight_t = at::empty(weight_size, grad_output->options());
+  auto memory_format = at::MemoryFormat::Contiguous;
+  if (miopen_conv_use_channels_last(*input, *grad_output)) {
+    memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
+  }
+
+  Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
+  // Make sure that NC11 strides follow formula
+  grad_output_contig_t.resize_(grad_output_contig_t.sizes(), memory_format);
+  TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
+
+  Tensor input_contig_t = input->contiguous(memory_format);
+  input_contig_t.resize_(input_contig_t.sizes(), memory_format);
+  TensorArg input_contig{ input_contig_t, "input", 2};
+
+  auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), memory_format);
 
   // For uniformity with everything else, although it seems grad_weight
   // would be unambiguous too.
   TensorArg grad_weight{ grad_weight_t, "result", 0 };
-  convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups);
+  convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups);
 
   raw_miopen_convolution_backward_weight_out(
-      *grad_weight, *grad_output, *input,
+      *grad_weight, *grad_output_contig, *input_contig,
       padding, stride, dilation, groups, benchmark, deterministic);
 
   return grad_weight_t;
@@ -1043,7 +1134,7 @@ void raw_miopen_depthwise_convolution_backward_weight_out(
   args.handle = getMiopenHandle();
   setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic);
   args.idesc.set(input);
-  args.wdesc.set(grad_weight);
+  args.wdesc.set(grad_weight, input.suggest_memory_format(), 0);
   args.odesc.set(grad_output);
   args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
 
@@ -1071,15 +1162,29 @@ Tensor miopen_depthwise_convolution_backward_weight(
   checkAllSameType(c, {grad_output, input});
   checkAllSameGPU(c, {grad_output, input});
 
-  auto grad_weight_t = at::empty(weight_size, grad_output->options());
+  auto memory_format = at::MemoryFormat::Contiguous;
+  if (miopen_conv_use_channels_last(*input, *grad_output)) {
+    memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
+  }
+
+  Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
+  // Make sure that NC11 strides follow formula
+  grad_output_contig_t.resize_(grad_output_contig_t.sizes(), memory_format);
+  TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
+
+  Tensor input_contig_t = input->contiguous(memory_format);
+  input_contig_t.resize_(input_contig_t.sizes(), memory_format);
+  TensorArg input_contig{ input_contig_t, "input", 2};
+
+  auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), memory_format);
 
   // For uniformity with everything else, although it seems grad_weight
   // would be unambiguous too.
   TensorArg grad_weight{ grad_weight_t, "result", 0 };
-  convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups);
+  convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups);
 
   raw_miopen_depthwise_convolution_backward_weight_out(
-      *grad_weight, *grad_output, *input,
+      *grad_weight, *grad_output_contig, *input_contig,
       padding, stride, dilation, groups, benchmark, deterministic);
 
   return grad_weight_t;
@@ -1141,6 +1246,25 @@ Tensor miopen_convolution_backward_bias(
 {
   TensorArg grad_output{ grad_output_t, "grad_output", 1 };
 
+  // TODO: Workaround since MIOpen does not support NHWC bias
+  // See #64426
+  std::vector<int64_t> discard_dims;
+  for( int i = 0; i < grad_output_t.dim(); i++ ) {
+      if(i != output_channels_dim ) {
+          discard_dims.push_back(i);
+      }
+  }
+
+  Tensor outputBias = at::squeeze( at::sum(grad_output_t, discard_dims, true) );
+  if( outputBias.dim() == 0 ) {
+      // always return a tensor of shape [_]
+      return outputBias.unsqueeze(0);
+  }
+  else {
+      return outputBias;
+  }
+
+/* MIOpen does not support NHWC bias. Activate once support is added.
   auto grad_bias_t = at::empty( { grad_output->size(output_channels_dim) }, grad_output->options());
 
   TensorArg grad_bias{ grad_bias_t, "result", 0 };
@@ -1157,6 +1281,7 @@ Tensor miopen_convolution_backward_bias(
   MIOPEN_CHECK(miopenConvolutionBackwardBias(handle, &one, odesc.desc(), grad_output->data_ptr(),
                                                    &zero, bdesc.desc(), grad_bias->data_ptr()));
   return *grad_bias;
+*/
 }
 
 
index 4f28a2a..ec48903 100644 (file)
@@ -13,7 +13,7 @@ namespace utils {
 //
 // NB:
 // Issues a warning if the value of the environment variable is not 0 or 1.
-optional<bool> check_env(const char* name) {
+inline optional<bool> check_env(const char* name) {
   auto envar = std::getenv(name);
   if (envar) {
     if (strcmp(envar, "0") == 0) {
index cc702df..1a6416d 100644 (file)
@@ -35,7 +35,7 @@ from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
 from torch.nn.parallel._functions import Broadcast
 from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes
 from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
-    TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
+    skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
     get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
     ALL_TENSORTYPES2, suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC
 from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
@@ -44,8 +44,8 @@ from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, Criteri
     ctcloss_reference, new_module_tests, single_batch_reference_fn
 from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
     dtypesIfCUDA, precisionOverride, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
-    skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, onlyOnCPUAndCUDA, \
-    deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta
+    skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, skipCUDAIfRocmVersionLessThan, skipCUDAIfNotMiopenSuggestNHWC, \
+    onlyOnCPUAndCUDA, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta
 from torch.nn import MultiheadAttention
 
 from hypothesis import given
@@ -10835,14 +10835,15 @@ class TestNN(NNTestCase):
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
-    @skipIfRocm
+    @skipIfRocmVersionLessThan((4, 3))
+    @skipIfNotMiopenSuggestNHWC
     def test_grouped_conv_cudnn_nhwc_support(self):
         # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
         input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last)
         weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last)
-        out = torch.cudnn_convolution(input, weight, None, (1, 1), (1, 1), (1, 1), 4, False, False)
+        out = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4)
         input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last)
-        out = torch.cudnn_convolution_transpose(input, weight, None, (1, 1), (0, 0), (1, 1), (1, 1), 4, False, False)
+        out_transpose = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4)
 
     @unittest.expectedFailure
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@@ -16964,7 +16965,8 @@ class TestNNDeviceType(NNTestCase):
             self._test_bfloat16_ops(torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=0.05, scale_factor=1000.0)
 
     @onlyCUDA
-    @skipCUDAIfRocm
+    @skipCUDAIfRocmVersionLessThan((4, 3))
+    @skipCUDAIfNotMiopenSuggestNHWC
     @skipCUDAIfCudnnVersionLessThan(7603)
     @dtypes(torch.half, torch.float)
     def test_conv_cudnn_nhwc(self, device, dtype):
@@ -17107,7 +17109,8 @@ class TestNNDeviceType(NNTestCase):
                                    ref_out, input_format, w_f, g_f, output_format)
 
     @onlyCUDA
-    @skipCUDAIfRocm
+    @skipCUDAIfRocmVersionLessThan((4, 3))
+    @skipCUDAIfNotMiopenSuggestNHWC
     @skipCUDAIfCudnnVersionLessThan(7603)
     @tf32_on_and_off(0.05)
     def test_conv_cudnn_mismatch_memory_format(self, device):
index 23e431d..971b3a6 100644 (file)
@@ -12,7 +12,8 @@ import os
 import torch
 from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
     skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
-    IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard, TEST_SKIP_NOARCH
+    IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard, TEST_SKIP_NOARCH, \
+    TEST_WITH_MIOPEN_SUGGEST_NHWC
 from torch.testing._internal.common_cuda import _get_torch_cuda_version
 from torch.testing._internal.common_dtype import get_all_dtypes
 
@@ -1166,6 +1167,32 @@ def skipCUDAIfRocm(fn):
 def skipCUDAIfNotRocm(fn):
     return skipCUDAIf(not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack")(fn)
 
+# Skips a test on CUDA if ROCm is unavailable or its version is lower than requested.
+def skipCUDAIfRocmVersionLessThan(version=None):
+
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if self.device_type == 'cuda':
+                if not TEST_WITH_ROCM:
+                    reason = "ROCm not available"
+                    raise unittest.SkipTest(reason)
+                rocm_version = str(torch.version.hip)
+                rocm_version = rocm_version.split("-")[0]    # ignore git sha
+                rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
+                if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
+                    reason = "ROCm {0} is available but {1} required".format(rocm_version_tuple, version)
+                    raise unittest.SkipTest(reason)
+
+            return fn(self, *args, **kwargs)
+
+        return wrap_fn
+    return dec_fn
+
+# Skips a test on CUDA when using ROCm.
+def skipCUDAIfNotMiopenSuggestNHWC(fn):
+    return skipCUDAIf(not TEST_WITH_MIOPEN_SUGGEST_NHWC, "test doesn't currently work without MIOpen NHWC activation")(fn)
+
 # Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s.
 def skipCUDAVersionIn(versions : List[Tuple[int, int]] = None):
     def dec_fn(fn):
index 0a265b5..11364c3 100644 (file)
@@ -428,6 +428,11 @@ TEST_WITH_DEV_DBG_ASAN = os.getenv('PYTORCH_TEST_WITH_DEV_DBG_ASAN', '0') == '1'
 TEST_WITH_TSAN = os.getenv('PYTORCH_TEST_WITH_TSAN', '0') == '1'
 TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
 TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'
+
+# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
+# See #64427
+TEST_WITH_MIOPEN_SUGGEST_NHWC = os.getenv('PYTORCH_MIOPEN_SUGGEST_NHWC', '0') == '1'
+
 # Enables tests that are slow to run (disabled by default)
 TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
 
@@ -500,6 +505,33 @@ def skipIfRocm(fn):
             fn(*args, **kwargs)
     return wrapper
 
+# Skips a test on CUDA if ROCm is unavailable or its version is lower than requested.
+def skipIfRocmVersionLessThan(version=None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if not TEST_WITH_ROCM:
+                reason = "ROCm not available"
+                raise unittest.SkipTest(reason)
+            rocm_version = str(torch.version.hip)
+            rocm_version = rocm_version.split("-")[0]    # ignore git sha
+            rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
+            if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
+                reason = "ROCm {0} is available but {1} required".format(rocm_version_tuple, version)
+                raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def skipIfNotMiopenSuggestNHWC(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_WITH_MIOPEN_SUGGEST_NHWC:
+            raise unittest.SkipTest("test doesn't currently work without MIOpen NHWC activation")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
 # Context manager for setting deterministic flag and automatically
 # resetting it to its original value
 class DeterministicGuard: