From adbcc819cd40deaa2755383815896d8c9dffb881 Mon Sep 17 00:00:00 2001 From: Kefei Lu Date: Tue, 7 Sep 2021 04:00:49 -0700 Subject: [PATCH] Fix fx2trt SplitterBase non_tensor_input logic (#64286) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64286 During graph splitting, `_SplitterBase` supports taking into consideration whether the subnet boundary nodes produces "supported" outputs that will cross the acc/non-acc boundary. Specifically, if the backend only supports Tensor-based data passing cross boundary, then we cannot split the graph at a place where the node output is a non-Tensor type (e.g., `Tuple[Tensor]`). There's currently a bug in this logic that it does not correctly detect the output type of a Node. Instead of using `Node.meta['tensor_meta']`, we should instead check `Node.meta['type']`. `Node.meta['tensor_meta']` is not appropriate because this key will exist if the node output is an iterable and one of the element is of type `Tensor`. So `Tuple[Tensor]` will be wrongly considered "supported". Test Plan: arc lint run CI tests Reviewed By: yinghai, 842974287 Differential Revision: D30617147 fbshipit-source-id: e8ba70dfaddc05cafb8037d58fca73b7ccbb1a49 --- torch/fx/passes/splitter_base.py | 15 ++++++++++----- torch/fx/passes/tools_common.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 42087bd..6541905 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -2,6 +2,7 @@ import argparse from collections import defaultdict from dataclasses import dataclass from typing import List, Dict, Optional, Tuple +import logging import torch from torch.fx.experimental.graph_manipulation import get_size_of_node @@ -20,8 +21,12 @@ from .tools_common import ( Tensors, NodeList, NodeSet, + is_node_output_tensor, ) +_LOGGER = logging.getLogger(__name__) + + class _SplitterSettingBase: def __init__(self): parser = argparse.ArgumentParser() @@ -98,7 +103,7 @@ class FxNetAccNodesFinder: for user in node.users: if user in self.acc_nodes: self.acc_nodes.remove(user) - if "tensor_meta" not in user.meta: + if not is_node_output_tensor(user): cpu_worklist.append(user) def reduce_acc_nodes_non_tensor_input(self): @@ -113,7 +118,7 @@ class FxNetAccNodesFinder: continue if node in self.acc_nodes: continue - if "tensor_meta" in node.meta: + if is_node_output_tensor(node): continue non_tensor_cpu_nodes.append(node) @@ -128,7 +133,7 @@ class FxNetAccNodesFinder: new_cpu_nodes: NodeList = [] for acc_node in self.acc_nodes: - if "tensor_meta" in acc_node.meta: + if is_node_output_tensor(acc_node): continue for user in acc_node.users: if user not in self.acc_nodes: @@ -461,7 +466,7 @@ class _SplitterBase: reports += "Checking inputs...\n" for n in submod.graph.nodes: if n.op == "placeholder": - if "tensor_meta" not in n.meta: + if not is_node_output_tensor(n): reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" else: total_input_bytes += get_size_of_node(submod, n)[0] @@ -473,7 +478,7 @@ class _SplitterBase: def get_bytes(node: torch.fx.Node): nonlocal total_output_bytes nonlocal reports - if "tensor_meta" not in node.meta: + if not is_node_output_tensor(node): reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" else: total_output_bytes += get_size_of_node(submod, node)[0] diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index a996dc8..8274f4b 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -48,6 +48,17 @@ def get_node_target(submodules: Dict[str, torch.nn.Module], node: torch.fx.Node) return node.target +def is_node_output_tensor(node: torch.fx.Node) -> bool: + """Checks if the node output produces a Tensor or not. + + NOTE: This requires to run `ShapeProp` on the containing fx graph before + calling this function. This is because it works by checking the `type` + metadata on the node. This metadata is produced by the `ShapeProp`. + """ + type_ = node.meta.get("type", None) + return type_ is not None and issubclass(type_, torch.Tensor) + + class FxNetAccFusionsFinder: """ Finds groups of connected ACC nodes that pass non-tensor data between each other. -- 2.7.4