fix normal with empty std (#66524)
authorNatalia Gimelshein <ngimel@fb.com>
Thu, 14 Oct 2021 16:42:41 +0000 (09:42 -0700)
committerGitHub <noreply@github.com>
Thu, 14 Oct 2021 16:42:41 +0000 (09:42 -0700)
aten/src/ATen/native/DistributionTemplates.h
test/test_tensor_creation_ops.py

index ed5aa84..0021194 100644 (file)
@@ -238,7 +238,7 @@ template<template<typename> class normal_kernel, typename RNG>
 Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
   TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex");
   TORCH_CHECK(
-    std.min().ge(0).item<bool>(),
+    std.numel() == 0 || std.min().ge(0).item<bool>(),
     "normal expects all elements of std >= 0.0");
   bool is_deprecated_th_impl = resize_output_for_normal(output, mean, std);
   normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
index e72b156..9530fbf 100644 (file)
@@ -3258,6 +3258,10 @@ class TestRandomTensorCreation(TestCase):
             self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0)
             self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0)
 
+            # test empty mean/std
+            out = torch.normal(mean=torch.empty((0, 2)), std=torch.empty((0, 1)))
+            self.assertEqual(out.size(), torch.Size([0, 2]))
+
             r.fill_(42)
             r = torch.normal(2, 3, (100, 100), dtype=dtype, device=device)
             self.assertEqual(r.dtype, dtype)