def forward(ctx, foo):
return foo.clone()
+ class BadBw2(Function):
+ @staticmethod
+ def forward(ctx, foo):
+ return foo.clone()
+
+ @staticmethod
+ def backward(ctx, foo):
+ return foo
+
+ @staticmethod
+ def vjp(ctx, foo):
+ return foo
+
inp = torch.rand(1, requires_grad=True)
with self.assertRaisesRegex(NotImplementedError, "must implement the forward"):
BadFw.apply(inp)
- with self.assertRaisesRegex(RuntimeError, "must implement the backward"):
+ with self.assertRaisesRegex(RuntimeError, "must implement either the backward"):
BadBw.apply(inp).sum().backward()
+ with self.assertRaisesRegex(RuntimeError, "Implementing both 'backward' and 'vjp'"):
+ BadBw2.apply(inp).sum().backward()
+
def test_custom_function_local_inplace(self):
class MyFn(torch.autograd.Function):
@staticmethod
class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
def apply(self, *args):
# _forward_cls is defined by derived class
- return self._forward_cls.backward(self, *args) # type: ignore[attr-defined]
+ # The user should define either backward or vjp but never both.
+ backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
+ vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
+ if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
+ raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom "
+ "Function is not allowed. You should only implement one "
+ "of them.")
+ user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
+ return user_fn(self, *args)
class FunctionMeta(type):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
- r"""Defines a formula for differentiating the operation.
+ r"""Defines a formula for differentiating the operation with backward mode
+ automatic differentiation.
This function is to be overridden by all subclasses.
first input to :func:`forward` needs gradient computated w.r.t. the
output.
"""
- raise NotImplementedError("You must implement the backward function for custom"
- " autograd.Function.")
+ raise NotImplementedError("You must implement either the backward or vjp method for "
+ "your custom autograd.Function to use it with backward "
+ "mode AD.")
+
+ # vjp and backward are alias of each other
+ vjp = backward
def once_differentiable(fn):