Summary:
Not ready yet, need some comments / help with this. It's good enough for https://github.com/pytorch/xla immediate goals (forward + backward trace fusion), but there are at least two issues with it:
1. If we don't allow it, `test/test_jit.py` fails to cover the change.
2. If we allow the weight to be set, running `test/test_jit.py TestJitGenerated.test_nn_nll_loss` fails with:
```
======================================================================
ERROR: test_nn_nll_loss (__main__.TestJitGenerated)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test/test_jit.py", line 10001, in do_test
fn, f_args_variable, kwargs_variable, no_grad=no_grad)
File "test/test_jit.py", line 9360, in check_against_reference
outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
File "test/test_jit.py", line 425, in runAndSaveRNG
results = func(*inputs, **kwargs)
File "test/test_jit.py", line 9298, in script_fn
self.assertExportImport(CU.the_method.graph, tensors)
File "test/test_jit.py", line 415, in assertExportImport
self.assertExportImportModule(m, inputs)
File "test/test_jit.py", line 419, in assertExportImportModule
self.assertEqual(self.runAndSaveRNG(m.forward, inputs),
File "test/test_jit.py", line 425, in runAndSaveRNG
results = func(*inputs, **kwargs)
RuntimeError:
arguments for call are not valid:
for operator aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor out) -> Tensor:
expected a value of type Tensor for argument 'total_weight' but found bool
<internally-created-node>
~ <--- HERE
for operator aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor:
expected a value of type Tensor for argument 'total_weight' but found bool
<internally-created-node>
~ <--- HERE
for call at:
<internally-created-node>
~ <--- HERE
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14305
Differential Revision:
D13356265
Pulled By: ezyang
fbshipit-source-id:
504d783b2d87f923e698a6a4efc0fd9935a94a41
'test_nn_avg_pool2d',
'test_nn_log_softmax',
'test_nn_threshold',
+ 'test_nn_nll_loss',
}
// "aten::min(Tensor self) -> Tensor"
if (n->kind() == prim::Constant ||
+ n->kind() == prim::Undefined ||
n->kind() == prim::AutogradAdd ||
n->kind() == prim::ConstantChunk ||
n->kind() == prim::None)
return n->get<std::vector<int64_t>>(attr::size) &&
n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
}
+ if (n->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
+ // TODO(asuhan): support weight
+ return n->namedInput(attr::weight)->node()->kind() == prim::Undefined;
+ }
// linear blocks may appear as inputs to graph executors, but they are removed
// before differentiation occurs
JIT_ASSERT(tuple_outputs.size() == size_t(3));
return {tuple_outputs[0], tuple_outputs[1], tuple_outputs[2], nullptr, nullptr, nullptr, nullptr, nullptr};
+ } else if (node->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
+ auto graph = node->owningGraph();
+ auto total_weight = graph->insertNode(graph->createUndefined());
+ auto weight = graph->insertNode(graph->createUndefined());
+ auto backward_value = graph->insert(aten::nll_loss_backward, {
+ grads.at(0).value(),
+ inputs.at(0).value(),
+ inputs.at(1).value(),
+ weight->output(),
+ inputs.at(3).value(),
+ inputs.at(4).value(),
+ total_weight->output()
+ });
+ return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr};
+
} else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) {
JIT_ASSERT(grads.size() == 1);
auto graph = node->owningGraph();
});
return {backward_value->node()->output(0), nullptr};
- } else if (node->kind() == prim::Constant || node->kind() == prim::None) {
+ } else if (node->kind() == prim::Constant || node->kind() == prim::Undefined || node->kind() == prim::None) {
return {};
}
throw std::runtime_error(std::string("failed to differentiate `") + node->kind().toDisplayString() + "`");