Fix bad use of channels last kernel in sync batch norm backward (#64100)
authorPeter Bell <peterbell10@live.co.uk>
Mon, 30 Aug 2021 19:14:09 +0000 (12:14 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 19:16:30 +0000 (12:16 -0700)
commit5b0dfd0f8aff50e2fce8f2f1fe6f2ef0594a9e25
tree87f82f098b01255770bf39195f445abd36b51d61
parentac99d63f83ceaee4a95e7baa8a52fba09903d00b
Fix bad use of channels last kernel in sync batch norm backward (#64100)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/64039

There are two distinct problems here.
1. If `grad_output` is channels last but not input, then input would be read as-if it were channels last. So reading the wrong values.
2. `use_channels_last_kernels` doesn't guarunte that `suggest_memory_format` will actually return channels last, so use `empty_like` instead so the strides always match.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64100

Reviewed By: mruberry

Differential Revision: D30622127

Pulled By: ngimel

fbshipit-source-id: e28cc57215596817f1432fcdd6c49d69acfedcf2
aten/src/ATen/native/cuda/Normalization.cu
aten/src/ATen/native/cuda/Normalization.cuh
test/test_nn.py