Fix fx2trt SplitterBase non_tensor_input logic (#64286)
authorKefei Lu <kefeilu@fb.com>
Tue, 7 Sep 2021 11:00:49 +0000 (04:00 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 7 Sep 2021 11:02:29 +0000 (04:02 -0700)
commitadbcc819cd40deaa2755383815896d8c9dffb881
treef832c50ab0f9a66374bbe8fdacaf6cc8b363f5d8
parent32fbeb170d57ab6a5af9ca6de23a54a6a910a433
Fix fx2trt SplitterBase non_tensor_input logic (#64286)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64286

During graph splitting, `_SplitterBase` supports taking into consideration whether the subnet boundary nodes
produces "supported" outputs that will cross the acc/non-acc boundary. Specifically, if the backend only
supports Tensor-based data passing cross boundary, then we cannot split the graph at a place where the node
output is a non-Tensor type (e.g., `Tuple[Tensor]`).

There's currently a bug in this logic that it does not correctly detect the output type of a Node. Instead of
using `Node.meta['tensor_meta']`, we should instead check `Node.meta['type']`.

`Node.meta['tensor_meta']` is not appropriate because this key will exist if the node output is an iterable
and one of the element is of type `Tensor`. So `Tuple[Tensor]` will be wrongly considered "supported".

Test Plan:
arc lint
run CI tests

Reviewed By: yinghai, 842974287

Differential Revision: D30617147

fbshipit-source-id: e8ba70dfaddc05cafb8037d58fca73b7ccbb1a49
torch/fx/passes/splitter_base.py
torch/fx/passes/tools_common.py