Fix tensor printing bug in Python 2 (#12732)
authorDavid Riazati <davidriazati@fb.com>
Mon, 17 Dec 2018 21:08:03 +0000 (13:08 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 17 Dec 2018 21:17:51 +0000 (13:17 -0800)
Summary:
`rsplit` doesn't have kwargs in Python 2 so this line raises an error

Fixes #15135
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12732

Differential Revision: D10458630

Pulled By: driazati

fbshipit-source-id: a63e42fbc0e39e4291480775b516c98122ec05a1

test/test_jit.py
torch/_tensor_str.py

index 81e8e5a..77df56b 100644 (file)
@@ -8764,6 +8764,18 @@ a")
 
         self.assertExpectedGraph(foo.graph)
 
+    def test_cpp_function_tensor_str(self):
+        x = torch.randn(2, 2)
+        scale = torch.randn(2, 2, requires_grad=True)
+        shift = torch.randn(2, 2, requires_grad=True)
+
+        @torch.jit.script
+        def fn(x, scale, shift):
+            return scale * x + shift
+
+        with self.capture_stdout() as captured:
+            print(fn(x, scale, shift))
+
 
 class MnistNet(nn.Module):
     def __init__(self):
index a00c32e..4132729 100644 (file)
@@ -282,7 +282,7 @@ def _str(self):
     if self.grad_fn is not None:
         name = type(self.grad_fn).__name__
         if name == 'CppFunction':
-            name = self.grad_fn.name().rsplit('::', maxsplit=1)[-1]
+            name = self.grad_fn.name().rsplit('::', 1)[-1]
         suffixes.append('grad_fn=<{}>'.format(name))
     elif self.requires_grad:
         suffixes.append('requires_grad=True')