self.assertExpectedGraph(foo.graph)
+ def test_cpp_function_tensor_str(self):
+ x = torch.randn(2, 2)
+ scale = torch.randn(2, 2, requires_grad=True)
+ shift = torch.randn(2, 2, requires_grad=True)
+
+ @torch.jit.script
+ def fn(x, scale, shift):
+ return scale * x + shift
+
+ with self.capture_stdout() as captured:
+ print(fn(x, scale, shift))
+
class MnistNet(nn.Module):
def __init__(self):
if self.grad_fn is not None:
name = type(self.grad_fn).__name__
if name == 'CppFunction':
- name = self.grad_fn.name().rsplit('::', maxsplit=1)[-1]
+ name = self.grad_fn.name().rsplit('::', 1)[-1]
suffixes.append('grad_fn=<{}>'.format(name))
elif self.requires_grad:
suffixes.append('requires_grad=True')