From 1c8d41a08d30185109a506fd984210ec130ae213 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Jan=20Schl=C3=BCter?= Date: Thu, 6 Dec 2018 09:29:01 -0800 Subject: [PATCH] Allow linspace and logspace with steps=1 and start != end like numpy (#14748) Summary: `torch.linspace(0, 1, 1)` fails with `RuntimeError: invalid argument 3: invalid number of points at ../aten/src/TH/generic/THTensorMoreMath.cpp:2119`, while `np.linspace(0, 1, 1)` works fine. Looking at the code, there is even a comment by gchanan asking: "NumPy allows you to pass different points even if n <= 1 -- should we?" I would say "yes". Currently, I would need to handle the case of `steps == 1` or `steps == 0` separately, making sure to change the `end` when calling `torch.linspace`. This is impractical. If we support `start != end`, there are two possibilities for the result: Either we ensure the first value in the resulting sequence always equals `start`, or we ensure the last value in the resulting sequence always equals `end`. Numpy chose the former, which also allows it to support a boolean `endpoint` flag. I'd say we should follow numpy. This PR adapts `linspace` and `logspace` to mimic the behavior of numpy, adapts the tests accordingly, and extends the docstrings to make clear what happens when passing `steps=1`. If you decide against this PR, the error message should become explicit about what I did wrong, and the documentation should be extended to mention this restriction. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14748 Differential Revision: D13356136 Pulled By: ezyang fbshipit-source-id: db85b8f0a98a5e24b3acd766132ab71c91794a82 --- aten/src/TH/generic/THTensorMoreMath.cpp | 6 ++---- aten/src/THC/generic/THCTensorMath.cu | 6 ++---- test/test_torch.py | 8 ++++---- torch/_torch_docs.py | 4 ++++ 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index bbbd4ad..56382e2 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -2135,8 +2135,7 @@ void THTensor_(linspace)(THTensor *r_, scalar_t a, scalar_t b, int64_t n) { scalar_t i = 0; - // NumPy allows you to pass different points even if n <= 1 -- should we? - THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points"); + THArgCheck((n >= 0), 3, "number of points must be non-negative"); if (THTensor_(nElement)(r_) != n) { THTensor_(resize1d)(r_, n); @@ -2157,8 +2156,7 @@ void THTensor_(logspace)(THTensor *r_, scalar_t a, scalar_t b, int64_t n) { scalar_t i = 0; - // NumPy allows you to pass different points even if n <= 1 -- should we? - THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points"); + THArgCheck((n >= 0), 3, "number of points must be non-negative"); if (THTensor_(nElement)(r_) != n) { THTensor_(resize1d)(r_, n); diff --git a/aten/src/THC/generic/THCTensorMath.cu b/aten/src/THC/generic/THCTensorMath.cu index fb266ca..07e3efc 100644 --- a/aten/src/THC/generic/THCTensorMath.cu +++ b/aten/src/THC/generic/THCTensorMath.cu @@ -395,8 +395,7 @@ accreal THCTensor_(trace)(THCState *state, THCTensor *src_) { void THCTensor_(linspace)(THCState *state, THCTensor *r_, scalar_t a, scalar_t b, int64_t n) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_)); - // NumPy allows you to pass different points even if n <= 1 -- should we? - THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points"); + THArgCheck((n >= 0), 3, "number of points must be non-negative"); if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n); if (n == 0) { // skip @@ -419,8 +418,7 @@ void THCTensor_(linspace)(THCState *state, THCTensor *r_, scalar_t a, scalar_t b void THCTensor_(logspace)(THCState *state, THCTensor *r_, scalar_t a, scalar_t b, int64_t n) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_)); - // NumPy allows you to pass different points even if n <= 1 -- should we? - THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points"); + THArgCheck((n >= 0), 3, "number of points must be non-negative"); if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n); if (n == 0) { // skip diff --git a/test/test_torch.py b/test/test_torch.py index 602bc6f..2f5352d 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3945,8 +3945,8 @@ class _TestTorchMixin(object): res2 = torch.Tensor() torch.linspace(_from, to, 137, out=res2) self.assertEqual(res1, res2, 0) - self.assertRaises(RuntimeError, lambda: torch.linspace(0, 1, 1)) - self.assertEqual(torch.linspace(0, 0, 1), torch.zeros(1), 0) + self.assertRaises(RuntimeError, lambda: torch.linspace(0, 1, -1)) + self.assertEqual(torch.linspace(0, 1, 1), torch.zeros(1), 0) # Check linspace for generating with start > end. self.assertEqual(torch.linspace(2, 0, 3), torch.Tensor((2, 1, 0)), 0) @@ -3963,8 +3963,8 @@ class _TestTorchMixin(object): res2 = torch.Tensor() torch.logspace(_from, to, 137, out=res2) self.assertEqual(res1, res2, 0) - self.assertRaises(RuntimeError, lambda: torch.logspace(0, 1, 1)) - self.assertEqual(torch.logspace(0, 0, 1), torch.ones(1), 0) + self.assertRaises(RuntimeError, lambda: torch.logspace(0, 1, -1)) + self.assertEqual(torch.logspace(0, 1, 1), torch.ones(1), 0) # Check logspace_ for generating with start > end. self.assertEqual(torch.logspace(1, 0, 2), torch.Tensor((10, 1)), 0) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9d0371f..906bea3 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2267,6 +2267,8 @@ Example:: tensor([-10., -5., 0., 5., 10.]) >>> torch.linspace(start=-10, end=10, steps=5) tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) """.format(**factory_common_args)) add_docstr(torch.log, @@ -2395,6 +2397,8 @@ Example:: tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) >>> torch.logspace(start=0.1, end=1.0, steps=5) tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) """.format(**factory_common_args)) add_docstr(torch.logsumexp, -- 2.7.4