From 1921816f85b6db978175de5029bab03f29718d8d Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 4 Dec 2018 07:02:01 -0800 Subject: [PATCH] Fix clamp when min/max are both None (#14716) Summary: Before this PR, tensor.clamp() would return an empty tensor if min and max were not specified. This is a regression from 0.4.1, which would throw an error. This PR restores that error message. Fixes #14470 Pull Request resolved: https://github.com/pytorch/pytorch/pull/14716 Differential Revision: D13311031 Pulled By: zou3519 fbshipit-source-id: 87894db582d5749eaccfc22ba06aac4e10983880 --- aten/src/ATen/native/UnaryOps.cpp | 12 +++--------- aten/src/ATen/native/cuda/CUDAUnaryOps.cpp | 12 +++--------- test/test_torch.py | 6 ++++++ 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index d5cd2ff..04b5932 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -46,15 +46,7 @@ Tensor clamp_min(const Tensor& self, Scalar min) { } Tensor& _clamp__cpu(Tensor& self, optional min, optional max) { - if (min && max) { - return _th_clamp_out(self, self, *min, *max); - } else if (max) { - return _th_clamp_max_out(self, self, *max); - } else if (min) { - return _th_clamp_min_out(self, self, *min); - } else { - return self; - } + return _clamp_out_cpu(self, self, min, max); } Tensor& _clamp_out_cpu( @@ -68,6 +60,8 @@ Tensor& _clamp_out_cpu( _th_clamp_max_out(result, self, *max); } else if (min) { _th_clamp_min_out(result, self, *min); + } else { + AT_ERROR("At least one of 'min' or 'max' must not be None"); } return result; } diff --git a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp index 905c243..f3c1a4b 100644 --- a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp +++ b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp @@ -3,15 +3,7 @@ namespace at { namespace native { Tensor& _clamp__cuda(Tensor& self, optional min, optional max) { - if (min && max) { - return _th_clamp_out(self, self, *min, *max); - } else if (max) { - return _th_clamp_max_out(self, self, *max); - } else if (min) { - return _th_clamp_min_out(self, self, *min); - } else { - return self; - } + return _clamp_out_cuda(self, self, min, max); } Tensor& _clamp_out_cuda( @@ -25,6 +17,8 @@ Tensor& _clamp_out_cuda( _th_clamp_max_out(result, self, *max); } else if (min) { _th_clamp_min_out(result, self, *min); + } else { + AT_ERROR("At least one of 'min' or 'max' must not be None"); } return result; } diff --git a/test/test_torch.py b/test/test_torch.py index d3d5d4c..f9634dc 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1722,6 +1722,12 @@ class _TestTorchMixin(object): torch.clamp(m1, max=max_val, out=out) self.assertEqual(out, res1) + error_msg = 'At least one of \'min\' or \'max\' must not be None' + with self.assertRaisesRegex(RuntimeError, error_msg): + m1.clamp() + with self.assertRaisesRegex(RuntimeError, error_msg): + m1.clamp_() + def test_pow(self): # [res] torch.pow([res,] x) -- 2.7.4