Add fx2trt pass for removing duplicate output args (#64433)
authorKefei Lu <kefeilu@fb.com>
Thu, 2 Sep 2021 17:38:43 +0000 (10:38 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 17:40:37 +0000 (10:40 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64433

Fx2TRT does not support duplicate nodes in the output args tuple.

This pass removes duplicate output args from the target subnets and fixes their uses in the top level module where the subnets are called. This pass must be called after acc split on the top-level net and subsequent calls to the acc trace on the subnets.

This pass will change both the subnets and top level module.

Test Plan:
Run:

```
buck run mode/opt -c python.package_style=inplace //caffe2/torch/fb/fx2trt/tests/passes/:test_remove_duplicate_output_args

```

Reviewed By: 842974287

Differential Revision: D30468409

fbshipit-source-id: b4d91b76ab5d8a5275d68dd48d1327a44c22568e

torch/fx/experimental/fx2trt/passes/remove_duplicate_output_args.py [new file with mode: 0644]

diff --git a/torch/fx/experimental/fx2trt/passes/remove_duplicate_output_args.py b/torch/fx/experimental/fx2trt/passes/remove_duplicate_output_args.py
new file mode 100644 (file)
index 0000000..488ce45
--- /dev/null
@@ -0,0 +1,133 @@
+#!/usr/bin/env python3
+
+import operator
+import typing as t
+import logging
+import torch.fx as fx
+import dataclasses as dc
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def remove_duplicate_output_args(
+    top_level: fx.GraphModule,
+    target_subnets: t.Collection[str]
+) -> t.Mapping[str, "RemoveDuplicateResult"]:
+    """Removes duplicate output args.
+
+    This pass removes duplicate output args from the target subnets and fixes
+    their uses in the top level module where the subnets are called. This pass
+    must be called after acc split on the top-level net and subsequent calls to
+    the acc trace on the subnets.
+
+    This pass will change both the subnets and top level module.
+
+    Returns:
+        a mapping of the target subnet name to its dedupcate result
+    """
+
+    processed_subnets = {}
+    for node in top_level.graph.nodes:
+        node: fx.Node
+        if node.op == "call_module" and node.name in target_subnets:
+            sub_gm = top_level.get_submodule(node.target)
+            assert isinstance(sub_gm, fx.GraphModule)
+
+            replace_res = _remove_duplicate_output_args(sub_gm)
+            processed_subnets[node.name] = replace_res
+            if replace_res.replacement_map is None:
+                continue
+            sub_gm.recompile()
+
+            needs_recompile = False
+            # iterate on the copy since we will be changing elements of node.users
+            for user in list(node.users):
+                idx = _ensure_proper_output_use(user, node)
+                idx_new = replace_res.replacement_map[idx]
+                if idx_new != idx:
+                    user.args = (user.args[0], idx_new)
+                    needs_recompile = True
+
+            if needs_recompile:
+                top_level.recompile()
+    return processed_subnets
+
+
+@dc.dataclass(frozen=True)
+class RemoveDuplicateResult:
+    replacement_map: t.Optional[t.List[int]]
+    module: fx.GraphModule
+
+
+def _ensure_proper_output_use(user: fx.Node, target_node: fx.Node) -> int:
+    """
+    Ensures the node looks in proper form of calling the output of an fx2trt
+    splitter sub-net. Specifically:
+
+    1. op is call function, target: operator.getitem
+    2. args is a 2-element tuple
+    3. args[0] is the name of the subnet's output
+    4. args[1] is the index into the subnet output tuple
+
+    E.g.:
+
+        %getitem_4 : [#users=1] = call_function[target=operator.getitem](args = (%_run_on_acc_1, 4), kwargs = {})
+
+    returns the index into the subnet output tuple
+    """
+    _LOGGER.info(f"Checking user node: {user.format_node()}")
+    assert (
+        user.op == "call_function"
+        and user.target == operator.getitem
+        and len(user.args) == 2
+        and user.args[0].name == target_node.name
+        and isinstance(user.args[1], int)
+    ), f"Node is not a proper user of splitter output: {user.format_node()}"
+
+    return user.args[1]
+
+
+def _remove_duplicate_output_args(gm: fx.GraphModule) -> RemoveDuplicateResult:
+    output_nodes = [n for n in gm.graph.nodes if n.op == "output"]
+    assert len(output_nodes) == 1, \
+           f"Expecting exactly one `output` node, but got {len(output_nodes)}"
+
+    changed = False
+    # arg node name to its index in the new output args tuple
+    name_to_idx: t.Dict[str, int] = {}
+    output_node = output_nodes[0]
+
+    # Output op only uses its `args[0]`, and it does not have `kwargs`.
+    # https://pytorch.org/docs/stable/fx.html#torch.fx.Node
+    args = output_node.args[0]
+
+    # Only concern outselves to the case where the args is an iterable of fx.Node.
+    # Other return cases (e.g., a single value) is possible and we don't handle
+    # that in this pass.
+    if not (isinstance(args, t.Iterable) and all(isinstance(a, fx.Node) for a in args)):
+        return RemoveDuplicateResult(replacement_map=None, module=gm)
+
+    # Map old index of the arg node to the remaining node's idx,
+    # initialized to `i => i`
+    replacement_map: t.List[int] = list(range(len(args)))
+    args_new = []
+    for idx, a in enumerate(args):
+        assert isinstance(a, fx.Node), \
+               f"Expecting fx.Node instance, but got: {type(a)}"
+
+        if a.name not in name_to_idx:
+            args_new.append(a)
+            name_to_idx[a.name] = len(args_new) - 1
+        else:
+            changed = True
+            _LOGGER.warning(
+                f"Replaced duplicate output arg '{a.name}': "
+                f"{idx} -> {name_to_idx[a.name]}"
+            )
+        replacement_map[idx] = name_to_idx[a.name]
+
+    output_node.args = (tuple(args_new),)
+    if changed:
+        gm.recompile()
+    return RemoveDuplicateResult(replacement_map, module=gm)