static_cast<bool(*)(Node*)>(isDifferentiable));
}
-// TODO: Remove this after #15355.
-namespace {
- std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
- auto outputs = inlineCallTo(g, callee, inputs);
- if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
- auto tc = createTupleUnpack(outputs.at(0));
- outputs = std::vector<Value*>(tc.begin(), tc.end());
- }
- return outputs;
- }
-} //anonymous namespace
-
-
// NB: Write gradient using torchscript
// For example, node aten::mul() should be defined as follows
// def forward(x, y):
{
WithInsertPoint guard(node->next());
auto fw_graph = compiled_graphs->forward;
- new_outputs = inlineUnpackedCallTo(*graph, *fw_graph, node->inputs());
+ new_outputs = inlineCallTo(*graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
for (size_t i = 0; i < node->outputs().size(); ++i) {
new_outputs.at(i)->setType(node->outputs()[i]->type());
new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
auto it = grad_vec.begin();
grad_vec.insert(it, new_outputs.back());
ArrayRef<Value*> grad(grad_vec);
- auto grad_inputs = inlineUnpackedCallTo(*graph, *bw_graph, grad);
+ auto grad_inputs = inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true);
return grad_inputs;
};
return g.insertNode(g.createTupleUnpack(v))->outputs();
}
-std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs) {
std::unordered_map<Value*, Value*> value_map;
auto value_map_func = [&](Value* v) { return value_map.at(v); };
JIT_ASSERT(callee.inputs().size() == inputs.size());
for (auto* output : callee.outputs()) {
outputs.push_back(value_map_func(output));
}
+
+ if (unpack_outputs && outputs.size() == 1 &&
+ callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
+ auto tup = outputs[0];
+ outputs.clear();
+ for(Value* v : createTupleUnpack(tup)) {
+ outputs.emplace_back(v);
+ }
+ // if this was a peephole tuple unpack we can just get rid of
+ // the tuple construct here and prevent needing DCE
+ if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
+ tup->node()->destroy();
+ }
+ }
+
return outputs;
}
TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);
TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
-TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
+// unpack_outputs - if true, and the callee returns a single tuple value, then insert a tuple unpack node
+// and return the resulting values
+TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs=false);
}} // namespace torch::jit
}
std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
- auto outputs = inlineCallTo(g, callee, inputs);
- if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
- auto tc = createTupleUnpack(outputs.at(0));
- outputs = std::vector<Value*>(tc.begin(), tc.end());
- }
- return outputs;
+ return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true);
}
// replace aten operator node with BatchTensor operator graph