Move fx2trt and oss_acc_tracer to oss (#63101)
authorShiyan Deng <dsy842974287@fb.com>
Sun, 15 Aug 2021 18:52:20 +0000 (11:52 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sun, 15 Aug 2021 18:53:36 +0000 (11:53 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63101

Move internal fx2trt to torch/fx/experimental/fx2trt and merge the two TRT interpreter we have right now. cc: mortzur as this might affect uru exporting script.

Move oss_acc_tracer to torch/fx/experimental/fx_acc.

Test Plan: CI

Reviewed By: jerryzh168

Differential Revision: D30257909

fbshipit-source-id: 4e374965fbf88d72e91844d9e9b6ff9b98f467d1

12 files changed:
test/test_fx_experimental.py
torch/fx/experimental/fx2trt/converters/__init__.py
torch/fx/experimental/fx2trt/converters/acc_ops_converters.py [new file with mode: 0644]
torch/fx/experimental/fx2trt/converters/add.py
torch/fx/experimental/fx2trt/converters/mul.py
torch/fx/experimental/fx2trt/fx2trt.py
torch/fx/experimental/fx_acc/__init__.py [new file with mode: 0644]
torch/fx/experimental/fx_acc/acc_normalizer.py [new file with mode: 0644]
torch/fx/experimental/fx_acc/acc_ops.py [new file with mode: 0644]
torch/fx/experimental/fx_acc/acc_tracer.py [new file with mode: 0644]
torch/fx/experimental/fx_acc/acc_utils.py [new file with mode: 0644]
torch/fx/experimental/graph_manipulation.py

index 87bf7ec..00f3201 100644 (file)
@@ -568,7 +568,7 @@ class TestFXExperimental(JitTestCase):
             node_to_partition_id = {}
             partition_to_logical_devices = {}
             count = 0
-            GraphManipulation.get_size_of_all_nodes(traced, [a])
+            graph_manipulation.get_size_of_all_nodes(traced, [a])
             for node in traced.graph.nodes:
                 if node.op not in {"placeholder", "get_attr", "output"}:
                     node_to_partition_id[node] = count
index a73d459..4c62156 100644 (file)
@@ -8,3 +8,4 @@ from .maxpool import *  # noqa: F403
 from .mul import *  # noqa: F403
 from .transformation import *  # noqa: F403
 from .quantization import *  # noqa: F403
+from .acc_ops_converters import *  # noqa: F403
diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
new file mode 100644 (file)
index 0000000..b85ab40
--- /dev/null
@@ -0,0 +1,1080 @@
+# type: ignore[attr-defined]
+import math
+import operator
+
+import torch.fx.experimental.fx_acc.acc_ops as acc_ops
+import torch.fx.experimental.fx_acc.acc_utils as acc_utils
+import numpy as np
+import tensorrt as trt
+import torch
+from torch.fx.experimental.fx2trt.fx2trt import (
+    tensorrt_converter,
+    torch_dtype_from_trt,
+    get_dynamic_dims,
+)
+
+
+def to_numpy(tensor: torch.Tensor):
+    """
+    Convert a PyTorch Tensor to a Numpy Array.
+    """
+    if tensor is None:
+        return tensor
+
+    if tensor.is_quantized:
+        tensor = tensor.dequantize()
+
+    return tensor.cpu().detach().contiguous().numpy()
+
+
+def has_dynamic_shape(shape):
+    return any(s == -1 for s in shape)
+
+
+def get_axes_for_reduce_op(dim, has_implicit_batch_dimension):
+    if isinstance(dim, int):
+        dim = (dim,)
+
+    if has_implicit_batch_dimension:
+        assert 0 not in dim, "Can't reduce over batch dimension when it's implicit."
+
+    axes = 0
+    for d in dim:
+        axes |= 1 << (d - (1 if has_implicit_batch_dimension else 0))
+
+    return axes
+
+
+def create_constant(network, tensor, name):
+    if isinstance(tensor, int):
+        tensor = torch.IntTensor([tensor])
+
+    if isinstance(tensor, float):
+        tensor = torch.Tensor([tensor])
+
+    shape = tuple(tensor.shape)
+
+    # Remove all preceding 1s as they can be re-inserted later during broadcasting.
+    num_preceding_ones = 0
+    for j in range(len(shape)):
+        if int(shape[j]) == 1:
+            num_preceding_ones += 1
+        else:
+            break
+
+    # If shape is all 1s, we want last digit.
+    shape = shape[num_preceding_ones:] if num_preceding_ones < len(shape) else (1,)
+    constant = network.add_constant(shape, to_numpy(tensor))
+    constant.name = name
+    return constant.get_output(0)
+
+
+def get_trt_tensor(network, input_val, name):
+    if isinstance(input_val, (torch.Tensor, int, float)):
+        return create_constant(network, input_val, name)
+    elif not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"Received input {input_val} of name {name} that "
+            "is not part of the TensorRT region!"
+        )
+    else:
+        return input_val
+
+
+def append_ones(network, input, name, num_prepend_ones):
+    layer = network.add_shuffle(input)
+
+    if has_dynamic_shape(input.shape):
+        input_shape_layer = network.add_shape(input)
+        input_shape_layer.name = f"{name}_broadcast_orig_shape"
+        prepend_shape_layer = network.add_constant(
+            (num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32)
+        )
+        prepend_shape_layer.name = f"{name}_broadcast_prepend_ones"
+        reshape_dim_layer = network.add_concatenation(
+            [prepend_shape_layer.get_output(0), input_shape_layer.get_output(0)]
+        )
+        reshape_dim_layer.axis = 0
+        reshape_dim_layer.name = f"{name}_broadcast_final_shape"
+        layer.set_input(1, reshape_dim_layer.get_output(0))
+    else:
+        layer.reshape_dims = (1,) * num_prepend_ones + tuple(input.shape)
+
+    layer.name = name
+    return layer.get_output(0)
+
+
+def broadcast(network, a, b, a_name, b_name, preset_diff=0):
+    a_shape = tuple(a.shape)
+    b_shape = tuple(b.shape)
+
+    diff = len(a_shape) - len(b_shape) - preset_diff
+    if diff > 0:
+        b = append_ones(network, b, f"{b_name}_broadcast", diff)
+    elif diff < 0:
+        a = append_ones(network, a, f"{a_name}_broadcast", -diff)
+
+    return a, b
+
+
+def add_binary_elementwise_layer(network, lhs_val, rhs_val, op_type, name):
+    lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs")
+    rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs")
+    lhs_val, rhs_val = broadcast(
+        network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
+    )
+    layer = network.add_elementwise(lhs_val, rhs_val, op_type)
+    layer.name = name
+    return layer.get_output(0)
+
+
+def add_unary_layer(network, input_val, operation_type, name):
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"{operation_type} received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+    layer = network.add_unary(input_val, operation_type)
+    layer.name = name
+    return layer.get_output(0)
+
+
+def add_activation_layer(network, input_val, operation_type, name):
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"{operation_type} received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+    layer = network.add_activation(input_val, operation_type)
+    layer.name = name
+    return layer.get_output(0)
+
+
+def add_transpose_layer(
+    network, input_val, dim_0, dim_1, name, ignore_implicit_batch=False
+):
+    """Adds a transpose layer to the TensorRT network
+    Args:
+        network: TensorRT Network object
+        input_val: tensorrt.ITensor
+        dim_0, dim_1: dimensions for transpose, e.g. dim_0=1, dim_1=0 means transpose
+        the first two dimensions
+        name: Name of the layer
+        ignore_implicit_batch: activations might have implicit batch, but weights do
+        not, when this is True, we'll ignore the implicit batch and use the dimension
+        argument as is
+    Returns:
+        output TensorRT ITensor from the transpose layer
+    """
+    if not ignore_implicit_batch and network.has_implicit_batch_dimension:
+        assert (
+            dim_0 != 0 and dim_1 != 0
+        ), "It's not allowed to call transpose on non-constant when batch dim is implicit!"
+        dim_0 -= 1
+        dim_1 -= 1
+
+    permutation = list(range(len(input_val.shape)))
+    permutation[dim_0] = dim_1
+    permutation[dim_1] = dim_0
+
+    layer = network.add_shuffle(input_val)
+    layer.second_transpose = tuple(permutation)
+    layer.name = name
+    return layer.get_output(0)
+
+
+def process_attr(val, num_elem):
+    if not isinstance(val, tuple):
+        val = (val,) * num_elem
+    return val
+
+
+@tensorrt_converter(acc_ops.conv2d)
+def acc_ops_conv2d(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"Conv2d received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    if has_dynamic_shape(input_val.shape):
+        assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
+
+    kernel = to_numpy(kwargs["weight"])
+    bias = to_numpy(kwargs["bias"])
+
+    layer = network.add_convolution(
+        input=input_val,
+        num_output_maps=kernel.shape[0],
+        kernel_shape=kernel.shape[2:],
+        kernel=kernel,
+        bias=bias,
+    )
+
+    layer.name = name
+    layer.stride = kwargs["stride"]
+    layer.padding = kwargs["padding"]
+    layer.dilation = kwargs["dilation"]
+    if kwargs["groups"] is not None:
+        layer.num_groups = kwargs["groups"]
+
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.flatten)
+def acc_ops_flatten(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"flatten received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
+    start_dim = (kwargs["start_dim"] if "start_dim" in kwargs else 0) % num_dims
+    end_dim = (kwargs["end_dim"] if "end_dim" in kwargs else -1) % num_dims
+
+    if network.has_implicit_batch_dimension:
+        assert start_dim != 0, "Can't flatten batch dimension when it's implicit."
+        start_dim -= 1
+        end_dim -= 1
+
+    layer = network.add_shuffle(input_val)
+    layer.name = name
+
+    # If there're dynamic shapes then we need to use shape layers
+    # to figure out the final shape after flatten. We first slice
+    # the input shape to three parts:
+    #   1. dimensions before start_dim
+    #   2. dimensions between start_dim and end_dim
+    #   3. dimensions after end_dim
+    # Part 1 and 3 might not exist if start_dim is 0 or end_dim is
+    # last dim. Then we do a reduced multiplication over part 2 to
+    # get flattened dim. Finally, we concatenate the three parts to
+    # get the final shape.
+    if has_dynamic_shape(input_val.shape):
+        input_shape_layer = network.add_shape(input_val)
+        input_shape_layer.name = f"{name}_orig_shape"
+
+        final_shapes = []
+
+        # Shapes before start_dim
+        if start_dim > 0:
+            prefix_shape_layer = network.add_slice(
+                input_shape_layer.get_output(0),
+                start=(0,),
+                shape=(start_dim,),
+                stride=(1,),
+            )
+            prefix_shape_layer.name = f"{name}_pre_shape"
+            final_shapes.append(prefix_shape_layer.get_output(0))
+
+        flatten_shape_layer = network.add_slice(
+            input_shape_layer.get_output(0),
+            start=(start_dim,),
+            shape=(end_dim - start_dim + 1,),
+            stride=(1,),
+        )
+        flatten_shape_layer.name = f"{name}_need_flatten"
+        flatten_shape_layer = network.add_reduce(
+            flatten_shape_layer.get_output(0),
+            trt.ReduceOperation.PROD,
+            axes=get_axes_for_reduce_op(0, False),
+            keep_dims=True,
+        )
+        flatten_shape_layer.name = f"{name}_flatten_dim"
+        final_shapes.append(flatten_shape_layer.get_output(0))
+
+        # Shapes after start_dim
+        if end_dim < len(input_val.shape) - 1:
+            suffix_shape_layer = network.add_slice(
+                input_shape_layer.get_output(0),
+                start=(end_dim + 1,),
+                shape=(len(input_val.shape) - end_dim - 1,),
+                stride=(1,),
+            )
+            suffix_shape_layer.name = f"{name}_suffix_shape"
+            final_shapes.append(suffix_shape_layer.get_output(0))
+
+        final_shape_layer = network.add_concatenation(final_shapes)
+        final_shape_layer.axis = 0
+        final_shape_layer.name = f"{name}_final_shape"
+        layer.set_input(1, final_shape_layer.get_output(0))
+    else:
+        final_shape = []
+        flatten_dim = 1
+        for i, s in enumerate(input_val.shape):
+            if i >= start_dim and i <= end_dim:
+                flatten_dim *= s
+            elif i == end_dim + 1:
+                final_shape.append(flatten_dim)
+                final_shape.append(s)
+            else:
+                final_shape.append(s)
+        if end_dim == len(input_val.shape) - 1:
+            final_shape.append(flatten_dim)
+
+        layer.reshape_dims = tuple(final_shape)
+
+    return layer.get_output(0)
+
+
+# For implicit batch dim mode, we use this to represent batch dim if we
+# ever trying to retrieve it via size() and we hope it will fail hard if
+# it's used somewhere else.
+IMPLICIT_BATCH_DIM = -999
+
+
+@tensorrt_converter(acc_ops.size)
+def acc_ops_size(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"size received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    if not has_dynamic_shape(input_val.shape):
+        if network.has_implicit_batch_dimension:
+            return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_val.shape))
+        return torch.Size(input_val.shape)
+
+    layer = network.add_shape(input_val)
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.batch_norm)
+def acc_ops_batch_norm(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"BatchNorm2d received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    if has_dynamic_shape(input_val.shape):
+        assert input_val.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
+
+    scale = to_numpy(kwargs["weight"]) / np.sqrt(
+        to_numpy(kwargs["running_var"]) + kwargs["eps"]
+    )
+    bias = (
+        to_numpy(kwargs["bias"])
+        - to_numpy(kwargs["running_mean"]) * scale
+    )
+    power = np.ones_like(scale)
+
+    layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power)
+    layer.name = name
+
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.softmax)
+def acc_ops_softmax(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    dim = kwargs["dim"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"softmax received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    # Used to get dim when dim is None. Copied from PyTorch softmax implementation.
+    def get_softmax_dim(ndim):
+        if ndim == 0 or ndim == 1 or ndim == 3:
+            ret = 0
+        else:
+            ret = 1
+        return ret
+
+    if dim is None:
+        dim = get_softmax_dim(
+            len(input_val.shape)
+            if not network.has_implicit_batch_dimension
+            else len(input_val.shape) + 1
+        )
+
+    if network.has_implicit_batch_dimension:
+        assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
+        dim = (dim % (len(input_val.shape) + 1)) - 1
+
+    layer = network.add_softmax(input_val)
+    layer.axes = 1 << dim
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.relu)
+def acc_ops_relu(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.ActivationType.RELU
+    return add_activation_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.sin)
+def acc_ops_sin(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.SIN
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.cos)
+def acc_ops_cos(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.COS
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.tan)
+def acc_ops_tan(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.TAN
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.sinh)
+def acc_ops_sinh(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.SINH
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.cosh)
+def acc_ops_cosh(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.COSH
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.tanh)
+def acc_ops_tanh(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.ActivationType.TANH
+    return add_activation_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.asin)
+def acc_ops_asin(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.ASIN
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.acos)
+def acc_ops_acos(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.ACOS
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.atan)
+def acc_ops_atan(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.ATAN
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.exp)
+def acc_ops_exp(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.EXP
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.log)
+def acc_ops_log(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.LOG
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.sqrt)
+def acc_ops_sqrt(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.SQRT
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.reciprocal)
+def acc_ops_reciprocal(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.RECIP
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.abs)
+def acc_ops_abs(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.ABS
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.neg)
+def acc_ops_neg(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.NEG
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.floor)
+def acc_ops_floor(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.FLOOR
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.ceil)
+def acc_ops_ceil(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    operation_type = trt.UnaryOperation.CEIL
+    return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.sum)
+def acc_ops_sum(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"sum received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    # If dim is specified, then we are computing reduced sum over certain dimensions.
+    # Otherwise, we are dong summation over all elements, which is only supported in
+    # explicit batch dimension.
+    if "dim" not in kwargs:
+        assert (
+            not network.has_implicit_batch_dimension
+        ), "Do not support sum all the elements for implicit batch."
+        dim = range(0, len(input_val.shape))
+    else:
+        dim = kwargs["dim"]
+
+    keepdim = False if "keepdim" not in kwargs else kwargs["keepdim"]
+    layer = network.add_reduce(
+        input_val,
+        trt.ReduceOperation.SUM,
+        get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension),
+        keepdim,
+    )
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.max_pool2d)
+def acc_ops_max_pool2d(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"MaxPool2d received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    kernel_size = process_attr(kwargs["kernel_size"], 2)
+    stride = process_attr(kwargs["stride"], 2)
+    padding = process_attr(kwargs["padding"], 2)
+    dilation = process_attr(kwargs["dilation"], 2)
+    ceil_mode = kwargs["ceil_mode"]
+
+    if dilation != (1, 1):
+        raise RuntimeError(
+            f"Only support dilation=(1, 1) for maxpool, but got {dilation}"
+        )
+
+    layer = network.add_pooling(
+        input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size
+    )
+    layer.stride = stride
+    layer.padding = padding
+    layer.name = name
+
+    if ceil_mode:
+        layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
+
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.squeeze)
+def acc_ops_squeeze(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"squeeze received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    dim = kwargs["dim"] if "dim" in kwargs else None
+    # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
+    # dim, which is a very rare case. For now we just claim not supporting dim=None.
+    assert dim is not None, "We don't support dim=None right now."
+
+    if network.has_implicit_batch_dimension:
+        assert dim != 0, "We don't support squeeze batch dim when it's implicit."
+        dim -= 1
+
+    assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
+    assert (
+        len(get_dynamic_dims(input_val.shape)) <= 1
+    ), "Currently more than one dynamic dim for input to squeeze is not supported."
+
+    output_shape = []
+    for i, s in enumerate(input_val.shape):
+        if i == dim and s == 1:
+            continue
+        output_shape.append(s)
+    layer = network.add_shuffle(input_val)
+    layer.reshape_dims = tuple(output_shape)
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.add)
+def acc_ops_add(network, target, args, kwargs, name):
+    return add_binary_elementwise_layer(
+        network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.SUM, name
+    )
+
+
+@tensorrt_converter(acc_ops.sub)
+def acc_ops_sub(network, target, args, kwargs, name):
+    return add_binary_elementwise_layer(
+        network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.SUB, name
+    )
+
+
+@tensorrt_converter(acc_ops.div)
+def acc_ops_div(network, target, args, kwargs, name):
+    return add_binary_elementwise_layer(
+        network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, name
+    )
+
+
+@tensorrt_converter(acc_ops.mul)
+def acc_ops_mul(network, target, args, kwargs, name):
+    return add_binary_elementwise_layer(
+        network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.PROD, name
+    )
+
+
+@tensorrt_converter(acc_ops.min_two_tensors_input)
+def acc_ops_min_two_tensors_input(network, target, args, kwargs, name):
+    return add_binary_elementwise_layer(
+        network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.MIN, name
+    )
+
+
+@tensorrt_converter(acc_ops.adaptive_avg_pool2d)
+def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"AdaptiveAvgPool2d received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    assert (
+        input_val.shape[-1] != -1 and input_val.shape[-1] != -1
+    ), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims."
+    output_size = kwargs["output_size"]
+
+    for input_dim, output_dim in zip(input_val.shape[-2:], output_size):
+        if input_dim % output_dim != 0:
+            raise RuntimeError(
+                "For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
+                f"Got input dim {input_dim}, output dim {output_dim}"
+            )
+
+    stride = (
+        input_val.shape[-2] // output_size[0],
+        input_val.shape[-1] // output_size[1],
+    )
+    kernel_size = (
+        input_val.shape[-2] - (output_size[0] - 1) * stride[0],
+        input_val.shape[-1] - (output_size[1] - 1) * stride[1],
+    )
+    layer = network.add_pooling(
+        input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
+    )
+    layer.stride = stride
+    layer.name = name
+
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.avg_pool2d)
+def acc_ops_avg_pool2d(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"AvgPool2d received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    kernel_size = process_attr(kwargs["kernel_size"], 2)
+    stride = process_attr(kwargs["stride"], 2)
+    padding = process_attr(kwargs["padding"], 2)
+    ceil_mode = kwargs["ceil_mode"]
+    count_include_pad = kwargs["count_include_pad"]
+    divisor_override = kwargs["divisor_override"]
+
+    if divisor_override:
+        raise RuntimeError("TensorRT does not support divisor_override.")
+
+    layer = network.add_pooling(
+        input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
+    )
+    layer.stride = stride
+    layer.padding = padding
+    layer.average_count_excludes_padding = False if count_include_pad else True
+
+    if ceil_mode:
+        layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
+
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.reshape)
+def acc_ops_reshape(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"Reshape received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    shape = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "shape")
+    if network.has_implicit_batch_dimension:
+        shape = shape[1:]
+
+    layer = network.add_shuffle(input_val)
+
+    if all(isinstance(s, int) for s in shape):
+        layer.reshape_dims = tuple(shape)
+    else:
+        # Convert all the dimensions to trt Tensors.
+        trt_shape = []
+
+        for i, s in enumerate(shape):
+            if isinstance(s, trt.tensorrt.ITensor):
+                if len(s.shape) == 0:
+                    s = append_ones(network, s, f"{name}_{i}", 1)
+                trt_shape.append(s)
+            else:
+                trt_shape.append(get_trt_tensor(network, s, f"{name}_{i}"))
+
+        shape_layer = network.add_concatenation(inputs=trt_shape)
+        shape_layer.axis = 0
+        shape_layer.name = f"{name}_output_shape"
+        layer.set_input(1, shape_layer.get_output(0))
+
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.linear)
+def acc_ops_linear(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"Linear received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    dynamic_dims = get_dynamic_dims(input_val.shape)
+    assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, (
+        "Currently we only support one dynmaic "
+        "dim for linear and it can't be the last dim."
+    )
+
+    layer = network.add_shuffle(input_val)
+    layer.reshape_dims = tuple(input_val.shape) + (1, 1)
+    layer.name = f"{name}_pre_shuffle"
+
+    # add fully connected
+    layer = network.add_fully_connected(
+        input=layer.get_output(0),
+        num_outputs=kwargs["weight"].shape[0],
+        kernel=to_numpy(kwargs["weight"]),
+        bias=to_numpy(kwargs["bias"]),
+    )
+    layer.name = f"{name}_linear"
+
+    # reshape back
+    layer = network.add_shuffle(layer.get_output(0))
+    layer.reshape_dims = tuple(input_val.shape[:-1]) + (kwargs["weight"].shape[0],)
+    layer.name = f"{name}_post_shuffle"
+
+    return layer.get_output(0)
+
+
+def add_clamp(network, input, val, op):
+    acc_ops_clamp_shape = (1,) * len(input.shape)  # broadcast all dimensions
+    acc_ops_clamp_tensor = (
+        val
+        * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
+        .cpu()
+        .numpy()
+    )
+    acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
+    layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)
+
+    return layer
+
+
+@tensorrt_converter(acc_ops.clamp)
+def acc_ops_clamp(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    min_val = kwargs["min"]
+    max_val = kwargs["max"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"Clamp received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    if min_val is not None:
+        clamp_min_layer = add_clamp(
+            network, input_val, min_val, trt.ElementWiseOperation.MAX
+        )
+        clamp_min_layer.name = f"{name}_clamp_min"
+        input_val = clamp_min_layer.get_output(0)
+    if max_val is not None:
+        clamp_max_layer = add_clamp(
+            network, input_val, max_val, trt.ElementWiseOperation.MIN
+        )
+        clamp_max_layer.name = f"{name}_clamp_max"
+        input_val = clamp_max_layer.get_output(0)
+
+    return input_val
+
+
+@tensorrt_converter(acc_ops.getitem)
+def acc_ops_getitem(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    slices = kwargs["idx"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        return operator.getitem(input_val, slices)
+
+    assert not has_dynamic_shape(
+        input_val.shape
+    ), "Currently we don't support slicing tensor if it has dynamic shape."
+
+    def num_slice_types(slices):
+        """
+        Gather the number of slice in getitem slices.
+        """
+        num_slice = 0
+        for s in slices:
+            if isinstance(s, slice) or isinstance(s, int):
+                num_slice += 1
+        return num_slice
+
+    def slice_to_trt_params(py_slice, dim_size):
+        """
+        Convert python slice to TensorRT slice layer parameters.
+        """
+        start = (py_slice.start % dim_size) if py_slice.start else 0
+        stride = py_slice.step if py_slice.step else 1
+        stop = (py_slice.stop % dim_size) if py_slice.stop else dim_size
+        size = math.ceil((stop - start) * 1.0 / stride)
+        return start, size, stride
+
+    if not isinstance(slices, tuple):
+        slices = (slices,)
+
+    if network.has_implicit_batch_dimension:
+        # Raise an error if it's trying to subscript batch dimension unless it's
+        # slice(None, None, None).
+        batch_subscript = slices[0]
+        if batch_subscript != slice(None, None, None):
+            raise RuntimeError(
+                f"Can't subscript batch dimension when it's implicit. Got {slices}"
+            )
+
+        # Remove batch_dim subscript
+        slices = slices[1:]
+
+    # Replace ellipsis with expanded slices.
+    # Compute the number of dim ellipsis represent.
+    num_ellipsis = len(input_val.shape) - num_slice_types(slices)
+    new_slices = []
+    for s in slices:
+        if s == Ellipsis:
+            while num_ellipsis > 0:
+                new_slices.append(slice(None, None, None))
+                num_ellipsis -= 1
+        else:
+            new_slices.append(s)
+    slices = new_slices
+
+    # Build trt slice layer params
+    start = []
+    size = []
+    stride = []
+
+    i = 0
+    for s in slices:
+        if s is None:
+            continue
+
+        if isinstance(s, slice):
+            params = slice_to_trt_params(s, input_val.shape[i])
+            start.append(params[0])
+            size.append(params[1])
+            stride.append(params[2])
+        else:
+            start.append(s % input_val.shape[i])
+            size.append(1)
+            stride.append(1)
+        i += 1
+
+    while i < len(input_val.shape):
+        start.append(0)
+        size.append(input_val.shape[i])
+        stride.append(1)
+        i += 1
+
+    layer = network.add_slice(
+        input=input_val,
+        start=start,
+        shape=size,
+        stride=stride,
+    )
+    layer.name = name
+
+    # Add shuffle layer to insert dimensions for 'None' and remove dimensions for 'int'.
+    if any(not isinstance(s, slice) for s in slices):
+        slice_out = layer.get_output(0)
+        layer = network.add_shuffle(slice_out)
+        final_shape = []
+        original_idx = 0
+        for s in slices:
+            # If it's a slice, keep the dim.
+            if isinstance(s, slice):
+                final_shape.append(slice_out.shape[original_idx])
+                original_idx += 1
+            # If it's None, extend the dim.
+            elif s is None:
+                final_shape.append(1)
+            # If it's a int, remove the dim.
+            else:
+                original_idx += 1
+        layer.reshape_dims = tuple(final_shape) + tuple(slice_out.shape)[original_idx:]
+
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.cat)
+def acc_ops_cat(network, target, args, kwargs, name):
+    tensors = kwargs["tensors"]
+
+    if any(not isinstance(t, trt.tensorrt.ITensor) for t in tensors):
+        raise RuntimeError(
+            f"cat received inputs {tensors} that is not part " "of the TensorRT region!"
+        )
+
+    layer = network.add_concatenation(inputs=tensors)
+    layer.axis = kwargs["dim"] - (1 if network.has_implicit_batch_dimension else 0)
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.transpose)
+def acc_ops_transpose(network, target, args, kwargs, name):
+    input_val, dim_0, dim_1 = kwargs["input"], kwargs["dim0"], kwargs["dim1"]
+
+    # TODO: Remove this after enabling const folding in fx_acc
+    if isinstance(input_val, torch.Tensor):
+        return input_val.transpose(dim_0, dim_1).contiguous()
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"transpose received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    return add_transpose_layer(network, input_val, dim_0, dim_1, name)
+
+
+@tensorrt_converter(acc_ops.matmul)
+def acc_ops_matmul(network, target, args, kwargs, name):
+    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
+    other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other")
+
+    for i in [input_val, other_val]:
+        if not isinstance(i, trt.tensorrt.ITensor):
+            raise RuntimeError(
+                f"matmul received input {i} that is not part " "of the TensorRT region!"
+            )
+
+    input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
+    preset_diff = 0
+
+    if len(input_val.shape) == 1:
+        preset_diff -= 1
+        input_matrix_op = trt.MatrixOperation.VECTOR
+
+    if len(other_val.shape) == 1:
+        preset_diff += 1
+        other_matrix_op = trt.MatrixOperation.VECTOR
+
+    input_val, other_val = broadcast(
+        network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
+    )
+    layer = network.add_matrix_multiply(
+        input_val, input_matrix_op, other_val, other_matrix_op
+    )
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.sigmoid)
+def acc_ops_sigmoid(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"Sigmoid received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID)
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.permute)
+def acc_ops_permute(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+    permutation = kwargs["permutation"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(
+            f"permute received input {input_val} that is not part "
+            "of the TensorRT region!"
+        )
+
+    if network.has_implicit_batch_dimension:
+        assert permutation[0] == 0, "Can't permute batch dimension when it's implicit."
+        permutation = [i - 1 for i in permutation[1:]]
+
+    layer = network.add_shuffle(input_val)
+    layer.second_transpose = tuple(permutation)
+    layer.name = name
+    return layer.get_output(0)
index 22ad9dc..f8a201b 100644 (file)
@@ -8,14 +8,17 @@ from .helper_functions import get_dyn_range, mark_as_int8_layer
 @tensorrt_converter(operator.add)
 @tensorrt_converter(torch.add)
 def add(network, target, args, kwargs, layer_name):
-    if len(kwargs) != 0:
-        raise RuntimeError(f"Add receives unsupported kwargs: {kwargs}!")
-
-    assert len(args) == 2
-    if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in args):
+    # operator.add
+    if len(kwargs) == 0:
+        lhs_val, rhs_val = args
+    else:
+        # torch.add
+        lhs_val, rhs_val = kwargs["input"], kwargs["other"]
+        assert kwargs["alpha"] == 1
+
+    if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in [lhs_val, rhs_val]):
         raise RuntimeError("add() received an input that is not part of the TensorRT region!")
 
-    lhs_val, rhs_val = args
     layer = network.add_elementwise(lhs_val, rhs_val, trt.ElementWiseOperation.SUM)
     layer.name = layer_name
 
index 169bcf2..bb1838a 100644 (file)
@@ -8,11 +8,16 @@ from .helper_functions import get_dyn_range, mark_as_int8_layer
 @tensorrt_converter(torch.mul)
 @tensorrt_converter(operator.mul)
 def mul(network, target, args, kwargs, layer_name):
-    assert len(args) == 2
-    if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in args):
+    # operator.mul
+    if len(kwargs) == 0:
+        lhs_val, rhs_val = args
+    else:
+        # torch.mul
+        lhs_val, rhs_val = kwargs["input"], kwargs["other"]
+
+    if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in [lhs_val, rhs_val]):
         raise RuntimeError('mul() received an input that is not part of the TensorRT region!')
 
