Add (partial) autodiff support for nll_loss (#14305)
authorAlex Şuhan <asuhan@google.com>
Thu, 6 Dec 2018 16:56:25 +0000 (08:56 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 6 Dec 2018 16:58:54 +0000 (08:58 -0800)
Summary:
Not ready yet, need some comments / help with this. It's good enough for https://github.com/pytorch/xla immediate goals (forward + backward trace fusion), but there are at least two issues with it:

1. If we don't allow it, `test/test_jit.py` fails to cover the change.
2. If we allow the weight to be set, running `test/test_jit.py TestJitGenerated.test_nn_nll_loss` fails with:

```
======================================================================
ERROR: test_nn_nll_loss (__main__.TestJitGenerated)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_jit.py", line 10001, in do_test
    fn, f_args_variable, kwargs_variable, no_grad=no_grad)
  File "test/test_jit.py", line 9360, in check_against_reference
    outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
  File "test/test_jit.py", line 425, in runAndSaveRNG
    results = func(*inputs, **kwargs)
  File "test/test_jit.py", line 9298, in script_fn
    self.assertExportImport(CU.the_method.graph, tensors)
  File "test/test_jit.py", line 415, in assertExportImport
    self.assertExportImportModule(m, inputs)
  File "test/test_jit.py", line 419, in assertExportImportModule
    self.assertEqual(self.runAndSaveRNG(m.forward, inputs),
  File "test/test_jit.py", line 425, in runAndSaveRNG
    results = func(*inputs, **kwargs)
RuntimeError:
arguments for call are not valid:

  for operator aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor out) -> Tensor:
  expected a value of type Tensor for argument 'total_weight' but found bool
  <internally-created-node>
  ~ <--- HERE

  for operator aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor:
  expected a value of type Tensor for argument 'total_weight' but found bool
  <internally-created-node>
  ~ <--- HERE
for call at:
<internally-created-node>
~ <--- HERE
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14305

Differential Revision: D13356265

Pulled By: ezyang

fbshipit-source-id: 504d783b2d87f923e698a6a4efc0fd9935a94a41

test/test_jit.py
torch/csrc/jit/autodiff.cpp

index 52d1a7d..ed6706f 100644 (file)
@@ -9835,6 +9835,7 @@ DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
     'test_nn_avg_pool2d',
     'test_nn_log_softmax',
     'test_nn_threshold',
+    'test_nn_nll_loss',
 }
 
 
index aa42855..27015d1 100644 (file)
@@ -94,6 +94,7 @@ bool isDifferentiable(Node * n) {
   // "aten::min(Tensor self) -> Tensor"
 
   if (n->kind() == prim::Constant ||
+      n->kind() == prim::Undefined ||
       n->kind() == prim::AutogradAdd ||
       n->kind() == prim::ConstantChunk ||
       n->kind() == prim::None)
@@ -109,6 +110,10 @@ bool isDifferentiable(Node * n) {
     return n->get<std::vector<int64_t>>(attr::size) &&
       n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
   }
+  if (n->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
+    // TODO(asuhan): support weight
+    return n->namedInput(attr::weight)->node()->kind() == prim::Undefined;
+  }
 
   // linear blocks may appear as inputs to graph executors, but they are removed
   // before differentiation occurs
@@ -483,6 +488,21 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
       JIT_ASSERT(tuple_outputs.size() == size_t(3));
       return {tuple_outputs[0], tuple_outputs[1], tuple_outputs[2], nullptr, nullptr, nullptr, nullptr, nullptr};
 
+    } else if (node->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
+      auto graph = node->owningGraph();
+      auto total_weight = graph->insertNode(graph->createUndefined());
+      auto weight = graph->insertNode(graph->createUndefined());
+      auto backward_value = graph->insert(aten::nll_loss_backward, {
+        grads.at(0).value(),
+        inputs.at(0).value(),
+        inputs.at(1).value(),
+        weight->output(),
+        inputs.at(3).value(),
+        inputs.at(4).value(),
+        total_weight->output()
+      });
+      return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr};
+
     } else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) {
       JIT_ASSERT(grads.size() == 1);
       auto graph = node->owningGraph();
@@ -494,7 +514,7 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
       });
       return {backward_value->node()->output(0), nullptr};
 
-    } else if (node->kind() == prim::Constant || node->kind() == prim::None) {
+    } else if (node->kind() == prim::Constant || node->kind() == prim::Undefined || node->kind() == prim::None) {
       return {};
     }
     throw std::runtime_error(std::string("failed to differentiate `") + node->kind().toDisplayString() + "`");