#include <gtest/gtest.h>
+#include <ATen/core/boxing/impl/test_helpers.h>
#include <torch/torch.h>
+#include <ATen/core/op_registration/op_registration.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
}
}
+/**
+ * Tests for AutogradNotImplementedFallback
+ * - Check that we created the NotImplemented kernel when inputs require grad
+ * but when no inputs require grad, we should not create this node
+ * - check_inplace logic
+ * - view ops (TODO: not an official view yet, update this once InplaceOrView kernel is landed)
+ * - TODO: Tests for NDEBUG checks?
+ * - tensorlist input and output
+ * - multiple outputs / non-tensor output
+ * - rebase_history vs set_history
+ */
+namespace {
+
+torch::Tensor inplace_op(const torch::Tensor& self, const torch::Tensor& other) {
+ return self.add_(other);
+}
+
+std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(const torch::Tensor& self, const torch::Tensor& other) {
+ other.add_(self);
+ self.add_(other);
+ return std::tuple<torch::Tensor, torch::Tensor>(self, other);
+}
+
+std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(const torch::Tensor& self, const torch::Tensor& other) {
+ // This is not allowed. We test below that this calling into the boxed kernel will raise an error
+ auto self_view = self.view(-1);
+ auto other_view = other.view(-1);
+ return std::tuple<torch::Tensor, torch::Tensor>(self_view, other_view);
+}
+
+int64_t ret_single_non_tensor(const torch::Tensor& self, const torch::Tensor& other) {
+ return 12;
+}
+
+torch::Tensor opt_op(const torch::Tensor& self, const c10::optional<at::Tensor>& other) {
+ if (other.has_value()) {
+ return self + other.value();
+ } else {
+ return self.clone();
+ }
+}
+
+torch::Tensor my_custom_op(const torch::Tensor& self, const torch::Tensor& other) {
+ return self + other;
+}
+
+std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(const torch::Tensor& self, const torch::Tensor& other) {
+ auto a = self - other;
+ auto b = self + other;
+ return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
+}
+
+torch::Tensor view_op(const torch::Tensor& self, const torch::Tensor& other) {
+ return self.view(-1);
+}
+
+std::vector<at::Tensor> ret_tensor_vector(const torch::Tensor& self, const torch::Tensor& other) {
+ std::vector<at::Tensor> out;
+ out.push_back(self + other);
+ out.push_back(self - other);
+ return out;
+}
+
+torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
+ const auto& res = self.clone();
+ for (const auto& t : other) {
+ res.add_(t);
+ }
+ return res;
+}
+
+#define REGISTER_TEST_OP(name, schema, fn) \
+ auto m = MAKE_TORCH_LIBRARY(_test); \
+ m.def(schema); \
+ auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \
+ auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \
+ m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \
+ m_autograd.impl(name, c10::DispatchKey::Autograd, autogradNotImplementedFallback());
+
+template <typename F>
+void assertBasicChecks(F op) {
+ auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
+ auto b = torch::tensor({1.}, {torch::kFloat32});
+ auto c = torch::tensor({1.}, {torch::kFloat32});
+
+ // If any inputs require grad,
+ auto out1 = op(a, b);
+ ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
+
+ // # Should not have grad_fn if none require grad
+ auto out2 = op(b, c);
+ ASSERT_THROWS_WITH(out2.backward(), "element 0 of tensors does not require grad and does not have a grad_fn");
+
+ // TODO: Forward AD Tests?
+}
+
+} // namespace
+
+TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
+ REGISTER_TEST_OP("ret_single_non_tensor", "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int", ret_single_non_tensor);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_single_non_tensor", "");
+ auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
+ return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
+ };
+
+ auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
+ auto b = torch::tensor({1.}, {torch::kFloat32});
+
+ ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
+}
+
+TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
+ REGISTER_TEST_OP("two_pairs_of_view_op", "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))", two_pairs_of_view_op);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_pairs_of_view_op", "");
+ auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
+ return callOpUnboxed<std::tuple<torch::Tensor, torch::Tensor>, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
+ };
+ auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
+ auto b = torch::tensor({1.}, {torch::kFloat32});
+ ASSERT_THROWS_WITH(op(a, b),
+ "Expected only a single output in the operator schema to have a non-write alias annotation");
+}
+
+TEST(TestAutogradNotImplementedFallback, InplaceOp) {
+ REGISTER_TEST_OP("inplace_op", "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)", inplace_op);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
+ auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
+ return callOpUnboxed<torch::Tensor, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
+ };
+
+ auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
+ auto b = torch::tensor({1.}, {torch::kFloat32});
+
+ // Check in-place
+ ASSERT_THROWS_WITH(op(a, b),
+ "a leaf Variable that requires grad is being used in an in-place operation");
+ op(b, a);
+ a = a.clone();
+ b = b.clone();
+ auto c = op(a, b);
+ ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
+
+ // Test in-place on view
+ auto base = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
+ auto view = base.view(-1);
+ auto t = torch::tensor({1.}, {torch::kFloat32});
+
+ torch::Tensor v_nograd;
+ {
+ c10::NoGradGuard guard;
+ v_nograd = base.view(-1);
+ op(v_nograd, t);
+ }
+
+ ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
+ ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
+
+ // TODO: once we have InplaceOrView kernel, renable this since version counter would actually
+ // be incremented
+ // ASSERT_THAT(op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
+}
+
+TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
+ REGISTER_TEST_OP("two_arg_inplace_op", "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))", two_arg_inplace_op);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_arg_inplace_op", "");
+ auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
+ return callOpUnboxed<std::tuple<torch::Tensor, torch::Tensor>, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
+ };
+ auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
+ auto b = torch::tensor({1.}, {torch::kFloat32});
+
+ // Both are modified in-place!
+ ASSERT_THROWS_WITH(op(a, b),
+ "a leaf Variable that requires grad is being used in an in-place operation");
+ ASSERT_THROWS_WITH(op(b, a),
+ "a leaf Variable that requires grad is being used in an in-place operation");
+}
+
+TEST(TestAutogradNotImplementedFallback, OptOp) {
+ REGISTER_TEST_OP("opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
+ auto op = [&](const torch::Tensor& _1, const c10::optional<torch::Tensor>& _2) {
+ return callOpUnboxed<torch::Tensor, const torch::Tensor&, const c10::optional<torch::Tensor>&>(opHandle, _1, _2);
+ };
+
+ auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
+ auto b = torch::tensor({1.}, {torch::kFloat32});
+
+ ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
+ ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
+}
+
+TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
+ REGISTER_TEST_OP("my_custom_op", "_test::my_custom_op(Tensor self, Tensor other) -> Tensor", my_custom_op);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
+ auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
+ return callOpUnboxed<torch::Tensor, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
+ };
+
+ assertBasicChecks(op);
+}
+
+TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
+ REGISTER_TEST_OP("ret_tuple_non_tensor", "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)", ret_tuple_non_tensor);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tuple_non_tensor", "");
+ auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
+ torch::Tensor out0;
+ torch::Tensor out1;
+ int64_t out2;
+ auto out = callOpUnboxed<std::tuple<torch::Tensor, torch::Tensor, int64_t>, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
+ std::tie(out0, out1, out2) = std::move(out);
+ return out0;
+ };
+
+ assertBasicChecks(op);
+}
+
+TEST(TestAutogradNotImplementedFallback, ViewOp) {
+ REGISTER_TEST_OP("view_op", "_test::view_op(Tensor(a) self, Tensor other) -> Tensor(a)", view_op);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
+ auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
+ return callOpUnboxed<torch::Tensor, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2);
+ };
+ assertBasicChecks(op);
+}
+
+TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
+ REGISTER_TEST_OP("ret_tensor_vector", "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]", ret_tensor_vector);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tensor_vector", "");
+ auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
+ return callOpUnboxed<std::vector<at::Tensor>, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2)[0];
+ };
+ assertBasicChecks(op);
+}
+
+TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
+ REGISTER_TEST_OP("tensorlist_op", "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor", tensorlist_op);
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::tensorlist_op", "");
+ auto op = [&](torch::Tensor _1, at::TensorList _2) {
+ return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(opHandle, _1, _2);
+ };
+
+ auto a = torch::tensor({1.}, {torch::kFloat32});
+ auto b = torch::tensor({1.}, {torch::kFloat32});
+ auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
+ std::vector<torch::Tensor> vec = {b, c};
+ auto out = op(a, vec);
+
+ ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[0]}), "One of the differentiated Tensors does not require grad");
+ ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[1]}), "is not implemented");
+
+ ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
+}
+
+
// TODO add these tests if needed
// test_once_differentiable
// test_sparse_backward
--- /dev/null
+#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
+
+#include <c10/util/irange.h>
+
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <ATen/core/ivalue.h>
+
+#include <torch/csrc/autograd/autograd.h>
+#include <torch/csrc/autograd/function.h>
+#include <torch/csrc/autograd/functions/basic_ops.h>
+#include <torch/csrc/autograd/functions/utils.h>
+#include <torch/csrc/autograd/VariableTypeUtils.h>
+
+#include <vector>
+
+namespace torch { namespace autograd {
+
+namespace {
+
+template <typename F>
+void _foreach_tensor(
+ F fn,
+ torch::jit::Stack* stack,
+ size_t stack_start,
+ size_t size) {
+ // Enumerate over tensors in a stack, including ones in TensorLists
+ int idx_tensor = 0;
+ for (const auto idx_arg : c10::irange(size)) {
+ auto& ivalue = (*stack)[stack_start + idx_arg];
+ if (ivalue.isTensor()) { // true for optional tensor that has value
+ const auto& tensor = ivalue.toTensor();
+ fn(idx_tensor, idx_arg, tensor);
+ idx_tensor++;
+ } else if (ivalue.isTensorList()) {
+ for (const auto& iv : ivalue.toListRef()) {
+ const auto& tensor = iv.toTensor();
+ fn(idx_tensor, idx_arg, tensor);
+ idx_tensor++;
+ }
+ }
+ }
+}
+
+}
+
+void autogradNotImplementedFallbackImpl(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
+ // Mimics the logic of a VariableType NotImplemented kernel
+ const auto& schema = op.schema();
+ const auto& op_name = schema.operator_name().name;
+ const auto& arguments = schema.arguments();
+ const auto& returns = schema.returns();
+ const auto num_arguments = arguments.size();
+ const auto num_returns = returns.size();
+ const auto stack_start = stack->size() - num_arguments;
+ const bool grad_mode = GradMode::is_enabled();
+ std::vector<const at::Tensor*> tensors_requiring_grad_on_stack;
+
+ // Keep track of which outputs are output of in-place modification
+ // so we can rebase_history if necessary
+ std::vector<bool> is_inplace_output;
+ bool any_is_inplace_output = false;
+ std::vector<bool> is_aliased_output;
+ is_inplace_output.reserve(num_returns);
+ is_aliased_output.reserve(num_returns);
+
+ for (const auto i : c10::irange(num_returns)) {
+ const auto& alias_info = returns[i].alias_info();
+ is_inplace_output.push_back(alias_info.has_value() && alias_info->isWrite());
+ any_is_inplace_output |= alias_info.has_value() && alias_info->isWrite();
+ is_aliased_output.push_back(alias_info.has_value());
+
+ }
+ int aliased_input_idx = -1;
+ int aliased_output_idx = -1;
+ for (const auto i : c10::irange(num_returns)) {
+ const auto& alias_info = returns[i].alias_info();
+ if (alias_info.has_value() && !alias_info->isWrite()) {
+ AT_ASSERT(
+ aliased_output_idx == -1,
+ "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
+ "Non-composite functions where multiple outputs are aliased with inputs aren't supported."
+ "Please rewrite your function as a composite function.");
+ aliased_output_idx = i;
+ }
+ }
+ for (const auto i : c10::irange(num_arguments)) {
+ const auto& alias_info = arguments[i].alias_info();
+ if (alias_info.has_value() && !alias_info->isWrite()) {
+ AT_ASSERT(
+ aliased_input_idx == -1,
+ "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). "
+ "Non-composite functions where multiple inputs are aliased with outputs aren't supported. "
+ "Please rewrite your function as a composite function.");
+ aliased_input_idx = i;
+ }
+ }
+
+ size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks
+
+ _foreach_tensor([&](size_t _, size_t idx_arg, const at::Tensor& t) {
+ if (grad_mode && t.requires_grad()) {
+ tensors_requiring_grad_on_stack.push_back(&t);
+ }
+ num_tensor_inputs++;
+ TORCH_CHECK_NOT_IMPLEMENTED(!isFwGradDefined(t), "Trying to use forward AD with ", op_name, " that does not support it.");
+ }, stack, stack_start, num_arguments);
+
+ const bool any_requires_grad = tensors_requiring_grad_on_stack.size() > 0;
+
+ _foreach_tensor([&](size_t _, size_t i, const at::Tensor& t) {
+ const auto& alias_info = arguments[i].alias_info();
+ if (alias_info.has_value() && alias_info->isWrite()) {
+ check_inplace(t, any_requires_grad);
+ }
+ }, stack, stack_start, num_arguments);
+
+ std::shared_ptr<NotImplemented> grad_fn;
+ if (any_requires_grad) {
+ grad_fn = std::shared_ptr<NotImplemented>(new NotImplemented(op_name), deleteNode);
+ grad_fn->set_next_edges(collect_next_edges(tensors_requiring_grad_on_stack));
+ }
+
+ #ifndef NDEBUG
+ // See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
+ auto stack_args_copy = std::vector<c10::IValue>(stack->begin() + stack_start, stack->end());
+ std::vector<c10::intrusive_ptr<c10::TensorImpl>> impl_saved;
+ impl_saved.reserve(num_tensor_inputs);
+ std::vector<c10::optional<c10::Storage>> storage_saved;
+ storage_saved.reserve(num_tensor_inputs);
+ _foreach_tensor([&](size_t idx, size_t _, const at::Tensor& t) {
+ storage_saved.push_back(t.has_storage() ? c10::optional<c10::Storage>(t.storage()) : c10::nullopt);
+ impl_saved.push_back(t.getIntrusivePtr());
+ }, &stack_args_copy, 0, num_arguments);
+ #endif
+ if (aliased_input_idx != -1 || any_is_inplace_output) {
+ at::AutoDispatchBelowAutograd guard;
+ op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
+ } else {
+ // If neither in-place nor view
+ at::AutoDispatchBelowADInplaceOrView guard;
+ op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
+ }
+ #ifndef NDEBUG
+ _foreach_tensor([&](size_t idx_tensor, size_t _, const at::Tensor& t) {
+ if (storage_saved.at(idx_tensor).has_value())
+ TORCH_INTERNAL_ASSERT(storage_saved.at(idx_tensor).value().is_alias_of(t.storage()), op_name);
+ if (impl_saved.at(idx_tensor))
+ TORCH_INTERNAL_ASSERT(impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name);
+ }, &stack_args_copy, 0, num_arguments);
+ _foreach_tensor([&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
+ if (!is_inplace_output[idx_ret])
+ TORCH_INTERNAL_ASSERT(t.use_count() <= 1, op_name); // Okay to return undefined tensor
+ if (!is_aliased_output[idx_ret] && t.has_storage())
+ TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1);
+ }, stack, stack->size() - num_returns, num_returns);
+ // There should be only a single base-view pair, make sure their storage is aliased
+ if (aliased_input_idx != -1 && aliased_output_idx != -1) {
+ const c10::IValue& aliased_input_iv = stack_args_copy[aliased_input_idx];
+ const c10::IValue& aliased_output_iv = (*stack)[stack->size() - num_returns + aliased_output_idx];
+ // We do not support views embedded inside tensorlist
+ TORCH_INTERNAL_ASSERT(aliased_input_iv.isTensor(), op_name);
+ TORCH_INTERNAL_ASSERT(aliased_output_iv.isTensor(), op_name);
+ const at::Tensor& aliased_input = aliased_input_iv.toTensor();
+ const at::Tensor& aliased_output = aliased_input_iv.toTensor();
+ if(is_aliased_output[aliased_input_idx] && aliased_input.has_storage())
+ TORCH_INTERNAL_ASSERT(aliased_input.storage().is_alias_of(aliased_output.storage()), op_name);
+ }
+ #endif
+
+ if (any_requires_grad) {
+ _foreach_tensor([&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
+ if (isDifferentiableType(t.scalar_type())) {
+ if (is_inplace_output[idx_ret]) {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ rebase_history(const_cast<at::Tensor&>(t), grad_fn);
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ set_history(const_cast<at::Tensor&>(t), grad_fn);
+ }
+ }
+ }, stack, stack->size() - num_returns, num_returns);
+ }
+}
+
+torch::CppFunction autogradNotImplementedFallback() {
+ return torch::CppFunction::makeFromBoxedFunction<&autogradNotImplementedFallbackImpl>();
+}
+
+}} // namespace torch::autograd