-    lhs_val, rhs_val = args
     layer = network.add_elementwise(lhs_val, rhs_val, trt.ElementWiseOperation.PROD)
     layer.name = layer_name
 
index 38059cd..6054586 100644 (file)
@@ -1,16 +1,14 @@
-import copy
 import warnings
 from typing import List, NamedTuple, Iterable, Any, Optional, Tuple
 
+import tensorrt as trt
 import torch
 import torch.fx
-import tensorrt as trt
-from torch.fx.experimental.normalize import NormalizeArgs
 
 
 # Borrowed from torch2trt
 def torch_dtype_to_trt(dtype):
-    if trt.__version__ >= '7.0' and dtype == torch.bool:
+    if trt.__version__ >= "7.0" and dtype == torch.bool:
         return trt.bool
     elif dtype == torch.int8:
         return trt.int8
@@ -27,7 +25,7 @@ def torch_dtype_to_trt(dtype):
 def torch_dtype_from_trt(dtype):
     if dtype == trt.int8:
         return torch.int8
-    elif trt.__version__ >= '7.0' and dtype == trt.bool:
+    elif trt.__version__ >= "7.0" and dtype == trt.bool:
         return torch.bool
     elif dtype == trt.int32:
         return torch.int32
@@ -40,7 +38,9 @@ def torch_dtype_from_trt(dtype):
 
 
 class TRTModule(torch.nn.Module):
