Allow linspace and logspace with steps=1 and start != end like numpy (#14748)
authorJan Schlüter <github@jan-schlueter.de>
Thu, 6 Dec 2018 17:29:01 +0000 (09:29 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 6 Dec 2018 17:30:55 +0000 (09:30 -0800)
Summary:
`torch.linspace(0, 1, 1)` fails with `RuntimeError: invalid argument 3: invalid number of points at ../aten/src/TH/generic/THTensorMoreMath.cpp:2119`, while `np.linspace(0, 1, 1)` works fine.
Looking at the code, there is even a comment by gchanan asking: "NumPy allows you to pass different points even if n <= 1 -- should we?"
I would say "yes". Currently, I would need to handle the case of `steps == 1` or `steps == 0` separately, making sure to change the `end` when calling `torch.linspace`. This is impractical. If we support `start != end`, there are two possibilities for the result: Either we ensure the first value in the resulting sequence always equals `start`, or we ensure the last value in the resulting sequence always equals `end`. Numpy chose the former, which also allows it to support a boolean `endpoint` flag. I'd say we should follow numpy.

This PR adapts `linspace` and `logspace` to mimic the behavior of numpy, adapts the tests accordingly, and extends the docstrings to make clear what happens when passing `steps=1`.

If you decide against this PR, the error message should become explicit about what I did wrong, and the documentation should be extended to mention this restriction.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14748

Differential Revision: D13356136

Pulled By: ezyang

fbshipit-source-id: db85b8f0a98a5e24b3acd766132ab71c91794a82

aten/src/TH/generic/THTensorMoreMath.cpp
aten/src/THC/generic/THCTensorMath.cu
test/test_torch.py
torch/_torch_docs.py

index bbbd4ad..56382e2 100644 (file)
@@ -2135,8 +2135,7 @@ void THTensor_(linspace)(THTensor *r_, scalar_t a, scalar_t b, int64_t n)
 {
   scalar_t i = 0;
 
-  // NumPy allows you to pass different points even if n <= 1 -- should we?
-  THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points");
+  THArgCheck((n >= 0), 3, "number of points must be non-negative");
 
   if (THTensor_(nElement)(r_) != n) {
     THTensor_(resize1d)(r_, n);
@@ -2157,8 +2156,7 @@ void THTensor_(logspace)(THTensor *r_, scalar_t a, scalar_t b, int64_t n)
 {
   scalar_t i = 0;
 
-  // NumPy allows you to pass different points even if n <= 1 -- should we?
-  THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points");
+  THArgCheck((n >= 0), 3, "number of points must be non-negative");
 
   if (THTensor_(nElement)(r_) != n) {
     THTensor_(resize1d)(r_, n);
index fb266ca..07e3efc 100644 (file)
@@ -395,8 +395,7 @@ accreal THCTensor_(trace)(THCState *state, THCTensor *src_) {
 
 void THCTensor_(linspace)(THCState *state, THCTensor *r_, scalar_t a, scalar_t b, int64_t n) {
   THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_));
-  // NumPy allows you to pass different points even if n <= 1 -- should we?
-  THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points");
+  THArgCheck((n >= 0), 3, "number of points must be non-negative");
   if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n);
   if (n == 0) {
     // skip
@@ -419,8 +418,7 @@ void THCTensor_(linspace)(THCState *state, THCTensor *r_, scalar_t a, scalar_t b
 
 void THCTensor_(logspace)(THCState *state, THCTensor *r_, scalar_t a, scalar_t b, int64_t n) {
   THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_));
-  // NumPy allows you to pass different points even if n <= 1 -- should we?
-  THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points");
+  THArgCheck((n >= 0), 3, "number of points must be non-negative");
   if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n);
   if (n == 0) {
     // skip
index 602bc6f..2f5352d 100644 (file)
@@ -3945,8 +3945,8 @@ class _TestTorchMixin(object):
         res2 = torch.Tensor()
         torch.linspace(_from, to, 137, out=res2)
         self.assertEqual(res1, res2, 0)
-        self.assertRaises(RuntimeError, lambda: torch.linspace(0, 1, 1))
-        self.assertEqual(torch.linspace(0, 0, 1), torch.zeros(1), 0)
+        self.assertRaises(RuntimeError, lambda: torch.linspace(0, 1, -1))
+        self.assertEqual(torch.linspace(0, 1, 1), torch.zeros(1), 0)
 
         # Check linspace for generating with start > end.
         self.assertEqual(torch.linspace(2, 0, 3), torch.Tensor((2, 1, 0)), 0)
@@ -3963,8 +3963,8 @@ class _TestTorchMixin(object):
         res2 = torch.Tensor()
         torch.logspace(_from, to, 137, out=res2)
         self.assertEqual(res1, res2, 0)
-        self.assertRaises(RuntimeError, lambda: torch.logspace(0, 1, 1))
-        self.assertEqual(torch.logspace(0, 0, 1), torch.ones(1), 0)
+        self.assertRaises(RuntimeError, lambda: torch.logspace(0, 1, -1))
+        self.assertEqual(torch.logspace(0, 1, 1), torch.ones(1), 0)
 
         # Check logspace_ for generating with start > end.
         self.assertEqual(torch.logspace(1, 0, 2), torch.Tensor((10, 1)), 0)
index 9d0371f..906bea3 100644 (file)
@@ -2267,6 +2267,8 @@ Example::
     tensor([-10.,  -5.,   0.,   5.,  10.])
     >>> torch.linspace(start=-10, end=10, steps=5)
     tensor([-10.,  -5.,   0.,   5.,  10.])
+    >>> torch.linspace(start=-10, end=10, steps=1)
+    tensor([-10.])
 """.format(**factory_common_args))
 
 add_docstr(torch.log,
@@ -2395,6 +2397,8 @@ Example::
     tensor([ 1.0000e-10,  1.0000e-05,  1.0000e+00,  1.0000e+05,  1.0000e+10])
     >>> torch.logspace(start=0.1, end=1.0, steps=5)
     tensor([  1.2589,   2.1135,   3.5481,   5.9566,  10.0000])
+    >>> torch.logspace(start=0.1, end=1.0, steps=1)
+    tensor([1.2589])
 """.format(**factory_common_args))
 
 add_docstr(torch.logsumexp,