From c3ea586e32458ac8b18ae5f149446beb507b5076 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 14 Oct 2021 09:42:41 -0700 Subject: [PATCH] fix normal with empty std (#66524) --- aten/src/ATen/native/DistributionTemplates.h | 2 +- test/test_tensor_creation_ops.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/DistributionTemplates.h b/aten/src/ATen/native/DistributionTemplates.h index ed5aa84..0021194 100644 --- a/aten/src/ATen/native/DistributionTemplates.h +++ b/aten/src/ATen/native/DistributionTemplates.h @@ -238,7 +238,7 @@ template class normal_kernel, typename RNG> Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional gen) { TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex"); TORCH_CHECK( - std.min().ge(0).item(), + std.numel() == 0 || std.min().ge(0).item(), "normal expects all elements of std >= 0.0"); bool is_deprecated_th_impl = resize_output_for_normal(output, mean, std); normal_impl_(output, 0, 1, gen); diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index e72b156..9530fbf 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -3258,6 +3258,10 @@ class TestRandomTensorCreation(TestCase): self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0) self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0) + # test empty mean/std + out = torch.normal(mean=torch.empty((0, 2)), std=torch.empty((0, 1))) + self.assertEqual(out.size(), torch.Size([0, 2])) + r.fill_(42) r = torch.normal(2, 3, (100, 100), dtype=dtype, device=device) self.assertEqual(r.dtype, dtype) -- 2.7.4