Support enough of closures to write autograd functions (#15411)
authorZachary DeVito <zdevito@fb.com>
Thu, 20 Dec 2018 22:26:06 +0000 (14:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 22:39:11 +0000 (14:39 -0800)
commit1a2ec10bd4a73f4ebeefeb7a7cc781a80b14794c
treee0ccf3ecf68a98420a543918006a4469260f2839
parent3fdf567752b484c5542efea02afc64e4f875eeca
Support enough of closures to write autograd functions (#15411)

Summary:
This PR adds enough of the infra for supporting closures (inner script functions) in order to allow us to expression symbolic gradients using them. We do not actually ever run graphs that contain these closures. The symbolic_script infrastructure just extracts them out of the original forward graph and turns them into discrete forward/backward pairs. This cuts down on the type annotations necessary to write forward/backward pairs and aligns closely with the "differentiator" function approach to expression reverse-mode AD.

Example:

This code:
```
import torch

r = torch.jit.CompilationUnit(
'''
def mul_forward(self, other):
    def backward(grad_output):
        grad_self = (grad_output * other).sum_to_size(self.size())
        grad_other = (grad_output * self).sum_to_size(other.size())
        return grad_self, grad_other
    return self * other, backward
''')

print(r.module.code)
```

Will produce this graph (pretty printed for clarity):

```
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  backward = (self.__lambda, (other, self))
  return (torch.mul(self, other), backward)

def __lambda(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)
```

symbolic_script will then do some modifications to remove the unsuppored prim::Function node, yielding:

```
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  return (torch.mul(self, other), (other, self))

def backward(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15411

Differential Revision: D13523340

Pulled By: zdevito

fbshipit-source-id: 4d4a269460e595b16802c00ec55ae00e3e682d49
aten/src/ATen/core/interned_strings.h
test/expect/TestFuser.test_lstm_cuda-backward.expect
test/expect/TestFuser.test_milstm_cuda-backward.expect
test/expect/TestJit.test_cpp_cuda.expect
test/test_jit.py
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/parser.cpp
torch/csrc/jit/script/tree_views.h
torch/csrc/jit/symbolic_script.cpp