From: soulitzer Date: Fri, 27 Aug 2021 21:59:08 +0000 (-0700) Subject: Add autograd not implemented boxed fallback (#63458) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~632 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=90a6498a1288a4248b4cfe603949fd5b2e60dc0f;p=platform%2Fupstream%2Fpytorch.git Add autograd not implemented boxed fallback (#63458) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63458 See description and discussion from https://github.com/pytorch/pytorch/pull/62450 Test Plan: Imported from OSS Reviewed By: heitorschueroff Differential Revision: D30518572 Pulled By: soulitzer fbshipit-source-id: 3b1504d49abb84560ae17077f0dec335749c9882 --- diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index 80d892d..edb73f9 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -1,6 +1,8 @@ #include +#include #include +#include #include @@ -869,6 +871,261 @@ TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) { } } +/** + * Tests for AutogradNotImplementedFallback + * - Check that we created the NotImplemented kernel when inputs require grad + * but when no inputs require grad, we should not create this node + * - check_inplace logic + * - view ops (TODO: not an official view yet, update this once InplaceOrView kernel is landed) + * - TODO: Tests for NDEBUG checks? + * - tensorlist input and output + * - multiple outputs / non-tensor output + * - rebase_history vs set_history + */ +namespace { + +torch::Tensor inplace_op(const torch::Tensor& self, const torch::Tensor& other) { + return self.add_(other); +} + +std::tuple two_arg_inplace_op(const torch::Tensor& self, const torch::Tensor& other) { + other.add_(self); + self.add_(other); + return std::tuple(self, other); +} + +std::tuple two_pairs_of_view_op(const torch::Tensor& self, const torch::Tensor& other) { + // This is not allowed. We test below that this calling into the boxed kernel will raise an error + auto self_view = self.view(-1); + auto other_view = other.view(-1); + return std::tuple(self_view, other_view); +} + +int64_t ret_single_non_tensor(const torch::Tensor& self, const torch::Tensor& other) { + return 12; +} + +torch::Tensor opt_op(const torch::Tensor& self, const c10::optional& other) { + if (other.has_value()) { + return self + other.value(); + } else { + return self.clone(); + } +} + +torch::Tensor my_custom_op(const torch::Tensor& self, const torch::Tensor& other) { + return self + other; +} + +std::tuple ret_tuple_non_tensor(const torch::Tensor& self, const torch::Tensor& other) { + auto a = self - other; + auto b = self + other; + return std::tuple(a, b, 12); +} + +torch::Tensor view_op(const torch::Tensor& self, const torch::Tensor& other) { + return self.view(-1); +} + +std::vector ret_tensor_vector(const torch::Tensor& self, const torch::Tensor& other) { + std::vector out; + out.push_back(self + other); + out.push_back(self - other); + return out; +} + +torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) { + const auto& res = self.clone(); + for (const auto& t : other) { + res.add_(t); + } + return res; +} + +#define REGISTER_TEST_OP(name, schema, fn) \ + auto m = MAKE_TORCH_LIBRARY(_test); \ + m.def(schema); \ + auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \ + auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \ + m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \ + m_autograd.impl(name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); + +template +void assertBasicChecks(F op) { + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + auto c = torch::tensor({1.}, {torch::kFloat32}); + + // If any inputs require grad, + auto out1 = op(a, b); + ASSERT_THROWS_WITH(out1.backward(), "is not implemented"); + + // # Should not have grad_fn if none require grad + auto out2 = op(b, c); + ASSERT_THROWS_WITH(out2.backward(), "element 0 of tensors does not require grad and does not have a grad_fn"); + + // TODO: Forward AD Tests? +} + +} // namespace + +TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) { + REGISTER_TEST_OP("ret_single_non_tensor", "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int", ret_single_non_tensor); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_single_non_tensor", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + + ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b)); +} + +TEST(TestAutogradNotImplementedFallback, DoubleViewOP) { + REGISTER_TEST_OP("two_pairs_of_view_op", "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))", two_pairs_of_view_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_pairs_of_view_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + }; + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + ASSERT_THROWS_WITH(op(a, b), + "Expected only a single output in the operator schema to have a non-write alias annotation"); +} + +TEST(TestAutogradNotImplementedFallback, InplaceOp) { + REGISTER_TEST_OP("inplace_op", "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)", inplace_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + + // Check in-place + ASSERT_THROWS_WITH(op(a, b), + "a leaf Variable that requires grad is being used in an in-place operation"); + op(b, a); + a = a.clone(); + b = b.clone(); + auto c = op(a, b); + ASSERT_TRUE(torch::allclose(c, inplace_op(a, b))); + + // Test in-place on view + auto base = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); + auto view = base.view(-1); + auto t = torch::tensor({1.}, {torch::kFloat32}); + + torch::Tensor v_nograd; + { + c10::NoGradGuard guard; + v_nograd = base.view(-1); + op(v_nograd, t); + } + + ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode"); + ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl()); + + // TODO: once we have InplaceOrView kernel, renable this since version counter would actually + // be incremented + // ASSERT_THAT(op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward")); +} + +TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) { + REGISTER_TEST_OP("two_arg_inplace_op", "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))", two_arg_inplace_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_arg_inplace_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + }; + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + + // Both are modified in-place! + ASSERT_THROWS_WITH(op(a, b), + "a leaf Variable that requires grad is being used in an in-place operation"); + ASSERT_THROWS_WITH(op(b, a), + "a leaf Variable that requires grad is being used in an in-place operation"); +} + +TEST(TestAutogradNotImplementedFallback, OptOp) { + REGISTER_TEST_OP("opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", ""); + auto op = [&](const torch::Tensor& _1, const c10::optional& _2) { + return callOpUnboxed&>(opHandle, _1, _2); + }; + + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + + ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b))); + ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {}))); +} + +TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) { + REGISTER_TEST_OP("my_custom_op", "_test::my_custom_op(Tensor self, Tensor other) -> Tensor", my_custom_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + + assertBasicChecks(op); +} + +TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) { + REGISTER_TEST_OP("ret_tuple_non_tensor", "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)", ret_tuple_non_tensor); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tuple_non_tensor", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + torch::Tensor out0; + torch::Tensor out1; + int64_t out2; + auto out = callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + std::tie(out0, out1, out2) = std::move(out); + return out0; + }; + + assertBasicChecks(op); +} + +TEST(TestAutogradNotImplementedFallback, ViewOp) { + REGISTER_TEST_OP("view_op", "_test::view_op(Tensor(a) self, Tensor other) -> Tensor(a)", view_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + assertBasicChecks(op); +} + +TEST(TestAutogradNotImplementedFallback, RetTensorVector) { + REGISTER_TEST_OP("ret_tensor_vector", "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]", ret_tensor_vector); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tensor_vector", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2)[0]; + }; + assertBasicChecks(op); +} + +TEST(TestAutogradNotImplementedFallback, TensorlistOp) { + REGISTER_TEST_OP("tensorlist_op", "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor", tensorlist_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::tensorlist_op", ""); + auto op = [&](torch::Tensor _1, at::TensorList _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + + auto a = torch::tensor({1.}, {torch::kFloat32}); + auto b = torch::tensor({1.}, {torch::kFloat32}); + auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + std::vector vec = {b, c}; + auto out = op(a, vec); + + ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[0]}), "One of the differentiated Tensors does not require grad"); + ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[1]}), "is not implemented"); + + ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec))); +} + + // TODO add these tests if needed // test_once_differentiable // test_sparse_backward diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 650830b..b2a1016 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -128,6 +128,7 @@ libtorch_edge_profiler_sources = libtorch_profiler_sources + [ core_trainer_sources = [ "torch/csrc/autograd/anomaly_mode.cpp", "torch/csrc/autograd/autograd.cpp", + "torch/csrc/autograd/autograd_not_implemented_fallback.cpp", "torch/csrc/autograd/cpp_hook.cpp", "torch/csrc/autograd/custom_function.cpp", "torch/csrc/autograd/engine.cpp", diff --git a/torch/csrc/api/include/torch/autograd.h b/torch/csrc/api/include/torch/autograd.h index 83aa102..809fbe8 100644 --- a/torch/csrc/api/include/torch/autograd.h +++ b/torch/csrc/api/include/torch/autograd.h @@ -2,3 +2,4 @@ #include #include +#include diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp new file mode 100644 index 0000000..ab9cb49 --- /dev/null +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -0,0 +1,189 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace torch { namespace autograd { + +namespace { + +template +void _foreach_tensor( + F fn, + torch::jit::Stack* stack, + size_t stack_start, + size_t size) { + // Enumerate over tensors in a stack, including ones in TensorLists + int idx_tensor = 0; + for (const auto idx_arg : c10::irange(size)) { + auto& ivalue = (*stack)[stack_start + idx_arg]; + if (ivalue.isTensor()) { // true for optional tensor that has value + const auto& tensor = ivalue.toTensor(); + fn(idx_tensor, idx_arg, tensor); + idx_tensor++; + } else if (ivalue.isTensorList()) { + for (const auto& iv : ivalue.toListRef()) { + const auto& tensor = iv.toTensor(); + fn(idx_tensor, idx_arg, tensor); + idx_tensor++; + } + } + } +} + +} + +void autogradNotImplementedFallbackImpl(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { + // Mimics the logic of a VariableType NotImplemented kernel + const auto& schema = op.schema(); + const auto& op_name = schema.operator_name().name; + const auto& arguments = schema.arguments(); + const auto& returns = schema.returns(); + const auto num_arguments = arguments.size(); + const auto num_returns = returns.size(); + const auto stack_start = stack->size() - num_arguments; + const bool grad_mode = GradMode::is_enabled(); + std::vector tensors_requiring_grad_on_stack; + + // Keep track of which outputs are output of in-place modification + // so we can rebase_history if necessary + std::vector is_inplace_output; + bool any_is_inplace_output = false; + std::vector is_aliased_output; + is_inplace_output.reserve(num_returns); + is_aliased_output.reserve(num_returns); + + for (const auto i : c10::irange(num_returns)) { + const auto& alias_info = returns[i].alias_info(); + is_inplace_output.push_back(alias_info.has_value() && alias_info->isWrite()); + any_is_inplace_output |= alias_info.has_value() && alias_info->isWrite(); + is_aliased_output.push_back(alias_info.has_value()); + + } + int aliased_input_idx = -1; + int aliased_output_idx = -1; + for (const auto i : c10::irange(num_returns)) { + const auto& alias_info = returns[i].alias_info(); + if (alias_info.has_value() && !alias_info->isWrite()) { + AT_ASSERT( + aliased_output_idx == -1, + "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple outputs are aliased with inputs aren't supported." + "Please rewrite your function as a composite function."); + aliased_output_idx = i; + } + } + for (const auto i : c10::irange(num_arguments)) { + const auto& alias_info = arguments[i].alias_info(); + if (alias_info.has_value() && !alias_info->isWrite()) { + AT_ASSERT( + aliased_input_idx == -1, + "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " + "Please rewrite your function as a composite function."); + aliased_input_idx = i; + } + } + + size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks + + _foreach_tensor([&](size_t _, size_t idx_arg, const at::Tensor& t) { + if (grad_mode && t.requires_grad()) { + tensors_requiring_grad_on_stack.push_back(&t); + } + num_tensor_inputs++; + TORCH_CHECK_NOT_IMPLEMENTED(!isFwGradDefined(t), "Trying to use forward AD with ", op_name, " that does not support it."); + }, stack, stack_start, num_arguments); + + const bool any_requires_grad = tensors_requiring_grad_on_stack.size() > 0; + + _foreach_tensor([&](size_t _, size_t i, const at::Tensor& t) { + const auto& alias_info = arguments[i].alias_info(); + if (alias_info.has_value() && alias_info->isWrite()) { + check_inplace(t, any_requires_grad); + } + }, stack, stack_start, num_arguments); + + std::shared_ptr grad_fn; + if (any_requires_grad) { + grad_fn = std::shared_ptr(new NotImplemented(op_name), deleteNode); + grad_fn->set_next_edges(collect_next_edges(tensors_requiring_grad_on_stack)); + } + + #ifndef NDEBUG + // See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] + auto stack_args_copy = std::vector(stack->begin() + stack_start, stack->end()); + std::vector> impl_saved; + impl_saved.reserve(num_tensor_inputs); + std::vector> storage_saved; + storage_saved.reserve(num_tensor_inputs); + _foreach_tensor([&](size_t idx, size_t _, const at::Tensor& t) { + storage_saved.push_back(t.has_storage() ? c10::optional(t.storage()) : c10::nullopt); + impl_saved.push_back(t.getIntrusivePtr()); + }, &stack_args_copy, 0, num_arguments); + #endif + if (aliased_input_idx != -1 || any_is_inplace_output) { + at::AutoDispatchBelowAutograd guard; + op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); + } else { + // If neither in-place nor view + at::AutoDispatchBelowADInplaceOrView guard; + op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); + } + #ifndef NDEBUG + _foreach_tensor([&](size_t idx_tensor, size_t _, const at::Tensor& t) { + if (storage_saved.at(idx_tensor).has_value()) + TORCH_INTERNAL_ASSERT(storage_saved.at(idx_tensor).value().is_alias_of(t.storage()), op_name); + if (impl_saved.at(idx_tensor)) + TORCH_INTERNAL_ASSERT(impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name); + }, &stack_args_copy, 0, num_arguments); + _foreach_tensor([&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { + if (!is_inplace_output[idx_ret]) + TORCH_INTERNAL_ASSERT(t.use_count() <= 1, op_name); // Okay to return undefined tensor + if (!is_aliased_output[idx_ret] && t.has_storage()) + TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1); + }, stack, stack->size() - num_returns, num_returns); + // There should be only a single base-view pair, make sure their storage is aliased + if (aliased_input_idx != -1 && aliased_output_idx != -1) { + const c10::IValue& aliased_input_iv = stack_args_copy[aliased_input_idx]; + const c10::IValue& aliased_output_iv = (*stack)[stack->size() - num_returns + aliased_output_idx]; + // We do not support views embedded inside tensorlist + TORCH_INTERNAL_ASSERT(aliased_input_iv.isTensor(), op_name); + TORCH_INTERNAL_ASSERT(aliased_output_iv.isTensor(), op_name); + const at::Tensor& aliased_input = aliased_input_iv.toTensor(); + const at::Tensor& aliased_output = aliased_input_iv.toTensor(); + if(is_aliased_output[aliased_input_idx] && aliased_input.has_storage()) + TORCH_INTERNAL_ASSERT(aliased_input.storage().is_alias_of(aliased_output.storage()), op_name); + } + #endif + + if (any_requires_grad) { + _foreach_tensor([&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { + if (isDifferentiableType(t.scalar_type())) { + if (is_inplace_output[idx_ret]) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + rebase_history(const_cast(t), grad_fn); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + set_history(const_cast(t), grad_fn); + } + } + }, stack, stack->size() - num_returns, num_returns); + } +} + +torch::CppFunction autogradNotImplementedFallback() { + return torch::CppFunction::makeFromBoxedFunction<&autogradNotImplementedFallbackImpl>(); +} + +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.h b/torch/csrc/autograd/autograd_not_implemented_fallback.h new file mode 100644 index 0000000..4b2cbd1 --- /dev/null +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +namespace torch { +namespace autograd { + +TORCH_API torch::CppFunction autogradNotImplementedFallback(); + +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 25336df..2a1de8e 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -563,6 +563,14 @@ struct MakeNextFunctionList : IterArgs { next_edges.emplace_back(); } } + void operator()(const Variable* variable) { + // NOLINTNEXTLINE(bugprone-branch-clone) + if (variable->defined()) { + next_edges.push_back(impl::gradient_edge(*variable)); + } else { + next_edges.emplace_back(); + } + } void operator()(const c10::optional& variable) { // NOLINTNEXTLINE(bugprone-branch-clone) if (variable.has_value() && variable->defined()) {