ROCm MIOpen NHWC Convolution support (#63617)
authorAswin John Mathews <Aswin.Mathews@amd.com>
Fri, 10 Sep 2021 15:05:21 +0000 (08:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 15:06:32 +0000 (08:06 -0700)
commit63b180beed0453141ab11d1df0527d808801e583
treea0d559a8208942f2637eadfd0fac391cc2ed7d32
parent2a81e8b8f1526b375d1e78402f91bf8fd82d2b68
ROCm MIOpen NHWC Convolution support (#63617)

Summary:
- Added 2D-Convolution NHWC support
  - on ROCm 4.3, with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` flag
  - May need to force MIOpen to search for solutions ( see examples below for flags )

**PYTORCH_MIOPEN_SUGGEST_NHWC Environment Flag**
MIOpen does not officially support NHWC yet, although convolution support has been added to tip-of-tree of MIOpen. This flag is intended to be a short-lived flag to explicitly turn on NHWC support until ROCm officially supports NHWC and performance is verified.

**Examples**
1. Example usage 1 : Run test on ROCm4.3
`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_MIOPEN_SUGGEST_NHWC=1 MIOPEN_FIND_ENFORCE=4 MIOPEN_DEBUG_CONV_GEMM=0 MIOPEN_FIND_MODE=1 pytest test_nn.py -v -k "test_conv_cudnn_nhwc" `
2. Example usage 2: Run the following with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` on ROCm4.3.
```
#!/usr/bin/env python3
import torch
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
model = model.to(memory_format=torch.channels_last)
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)

# should print True for is_contiguous(channels_last), and strides must match NHWC format
print(input.is_contiguous(memory_format=torch.channels_last), input.shape, input.stride() )

out = model(input)

# should print True for is_contiguous(channels_last), and strides must match NHWC format
print("Contiguous channel last :", out.is_contiguous(memory_format=torch.channels_last), " out shape :",  out.shape, "out stride :", out.stride() )
```

See https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html for more examples.

cc jeffdaily sunway513 jithunnair-amd ROCmSupport

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63617

Reviewed By: saketh-are

Differential Revision: D30730800

Pulled By: ezyang

fbshipit-source-id: 61906a0f30be8299e6547d312ae6ac91cc7c3238
aten/src/ATen/miopen/Descriptors.cpp
aten/src/ATen/miopen/Descriptors.h
aten/src/ATen/native/ConvUtils.h
aten/src/ATen/native/Convolution.cpp
aten/src/ATen/native/miopen/Conv_miopen.cpp
c10/util/env.h
test/test_nn.py
torch/testing/_internal/common_device_type.py
torch/testing/_internal/common_utils.py