From 2e7cc86a62aa2affd41b0b28bae7ccb2a4cf148b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Alex=20=C5=9Euhan?= Date: Thu, 6 Dec 2018 08:56:25 -0800 Subject: [PATCH] Add (partial) autodiff support for nll_loss (#14305) Summary: Not ready yet, need some comments / help with this. It's good enough for https://github.com/pytorch/xla immediate goals (forward + backward trace fusion), but there are at least two issues with it: 1. If we don't allow it, `test/test_jit.py` fails to cover the change. 2. If we allow the weight to be set, running `test/test_jit.py TestJitGenerated.test_nn_nll_loss` fails with: ``` ====================================================================== ERROR: test_nn_nll_loss (__main__.TestJitGenerated) ---------------------------------------------------------------------- Traceback (most recent call last): File "test/test_jit.py", line 10001, in do_test fn, f_args_variable, kwargs_variable, no_grad=no_grad) File "test/test_jit.py", line 9360, in check_against_reference outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs) File "test/test_jit.py", line 425, in runAndSaveRNG results = func(*inputs, **kwargs) File "test/test_jit.py", line 9298, in script_fn self.assertExportImport(CU.the_method.graph, tensors) File "test/test_jit.py", line 415, in assertExportImport self.assertExportImportModule(m, inputs) File "test/test_jit.py", line 419, in assertExportImportModule self.assertEqual(self.runAndSaveRNG(m.forward, inputs), File "test/test_jit.py", line 425, in runAndSaveRNG results = func(*inputs, **kwargs) RuntimeError: arguments for call are not valid: for operator aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor out) -> Tensor: expected a value of type Tensor for argument 'total_weight' but found bool ~ <--- HERE for operator aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor: expected a value of type Tensor for argument 'total_weight' but found bool ~ <--- HERE for call at: ~ <--- HERE ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/14305 Differential Revision: D13356265 Pulled By: ezyang fbshipit-source-id: 504d783b2d87f923e698a6a4efc0fd9935a94a41 --- test/test_jit.py | 1 + torch/csrc/jit/autodiff.cpp | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/test/test_jit.py b/test/test_jit.py index 52d1a7d..ed6706f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9835,6 +9835,7 @@ DISABLE_AUTODIFF_SUBGRAPH_INLINING = { 'test_nn_avg_pool2d', 'test_nn_log_softmax', 'test_nn_threshold', + 'test_nn_nll_loss', } diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index aa42855..27015d1 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -94,6 +94,7 @@ bool isDifferentiable(Node * n) { // "aten::min(Tensor self) -> Tensor" if (n->kind() == prim::Constant || + n->kind() == prim::Undefined || n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk || n->kind() == prim::None) @@ -109,6 +110,10 @@ bool isDifferentiable(Node * n) { return n->get>(attr::size) && n->namedInput(attr::self)->type()->cast(); } + if (n->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) { + // TODO(asuhan): support weight + return n->namedInput(attr::weight)->node()->kind() == prim::Undefined; + } // linear blocks may appear as inputs to graph executors, but they are removed // before differentiation occurs @@ -483,6 +488,21 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val JIT_ASSERT(tuple_outputs.size() == size_t(3)); return {tuple_outputs[0], tuple_outputs[1], tuple_outputs[2], nullptr, nullptr, nullptr, nullptr, nullptr}; + } else if (node->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) { + auto graph = node->owningGraph(); + auto total_weight = graph->insertNode(graph->createUndefined()); + auto weight = graph->insertNode(graph->createUndefined()); + auto backward_value = graph->insert(aten::nll_loss_backward, { + grads.at(0).value(), + inputs.at(0).value(), + inputs.at(1).value(), + weight->output(), + inputs.at(3).value(), + inputs.at(4).value(), + total_weight->output() + }); + return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr}; + } else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) { JIT_ASSERT(grads.size() == 1); auto graph = node->owningGraph(); @@ -494,7 +514,7 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val }); return {backward_value->node()->output(0), nullptr}; - } else if (node->kind() == prim::Constant || node->kind() == prim::None) { + } else if (node->kind() == prim::Constant || node->kind() == prim::Undefined || node->kind() == prim::None) { return {}; } throw std::runtime_error(std::string("failed to differentiate `") + node->kind().toDisplayString() + "`"); -- 2.7.4