From 196eee6ccd9bbab46d712c863265a433832b8f9c Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 8 Jan 2019 12:34:43 -0800 Subject: [PATCH] Fix sum_to behavior with zero dimensions (#15796) Summary: Fixes #15223. This fixes an autograd bug where backprop either fails or produces gradients of incorrect sizes when tensors with zero-sized dimensions are involved. Previously, we were reducing along dimensions that had size greater than 1 when summing to a size in autograd. This is incorrect because we should also reduce along dimensions with size 0 to produce a tensor of size 1 in that dimension that then gets viewed to the correct shape. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15796 Differential Revision: D13593199 Pulled By: zou3519 fbshipit-source-id: 2e2acac34943a9b7fabadc10c9efd4f66db298fd --- aten/src/ATen/ExpandUtils.h | 2 +- test/test_autograd.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 416d84f..f287d56 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -147,7 +147,7 @@ static inline Tensor sum_to(Tensor tensor, const IntList shape) { reduce_dims.push_back(i); } for (int64_t i = leading_dims; i < static_cast(sizes.size()); ++i) { - if (shape[i - leading_dims] == 1 && sizes[i] > 1) { + if (shape[i - leading_dims] == 1 && sizes[i] != 1) { reduce_dims.push_back(i); } } diff --git a/test/test_autograd.py b/test/test_autograd.py index 695e6ad..284c45a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -202,6 +202,16 @@ class TestAutograd(TestCase): x_grad, x_grad_clone = compute_grad(create_graph=True) self.assertEqual(x_grad, x_grad_clone) + def test_sum_to_with_empty_dim_grad(self): + a = torch.rand(4, 0, requires_grad=True) + b = torch.rand(4, 1, requires_grad=True) + c = a + b + assert c.shape == (4, 0) + c.sum().backward() + + self.assertEqual(b.grad, torch.zeros(4, 1)) + self.assertEqual(a.grad, torch.zeros(4, 0)) + def test_hessian_vector(self): x = torch.randn(2, 2, requires_grad=True) y = torch.randn(2, 2, requires_grad=True) -- 2.7.4