c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
- if (at::cuda::detail::canUse32BitIndexMath(self) && batch_norm_use_channels_last_kernels(self)){
+ if (at::cuda::detail::canUse32BitIndexMath(self) &&
+ batch_norm_use_channels_last_kernels(self) &&
+ batch_norm_use_channels_last_kernels(input)) {
return batch_norm_backward_elemt_channels_last_cuda_template(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
}
const auto stride = input.sizes()[1];
const auto reduction_size = input.numel() / stride;
- at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format());
+ // Input is guarunteed to be channels-last compatible
+ at::Tensor grad_input = at::empty_like(input);
dim3 block;
dim3 grid;
const auto reduction_size = input.numel() / stride;
auto norm_fct = 1.0 / reduction_size;
- at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format());
+ // Input is guarunteed to be channels-last compatible
+ at::Tensor grad_input = at::empty_like(input);
dim3 block;
dim3 grid;
self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key])
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
+ def test_sync_batchnorm_backward_elemt(self):
+ device = 'cuda'
+ saved_input = torch.rand(2, 3, 2, 1, device=device)
+ grad_output = torch.rand(2, 3, 2, 1, device=device)
+ mean = torch.rand(3, device=device)
+ invstd = torch.rand(3, device=device)
+ weight = torch.rand(3, device=device)
+ sum_dy = torch.rand(3, device=device)
+ sum_dy_xmu = torch.rand(3, device=device)
+ count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device)
+
+ gI_contiguous = torch.batch_norm_backward_elemt(
+ grad_output,
+ saved_input,
+ mean,
+ invstd,
+ weight,
+ sum_dy,
+ sum_dy_xmu,
+ count_tensor
+ )
+
+ # Test batch_norm_backward_elemt gives the same answer for all
+ # combinations of contiguous as channels_last input
+ for a, b in [
+ (torch.channels_last, torch.contiguous_format),
+ (torch.contiguous_format, torch.channels_last),
+ (torch.channels_last, torch.channels_last),
+ ]:
+ gI_actual = torch.batch_norm_backward_elemt(
+ grad_output.contiguous(memory_format=a),
+ saved_input.contiguous(memory_format=b),
+ mean,
+ invstd,
+ weight,
+ sum_dy,
+ sum_dy_xmu,
+ count_tensor
+ )
+ self.assertEqual(gI_actual, gI_contiguous)
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_sync_batchnorm_accuracy_cuda(self):
# The target of this test is to test the functionality and accuracy of
# those single-GPU cuda kernels used in SyncBatchNorm