From e322547fe6dd4f0ca9261a1ac2ae7095800b98a1 Mon Sep 17 00:00:00 2001 From: Alban Desmaison Date: Wed, 1 Sep 2021 13:34:48 -0700 Subject: [PATCH] Add forward AD support for custom Functions (#64061) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64061 Test Plan: Imported from OSS Reviewed By: soulitzer Differential Revision: D30640868 Pulled By: albanD fbshipit-source-id: b0e6610430a879074d6d5306443772fc154b431f --- test/test_autograd.py | 114 +++++++++++++++++++ torch/autograd/function.py | 24 ++++ torch/csrc/autograd/custom_function.cpp | 187 ++++++++++++++++++++++++++++++-- torch/csrc/autograd/custom_function.h | 12 +- torch/csrc/autograd/python_function.cpp | 59 +++++++++- 5 files changed, 385 insertions(+), 11 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 8b3c8bd..ebe3aa5 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5494,6 +5494,11 @@ for shape in [(1,), ()]: def vjp(ctx, foo): return foo + class BadJvp(Function): + @staticmethod + def forward(ctx, foo): + return foo.clone() + inp = torch.rand(1, requires_grad=True) with self.assertRaisesRegex(NotImplementedError, "must implement the forward"): BadFw.apply(inp) @@ -5504,6 +5509,115 @@ for shape in [(1,), ()]: with self.assertRaisesRegex(RuntimeError, "Implementing both 'backward' and 'vjp'"): BadBw2.apply(inp).sum().backward() + with self.assertRaisesRegex(RuntimeError, "must implement the jvp function"): + with fwAD.dual_level(): + d = fwAD.make_dual(inp, torch.rand_like(inp)) + res = BadJvp.apply(d) + + def test_custom_function_forward_mode_view_checks(self): + flag_to_error = { + "ok": None, + "not_a_view": "jvp is not returning a view", + "not_a_view_of_inp": "jvp is not returning a view of the given", + "not_a_view_of_inp_base": "jvp is not returning a view of the same base", + } + + class ViewFn(Function): + @staticmethod + def forward(ctx, foo, flag): + ctx.flag = flag + ctx.size = foo.size() + return foo.narrow(0, 0, 2) + + @staticmethod + def vjp(ctx, gO): + gI = gO.new_zeros(ctx.size) + gI.narrow(0, 0, 2).copy_(gO) + return gI, None + + @staticmethod + def jvp(ctx, gI, _): + res = gI.narrow(0, 0, 2) + if ctx.flag != "ok": + # Break the view in the gradients! + res = res.clone() + if ctx.flag in ["not_a_view_of_inp", "not_a_view_of_inp_base"]: + # Result should be a view, just of the wrong thing + res = res.view_as(res) + return res + + inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True) + + for flag, msg in flag_to_error.items(): + def test_fn(inp): + if flag == "not_a_view_of_inp_base": + inp = inp.view_as(inp) + return ViewFn.apply(inp, flag) + + if msg is None: + gradcheck(test_fn, inp, check_forward_ad=True) + else: + with self.assertRaisesRegex(RuntimeError, msg): + gradcheck(test_fn, inp, check_forward_ad=True) + + def test_custom_function_forward_mode_inplace_checks(self): + class InplaceFn(Function): + @staticmethod + def forward(ctx, foo, flag): + ctx.mark_dirty(foo) + ctx.flag = flag + foo.mul_(2) + return foo + + @staticmethod + def vjp(ctx, gO): + return 2 * gO, None + + @staticmethod + def jvp(ctx, gI, _): + if ctx.flag: + # Don't do the change inplace + return 2 * gI + else: + gI.mul_(2) + return gI + + inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True) + + def test_fn(inp, flag): + inp = inp.clone() + return InplaceFn.apply(inp, flag) + + gradcheck(test_fn, (inp, False), check_forward_ad=True) + + with self.assertRaisesRegex(RuntimeError, "inplace custom Function is not modifying the forward mode gradients inplace"): + gradcheck(test_fn, (inp, True), check_forward_ad=True) + + def test_custom_function_forward_mode_wrong_formula(self): + class UserFn(Function): + @staticmethod + def forward(ctx, foo, should_fail): + ctx.should_fail = should_fail + return foo * 2 + + @staticmethod + def vjp(ctx, gO): + return 2 * gO, None + + @staticmethod + def jvp(ctx, gI, _): + if ctx.should_fail: + # Wrong gradient formula + return 3 * gI + else: + return 2 * gI + + inp = torch.rand(10, dtype=torch.double, requires_grad=True) + gradcheck(UserFn.apply, (inp, False), check_forward_ad=True) + + with self.assertRaisesRegex(RuntimeError, "Jacobian computed with forward mode mismatch for output 0"): + gradcheck(UserFn.apply, (inp, True), check_forward_ad=True) + def test_custom_function_local_inplace(self): class MyFn(torch.autograd.Function): @staticmethod diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 90aeea5..909e719 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -198,6 +198,10 @@ class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn return user_fn(self, *args) + def apply_jvp(self, *args): + # _forward_cls is defined by derived class + return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined] + class FunctionMeta(type): """Function metaclass. @@ -307,6 +311,26 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _Hook # vjp and backward are alias of each other vjp = backward + @staticmethod + def jvp(ctx: Any, *grad_inputs: Any) -> Any: + r"""Defines a formula for differentiating the operation with forward mode + automatic differentiation. + This function is to be overridden by all subclasses. + It must accept a context :attr:`ctx` as the first argument, followed by + as many inputs as the :func:`forward` got (None will be passed in + for non tensor inputs of the forward function), + and it should return as many tensors as there were outputs to + :func:`forward`. Each argument is the gradient w.r.t the given input, + and each returned value should be the gradient w.r.t. the + corresponding output. If an output is not a Tensor or the function is not + differentiable with respect to that output, you can just pass None as a + gradient for that input. + + You can use the :attr:`ctx` object to pass any value from the forward to this + functions. + """ + raise NotImplementedError("You must implement the jvp function for custom " + "autograd.Function to use it with forward mode AD.") def once_differentiable(fn): diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index fdcf997..1bb4cb8 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -26,8 +26,175 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { } } +// This function has two main goals: +// 1) Use the user-provided jvp function to populate the the outputs' forward gradient +// 2) Perform error checking to ensure that view and inplace ops are properly handled +// +// For 1) we have to: +// - Create a variable_list of grad_inputs based on the function inputs +// - Call the user jvp function with these to get the grad_outputs +// - Set the forward grad field on each output based on these grad_outputs +// +// For 2) we want to check the following: +// - If an output is a view, then the generated forward grad must be a view as well and +// the output's base's forward grad must be the output's forward grad's base. +// - If an input was modified inplace (it must be an output as well) we make sure that its +// forward grad was also modified inplace and already present on the corresponding output. +void _process_forward_mode_AD(const variable_list &inputs, + std::unordered_map inputs_mapping, + const at::ArrayRef> raw_outputs, + const optional_variable_list &outputs, + const std::unordered_set &non_differentiable, + const std::unordered_set &dirty_inputs, + _jvp_fn_t jvp_user_function) { + + // TODO handle multiple levels here + uint64_t level = 0; + + const auto num_inputs = inputs.size(); + const auto num_outputs = outputs.size(); + + // The tracking info below are used to perform the view and inplace checks. + // They are lazily initialized to reduce the cost of this function in the common + // case where the user is not using forward mode AD. + variable_list input_grads; + std::vector grad_versions; + std::vector grad_impls; + std::unordered_map inputs_bases; + + auto init_tracked_info = [&] () { + input_grads.resize(num_inputs); + grad_versions.resize(num_inputs); + grad_impls.resize(num_inputs); + + for (const auto i: c10::irange(num_inputs)) { + const auto& inp = inputs[i]; + if (inp.is_view() && impl::get_view_autograd_meta(inp)->has_fw_view()) { + inputs_bases.emplace(impl::get_view_autograd_meta(inp)->get_forward_view().base_.unsafeGetTensorImpl(), i); + } else { + inputs_bases.emplace(inp.unsafeGetTensorImpl(), i); + } + + } + }; + + bool any_input_has_grad = false; + // Extract the input's forward gradients and record any info we will need later + for (const auto i : c10::irange(num_inputs)) { + const auto& inp = inputs[i]; + if (!inp.defined()) { + continue; + } + const auto& fw_grad = inp._fw_grad(level); + if (fw_grad.defined()) { + if (!any_input_has_grad) { + any_input_has_grad = true; + init_tracked_info(); + } + input_grads[i] = fw_grad; + grad_versions[i] = fw_grad._version(); + grad_impls[i] = fw_grad.unsafeGetTensorImpl(); + } + } + + // If no input has forward grad, nothing to do here + if (!any_input_has_grad) { + return; + } + + + auto forward_grads = jvp_user_function(inputs, input_grads); + + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const auto num_forward_grads = forward_grads.size(); + // contrary to backward mode, we don't allow returning too many gradients + TORCH_CHECK(num_forward_grads == num_outputs, "Function's jvp returned " + "an invalid number of of forward gradients (expected ", num_outputs, + " but got ", num_forward_grads, ")"); + + for (const auto i : c10::irange(num_outputs)) { + const auto& out = outputs[i].has_value()? outputs[i].value() : at::Tensor(); + const auto& out_grad = forward_grads[i]; + if (!out.defined()) { + TORCH_CHECK(!out_grad.defined(), "Function's jvp returned a gradient at position ", i, ", but " + " the corresponding forward output is not a differentiable Tensor"); + continue; + } + + TORCH_INTERNAL_ASSERT(raw_outputs[i].has_value()); + auto out_tensor_impl = raw_outputs[i].value().unsafeGetTensorImpl(); + bool is_input = inputs_mapping.count(out_tensor_impl) > 0; + bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; + + if (is_modified) { + TORCH_CHECK(is_input, "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there" + " is no need to pass it to mark_dirty()."); + auto inp_idx = inputs_mapping[out_tensor_impl]; + if (grad_impls[inp_idx]) { + // If there was already a forward grad for that input + // Just make sure that it is modified inplace and returned as-is + TORCH_CHECK(out_grad._version() != grad_versions[inp_idx], "An inplace custom Function is not modifying the " + "forward mode gradients inplace. If the forward is modifying an input inplace, then the jvp " + "function must modify the corresponding gradient inplace.") + TORCH_CHECK(out_grad.unsafeGetTensorImpl() == grad_impls[inp_idx], "An inplace custom Function is not returning the " + "forward mode gradients as-is. If the forward is modifying an input inplace, then the jvp " + "function must modify the gradient inplace and return it as-is.") + } else { + // If that Tensor didn't had gradients already, set the newly returned one + // We could also use inputs[inp_idx] here as it is the same as out + out._set_fw_grad(out_grad, level, /* is_inplace_op */ true); + } + } else { + // At this point, outputs[i] cannot be one of the input (raw_outputs[i] might be but was changed by the backward code) + TORCH_INTERNAL_ASSERT(!is_input); + + if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) { + // If the output is a view + const auto& out_view_info = impl::get_view_autograd_meta(out)->get_forward_view(); + if (inputs_bases.count(out_view_info.base_.unsafeGetTensorImpl())) { + // And it is a view of an input (either that input is its base or they have a common base) + const auto matching_input_idx = inputs_bases[out_view_info.base_.unsafeGetTensorImpl()]; + const auto& matching_input = inputs[matching_input_idx]; + + const auto& matching_input_grad = matching_input._fw_grad(level); + + // If the matching input has a forward grad, the user should have returned a view of that Tensor + if (matching_input_grad.defined()) { + TORCH_CHECK(out_grad.is_view() && impl::get_view_autograd_meta(out_grad)->has_fw_view(), + "A custom Function's forward is returning a view but the jvp is not returning a view."); + + const auto& out_grad_base = impl::get_view_autograd_meta(out_grad)->get_forward_view().base_; + if (matching_input_grad.is_view() && impl::get_view_autograd_meta(matching_input_grad)->has_fw_view()) { + // If the matching input's grad is a view, ensure that the out_grad is a view of the same base + const auto& matching_input_grad_base = impl::get_view_autograd_meta(matching_input_grad)->get_forward_view().base_; + TORCH_CHECK(matching_input_grad_base.unsafeGetTensorImpl() == out_grad_base.unsafeGetTensorImpl(), + "A custom Function is returning a view but the jvp is not returning a view of the same base as " + "the given grad input."); + } else { + // If the matching input's grad is not a view, then it must be the output gradient's base + TORCH_CHECK(matching_input_grad.unsafeGetTensorImpl() == out_grad_base.unsafeGetTensorImpl(), + "A custom Function is returning a view but the jvp is not returning a view of the given grad input."); + } + } else { + // We have a view op where the input didn't have a forward grad but the user returned one for the output + // To ensure that we maintain the view/inplace constraints, we consider this as an inplace op + // This case CANNOT happen in codegen as all view ops are mapping from one Tensor to one Tensor and so the output + // of the view cannot have a forward grad if the base does not. + out._set_fw_grad(out_grad, level, /* is_inplace_op */ true); + return; + } + + } + } + + out._set_fw_grad(out_grad, level, /* is_inplace_op */ false); + } + } +} + optional_variable_list _process_backward_mode_ad( - const std::unordered_set &inputs_set, + const std::unordered_map &inputs_mapping, const std::unordered_set &non_differentiable, const std::unordered_set &dirty_inputs, const at::ArrayRef> raw_outputs, @@ -121,7 +288,7 @@ optional_variable_list _process_backward_mode_ad( Variable var = raw_outputs[i].value(); auto out_tensor_impl = var.unsafeGetTensorImpl(); - bool is_input = inputs_set.count(out_tensor_impl) > 0; + bool is_input = inputs_mapping.count(out_tensor_impl) > 0; bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0 && isDifferentiableType(var.scalar_type()); @@ -179,16 +346,20 @@ optional_variable_list _wrap_outputs(const variable_list &input_vars, const std::unordered_set &non_differentiable, const std::unordered_set &dirty_inputs, const at::ArrayRef> raw_outputs, - const std::shared_ptr &cdata) { + const std::shared_ptr &cdata, + _jvp_fn_t jvp_user_function) { - std::unordered_set inputs_set; - inputs_set.reserve(input_vars.size()); - for (auto& var : input_vars) { - inputs_set.emplace(var.unsafeGetTensorImpl()); + std::unordered_map inputs_mapping; + inputs_mapping.reserve(input_vars.size()); + for (const auto i: c10::irange(input_vars.size())) { + inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i); } - auto outputs = _process_backward_mode_ad(inputs_set, non_differentiable, dirty_inputs, raw_outputs, cdata); + auto outputs = _process_backward_mode_ad(inputs_mapping, non_differentiable, dirty_inputs, raw_outputs, cdata); + // This must happen after the backward processing as we expect the computations happening here to track + // backward mode gradients. + _process_forward_mode_AD(input_vars, inputs_mapping, raw_outputs, outputs, non_differentiable, dirty_inputs, jvp_user_function); return outputs; } diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 376cab6..94e62bf 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -10,13 +10,15 @@ namespace torch { namespace autograd { using optional_variable_list = std::vector>; +using _jvp_fn_t = std::function; TORCH_API std::vector> _wrap_outputs( const variable_list &input_vars, const std::unordered_set &non_differentiable, const std::unordered_set &dirty_inputs, const at::ArrayRef> raw_outputs, - const std::shared_ptr &cdata); + const std::shared_ptr &cdata, + _jvp_fn_t jvp_user_function); TORCH_API void check_variable_result(const Variable& original, const Variable& result, std::string hook_name); @@ -265,12 +267,18 @@ auto Function::apply(Args&&... args) -> std::enable_if_t::v outputs = T::forward(&node->ctx_, std::forward(args)...); } + _jvp_fn_t jvp_fn = [](variable_list inputs, variable_list gI) -> variable_list { + TORCH_CHECK(false, "jvp is not implemented for the c++ API of custom Function yet.", + "Please open a feature request on Github if you need this."); + }; + auto wrapped_outputs = _wrap_outputs( input_vars, node->ctx_.get_non_differentiable(), node->ctx_.get_and_bump_dirty(), to_optional(outputs), - is_executable ? node : nullptr); + is_executable ? node : nullptr, + jvp_fn); node->output_info_.reserve(wrapped_outputs.size()); for (auto& output : wrapped_outputs) { diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 1487418..eee56f7 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -340,8 +340,61 @@ static void _wrap_outputs(const std::shared_ptr& cdata, THPFunction *sel } } + _jvp_fn_t jvp_user_function = [self](variable_list inputs, variable_list grad_inputs) { + pybind11::gil_scoped_acquire gil; + + // Massage a C++ variable_list into a Python arguments tuple + // Making sure to introduce the proper None for non-Tensor inputs + auto num_inputs = self->is_variable_input.size(); + THPObjectPtr pyInputs(PyTuple_New(num_inputs)); + if (!pyInputs) throw_python_error(); + auto var_input_idx = 0; + for (const auto i : c10::irange(num_inputs)) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + PyObject* input; + if (self->is_variable_input[i]) { + if (grad_inputs[i].defined() || !self->materialize_grads) { + input = THPVariable_Wrap(grad_inputs[i]); + } else { + input = THPVariable_Wrap(at::zeros_like(inputs[i])); + } + if (!input) throw_python_error(); + } else { + Py_INCREF(Py_None); + input = Py_None; + } + PyTuple_SET_ITEM(pyInputs.get(), i, input); + } + + THPObjectPtr apply_jvp_fn(PyObject_GetAttrString((PyObject*)self, "apply_jvp")); + if (!apply_jvp_fn) throw_python_error(); + THPObjectPtr r(PyObject_CallObject(apply_jvp_fn, pyInputs.get())); + if (!r) throw_python_error(); + ensure_tuple(r); + + // Massage the Python results tuple back into a C++ variable_list + // Don't do any check on the number of results here as + // it is handled by the caller + const int num_outputs = PyTuple_GET_SIZE(r.get()); + variable_list results; + results.reserve(num_outputs); + for (int i = 0; i != num_outputs; ++i) { + PyObject* output = PyTuple_GET_ITEM(r.get(), i); + if (output == Py_None) { + results.emplace_back(); + } else { + TORCH_CHECK(THPVariable_Check(output), "expected Variable or None (got ", + THPUtils_typename(output), ") for grad output ", i, ".") + results.emplace_back(THPVariable_Unpack(output)); + } + } + + return results; + }; + // Wrap only the tensor outputs. - auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, raw_output_vars, cdata_if_executable); + auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, + raw_output_vars, cdata_if_executable, jvp_user_function); for(const auto i : c10::irange(num_outputs)) { PyObject* obj = PyTuple_GetItem(raw_output, i); @@ -571,6 +624,9 @@ PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr& cdata bool is_inplace = static_cast(grad_fn->dirty_tensors); _wrap_outputs(cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable); _trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output); + + // It is important that creating the SavedVariables happen after the output wrapping as the + // outputs must have their grad_fn/fw_grad properly set before we save them. if (is_executable) { _save_variables(cdata, grad_fn); } else { @@ -651,6 +707,7 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) THPObjectPtr tensor_outputs; { AutoGradMode grad_mode(false); + at::AutoFwGradMode fw_grad_mode(false); THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); if (!forward_fn) return nullptr; tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple); -- 2.7.4