Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex");
TORCH_CHECK(
- std.min().ge(0).item<bool>(),
+ std.numel() == 0 || std.min().ge(0).item<bool>(),
"normal expects all elements of std >= 0.0");
bool is_deprecated_th_impl = resize_output_for_normal(output, mean, std);
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0)
self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0)
+ # test empty mean/std
+ out = torch.normal(mean=torch.empty((0, 2)), std=torch.empty((0, 1)))
+ self.assertEqual(out.size(), torch.Size([0, 2]))
+
r.fill_(42)
r = torch.normal(2, 3, (100, 100), dtype=dtype, device=device)
self.assertEqual(r.dtype, dtype)