Use stacklevel for floordiv deprecation warnings (#64034)
authorSaketh Are <saketh.are@gmail.com>
Tue, 31 Aug 2021 17:59:57 +0000 (10:59 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 18:27:56 +0000 (11:27 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/60548

`Tensor.__floordiv__` was indirectly deprecated by deprecation of `torch.floor_divide` (see https://github.com/pytorch/pytorch/issues/43874). Deprecating it directly provides clearer feedback.

Repro:
```
import torch
x = torch.tensor(0)
x // 1
```

Before this change, a deprecation warning was triggered within the C++ implementation of floor_divide:
```
UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:571.)
  return torch.floor_divide(self, other)
```

After this change, the warning instead cites the user's offending line of Python code:
```
UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  x // 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64034

Reviewed By: mruberry

Differential Revision: D30658010

Pulled By: saketh-are

fbshipit-source-id: b0e6c5008d741897509d102f4a89efb47de4aa2a

test/test_binary_ufuncs.py
test/test_sparse.py
torch/_tensor.py

index 1e9e804..2695ab6 100644 (file)
@@ -1622,7 +1622,7 @@ class TestBinaryUfuncs(TestCase):
         x = torch.randn(10, device=device).mul(30).to(dtype)
         y = torch.arange(1, 11, dtype=dtype, device=device)
 
-        with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
+        with self.assertWarnsOnceRegex(UserWarning, "__floordiv__"):
             z = x // y
         z_alt = torch.trunc(x.double() / y.double()).to(dtype)
 
@@ -1634,7 +1634,7 @@ class TestBinaryUfuncs(TestCase):
     def test_floor_divide_scalar(self, device, dtype):
         x = torch.randn(100, device=device).mul(10).to(dtype)
 
-        with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
+        with self.assertWarnsOnceRegex(UserWarning, "__floordiv__"):
             z = x // 3
         z_alt = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=x.dtype, device=device)
 
index aaf045c..8fa32ed 100644 (file)
@@ -1562,7 +1562,7 @@ class TestSparse(TestCase):
         self.assertEqual(self.safeToDense(y1), expected)
         self.assertEqual(self.safeToDense(y2), expected)
 
-        with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'):
+        with self.assertWarnsOnceRegex(UserWarning, '__floordiv__'):
             y1 = x1 // 37.5
         y2 = x1.clone()
         with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'):
@@ -2915,7 +2915,7 @@ class TestSparse(TestCase):
                                / torch.tensor(1., device=device).to_sparse())
 
     def test_floor_divide_by_sparse_error(self, device):
-        self.assertRaisesRegex(RuntimeError, 'Sparse floor division requires',
+        self.assertRaisesRegex(RuntimeError, 'Sparse division requires',
                                lambda: torch.tensor(1., device=device).to_sparse()
                                // torch.tensor(1., device=device).to_sparse())
 
index b4cee9a..e7bc4ed 100644 (file)
@@ -582,11 +582,21 @@ class Tensor(torch._C._TensorBase):
 
     @_wrap_type_error_to_not_implemented
     def __floordiv__(self, other):
-        return torch.floor_divide(self, other)
+        warnings.warn("__floordiv__ is deprecated, and its behavior will change in a future version of pytorch. "
+                      "It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). "
+                      "This results in incorrect rounding for negative values. "
+                      "To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
+                      "or for actual floor division, use torch.div(a, b, rounding_mode='floor').", stacklevel=3)
+        return torch.div(self, other, rounding_mode='trunc')
 
     @_wrap_type_error_to_not_implemented
     def __rfloordiv__(self, other):
-        return torch.floor_divide(other, self)
+        warnings.warn("__rfloordiv__ is deprecated, and its behavior will change in a future version of pytorch. "
+                      "It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). "
+                      "This results in incorrect rounding for negative values. "
+                      "To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
+                      "or for actual floor division, use torch.div(a, b, rounding_mode='floor').", stacklevel=3)
+        return torch.div(other, self, rounding_mode='trunc')
 
     @_wrap_type_error_to_not_implemented
     def __rlshift__(self, other):