Fix clamp when min/max are both None (#14716)
authorRichard Zou <zou3519@gmail.com>
Tue, 4 Dec 2018 15:02:01 +0000 (07:02 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 15:07:09 +0000 (07:07 -0800)
Summary:
Before this PR, tensor.clamp() would return an empty tensor if min and
max were not specified. This is a regression from 0.4.1, which would
throw an error. This PR restores that error message.

Fixes #14470
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14716

Differential Revision: D13311031

Pulled By: zou3519

fbshipit-source-id: 87894db582d5749eaccfc22ba06aac4e10983880

aten/src/ATen/native/UnaryOps.cpp
aten/src/ATen/native/cuda/CUDAUnaryOps.cpp
test/test_torch.py

index d5cd2ff..04b5932 100644 (file)
@@ -46,15 +46,7 @@ Tensor clamp_min(const Tensor& self, Scalar min) {
 }
 
 Tensor& _clamp__cpu(Tensor& self, optional<Scalar> min, optional<Scalar> max) {
-  if (min && max) {
-    return _th_clamp_out(self, self, *min, *max);
-  } else if (max) {
-    return _th_clamp_max_out(self, self, *max);
-  } else if (min) {
-    return _th_clamp_min_out(self, self, *min);
-  } else {
-    return self;
-  }
+  return _clamp_out_cpu(self, self, min, max);
 }
 
 Tensor& _clamp_out_cpu(
@@ -68,6 +60,8 @@ Tensor& _clamp_out_cpu(
     _th_clamp_max_out(result, self, *max);
   } else if (min) {
     _th_clamp_min_out(result, self, *min);
+  } else {
+    AT_ERROR("At least one of 'min' or 'max' must not be None");
   }
   return result;
 }
index 905c243..f3c1a4b 100644 (file)
@@ -3,15 +3,7 @@
 namespace at { namespace native {
 
 Tensor& _clamp__cuda(Tensor& self, optional<Scalar> min, optional<Scalar> max) {
-  if (min && max) {
-    return _th_clamp_out(self, self, *min, *max);
-  } else if (max) {
-    return _th_clamp_max_out(self, self, *max);
-  } else if (min) {
-    return _th_clamp_min_out(self, self, *min);
-  } else {
-    return self;
-  }
+  return _clamp_out_cuda(self, self, min, max);
 }
 
 Tensor& _clamp_out_cuda(
@@ -25,6 +17,8 @@ Tensor& _clamp_out_cuda(
     _th_clamp_max_out(result, self, *max);
   } else if (min) {
     _th_clamp_min_out(result, self, *min);
+  } else {
+    AT_ERROR("At least one of 'min' or 'max' must not be None");
   }
   return result;
 }
index d3d5d4c..f9634dc 100644 (file)
@@ -1722,6 +1722,12 @@ class _TestTorchMixin(object):
         torch.clamp(m1, max=max_val, out=out)
         self.assertEqual(out, res1)
 
+        error_msg = 'At least one of \'min\' or \'max\' must not be None'
+        with self.assertRaisesRegex(RuntimeError, error_msg):
+            m1.clamp()
+        with self.assertRaisesRegex(RuntimeError, error_msg):
+            m1.clamp_()
+
     def test_pow(self):
         # [res] torch.pow([res,] x)