-    def __init__(self, engine=None, input_names=None, output_names=None, fp16_output=False):
+    def __init__(
+        self, engine=None, input_names=None, output_names=None, fp16_output=False
+    ):
         super(TRTModule, self).__init__()
         self._register_state_dict_hook(TRTModule._on_state_dict)
         self.engine = engine
@@ -78,7 +78,9 @@ class TRTModule(torch.nn.Module):
         self.output_names = state_dict[prefix + "output_names"]
 
     def forward(self, *inputs):
-        assert len(inputs) == len(self.input_names), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
+        assert len(inputs) == len(
+            self.input_names
+        ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
         batch_size = inputs[0].shape[0]
         contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
         bindings: List[Any] = [None] * (len(self.input_names) + len(self.output_names))
@@ -119,8 +121,10 @@ class TRTModule(torch.nn.Module):
         return tuple(outputs)
 
     def enable_profiling(self):
-        raise RuntimeError("Profiling is not supported right now because it requires calling"
-                           " execute() instead of execute_async().")
+        raise RuntimeError(
+            "Profiling is not supported right now because it requires calling"
+            " execute() instead of execute_async()."
+        )
         if not self.context.profiler:
             self.context.profiler = trt.Profiler()
 
@@ -132,6 +136,7 @@ def tensorrt_converter(key):
     def register_converter(converter):
         CONVERTERS[key] = converter
         return converter
+
     return register_converter
 
 
@@ -157,11 +162,12 @@ class InputTensorSpec(NamedTuple):
     has_batch_dim: Whether the shape includes batch dimension. Batch dimension has to be provided
         if the engine want to run with dynamic shape.
     """
-    shape : torch.Size
-    dtype : torch.dtype
-    device : torch.device = torch.device("cpu")
-    shape_ranges : List[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]] = []
-    has_batch_dim : bool = True
+
+    shape: torch.Size
+    dtype: torch.dtype
+    device: torch.device = torch.device("cpu")
+    shape_ranges: List[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]] = []
+    has_batch_dim: bool = True
 
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
@@ -182,13 +188,27 @@ def get_dynamic_dims(shape):
     return dynamic_dims
 
 
-class BaseTRTInterpreter(torch.fx.Interpreter):
+def create_inputs_from_specs(input_specs):
+    inputs = []
+
+    for shape, dtype, device, shape_ranges, has_batch_dim in input_specs:
+        if len(get_dynamic_dims(shape)):
+            shape = shape_ranges[0][1]
+        elif not has_batch_dim:
+            shape = (1,) + tuple(shape)
+
+        inputs.append(torch.empty(shape, dtype=dtype, device=device))
+
+    return inputs
+
+
+class TRTInterpreter(torch.fx.Interpreter):
     def __init__(
         self,
-        module : torch.fx.GraphModule,
-        input_specs : List[InputTensorSpec],
-        explicit_batch_dimension : bool = False,
-        logger_level=trt.Logger.WARNING
+        module: torch.fx.GraphModule,
+        input_specs: List[InputTensorSpec],
+        explicit_batch_dimension: bool = False,
+        logger_level=trt.Logger.WARNING,
     ):
         super().__init__(module)
 
@@ -196,12 +216,14 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
         self.builder = trt.Builder(self.logger)
 
         if explicit_batch_dimension:
-            EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+            EXPLICIT_BATCH = 1 << (int)(
+                trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
+            )
             self.network = self.builder.create_network(EXPLICIT_BATCH)
         else:
             self.network = self.builder.create_network()
 
-        self.optimization_profiles : Optional[List] = None
+        self.optimization_profiles: Optional[List] = None
         self.input_specs = input_specs
         self.input_specs_iter = 0
         self.validate_input_specs()
@@ -212,35 +234,59 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
     def validate_input_specs(self):
         for shape, dtpe, _, shape_ranges, has_batch_dim in self.input_specs:
             if not self.network.has_implicit_batch_dimension:
-                assert has_batch_dim, "It's required to specify batch dimension when it's explicit in TensorRT network."
+                assert (
+                    has_batch_dim
+                ), "It's required to specify batch dimension when it's explicit in TensorRT network."
 
             dynamic_dims = get_dynamic_dims(shape)
             if len(dynamic_dims):
-                assert not self.network.has_implicit_batch_dimension, "Can't have dynamic dim when " \
+                assert not self.network.has_implicit_batch_dimension, (
+                    "Can't have dynamic dim when "
                     f"batch dim is implicit, got {shape}."
-                assert len(shape_ranges), "shape_ranges must be provided when shape has dynamic dim."
+                )
+                assert len(
+                    shape_ranges
+                ), "shape_ranges must be provided when shape has dynamic dim."
 
                 if self.optimization_profiles:
-                    assert len(shape_ranges) == len(self.optimization_profiles), "Number of optimization " \
-                        f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range" \
+                    assert len(shape_ranges) == len(self.optimization_profiles), (
+                        "Number of optimization "
+                        f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range"
                         f" {len(shape_ranges)} provided."
+                    )
                 else:
-                    self.optimization_profiles = [self.builder.create_optimization_profile() for _ in range(len(shape_ranges))]
+                    self.optimization_profiles = [
+                        self.builder.create_optimization_profile()
+                        for _ in range(len(shape_ranges))
+                    ]
 
                 for shape_range in shape_ranges:
-                    assert len(shape_range) == 3, f"Expect three elements in shape_range, got {len(shape_range)}"
-                    assert all(len(s) == len(shape) for s in shape_range), "Expect elements in shape_range" \
+                    assert (
+                        len(shape_range) == 3
+                    ), f"Expect three elements in shape_range, got {len(shape_range)}"
+                    assert all(len(s) == len(shape) for s in shape_range), (
+                        "Expect elements in shape_range"
                         f" {shape_range} have the same number of dimension as the provided shape {len(shape)}"
+                    )
 
                     for i in range(len(shape)):
                         if i in dynamic_dims:
-                            assert all(shape_range[j][i] <= shape_range[j + 1][i] for j in range(2)), "Expect dynamic dim" \
+                            assert all(
+                                shape_range[j][i] <= shape_range[j + 1][i]
+                                for j in range(2)
+                            ), (
+                                "Expect dynamic dim"
                                 f" {i} to have incremental value for shapes in shape_range {shape_range}."
+                            )
                         else:
-                            assert all(s[i] == shape[i] for s in shape_range), f"Expect non dynamic dim {i} to be the same" \
+                            assert all(s[i] == shape[i] for s in shape_range), (
+                                f"Expect non dynamic dim {i} to be the same"
                                 f" for all shapes in shape_range {shape_range}."
+                            )
             else:
-                assert len(shape_ranges) == 0, "shape_ranges are provided for input that doesn't have dynamic dim."
+                assert (
+                    len(shape_ranges) == 0
+                ), "shape_ranges are provided for input that doesn't have dynamic dim."
 
     def run(
         self,
@@ -248,7 +294,7 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
         max_workspace_size=1 << 25,
         fp16_mode=True,
         int8_mode=False,
-        strict_type_constraints=True
+        strict_type_constraints=True,
     ):
         # TODO hack, should check contents of args and remove fp16_mode probably
         self.fp16_mode = fp16_mode
@@ -279,7 +325,7 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
                 builder_config.add_optimization_profile(optimization_profile)
 
         engine = self.builder.build_engine(self.network, builder_config)
-        assert(engine)
+        assert engine
         return engine, self._input_names, self._output_names
 
     def run_node(self, n):
@@ -288,7 +334,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
 
     def placeholder(self, target, args, kwargs):
         self._input_names.append(target)
-        shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[self.input_specs_iter]
+        shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[
+            self.input_specs_iter
+        ]
         self.input_specs_iter += 1
 
         if self.network.has_implicit_batch_dimension:
@@ -299,7 +347,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
                 assert self.optimization_profiles
                 self.optimization_profiles[i].set_shape(target, *shape_range)
 
-        return self.network.add_input(name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype))
+        return self.network.add_input(
+            name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
+        )
 
     def call_module(self, target, args, kwargs):
         assert isinstance(target, str)
@@ -307,7 +357,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
         converter = CONVERTERS.get(type(submod))
 
         if not converter:
-            raise RuntimeError(f'Conversion of module of type {type(submod)} not currently supported!')
+            raise RuntimeError(
+                f"Conversion of module of type {type(submod)} not currently supported!"
+            )
 
         return converter(self.network, submod, args, kwargs, self._cur_node_name)
 
@@ -315,7 +367,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
         converter = CONVERTERS.get(target)
 
         if not converter:
-            raise RuntimeError(f'Conversion of function {torch.typename(target)} not currently supported!')
+            raise RuntimeError(
+                f"Conversion of function {torch.typename(target)} not currently supported!"
+            )
 
         return converter(self.network, target, args, kwargs, self._cur_node_name)
 
@@ -324,7 +378,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
         converter = CONVERTERS.get(target)
 
         if not converter:
-            raise RuntimeError(f'Conversion of method {target} not currently supported!')
+            raise RuntimeError(
+                f"Conversion of method {target} not currently supported!"
+            )
 
         return converter(self.network, target, args, kwargs, self._cur_node_name)
 
@@ -333,10 +389,10 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
         outputs = args[0] if isinstance(args[0], tuple) else (args[0],)
 
         if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
-            raise RuntimeError('TensorRT requires all outputs to be Tensor!')
+            raise RuntimeError("TensorRT requires all outputs to be Tensor!")
 
         for i, output in enumerate(outputs):
-            name = f'output{i}'
+            name = f"output{i}"
             output.name = name
             self.network.mark_output(output)
             if self.fp16_mode:
@@ -344,18 +400,3 @@ class BaseTRTInterpreter(torch.fx.Interpreter):
             else:
                 output.dtype = trt.float32
             self._output_names.append(name)
-
-
-class TRTInterpreter(BaseTRTInterpreter):
-    """
-    Use this for general case where there're PyTorch vanilla ops in the FX mdoule.
-    """
-    def __init__(self, module : torch.nn.Module, input_specs : List[InputTensorSpec], logger_level=trt.Logger.WARNING):
-        # Preprocess the model
-        if not isinstance(module, torch.fx.GraphModule):
-            module = torch.fx.symbolic_trace(module)
-        else:
-            module = copy.deepcopy(module)
-        module = module.cpu().float()
-        module = NormalizeArgs(module).transform()
-        super().__init__(module, input_specs, logger_level=logger_level)
diff --git a/torch/fx/experimental/fx_acc/__init__.py b/torch/fx/experimental/fx_acc/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/torch/fx/experimental/fx_acc/acc_normalizer.py b/torch/fx/experimental/fx_acc/acc_normalizer.py
new file mode 100644 (file)
index 0000000..a5d116a
--- /dev/null
@@ -0,0 +1,401 @@
+# type: ignore[]
+import inspect
+import re
+from typing import NamedTuple, Optional, Callable, Dict, List, Tuple, Union, Any, Set
+
+import torch.fx.experimental.fx_acc.acc_utils as acc_utils
+import torch
+import torch.fx
+from torch.fx.node import _get_qualified_name
+
+# Need to keep up-to-date with https://fburl.com/codesearch/7r2hhh53
+ALIAS_MAP = {
+    "input": ("input", "x", "a", "x1"),
+    "dim": ("dim", "axis"),
+    "keepdim": ("keepdim", "keepdims"),
+    "other": ("other", "x2"),
+}
+
+# Type used for arg replacement tuples. The list represents the argument signature of
+# some callable. Each item in the list is a tuple, where for each member of a tuple:
+# - The first member is union of either:
+#   - A tuple of all potential alias kwarg str names of the source signature, or
+#   - A tuple of a single str representing the single kwarg name allowed.
+# - The second member is the str name of the kwarg to map it to. This is either from the
+#   signature of the acc_op, or for custom mapped nodes from the original unnormalized op.
+# - The third member is a bool representing whether this arg is optional, i.e. whether it
+#   is allowed to not be present in the original input args.
+ArgReplacementTuplesType = List[Tuple[Tuple[str, ...], str, bool]]
+
+
+class NormalizationInfo(NamedTuple):
+    """
+    Holds normalization info for some FX node, where the FX node will be mapped either
+    via new_fn_target and arg_replacement_tuples, or via custom_mapping_fn.
+
+    If via new_fn_target and arg_replacement_tuples:
+      - new_fn_target is the target function to replace the original node with
+        (generally some function from acc_ops).
+
+      - arg_replacement_tuples describes how to map the original FX node's args/kwargs to
+        the new FX node. If set to None, then the kwargs are copied directly from the
+        original FX node. Else, this is list of three-member tuples, where each tuple
+        represents a mapping from either an arg or kwarg in the original FX node to the
+        kwarg it should be mapped to. If for ops registered with `register_acc_op` then
+        this is a mapping to the the new FX node for the acc_op. Otherwise it is for some
+        op registered with `register_custom_acc_mapper_fn`, in which case this is a
+        mapping for the original input node so its args are normalized to kwargs before
+        being custom normalized to acc_ops. The third member of the tuple is a bool
+        representing whether this argument is optional; if False and the arg is not
+        present then an assertion will be thrown. The index of the tuple indicates where
+        the original arg is in node.args and the string name indicates which original
+        kwarg it is.
+
+    If via custom_mapping_fn, then custom_mapping_fn is some function that takes the
+    original FX node as input and returns the FX node that should replace it. This means
+    it was registered via `register_custom_acc_mapper_fn`.
+    """
+
+    new_fn_target: Callable
+    arg_replacement_tuples: Optional[ArgReplacementTuplesType]
+    custom_mapping_fn: Optional[Callable]
+    kwargs_to_move_to_acc_out_ty: Optional[Optional[List[Tuple[str, str]]]]
+    needs_shapes_for_normalization: bool
+
+
+# Dict from (op, target) to NormalizationInfo for that op.
+_normalization_dict: Dict[Tuple[str, Union[str, Callable]], NormalizationInfo] = {}
+
+# Set of all the acc ops.
+_acc_ops: Set[Callable] = set()
+
+
+def _insert_fun(
+    op_and_target: Tuple[str, Union[str, Callable]],
+    arg_replacement_tuples: List[Tuple],
+    new_fn_target: Optional[Callable] = None,
+    custom_mapping_fn: Optional[Callable] = None,
+    kwargs_to_move_to_acc_out_ty: Optional[Optional[List[Tuple[str, str]]]] = None,
+    needs_shapes_for_normalization=False,
+    allow_normalize_from_torch_package=False,
+):
+    if op_and_target[0] == "call_function":
+        assert callable(op_and_target[1])
+    elif op_and_target[0] == "call_method":
+        assert isinstance(op_and_target[1], str)
+    elif op_and_target[0] == "call_module":
+        assert isinstance(op_and_target[1], type)
+
+    # Finalize arg replacement tuples.
+    # 1. Check to see if they have the `is_optional` bool, and if not defaulting it to
+    #   False.
+    # 2. Some kwargs might have aliases. e.g. "a", "x" and "x1" are aliases of "input".
+    #   Here we replace `orig_kwarg` with a tuple of all aliases if it has aliases.
+    final_arg_replacement_tuples = []
+    for arg_replacement_tuple in arg_replacement_tuples:
+        if len(arg_replacement_tuple) == 2:
+            orig_kwarg, new_kwarg, is_optional = *arg_replacement_tuple, False
+        else:
+            assert len(arg_replacement_tuple) == 3
+            orig_kwarg, new_kwarg, is_optional = arg_replacement_tuple
+
+        if not isinstance(orig_kwarg, tuple):
+            orig_kwarg = (orig_kwarg,)
+
+        # Use set to avoid duplicates.
+        orig_kwarg_set = set(orig_kwarg)
+
+        for k in orig_kwarg:
+            if k in ALIAS_MAP:
+                orig_kwarg_set.update(ALIAS_MAP[k])
+        final_arg_replacement_tuples.append(
+            (tuple(orig_kwarg_set), new_kwarg, is_optional)
+        )
+
+    assert op_and_target not in _normalization_dict.keys()
+    norm_info = NormalizationInfo(
+        new_fn_target=new_fn_target,
+        arg_replacement_tuples=final_arg_replacement_tuples,
+        custom_mapping_fn=custom_mapping_fn,
+        kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty,
+        needs_shapes_for_normalization=needs_shapes_for_normalization,
+    )
+    _normalization_dict[op_and_target] = norm_info
+
+    # If allow_normalize_from_torch_package then add another entry to
+    # _normalization_dict where we look for the qualified name of the target with the
+    # torch_package module prefix. Note that we leave off any integer at the end of
+    # "<torch_package_>" in order to allow for whatever mangling index is used.
+    if allow_normalize_from_torch_package:
+        torch_package_op_and_target = (
+            op_and_target[0],
+            f"<torch_package_>.{_get_qualified_name(op_and_target[1])}",
+        )
+        _normalization_dict[torch_package_op_and_target] = norm_info
+
+
+def _get_dup_signature_tuples(fn: Callable) -> List[Tuple[str, str]]:
+    """
+    Helper that inspects the arg signature of `fn` and returns a list of tuples, where
+    each tuple is a pair of duplicated names which is used for arg_replacement_tuples.
+    """
+    sig_tuples: List[Tuple[str, str]] = []
+    for param in inspect.signature(inspect.unwrap(fn)).parameters:
+        sig_tuples.append((param, param))
+    return sig_tuples
+
+
+def register_acc_op(acc_op: Callable):
+    """
+    For a new acc op, add this as decorator to register it.
+    """
+    _acc_ops.add(acc_op)
+    return acc_op
+
+
+def register_acc_op_mapping(
+    op_and_target: Tuple[str, Union[str, Callable]],
+    arg_replacement_tuples: Optional[
+        List[Tuple[Union[str, Tuple[str, ...]], str]]
+    ] = None,
+    kwargs_to_move_to_acc_out_ty: Optional[List[Tuple[str, str]]] = None,
+):
+    """
+    Use this decorator to map a non-acc operator to an acc operator.
+
+    Args:
+        op_and_target: A tuple that contains op and target of the node that represents the non-acc operator.
+        arg_replacement_tuples: Please refer to the comment on above for `ArgReplacementTuplesType`.
+        kwargs_to_move_to_acc_out_ty: The kwargs we want to move out from the non-acc op kwargs to acc_out_ty.
+    """
+
+    def insert(new_fn_target: Callable):
+        # If arg_replacement_tuples is None then assume we use the same signature for
+        # the acc_op and the original op.
+        if arg_replacement_tuples is None:
+            final_arg_replacement_tuples = _get_dup_signature_tuples(new_fn_target)
+        else:
+            final_arg_replacement_tuples = arg_replacement_tuples
+
+        _insert_fun(
+            op_and_target=op_and_target,
+            new_fn_target=new_fn_target,
+            arg_replacement_tuples=final_arg_replacement_tuples,
+            kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty,
+        )
+        return new_fn_target
+
+    return insert
+
+
+def register_custom_acc_mapper_fn(
+    op_and_target: Tuple[str, Union[str, Callable]],
+    arg_replacement_tuples: List[Tuple[Union[str, Tuple[str, ...]], str]],
+    needs_shapes_for_normalization=False,
+    allow_normalize_from_torch_package=False,
+):
+    def insert(custom_mapping_fn: Callable):
+        _insert_fun(
+            op_and_target=op_and_target,
+            custom_mapping_fn=custom_mapping_fn,
+            arg_replacement_tuples=arg_replacement_tuples,
+            needs_shapes_for_normalization=needs_shapes_for_normalization,
+            allow_normalize_from_torch_package=allow_normalize_from_torch_package,
+        )
+        return custom_mapping_fn
+
+    return insert
+
+
+def move_kwargs_to_acc_out_ty(
+    node_or_normalization_info: Union[NormalizationInfo, torch.fx.Node],
+    new_kwargs: Dict[str, Any],
+):
+    """
+    Given `node_or_normalization_info` which is either NormalizationInfo for a node, or
+    a node to fetch NormalizationInfo for, check if kwargs_to_move_to_acc_out_ty exists
+    in the NormalizationInfo, and if so perform the move of kwargs to acc_out_ty.
+    """
+    if isinstance(node_or_normalization_info, torch.fx.Node):
+        node = node_or_normalization_info
+        normalization_info = _normalization_dict.get((node.op, node.target))
+    else:
+        normalization_info = node_or_normalization_info
+
+    if normalization_info.kwargs_to_move_to_acc_out_ty is None:
+        return
+
+    assert acc_utils.is_acc_op_with_kwarg(
+        normalization_info.new_fn_target, "acc_out_ty"
+    )
+
+    # Build a dict representing the new TensorMetadata to use for acc_out_ty,
+    # and then remove the kwarg from the new_kwargs since it's passed in via
+    # acc_out_ty instead.
+    tmd_dict: Dict[str, Any] = {}
+    for (
+        orig_kwarg_name,
+        tmd_field_name,
+    ) in normalization_info.kwargs_to_move_to_acc_out_ty:
+        tmd_dict[tmd_field_name] = new_kwargs[orig_kwarg_name]
+        del new_kwargs[orig_kwarg_name]
+    # Note: allow_partial_spec here because we are only using the tensor metadata tuple
+    # here to pass specific values into the function. For example, for quantization we
+    # only need to provide dtype/q_scale/q_zero_point, but is_quantized and qscheme are
+    # not passed in.
+    new_kwargs["acc_out_ty"] = acc_utils.build_raw_tensor_meta(**tmd_dict)
+
+
+def get_normalized_kwargs(
+    node: torch.fx.Node, arg_replacement_tuples: ArgReplacementTuplesType
+):
+    new_kwargs = {}
+    final_arg_is_varg = False
+    for i, replacement_tuple in enumerate(arg_replacement_tuples):
+        orig_kwargs_names, new_kwarg_name, is_optional = replacement_tuple
+
+        # Check if this is a varg and if so break/process the rest outside the loop.
+        if len(orig_kwargs_names) == 1 and orig_kwargs_names[0] == "*":
+            assert i == len(arg_replacement_tuples) - 1
+            final_arg_is_varg = True
+            break
+
+        # If nothing is found in node.kwargs it means the kwarg is in node.arg
+        # or it's optional. In this case, we set orig_kwargs_name to None.
+        assert isinstance(orig_kwargs_names, tuple)
+        orig_kwargs_name = next(
+            (key for key in orig_kwargs_names if key in node.kwargs),
+            None,
+        )
+
+        # If can't find in node.kwargs then it should be in the i index
+        # of node.args.
+        if orig_kwargs_name is None:
+            if i < len(node.args):
+                new_kwargs[new_kwarg_name] = node.args[i]
+            else:
+                # Verify the arg we're trying to normalize was optional.
+                assert is_optional
+        else:
+            new_kwargs[new_kwarg_name] = node.kwargs[orig_kwargs_name]
+
+    # If using var args then process the rest of the args now.
+    if final_arg_is_varg:
+        var_arg_idx = len(arg_replacement_tuples) - 1
+        new_kwarg_name = arg_replacement_tuples[var_arg_idx][1]
+        rest_of_args = []
+        for i in range(var_arg_idx, len(node.args)):
+            rest_of_args.append(node.args[i])
+        new_kwargs[new_kwarg_name] = rest_of_args
+
+    return new_kwargs
+
+
+def normalize(mod: torch.fx.GraphModule, expect_nodes_have_shapes: bool = False):
+    assert len(_normalization_dict) > 0
+    graph = mod.graph
+
+    # For "call_module" node we return _base_class_origin if it's a
+    # RewrittenModule, otherwise, return its type. For other nodes,
+    # we return node.target.
+    def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node):
+        if node.op != "call_module":
+            return node.target
+
+        # Find the module that node.target points to
+        m = dict(mod.named_modules())[node.target]
+        return getattr(m, "_base_class_origin", type(m))
+
+    def normalize_to_acc_op(
+        node: torch.fx.Node,
+        normalization_info: NormalizationInfo,
+        normalized_args: Tuple[Any, ...],
+        normalized_kwargs: Dict[str, Any],
+    ):
+        # If there's a custom mapping function then use it.
+        if normalization_info.custom_mapping_fn is not None:
+            # For custom mapping, the normalized_kwargs are used for the original op,
+            # i.e. *before* custom acc_ops normalization. Do that now.
+            node.args = normalized_args
+            node.kwargs = normalized_kwargs
+            new_node = normalization_info.custom_mapping_fn(node, mod)
+            # If a new node is returned then use it to replace the old node. Otherwise
+            # the custom mapping function did its own replacement, so return early.
+            if new_node is None:
+                return
+        else:
+            # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in
+            # normalized_kwargs, and remove the kwarg from normalized_kwargs.
+            move_kwargs_to_acc_out_ty(normalization_info, normalized_kwargs)
+
+            # All acc ops are functions. Create a call to the correct acc_ops target using
+            # the normalized kwargs provided.
+            with graph.inserting_before(node):
+                new_node = graph.create_node(
+                    "call_function",
+                    normalization_info.new_fn_target,
+                    args=normalized_args,
+                    kwargs=normalized_kwargs,
+                    name=node.name,
+                )
+                new_node.meta = node.meta.copy()
+
+        # Finally replace the original node with the normalized node.
+        node.replace_all_uses_with(new_node)
+        graph.erase_node(node)
+
+    for node in graph.nodes:
+        if node.op in {"placeholder", "get_attr", "output"}:
+            continue
+
+        normalization_info = _normalization_dict.get((node.op, get_target(mod, node)))
+
+        # Also check if the torch_packaged version of the op was specified to be normalized.
+        if normalization_info is None and node.op == "call_function":
+            # Strip off the mangle_index suffix here before checking the map.
+            target = re.sub(
+                r"\A<torch_package_\d+>",
+                "<torch_package_>",
+                _get_qualified_name(node.target),
+            )
+            torch_package_op_and_target = (node.op, target)
+            normalization_info = _normalization_dict.get(torch_package_op_and_target)
+
+        if normalization_info is None:
+            continue
+
+        # Get the normalized kwargs to be used by normalize_to_acc_op below. If
+        # normalization_info.arg_replacement_tuples is empty then assume the function
+        # signature must be left as is.
+        if len(normalization_info.arg_replacement_tuples) == 0:
+            normalized_args = node.args
+            normalized_kwargs = node.kwargs
+        else:
+            normalized_args = ()
+            normalized_kwargs = get_normalized_kwargs(
+                node, normalization_info.arg_replacement_tuples
+            )
+
+        if (
+            normalization_info.needs_shapes_for_normalization
+            and not expect_nodes_have_shapes
+        ):
+            # All nodes needing shapes for normalization should be custom mapped.
+            assert normalization_info.custom_mapping_fn is not None
+            # For custom mapping, the normalized_kwargs are used for the original op,
+            # i.e. *before* custom acc_ops normalization. Do that now so that whoever
+            # consumes the graph next (e.g. shape inference) can use kwargs safely.
+            node.args = normalized_args
+            node.kwargs = normalized_kwargs
+            continue
+
+        try:
+            normalize_to_acc_op(
+                node, normalization_info, normalized_args, normalized_kwargs
+            )
+        except Exception:
+            print(f"Error during normalization for node: {node.format_node()}")
+            raise
+
+    # If there are any dead nodes left after normalization, eliminate them now.
+    mod.graph.eliminate_dead_code()
diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py
new file mode 100644 (file)
index 0000000..bc4dfb3
--- /dev/null
@@ -0,0 +1,1176 @@
+# encoding: utf-8
+# type: ignore[]
+import operator
+
+import torch  # isort:skip
+import torch.fx.experimental.fx_acc.acc_utils as acc_utils
+import torch.nn as nn
+from torch.fx.experimental.fx_acc.acc_normalizer import (
+    register_acc_op,
+    register_acc_op_mapping,
+    register_custom_acc_mapper_fn,
+)
+from torch.fx.passes.shape_prop import extract_tensor_metadata
+
+this_arg_is_optional = True
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.linear))
+@register_acc_op
+def linear(*, input, weight, bias):
+    return nn.functional.linear(**locals())
+
+
+@register_acc_op
+def quantized_linear(*, input, weight, bias, acc_out_ty=None):
+    assert acc_out_ty is not None
+    return nn.quantized.functional.linear(
+        input,
+        weight,
+        bias,
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+    )
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_method", "flatten"),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("start_dim", "start_dim", this_arg_is_optional),
+        ("end_dim", "end_dim", this_arg_is_optional),
+    ],
+)
+@register_acc_op_mapping(op_and_target=("call_function", torch.flatten))
+@register_acc_op
+def flatten(*, input, start_dim=0, end_dim=-1):
+    return torch.flatten(**locals())
+
+
+@register_acc_op_mapping(
+    op_and_target=(
+        "call_method",
+        "squeeze",
+    ),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("dim", "dim", this_arg_is_optional),
+    ],
+)
+@register_acc_op
+def squeeze(*, input, dim=None):
+    if dim is None:
+        return input.squeeze()
+    return input.squeeze(dim=dim)
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.max_pool2d))
+@register_acc_op
+def max_pool2d(
+    *, input, kernel_size, stride, padding, dilation, ceil_mode, return_indices
+):
+    return nn.functional.max_pool2d(**locals())
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_function", nn.functional.adaptive_avg_pool2d)
+)
+@register_acc_op
+def adaptive_avg_pool2d(*, input, output_size):
+    return nn.functional.adaptive_avg_pool2d(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.avg_pool2d))
+@register_acc_op
+def avg_pool2d(
+    *,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad,
+    divisor_override,
+):
+    return nn.functional.avg_pool2d(**locals())
+
+
+@register_acc_op
+def size(*, input):
+    return input.size()
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", getattr),
+    arg_replacement_tuples=[],
+)
+def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+    """
+    Custom function for mapping a call_function getattr to other ops. Currently only
+    supports loading a getattr called on a torch.Tensor with attr name "shape", which is
+    supported by mapping it to acc_ops.size().
+    """
+    # Have to use args here since getattr forces positional args.
+    input_obj = node.args[0]
+    attr_name = node.args[1]
+    assert (
+        input_obj.meta["type"] == torch.Tensor
+    ), f"Expected torch.Tensor type for {input_obj.meta['type']}"
+    assert (
+        attr_name == "shape"
+    ), f"Only supporting shape getattr for now, not {attr_name}"
+    with node.graph.inserting_before(node):
+        size_node = node.graph.call_function(size, kwargs={"input": input_obj})
+        size_node.meta = node.meta.copy()
+        return size_node
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_method", "size"),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("dim", "dim", this_arg_is_optional),
+    ],
+)
+def tensor_size_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+    """
+    Mapping from Tensor.size() to acc_ops.size. We map size() to acc_ops.size directly
+    and map size(dim) to acc_ops.size + acc_ops.getitem.
+    """
+
+    with node.graph.inserting_before(node):
+        size_node = node.graph.call_function(
+            size, kwargs={"input": node.kwargs["input"]}
+        )
+
+        if "dim" not in node.kwargs:
+            size_node.meta = node.meta.copy()
+            return size_node
+
+        size_node.meta["type"] = torch.Size
+        getitem_node = node.graph.call_function(
+            getitem, kwargs={"input": size_node, "idx": node.kwargs["dim"]}
+        )
+        getitem_node.meta = node.meta.copy()
+        return getitem_node
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.add))
+@register_acc_op_mapping(op_and_target=("call_method", "add"))
+@register_acc_op
+def add(*, input, other):
+    return input + other
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.unsqueeze))
+@register_acc_op
+def unsqueeze(*, input, dim):
+    return torch.unsqueeze(**locals())
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.stack),
+    arg_replacement_tuples=[
+        ("tensors", "tensors"),
+        ("dim", "dim"),
+    ],
+)
+def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+    """
+    Map torch.stack to unsqueeze + cat.
+    """
+    with node.graph.inserting_before(node):
+        inputs = node.kwargs["tensors"]
+        unsqueeze_nodes = []
+        for i, t in enumerate(inputs):
+            new_node = node.graph.create_node(
+                "call_function",
+                unsqueeze,
+                kwargs={"input": t, "dim": node.kwargs["dim"]},
+                name=f"{node.name}_unsqueeze_{i}",
+            )
+            new_node.meta["type"] = torch.Tensor
+            unsqueeze_nodes.append(new_node)
+        cat_node = node.graph.create_node(
+            "call_function",
+            cat,
+            kwargs={"tensors": unsqueeze_nodes, "dim": node.kwargs["dim"]},
+        )
+        cat_node.meta = node.meta.copy()
+        return cat_node
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.clamp))
+@register_acc_op_mapping(op_and_target=("call_method", "clamp"))
+@register_acc_op
+def clamp(*, input, min, max):
+    return torch.clamp(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.cat))
+@register_acc_op
+def cat(*, tensors, dim):
+    return torch.cat(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.transpose))
+@register_acc_op_mapping(op_and_target=("call_method", "transpose"))
+@register_acc_op
+def transpose(*, input, dim0, dim1):
+    if input.dim() < 2:
+        return input
+    return torch.transpose(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.softmax))
+@register_acc_op
+def softmax(*, input, dim, dtype):
+    """
+    _stacklevel are ignored here.
+    """
+    return torch.nn.functional.softmax(**locals())
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.addmm),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("mat1", "mat1"),
+        ("mat2", "mat2"),
+        ("beta", "beta"),
+        ("alpha", "alpha"),
+    ],
+)
+def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+    """
+    Mapping from torch.addmm to acc_ops.mm -> acc_ops.add, if alpha or beta is not 1
+    then we also insert acc_ops.mul to the right place.
+    """
+    with node.graph.inserting_before(node):
+        mm_kwargs = {"input": node.kwargs["mat1"], "other": node.kwargs["mat2"]}
+        mm_node = node.graph.create_node(
+            "call_function", matmul, kwargs=mm_kwargs, name=f"{node.name}_mm"
+        )
+        mm_node.meta = node.meta.copy()
+
+        if node.kwargs["alpha"] != 1:
+            mul_kwargs = {"input": mm_node, "other": node.kwargs["alpha"]}
+            mm_node = node.graph.create_node(
+                "call_function", mul, kwargs=mul_kwargs, name=f"{mm_node.name}_mul"
+            )
+        mm_node.meta = node.meta.copy()
+
+        input_node = node.kwargs["input"]
+        if node.kwargs["beta"] != 1:
+            mul_kwargs = {"input": input_node, "other": node.kwargs["beta"]}
+            new_input_node = node.graph.create_node(
+                "call_function", mul, kwargs=mul_kwargs, name=f"{node.name}_input_mul"
+            )
+            new_input_node.meta = input_node.meta.copy()
+            input_node = new_input_node
+
+        add_kwargs = {"input": mm_node, "other": input_node}
+        add_node = node.graph.create_node(
+            "call_function", add, kwargs=add_kwargs, name=f"{node.name}_add"
+        )
+        add_node.meta = node.meta.copy()
+        return add_node
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.t),
+    arg_replacement_tuples=[
+        ("input", "input"),
+    ],
+)
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_method", "t"),
+    arg_replacement_tuples=[
+        ("input", "input"),
+    ],
+)
+def t_mapper(node: torch.fx.Node, _: nn.Module):
+    with node.graph.inserting_before(node):
+        new_node = node.graph.create_node(
+            "call_function",
+            transpose,
+            kwargs={"input": node.kwargs["input"], "dim0": 0, "dim1": 1},
+            name=node.name,
+        )
+        new_node.meta = node.meta.copy()
+        return new_node
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_method", "permute"),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("*", "permutation"),
+    ],
+)
+@register_acc_op
+def permute(*, input, permutation):
+    return input.permute(*permutation)
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.square),
+    arg_replacement_tuples=[
+        ("input", "input"),
+    ],
+)
+def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+    input_node = node.kwargs["input"]
+    with node.graph.inserting_before(node):
+        new_node = node.graph.call_function(
+            mul, kwargs={"input": input_node, "other": input_node}
+        )
+        new_node.meta = node.meta.copy()
+        return new_node
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_function", torch.bmm),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("mat2", "other"),
+    ],
+)
+@register_acc_op_mapping(op_and_target=("call_function", torch.matmul))
+@register_acc_op
+def matmul(*, input, other):
+    return torch.matmul(**locals())
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.min),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("other", "other", this_arg_is_optional),
+        ("dim", "dim", this_arg_is_optional),
+        ("keepdim", "keepdim", this_arg_is_optional),
+    ],
+)
+def custom_torch_min_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
+    """
+    Add custom mapping for torch.min because torch.min has three input types, where each yields a different output type:
+        1. torch.min(input); Output: tensor wih the minimum number
+        2. torch.min(input, other); Output: tensor with coordinate-wise min value
+        3[Not Supported] torch.min(input, dim, keepdim); Output:(min_values,and min_indices)
+    """
+    with node.graph.inserting_before(node):
+        # If dim is in kwargs, assert "Not Supported"
+        assert "dim" not in node.kwargs, "Currently not support dim in torch.min"
+
+        if "other" in node.kwargs and node.kwargs["other"] is not None:
+            # If kwargs[other] is a valid tensor, call min_two_tensors_input,
+            op_func = min_two_tensors_input
+        else:
+            # Otherwise, call min_single_tensor_input
+            op_func = min_single_tensor_input
+
+        new_node = node.graph.create_node(
+            "call_function",
+            op_func,
+            kwargs=node.kwargs,
+            name=node.name,
+        )
+        new_node.meta = node.meta
+        return new_node
+
+
+@register_acc_op
+def min_single_tensor_input(*, input):
+    return torch.min(input)
+
+
+@register_acc_op
+def min_two_tensors_input(*, input, other):
+    return torch.min(input, other)
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_function", torch.ops.quantized.add),
+    arg_replacement_tuples=[
+        ("qa", "input"),
+        ("qb", "other"),
+        ("scale", "scale"),
+        ("zero_point", "zero_point"),
+    ],
+    kwargs_to_move_to_acc_out_ty=[
+        ("scale", "q_scale"),
+        ("zero_point", "q_zero_point"),
+    ],
+)
+@register_acc_op
+def quantized_add(*, input, other, acc_out_ty=None):
+    assert acc_out_ty is not None
+    return torch.ops.quantized.add(
+        input,
+        other,
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+    )
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_function", torch.ops.quantized.mul),
+    arg_replacement_tuples=[
+        ("qa", "input"),
+        ("qb", "other"),
+        ("scale", "scale"),
+        ("zero_point", "zero_point"),
+    ],
+    kwargs_to_move_to_acc_out_ty=[
+        ("scale", "q_scale"),
+        ("zero_point", "q_zero_point"),
+    ],
+)
+@register_acc_op
+def quantized_mul(*, input, other, acc_out_ty=None):
+    assert acc_out_ty is not None
+    return torch.ops.quantized.mul(
+        input,
+        other,
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+    )
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_function", torch.quantize_per_tensor),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("scale", "scale"),
+        ("zero_point", "zero_point"),
+        ("dtype", "dtype"),
+    ],
+    kwargs_to_move_to_acc_out_ty=[
+        ("dtype", "dtype"),
+        ("scale", "q_scale"),
+        ("zero_point", "q_zero_point"),
+    ],
+)
+@register_acc_op
+def quantize_per_tensor(*, input, acc_out_ty=None):
+    assert acc_out_ty is not None
+    return torch.quantize_per_tensor(
+        input,
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"),
+    )
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.dequantize))
+@register_acc_op_mapping(op_and_target=("call_method", "dequantize"))
+@register_acc_op
+def dequantize(*, input):
+    return torch.dequantize(input)
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.sub))
+@register_acc_op
+def sub(*, input, other):
+    return input - other
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.mul))
+@register_acc_op
+def mul(*, input, other):
+    return input * other
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.truediv))
+@register_acc_op
+def div(*, input, other):
+    return input / other
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.relu))
+@register_acc_op_mapping(
+    op_and_target=("call_function", torch.relu),
+    arg_replacement_tuples=[("input", "input")],
+)
+@register_acc_op_mapping(
+    op_and_target=("call_method", "relu"),
+    arg_replacement_tuples=[("input", "input")],
+)
+@register_acc_op
+def relu(*, input, inplace=False):
+    return nn.functional.relu(**locals())
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_method", "sum"),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("dim", "dim", this_arg_is_optional),
+        ("keepdim", "keepdim", this_arg_is_optional),
+        ("dtype", "dtype", this_arg_is_optional),
+    ],
+)
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.sum),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("dim", "dim", this_arg_is_optional),
+        ("keepdim", "keepdim", this_arg_is_optional),
+        ("dtype", "dtype", this_arg_is_optional),
+    ],
+)
+def add_sum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node:
+    with node.graph.inserting_before(node):
+        sum_kwargs = dict(node.kwargs)
+        if "dim" in sum_kwargs and isinstance(sum_kwargs["dim"], int):
+            sum_kwargs["dim"] = (sum_kwargs["dim"],)
+        sum_node = node.graph.call_function(sum, kwargs=sum_kwargs)
+        sum_node.meta = node.meta.copy()
+        return sum_node
+
+
+@register_acc_op
+def sum(*, input, dim=None, keepdim=False, dtype=None):
+    if dim:
+        return torch.sum(**locals())
+    else:
+        return input.sum(dtype=dtype)
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.sigmoid))
+@register_acc_op_mapping(op_and_target=("call_method", "sigmoid"))
+@register_acc_op
+def sigmoid(*, input):
+    return torch.sigmoid(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.sinh))
+@register_acc_op
+def sinh(*, input):
+    return torch.sinh(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.cosh))
+@register_acc_op
+def cosh(*, input):
+    return torch.cosh(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.tanh))
+@register_acc_op
+def tanh(*, input):
+    return torch.tanh(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.asin))
+@register_acc_op
+def asin(*, input):
+    return torch.asin(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.acos))
+@register_acc_op
+def acos(*, input):
+    return torch.acos(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.atan))
+@register_acc_op
+def atan(*, input):
+    return torch.atan(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.exp))
+@register_acc_op
+def exp(*, input):
+    return torch.exp(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.log))
+@register_acc_op
+def log(*, input):
+    return torch.log(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.sqrt))
+@register_acc_op
+def sqrt(*, input):
+    return torch.sqrt(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.reciprocal))
+@register_acc_op
+def reciprocal(*, input):
+    return torch.reciprocal(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.abs))
+@register_acc_op
+def abs(*, input):
+    return torch.abs(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.neg))
+@register_acc_op
+def neg(*, input):
+    return torch.neg(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.floor))
+@register_acc_op
+def floor(*, input):
+    return torch.floor(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.ceil))
+@register_acc_op
+def ceil(*, input):
+    return torch.ceil(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.conv2d))
+@register_acc_op
+def conv2d(*, input, weight, bias, stride, padding, dilation, groups):
+    return nn.functional.conv2d(**locals())
+
+
+@register_acc_op
+def quantized_conv2d(
+    *,
+    input,
+    weight,
+    bias,
+    stride,
+    padding,
+    dilation,
+    groups,
+    padding_mode,
+    acc_out_ty=None,
+):
+    assert acc_out_ty is not None
+    return torch.nn.quantized.functional.conv2d(
+        input,
+        weight,
+        bias,
+        stride,
+        padding,
+        dilation,
+        groups,
+        padding_mode,
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+    )
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.batch_norm))
+@register_acc_op
+def batch_norm(
+    *, input, running_mean, running_var, weight, bias, training, momentum, eps
+):
+    return nn.functional.batch_norm(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.layer_norm))
+@register_acc_op
+def layer_norm(*, input, normalized_shape, weight, bias, eps):
+    return nn.functional.layer_norm(**locals())
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_method", "split"),
+    arg_replacement_tuples=[
+        ("tensor", "input"),
+        ("split_size_or_sections", "split_size_or_sections"),
+        ("dim", "dim"),
+    ],
+)
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.split),
+    arg_replacement_tuples=[
+        ("tensor", "input"),
+        ("split_size_or_sections", "split_size_or_sections"),
+        ("dim", "dim"),
+    ],
+)
+def torch_split_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
+    """
+    If split_size_or_sections is sections, map the node to slice_tensors
+    + tuple_construct. Otherwise, if split_size_or_sections is split_size,
+    map the node to acc_ops.split.
+    """
+    split_size_or_sections = node.kwargs["split_size_or_sections"]
+    with node.graph.inserting_before(node):
+        if isinstance(split_size_or_sections, int):
+            new_kwargs = {
+                "input": node.kwargs["input"],
+                "split_size": split_size_or_sections,
+                "dim": node.kwargs["dim"],
+            }
+            new_node = node.graph.call_function(split, kwargs=new_kwargs)
+            new_node.meta = node.meta.copy()
+            return new_node
+
+        start = 0
+        slice_nodes = []
+        for i in split_size_or_sections:
+            new_kwargs = {
+                "input": node.kwargs["input"],
+                "dims": (node.kwargs["dim"],),
+                "starts": (start,),
+                "stops": (start + i,),
+                "steps": (1,),
+            }
+            new_node = node.graph.call_function(slice_tensor, kwargs=new_kwargs)
+            new_node.meta["type"] = torch.Tensor
+            slice_nodes.append(new_node)
+            start += i
+
+        new_node = node.graph.call_function(
+            tuple_construct, kwargs={"tensors": tuple(slice_nodes)}
+        )
+        new_node.meta = node.meta.copy()
+        return new_node
+
+
+@register_acc_op
+def split(*, input, split_size, dim):
+    return torch.split(input, split_size, dim)
+
+
+@register_acc_op
+def tuple_construct(*, tensors):
+    return tuple(tensors)
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_function", torch.ops.quantized.batch_norm2d),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("weight", "weight"),
+        ("bias", "bias"),
+        ("running_mean", "running_mean"),
+        ("running_var", "running_var"),
+        ("eps", "eps"),
+        ("scale", "scale"),
+        ("zero_point", "zero_point"),
+    ],
+    kwargs_to_move_to_acc_out_ty=[
+        ("scale", "q_scale"),
+        ("zero_point", "q_zero_point"),
+    ],
+)
+@register_acc_op
+def quantized_batch_norm2d(
+    *, input, running_mean, running_var, weight, bias, eps, acc_out_ty
+):
+    return torch.ops.quantized.batch_norm2d(
+        input,
+        weight,
+        bias,
+        running_mean,
+        running_var,
+        eps,
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+        acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+    )
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.embedding_bag))
+@register_acc_op
+def embedding_bag(
+    *,
+    input,
+    weight,
+    offsets,
+    max_norm,
+    norm_type,
+    scale_grad_by_freq,
+    mode,
+    sparse,
+    per_sample_weights,
+    include_last_offset,
+    padding_idx,
+):
+    return nn.functional.embedding_bag(**locals())
+
+
+@register_acc_op_mapping(
+    op_and_target=(
+        "call_function",
+        torch.ops.quantized.embedding_bag_byte_rowwise_offsets,
+    )
+)
+@register_acc_op
+def embedding_bag_byte_rowwise_offsets(
+    *,
+    weight,
+    input,
+    offsets,
+    scale_grad_by_freq,
+    mode,
+    pruned_weights,
+    per_sample_weights,
+    compressed_indices_mapping,
+    include_last_offset,
+):
+    return torch.ops.quantized.embedding_bag_byte_rowwise_offsets(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.sin))
+@register_acc_op
+def sin(*, input):
+    return torch.sin(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.cos))
+@register_acc_op
+def cos(*, input):
+    return torch.cos(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.tan))
+@register_acc_op
+def tan(*, input):
+    return torch.tan(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.topk))
+@register_acc_op
+def topk(*, input, k, dim, largest, sorted):
+    return torch.topk(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.getitem))
+@register_acc_op
+def getitem(*, input, idx):
+    return input[idx]
+
+
+@register_acc_op
+def slice_tensor(*, input, dims, starts, stops, steps):
+    slices = [None for _ in range(input.dim())]
+
+    # For all provided dims, extract out a slice for starts/stops/steps.
+    for idx, dim in enumerate(dims):
+        slices[dim] = slice(starts[idx], stops[idx], steps[idx])
+
+    # For all unspecified dims, default to the full slice.
+    for idx, s in enumerate(slices):
+        if s is None:
+            slices[idx] = slice(None, None, None)
+
+    return input[slices]
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.narrow),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("dim", "dim"),
+        ("start", "start"),
+        ("length", "length"),
+    ],
+)
+def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
+    kwargs = {
+        "input": node.kwargs["input"],
+        "dims": (node.kwargs["dim"],),
+        "starts": (node.kwargs["start"],),
+        "stops": (node.kwargs["start"] + node.kwargs["length"],),
+        "steps": (1,),
+    }
+    with node.graph.inserting_before(node):
+        new_node = node.graph.call_function(slice_tensor, kwargs=kwargs)
+    new_node.meta = node.meta.copy()
+    return new_node
+
+
+@register_acc_op_mapping(
+    op_and_target=("call_function", torch.reshape),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("shape", "shape"),
+    ],
+    kwargs_to_move_to_acc_out_ty=[("shape", "shape")],
+)
+@register_acc_op_mapping(
+    op_and_target=("call_method", "view"),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("*", "shape"),
+    ],
+    kwargs_to_move_to_acc_out_ty=[("shape", "shape")],
+)
+@register_acc_op
+def reshape(*, input, acc_out_ty=None):
+    assert acc_out_ty is not None
+    return torch.reshape(
+        input, tuple(acc_utils.get_field_from_acc_out_ty(acc_out_ty, "shape"))
+    )
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_method", "reshape"),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("*", "shape"),
+    ],
+)
+def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+    """
+    For Tensor.reshape node, args could be (input, 1, 2, 3) or (input, (1, 2, 3)).
+    Here we do some special handling with the `shape` arg in order to map it to
+    acc_ops.reshape. It also handles the case when `shape` is a list instead of
+    tuple.
+    """
+    input_node = node.kwargs["input"]
+    shape = node.kwargs["shape"]
+
+    if isinstance(shape[0], (tuple, list)):
+        shape = shape[0]
+
+    with node.graph.inserting_before(node):
+        new_node = node.graph.call_function(
+            reshape,
+            kwargs={
+                "input": input_node,
+                "acc_out_ty": acc_utils.build_raw_tensor_meta(shape=shape),
+            },
+        )
+        new_node.meta = node.meta.copy()
+        return new_node
+
+
+@register_acc_op
+def to_dtype(input, acc_out_ty=None):
+    assert acc_out_ty is not None, "valid acc_out_ty needed"
+    return input.to(dtype=acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"))
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_method", "to"),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("dtype", "dtype"),
+    ],
+)
+def custom_tensor_to_mapper(node: torch.fx.Node, _: nn.Module):
+    dest_dtype = node.kwargs["dtype"]
+    mem_format = node.kwargs.get("memory_format")
+    device = node.kwargs.get("device")
+    assert dest_dtype is not None
+    assert mem_format is None or mem_format == torch.preserve_format
+    assert device is None
+
+    new_kwargs = {
+        "input": node.kwargs["input"],
+        "acc_out_ty": acc_utils.build_raw_tensor_meta(dtype=dest_dtype),
+    }
+
+    with node.graph.inserting_before(node):
+        new_node = node.graph.create_node(
+            "call_function", to_dtype, kwargs=new_kwargs, name=node.name
+        )
+        new_node.meta = node.meta
+        return new_node
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.add),
+    # Note that we may have aliases for inputs here due to issues with deterministically
+    # knowing the correct target that will be resolved by pytorch.
+    arg_replacement_tuples=[
+        (("input", "a"), "input"),
+        (("other", "b"), "other"),
+        ("alpha", "alpha", this_arg_is_optional),
+    ],
+)
+def custom_torch_add_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
+    """
+    Add custom mapping for torch.add because it has an `alpha` parameter which scales
+    the `other` input, and we want to make that mul a separate node.
+    """
+    with node.graph.inserting_before(node):
+        # If alpha is in kwargs check if we need to add a mul, and use correct kwargs.
+        if "alpha" in node.kwargs:
+            # Add mul node only if it has a numerical impact, i.e. alpha != 1.0.
+            if node.kwargs["alpha"] != 1.0:
+                other_node = node.graph.create_node(
+                    "call_function",
+                    mul,
+                    kwargs={
+                        "input": node.kwargs["other"],
+                        "other": node.kwargs["alpha"],
+                    },
+                    name=node.name + "_mul_alpha",
+                )
+                other_node.meta = node.meta
+            else:
+                other_node = node.kwargs["other"]
+            add_kwargs = {"input": node.kwargs["input"], "other": other_node}
+        else:
+            add_kwargs = node.kwargs
+
+        new_node = node.graph.create_node(
+            "call_function", add, kwargs=add_kwargs, name=node.name
+        )
+        new_node.meta = node.meta
+        return new_node
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_module", nn.quantized.Linear),
+    arg_replacement_tuples=[
+        ("input", "input"),
+    ],
+)
+def packed_quantized_linear_mapper(
+    node: torch.fx.Node, mod: nn.Module
+) -> torch.fx.Node:
+    """
+    Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias
+    in this mapper and pass them directly to linear node.
+    """
+    linear_module = dict(mod.named_modules())[node.target]
+    prefix = node.target.replace(".", "_")
+    weight_name = f"{prefix}_weight"
+    bias_name = f"{prefix}_bias"
+
+    # Store weight and bias in the main module
+    mod.register_buffer(weight_name, linear_module.weight())
+    if linear_module.bias() is not None:
+        mod.register_buffer(bias_name, linear_module.bias())
+
+    with node.graph.inserting_before(node):
+        # Insert get_attr nodes for weight and bias
+        get_weight = node.graph.get_attr(weight_name)
+        get_weight.meta["tensor_meta"] = extract_tensor_metadata(linear_module.weight())
+
+        get_bias = None
+        if linear_module.bias() is not None:
+            get_bias = node.graph.get_attr(bias_name)
+            get_bias.meta["tensor_meta"] = extract_tensor_metadata(linear_module.bias())
+
+        # Create kwargs for acc_op.quantized_linear
+        kwargs = {
+            "input": node.kwargs["input"],
+            "weight": get_weight,
+            "bias": get_bias,
+            "acc_out_ty": acc_utils.build_raw_tensor_meta(
+                q_scale=linear_module.scale, q_zero_point=linear_module.zero_point
+            ),
+        }
+
+        new_node = node.graph.call_function(quantized_linear, kwargs=kwargs)
+        new_node.meta = node.meta
+        return new_node
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_module", nn.quantized.Conv2d),
+    arg_replacement_tuples=[
+        ("input", "input"),
+    ],
+)
+def packed_quantized_conv2d_mapper(
+    node: torch.fx.Node, mod: nn.Module
+) -> torch.fx.Node:
+    """
+    Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters
+    in this mapper and pass them directly to conv2d node.
+    """
+    conv_module = dict(mod.named_modules())[node.target]
+    prefix = node.target.replace(".", "_")
+    weight_name = f"{prefix}_weight"
+    bias_name = f"{prefix}_bias"
+
+    # Store weight and bias in the main module
+    mod.register_buffer(weight_name, conv_module.weight())
+    if conv_module.bias() is not None:
+        mod.register_buffer(bias_name, conv_module.bias())
+
+    with node.graph.inserting_before(node):
+        # Insert get_attr nodes for weight and bias
+        get_weight = node.graph.get_attr(weight_name)
+        get_weight.meta["tensor_meta"] = extract_tensor_metadata(conv_module.weight())
+
+        get_bias = None
+        if conv_module.bias() is not None:
+            get_bias = node.graph.get_attr(bias_name)
+            get_bias.meta["tensor_meta"] = extract_tensor_metadata(conv_module.bias())
+
+        # Create kwargs for acc_op.conv
+        kwargs = {
+            "input": node.kwargs["input"],
+            "weight": get_weight,
+            "bias": get_bias,
+            "stride": conv_module.stride,
+            "padding": conv_module.padding,
+            "dilation": conv_module.dilation,
+            "groups": conv_module.groups,
+            "padding_mode": conv_module.padding_mode,
+            "acc_out_ty": acc_utils.build_raw_tensor_meta(
+                q_scale=conv_module.scale, q_zero_point=conv_module.zero_point
+            ),
+        }
+
+        new_node = node.graph.call_function(quantized_conv2d, kwargs=kwargs)
+        new_node.meta = node.meta
+        return new_node
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.ops.quantized.add_relu),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("other", "other"),
+        ("scale", "scale"),
+        ("zero_point", "zero_point"),
+    ],
+)
+def add_relu_unfuse_mapper(
+    node: torch.fx.Node, mod: torch.fx.GraphModule
+) -> torch.fx.Node:
+    with node.graph.inserting_before(node):
+        add_kwargs = {
+            "input": node.kwargs["input"],
+            "other": node.kwargs["other"],
+            "acc_out_ty": acc_utils.build_raw_tensor_meta(
+                q_scale=node.kwargs["scale"],
+                q_zero_point=node.kwargs["zero_point"],
+            ),
+        }
+        add_node = node.graph.call_function(quantized_add, kwargs=add_kwargs)
+        add_node.meta = node.meta.copy()
+
+        relu_node = node.graph.call_function(
+            relu, kwargs={"input": add_node, "inplace": False}
+        )
+        relu_node.meta = node.meta
+        return relu_node
+
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_module", nn.intrinsic.quantized.ConvReLU2d),
+    arg_replacement_tuples=[
+        ("input", "input"),
+    ],
+)
+def packed_quantized_convrelu2d_mapper(
+    node: torch.fx.Node, mod: nn.Module
+) -> torch.fx.Node:
+    """
+    Mapping from quantized ConvReLU2d module to acc_op.relu. We use packed_quantized_conv2d_mapper to unpack all the parameters
+    in this mapper and pass the returned conv2d node directly to relu node.
+    """
+
+    with node.graph.inserting_before(node):
+        # conv2d op
+        conv2d_node = packed_quantized_conv2d_mapper(node, mod)
+
+        # relu op
+        relu_node = node.graph.call_function(
+            relu, kwargs={"input": conv2d_node, "inplace": False}
+        )
+        relu_node.meta = node.meta
+        return relu_node
diff --git a/torch/fx/experimental/fx_acc/acc_tracer.py b/torch/fx/experimental/fx_acc/acc_tracer.py
new file mode 100644 (file)
index 0000000..1f6a365
--- /dev/null
@@ -0,0 +1,433 @@
+# type: ignore[]
+import ast
+import builtins
+import copy
+import inspect
+import textwrap
+import warnings
+from types import FunctionType
+from typing import Dict, Optional, Any, Type, Tuple, Set, List
+
+import torch.fx.experimental.fx_acc.acc_normalizer as acc_normalizer
+import torch.fx.experimental.fx_acc.acc_ops  # noqa: F401
+import torch
+import torch.nn as nn
+from torch._sources import normalize_source_lines
+from torch.fx import Graph, Tracer
+from torch.fx.experimental.normalize import NormalizeArgs
+from torch.fx.passes import shape_prop
+
+
+def _get_exception_wrapper_attr_name(exc_type: Type[Exception]) -> str:
+    return f"_conditional_exception_wrapper_{exc_type.__name__}"
+
+
+class Acc_Rewriter(ast.NodeTransformer):
+    """
+    Take a FunctionType object representing a `forward` method, then
+    perform an AST rewrite to swap out nodes that are not symbolically
+    traceable with a callsite to the FX alternative.
+
+    To support swapping out an AST node, define a new `visit` method on
+    that node. For more details, see:
+    https://docs.python.org/3/library/ast.html#ast.NodeTransformer
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.exceptions_rewritten: Set[Type[Exception]] = set()
+
+    def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]]]:
+
+        # Normalize the source lines
+        sourcelines, _ = inspect.getsourcelines(fn)
+        sourcelines = normalize_source_lines(sourcelines)
+        source = "".join(sourcelines)
+        normalized_str = textwrap.dedent(source)
+
+        # Rewrite the original AST
+        source_ast = ast.parse(normalized_str)
+        dest_ast = ast.fix_missing_locations(self.visit(source_ast))
+
+        # Pull out the compiled function from the newly-created Module
+        code = compile(dest_ast, "", "exec")
+        globals_dict = copy.copy(fn.__globals__)
+        keys_before = set(globals_dict.keys())
+        exec(code, globals_dict)
+        new_keys = list(set(globals_dict.keys()) - keys_before)
+        assert len(new_keys) <= 1
+        fn_compiled = globals_dict[fn.__name__]
+
+        # Return the correct FunctionType object and the Exceptions that were
+        # rewritten during visit_If.
+        return fn_compiled, self.exceptions_rewritten
+
+    def visit_Assert(self, node: ast.Assert):
+        """
+        Swap out the Assert node (Python's `assert`) with a callsite to the
+        symbolically-traceable torch._assert function
+        """
+        # Create the Call node
+        n = ast.parse("torch._assert()", mode="eval")
+        assert isinstance(n, ast.Expression)
+        call_node = n.body
+        assert isinstance(call_node, ast.Call)
+        msg = node.msg if node.msg else ast.Constant(value="", kind=None)
+        call_node.args = [node.test, msg]
+
+        # Ensure that the new node conforms to the Python AST grammar
+        expr_wrapper = ast.Expr(value=call_node)
+
+        # Return the new Call node to signify that we want to use it as
+        # a replacement for the original _assert node
+        return ast.copy_location(expr_wrapper, node)
+
+    def visit_If(self, if_node: ast.If):
+        """
+        Swap out the pattern `If(x): Raise(y)` with a ConditionalExceptionWrapper
+        specialized for the specific exception y. The specialized
+        ConditionalExceptionWrapper module will be added in the RewrittenModule.
+        Only works with builtin Exceptions, as we assume the signature of the
+        init for the Exception is a string.
+        """
+        raise_node = if_node.body[0]
+        if not isinstance(raise_node, ast.Raise):
+            return if_node
+
+        # Don't handle orelse for now.
+        # TODO: Move orelse to the body after calling ConditionalExceptionWrapper.
+        if len(if_node.orelse) != 0:
+            return if_node
+
+        def _reuse_loc(node):
+            return ast.copy_location(node, if_node)
+
+        # If the exception has a message then we expect the raise's exc to be a
+        # Call w/ a msg. Else if it's a exc Name then there's no msg to use.
+        node_for_exc = raise_node.exc
+        if isinstance(node_for_exc, ast.Name):
+            # E.g. `raise AssertionError`, i.e. without an exc_msg.
+            name_node_of_exc = node_for_exc
+            exc_msg = _reuse_loc(ast.Constant(None))
+        elif isinstance(node_for_exc, ast.Call):
+            # E.g. `raise AssertionError("error message")`
+            name_node_of_exc = node_for_exc.func
+            if not isinstance(name_node_of_exc, ast.Name):
+                return if_node
+            # Most assertions just take a single string arg, but some may not; skip
+            # handling such assertions for now.
+            if len(node_for_exc.args) != 1:
+                return if_node
+            exc_msg = node_for_exc.args[0]
+        else:
+            return if_node
+
+        # Convert what we expect is the name of the exception into its
+        # associated python class.
+        name_of_exc = name_node_of_exc.id
+        try:
+            exc_type = eval(name_of_exc)
+        except Exception:
+            return if_node
+
+        # Check that we actually have a builtin exception.
+        if (
+            not issubclass(exc_type, Exception)
+            or getattr(getattr(exc_type, "__class__", None), "__module__", None)
+            != "builtins"
+        ):
+            return if_node
+
+        # We need a ConditionalExceptionWrapper specialized for every kind of
+        # exception, so add it to exceptions_rewritten to remember for later to
+        # add a specialized attr with it.
+        self.exceptions_rewritten.add(exc_type)
+
+        # From here we definitely should be able to do the replacement. Create a
+        # Call node to the ConditionalExceptionWrapper module we're replacing
+        # the If with, with args set as the If's condition and the string of the
+        # exception. The call to the self._conditional_exception_wrapper_*Error
+        # module is safe because the RewrittenModule will add it as an attr
+        # based on the returned exceptions_rewritten, and we assume we are
+        # currently modifying the AST of a method from a RewrittenModule.
+        exc_wrapper_node = ast.parse(
+            f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval"
+        )
+        assert isinstance(exc_wrapper_node, ast.Expression)
+        exc_wrapper_call_node = exc_wrapper_node.body
+        assert isinstance(exc_wrapper_call_node, ast.Call)
+        exc_wrapper_call_node.args = [if_node.test, exc_msg]
+
+        # Ensure that the new node conforms to the Python AST grammar
+        expr_wrapper = _reuse_loc(ast.Expr(_reuse_loc(exc_wrapper_call_node)))
+
+        # Return the new node to signify that we want to use it as a replacement
+        # for the original `If x: Raise y` pattern.
+        return expr_wrapper
+
+
+class ConditionalExceptionWrapper(nn.Module):
+    """
+    This wrapper class is used to wrap conditional raising of exceptions during
+    rewriting. For example:
+
+    .. code-block:: python
+
+        if self.name != "x":
+            raise AssertionError(f"Name was not x: {self.name}")
+
+    Is rewritten into
+
+    .. code-block:: python
+
+        self._conditional_exception_wrapper_AssertionError(
+            self.name != "x", f"Name was not x: {self.name}"
+        )
+
+    Note that __init__ takes the Exception class that it is wrapping, while
+    forward takes the condition to check and the message for the exception.
+
+    """
+
+    # Mark as impure so that calls to it will not be removed during DCE.
+    _is_impure = True
+
+    def __init__(self, exc: Type[Exception]):
+        super().__init__()
+        self.exc = exc
+
+    def forward(self, cond: bool, msg: str):
+        if cond:
+            raise self.exc if msg is None else self.exc(msg)
+
+
+# Custom tracer that traces to the functional level and rewrites asserts and
+# exceptions.
+class AccRewritingTracer(Tracer):
+    # Note: Treat ConditionalExceptionWrapper as a leaf so that we don't
+    # trace into it, because it contains control flow and raises an exception.
+    DEFAULT_LEAF_MODULE_LIST = {
+        ConditionalExceptionWrapper,
+        torch.nn.quantized.Linear,
+        torch.nn.quantized.Conv2d,
+        torch.nn.intrinsic.quantized.ConvReLU2d,
+    }
+
+    def is_leaf_module(self, m: nn.Module, mod_qual_name: str) -> bool:
+        return getattr(m, "_base_class_origin", type(m)) in self.leaf_module_list
+
+    def trace(
+        self,
+        root: nn.Module,
+        concrete_args: Optional[Dict[str, Any]] = None,
+        ast_rewriter_allow_list: Optional[Set] = None,
+        leaf_module_list: Optional[Set] = None,
+    ) -> Tuple[Graph, nn.Module]:
+        rewritten = _rewrite(root, ast_rewriter_allow_list)
+        self.leaf_module_list = self.DEFAULT_LEAF_MODULE_LIST
+        if leaf_module_list:
+            self.leaf_module_list.update(leaf_module_list)
+        return super().trace(rewritten, concrete_args), rewritten
+
+
+# List of modules that need rewriting to be supported for tracing.
+DEFAULT_REWRITE_ALLOW_LIST = {
+    nn.BatchNorm1d,
+    nn.BatchNorm2d,
+    nn.BatchNorm3d,
+}
+
+
+def _rewrite(mod_to_rewrite: nn.Module, allow_list: Optional[Set] = None) -> nn.Module:
+    if allow_list is None:
+        allow_list = DEFAULT_REWRITE_ALLOW_LIST
+    else:
+        allow_list.union(DEFAULT_REWRITE_ALLOW_LIST)
+
+    # Rewrite this module's functions as well as all recursive modules'
+    # functions that are attrs of this moodule. Return the new, rewritten module
+    # hierarchy.
+    def rewrite_module(m: nn.Module):
+        base_class = type(m)
+
+        # Keep track of all the ConditionalExceptionWrappers that the
+        # Acc_Rewriter calls into in this module so we can add them in init
+        # below.
+        all_added_wrappers: Set[Type[Exception]] = set()
+
+        # Note: Make this a subclass of our base class.
+        class RewrittenModule(base_class):
+            # Keep track of the base_class so that symbolic tracing can
+            # determine what kind of module this originally was later on.
+            _base_class_origin = base_class
+            # Add suffix to qualname so it's easier to debug the origin of this module.
+            __qualname__ = f"{base_class.__qualname__}__AccRewrittenModule"
+
+            # Write all of the non-dunder or special methods from base_class
+            # into RewrittenModule.
+            for method_name in dir(base_class):
+                method = getattr(base_class, method_name)
+                if builtins.type(method) is not FunctionType:
+                    continue
+
+                # Always skip rewriting dunder methods, as they haven't (yet) been
+                # problematic, and modifying them has caused issues previously.
+                if method_name.startswith("__") and method_name.endswith("__"):
+                    continue
+
+                # Only rewrite those Modules explicitly in the allow_list.
+                if base_class not in allow_list:
+                    vars()[method_name] = method
+                else:
+                    vars()[method_name], added_wrappers = Acc_Rewriter().rewrite(method)
+                    all_added_wrappers.update(added_wrappers)
+
+            def __init__(self, orig):
+                nn.Module.__init__(self)
+                # Iterate over all added exception wrappers and add
+                # ConditionalExceptionWrapper attrs for each.
+                for exc_type in all_added_wrappers:
+                    wrapper_name = _get_exception_wrapper_attr_name(exc_type)
+                    assert not hasattr(self, wrapper_name)
+                    setattr(
+                        self,
+                        wrapper_name,
+                        ConditionalExceptionWrapper(exc_type),
+                    )
+                # Recursively rewrite and copy all module attrs of this module.
+                for k, v in orig.__dict__.items():
+                    if k == "_modules":
+                        for mod_k, mod_v in v.items():
+                            self._modules[mod_k] = rewrite_module(mod_v)
+                    else:
+                        self.__dict__[k] = v
+
+        # Add suffix to name so it's easier to debug the origin of this module.
+        RewrittenModule.__name__ = f"{base_class.__name__}__AccRewrittenModule"
+        return RewrittenModule(m)
+
+    return rewrite_module(mod_to_rewrite)
+
+
+def _remove_assertions(gm: torch.fx.GraphModule) -> bool:
+    """
+    Unconditionally removes all assertions found in GraphModule gm.
+    Returns whether the graph is modified.
+    """
+    changed = False
+    for node in gm.graph.nodes:
+        if node.op == "call_function" and node.target == torch._assert:
+            gm.graph.erase_node(node)
+            changed = True
+    return changed
+
+
+def _remove_exceptions(gm: torch.fx.GraphModule) -> bool:
+    """
+    Unconditionally removes all call_modules to ConditionalExceptionWrappers
+    found in GraphModule gm. Returns whether the graph is modified.
+    """
+    changed = False
+    for node in gm.graph.nodes:
+        if node.op == "call_module" and isinstance(
+            gm.get_submodule(node.target), ConditionalExceptionWrapper
+        ):
+            gm.graph.erase_node(node)
+            changed = True
+    return changed
+
+
+def trace(
+    mod: nn.Module,
+    sample_inputs: List[torch.Tensor],
+    remove_assertions: bool = True,
+    remove_exceptions: bool = True,
+    use_acc_normalization: bool = True,
+    ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None,
+    leaf_module_list: Optional[Set[Type[nn.Module]]] = None,
+) -> torch.fx.GraphModule:
+    """
+    Performs tracing and arg normalization specialized for accelerator lowering.
+
+    It first rewrites the AST of the module's methods (and all attr methods
+    recursively) to transform un-tracable parts of the module to make them
+    traceable.
+
+    It then traces to the functional level so that optimizations and backend
+    accelerator importers have the ability to see and/or change inputs to each
+    op.
+
+    It then removes assertions and exception wrappers found during symbolic
+    tracing if requested based on remove_assertions and remove_exceptions
+
+    Dead code is then eliminated, which will e.g. remove any nodes that were
+    only used by assertions or exceptions if they were removed.
+
+    It then performs normalization on args/kwargs, aligning any arg that can be
+    moved to kwarg to be so, and then making default values explicit.
+
+    Args:
+
+        mod (Module): The module to transform and trace.
+
+        sample_inputs (Tuple[Union[torch.Tensor, List[torch.Tensor]]]):
+                Sample inputs with which to run shape prop.
+
+        remove_assertions (bool): Whether to remove assertion nodes from
+                                    the graph after symbolic tracing.
+
+        remove_exceptions (bool): Whether to remove exception wrapper nodes
+                                    from the graph after symbolic tracing.
+
+        use_acc_normalization (bool): Whether to use acc-specific
+                                        normalization to all acc_ops.
+
+        ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
+                                            modules that need AST rewriting.
+
+        leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
+                                            modules will not be traced into.
+
+    """
+    if mod.training:
+        warnings.warn(
+            "acc_tracer does not support currently support models for training."
+            " Calling eval on model before tracing."
+        )
+        mod.eval()
+
+    # Rewrite the module to make it symbolic traceable, and then trace it.
+    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
+        mod,
+        ast_rewriter_allow_list=ast_rewriter_allow_list,
+        leaf_module_list=leaf_module_list,
+    )
+
+    assert isinstance(rewritten_mod, nn.Module)
+    # Note: use the rewritten_mod here as the root. This is necessary because
+    # RewrittenModule includes a new module for the ConditionalExceptionWrapper.
+    traced = torch.fx.GraphModule(rewritten_mod, rewritten_graph)
+
+    # Now remove all assertions and exceptions if requested.
+    if remove_assertions:
+        _remove_assertions(traced)
+    if remove_exceptions:
+        _remove_exceptions(traced)
+
+    # Cleanup any dead code from the original module as well as resulting dead
+    # nodes after removing assertions and exceptions.
+    traced.graph.eliminate_dead_code()
+
+    # Now normalize args/kwargs to make default values visible. Leave args/kwargs as
+    # they were, since all-kwarg normalization is broken, and we don't need it anyway.
+    shape_prop.ShapeProp(traced).propagate(*sample_inputs)
+    traced = NormalizeArgs(traced, normalize_to_only_use_kwargs=False).transform()
+
+    # Normalize to acc-specialized wrappers for consistency across op naming and
+    # ensuring all kwarg usage.
+    if use_acc_normalization:
+        acc_normalizer.normalize(traced)
+
+    traced.recompile()
+
+    return traced
diff --git a/torch/fx/experimental/fx_acc/acc_utils.py b/torch/fx/experimental/fx_acc/acc_utils.py
new file mode 100644 (file)
index 0000000..a77f3fa
--- /dev/null
@@ -0,0 +1,91 @@
+import inspect
+import json
+from typing import Any, Tuple, Callable, Union, Dict
+
+import torch
+import torch.fx
+from torch.fx.experimental.graph_manipulation import (
+    serialize_module,
+)
+from torch.fx.graph_module import GraphModule
+from torch.fx.passes import graph_drawer
+from torch.fx.passes.shape_prop import TensorMetadata
+
+
+def is_acc_op(node_or_target: Union[Callable, torch.fx.Node]) -> bool:
+    """
+    Returns whether `node_or_target` is an acc_op. If it's a node, then checks whether
+    it's a call_function target is from the acc_ops module. Otherwise it's already
+    the target, which is similarly checked to see if it's from the acc_ops module.
+    """
+    if isinstance(node_or_target, torch.fx.Node):
+        # All acc_ops are call_functions.
+        if node_or_target.op != "call_function":
+            return False
+        target = node_or_target.target
+    else:
+        target = node_or_target
+    return "acc_ops" in target.__module__
+
+
+def is_acc_op_with_kwarg(
+    node_or_target: Union[Callable, torch.fx.Node], kwarg: str
+) -> bool:
+    """
+    Helper that inspects `node_or_target` and returns whether it is an acc_op node
+    (or a target for an acc_op) that has an arg signature that includes `kwarg`.
+    """
+    if not is_acc_op(node_or_target):
+        return False
+
+    target = (
+        node_or_target.target
+        if isinstance(node_or_target, torch.fx.Node)
+        else node_or_target
+    )
+    assert not isinstance(target, str)
+    return kwarg in inspect.signature(inspect.unwrap(target)).parameters
+
+
+def get_field_from_acc_out_ty(
+    acc_out_ty_or_dict: Union[Tuple, Dict[str, Any]], field: str
+):
+    """
+    After tracing NamedTuple inputs are converted to standard tuples, so we cannot
+    access them by name directly. Use this helper instead.
+    """
+    if isinstance(acc_out_ty_or_dict, dict):
+        acc_out_ty = acc_out_ty_or_dict["acc_out_ty"]
+    else:
+        acc_out_ty = acc_out_ty_or_dict
+    return acc_out_ty[TensorMetadata._fields.index(field)]
+
+
+def serialize_module_json_to_file(fx_module: GraphModule, fname: str):
+    weights: Dict = {}
+    serialized_json = json.dumps(serialize_module(fx_module, weights), indent=2)
+    with open(fname, "w") as ofile:
+        ofile.write(serialized_json)
+
+
+def build_raw_tensor_meta(
+    shape=None,
+    dtype=None,
+    requires_grad=None,
+    stride=None,
+    memory_format=None,
+    is_quantized=None,
+    qscheme=None,
+    q_scale=None,
+    q_zero_point=None,
+):
+    return TensorMetadata(**locals())
+
+
+def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph"):
+    if not fname.endswith(".svg"):
+        fname = fname + ".svg"
+    print(f"Writing FX graph to file: {fname}")
+    g = graph_drawer.FxGraphDrawer(traced, figname)
+    x = g.get_main_dot_graph()
+    x.write_svg(fname)
index 7f96fcf..9d0af53 100644 (file)
@@ -372,7 +372,7 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
                 user_targets = {
                     _get_qualified_name(
                         n.target
-                    ).replace("glow.fb.fx.oss_acc_tracer.", "").replace("glow.fb.fx.", ""): n
+                    ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n
                     for n in node.users.keys()
                 }
                 if (