From 1dbc7cff3e654af6d74ebbfb81c6f045c45d7182 Mon Sep 17 00:00:00 2001 From: David Riazati Date: Mon, 17 Dec 2018 13:08:03 -0800 Subject: [PATCH] Fix tensor printing bug in Python 2 (#12732) Summary: `rsplit` doesn't have kwargs in Python 2 so this line raises an error Fixes #15135 Pull Request resolved: https://github.com/pytorch/pytorch/pull/12732 Differential Revision: D10458630 Pulled By: driazati fbshipit-source-id: a63e42fbc0e39e4291480775b516c98122ec05a1 --- test/test_jit.py | 12 ++++++++++++ torch/_tensor_str.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_jit.py b/test/test_jit.py index 81e8e5a..77df56b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8764,6 +8764,18 @@ a") self.assertExpectedGraph(foo.graph) + def test_cpp_function_tensor_str(self): + x = torch.randn(2, 2) + scale = torch.randn(2, 2, requires_grad=True) + shift = torch.randn(2, 2, requires_grad=True) + + @torch.jit.script + def fn(x, scale, shift): + return scale * x + shift + + with self.capture_stdout() as captured: + print(fn(x, scale, shift)) + class MnistNet(nn.Module): def __init__(self): diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index a00c32e..4132729 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -282,7 +282,7 @@ def _str(self): if self.grad_fn is not None: name = type(self.grad_fn).__name__ if name == 'CppFunction': - name = self.grad_fn.name().rsplit('::', maxsplit=1)[-1] + name = self.grad_fn.name().rsplit('::', 1)[-1] suffixes.append('grad_fn=<{}>'.format(name)) elif self.requires_grad: suffixes.append('requires_grad=True') -- 2.7.4