From bafd875f743d93ccb3463676ea29101cae1760d7 Mon Sep 17 00:00:00 2001 From: Alban Desmaison Date: Mon, 23 Aug 2021 07:05:51 -0700 Subject: [PATCH] Allow implementing either backward or vjp for Function (#63434) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63434 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30431968 Pulled By: albanD fbshipit-source-id: 0bb88664283486a9fd3364e6c3d79442a44625c2 --- test/test_autograd.py | 18 +++++++++++++++++- torch/autograd/function.py | 21 +++++++++++++++++---- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 7200bd5..8b7aeb4 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5477,13 +5477,29 @@ for shape in [(1,), ()]: def forward(ctx, foo): return foo.clone() + class BadBw2(Function): + @staticmethod + def forward(ctx, foo): + return foo.clone() + + @staticmethod + def backward(ctx, foo): + return foo + + @staticmethod + def vjp(ctx, foo): + return foo + inp = torch.rand(1, requires_grad=True) with self.assertRaisesRegex(NotImplementedError, "must implement the forward"): BadFw.apply(inp) - with self.assertRaisesRegex(RuntimeError, "must implement the backward"): + with self.assertRaisesRegex(RuntimeError, "must implement either the backward"): BadBw.apply(inp).sum().backward() + with self.assertRaisesRegex(RuntimeError, "Implementing both 'backward' and 'vjp'"): + BadBw2.apply(inp).sum().backward() + def test_custom_function_local_inplace(self): class MyFn(torch.autograd.Function): @staticmethod diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 4d61229..90aeea5 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -188,7 +188,15 @@ class _HookMixin(object): class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): def apply(self, *args): # _forward_cls is defined by derived class - return self._forward_cls.backward(self, *args) # type: ignore[attr-defined] + # The user should define either backward or vjp but never both. + backward_fn = self._forward_cls.backward # type: ignore[attr-defined] + vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined] + if backward_fn is not Function.backward and vjp_fn is not Function.vjp: + raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom " + "Function is not allowed. You should only implement one " + "of them.") + user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn + return user_fn(self, *args) class FunctionMeta(type): @@ -271,7 +279,8 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _Hook @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: - r"""Defines a formula for differentiating the operation. + r"""Defines a formula for differentiating the operation with backward mode + automatic differentiation. This function is to be overridden by all subclasses. @@ -291,8 +300,12 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _Hook first input to :func:`forward` needs gradient computated w.r.t. the output. """ - raise NotImplementedError("You must implement the backward function for custom" - " autograd.Function.") + raise NotImplementedError("You must implement either the backward or vjp method for " + "your custom autograd.Function to use it with backward " + "mode AD.") + + # vjp and backward are alias of each other + vjp = backward def once_differentiable(fn): -- 2.7.4