Allow implementing either backward or vjp for Function (#63434)
authorAlban Desmaison <albandes@fb.com>
Mon, 23 Aug 2021 14:05:51 +0000 (07:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 23 Aug 2021 14:07:11 +0000 (07:07 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63434

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30431968

Pulled By: albanD

fbshipit-source-id: 0bb88664283486a9fd3364e6c3d79442a44625c2

test/test_autograd.py
torch/autograd/function.py

index 7200bd5..8b7aeb4 100644 (file)
@@ -5477,13 +5477,29 @@ for shape in [(1,), ()]:
             def forward(ctx, foo):
                 return foo.clone()
 
+        class BadBw2(Function):
+            @staticmethod
+            def forward(ctx, foo):
+                return foo.clone()
+
+            @staticmethod
+            def backward(ctx, foo):
+                return foo
+
+            @staticmethod
+            def vjp(ctx, foo):
+                return foo
+
         inp = torch.rand(1, requires_grad=True)
         with self.assertRaisesRegex(NotImplementedError, "must implement the forward"):
             BadFw.apply(inp)
 
-        with self.assertRaisesRegex(RuntimeError, "must implement the backward"):
+        with self.assertRaisesRegex(RuntimeError, "must implement either the backward"):
             BadBw.apply(inp).sum().backward()
 
+        with self.assertRaisesRegex(RuntimeError, "Implementing both 'backward' and 'vjp'"):
+            BadBw2.apply(inp).sum().backward()
+
     def test_custom_function_local_inplace(self):
         class MyFn(torch.autograd.Function):
             @staticmethod
index 4d61229..90aeea5 100644 (file)
@@ -188,7 +188,15 @@ class _HookMixin(object):
 class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
     def apply(self, *args):
         # _forward_cls is defined by derived class
-        return self._forward_cls.backward(self, *args)  # type: ignore[attr-defined]
+        # The user should define either backward or vjp but never both.
+        backward_fn = self._forward_cls.backward  # type: ignore[attr-defined]
+        vjp_fn = self._forward_cls.vjp  # type: ignore[attr-defined]
+        if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
+            raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
+                               "Function is not allowed. You should only implement one "
+                               "of them.")
+        user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
+        return user_fn(self, *args)
 
 
 class FunctionMeta(type):
@@ -271,7 +279,8 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _Hook
 
     @staticmethod
     def backward(ctx: Any, *grad_outputs: Any) -> Any:
-        r"""Defines a formula for differentiating the operation.
+        r"""Defines a formula for differentiating the operation with backward mode
+        automatic differentiation.
 
         This function is to be overridden by all subclasses.
 
@@ -291,8 +300,12 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _Hook
         first input to :func:`forward` needs gradient computated w.r.t. the
         output.
         """
-        raise NotImplementedError("You must implement the backward function for custom"
-                                  " autograd.Function.")
+        raise NotImplementedError("You must implement either the backward or vjp method for "
+                                  "your custom autograd.Function to use it with backward "
+                                  "mode AD.")
+
+    # vjp and backward are alias of each other
+    vjp = backward
 
 
 def once_differentiable(fn):