if (n->matches(
"aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
// TODO(asuhan): support weight
- return n->namedInput(attr::weight)->node()->kind() == prim::Undefined;
+ return n->namedInput(attr::weight)->node()->kind() == prim::None;
}
// linear blocks may appear as inputs to graph executors, but they are removed
"aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
auto graph = node->owningGraph();
auto total_weight = graph->insertNode(graph->createUndefined());
- auto weight = graph->insertNode(graph->createUndefined());
+ auto weight = graph->insertNode(graph->createNone(TensorType::get()));
auto backward_value = graph->insert(
aten::nll_loss_backward,
{grads.at(0).value(),