Fix autodiff of nll_loss
authorAlex Şuhan <asuhan@google.com>
Fri, 8 Feb 2019 01:31:52 +0000 (17:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 8 Feb 2019 01:42:01 +0000 (17:42 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16851

Differential Revision: D13995046

Pulled By: wanchaol

fbshipit-source-id: 557c99f1d1825fa9b6031dd9fa8ba9b54205e8c4

torch/csrc/jit/autodiff.cpp

index 8122aea..c43669d 100644 (file)
@@ -129,7 +129,7 @@ bool isDifferentiable(Node* n) {
   if (n->matches(
           "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
     // TODO(asuhan): support weight
-    return n->namedInput(attr::weight)->node()->kind() == prim::Undefined;
+    return n->namedInput(attr::weight)->node()->kind() == prim::None;
   }
 
   // linear blocks may appear as inputs to graph executors, but they are removed
@@ -717,7 +717,7 @@ class GradientHelper {
             "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
       auto graph = node->owningGraph();
       auto total_weight = graph->insertNode(graph->createUndefined());
-      auto weight = graph->insertNode(graph->createUndefined());
+      auto weight = graph->insertNode(graph->createNone(TensorType::get()));
       auto backward_value = graph->insert(
           aten::nll_loss_backward,
           {grads.at(0).value(),