node_to_partition_id = {}
partition_to_logical_devices = {}
count = 0
- GraphManipulation.get_size_of_all_nodes(traced, [a])
+ graph_manipulation.get_size_of_all_nodes(traced, [a])
for node in traced.graph.nodes:
if node.op not in {"placeholder", "get_attr", "output"}:
node_to_partition_id[node] = count
from .mul import * # noqa: F403
from .transformation import * # noqa: F403
from .quantization import * # noqa: F403
+from .acc_ops_converters import * # noqa: F403
--- /dev/null
+# type: ignore[attr-defined]
+import math
+import operator
+
+import torch.fx.experimental.fx_acc.acc_ops as acc_ops
+import torch.fx.experimental.fx_acc.acc_utils as acc_utils
+import numpy as np
+import tensorrt as trt
+import torch
+from torch.fx.experimental.fx2trt.fx2trt import (
+ tensorrt_converter,
+ torch_dtype_from_trt,
+ get_dynamic_dims,
+)
+
+
+def to_numpy(tensor: torch.Tensor):
+ """
+ Convert a PyTorch Tensor to a Numpy Array.
+ """
+ if tensor is None:
+ return tensor
+
+ if tensor.is_quantized:
+ tensor = tensor.dequantize()
+
+ return tensor.cpu().detach().contiguous().numpy()
+
+
+def has_dynamic_shape(shape):
+ return any(s == -1 for s in shape)
+
+
+def get_axes_for_reduce_op(dim, has_implicit_batch_dimension):
+ if isinstance(dim, int):
+ dim = (dim,)
+
+ if has_implicit_batch_dimension:
+ assert 0 not in dim, "Can't reduce over batch dimension when it's implicit."
+
+ axes = 0
+ for d in dim:
+ axes |= 1 << (d - (1 if has_implicit_batch_dimension else 0))
+
+ return axes
+
+
+def create_constant(network, tensor, name):
+ if isinstance(tensor, int):
+ tensor = torch.IntTensor([tensor])
+
+ if isinstance(tensor, float):
+ tensor = torch.Tensor([tensor])
+
+ shape = tuple(tensor.shape)
+
+ # Remove all preceding 1s as they can be re-inserted later during broadcasting.
+ num_preceding_ones = 0
+ for j in range(len(shape)):
+ if int(shape[j]) == 1:
+ num_preceding_ones += 1
+ else:
+ break
+
+ # If shape is all 1s, we want last digit.
+ shape = shape[num_preceding_ones:] if num_preceding_ones < len(shape) else (1,)
+ constant = network.add_constant(shape, to_numpy(tensor))
+ constant.name = name
+ return constant.get_output(0)
+
+
+def get_trt_tensor(network, input_val, name):
+ if isinstance(input_val, (torch.Tensor, int, float)):
+ return create_constant(network, input_val, name)
+ elif not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"Received input {input_val} of name {name} that "
+ "is not part of the TensorRT region!"
+ )
+ else:
+ return input_val
+
+
+def append_ones(network, input, name, num_prepend_ones):
+ layer = network.add_shuffle(input)
+
+ if has_dynamic_shape(input.shape):
+ input_shape_layer = network.add_shape(input)
+ input_shape_layer.name = f"{name}_broadcast_orig_shape"
+ prepend_shape_layer = network.add_constant(
+ (num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32)
+ )
+ prepend_shape_layer.name = f"{name}_broadcast_prepend_ones"
+ reshape_dim_layer = network.add_concatenation(
+ [prepend_shape_layer.get_output(0), input_shape_layer.get_output(0)]
+ )
+ reshape_dim_layer.axis = 0
+ reshape_dim_layer.name = f"{name}_broadcast_final_shape"
+ layer.set_input(1, reshape_dim_layer.get_output(0))
+ else:
+ layer.reshape_dims = (1,) * num_prepend_ones + tuple(input.shape)
+
+ layer.name = name
+ return layer.get_output(0)
+
+
+def broadcast(network, a, b, a_name, b_name, preset_diff=0):
+ a_shape = tuple(a.shape)
+ b_shape = tuple(b.shape)
+
+ diff = len(a_shape) - len(b_shape) - preset_diff
+ if diff > 0:
+ b = append_ones(network, b, f"{b_name}_broadcast", diff)
+ elif diff < 0:
+ a = append_ones(network, a, f"{a_name}_broadcast", -diff)
+
+ return a, b
+
+
+def add_binary_elementwise_layer(network, lhs_val, rhs_val, op_type, name):
+ lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs")
+ rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs")
+ lhs_val, rhs_val = broadcast(
+ network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
+ )
+ layer = network.add_elementwise(lhs_val, rhs_val, op_type)
+ layer.name = name
+ return layer.get_output(0)
+
+
+def add_unary_layer(network, input_val, operation_type, name):
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"{operation_type} received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+ layer = network.add_unary(input_val, operation_type)
+ layer.name = name
+ return layer.get_output(0)
+
+
+def add_activation_layer(network, input_val, operation_type, name):
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"{operation_type} received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+ layer = network.add_activation(input_val, operation_type)
+ layer.name = name
+ return layer.get_output(0)
+
+
+def add_transpose_layer(
+ network, input_val, dim_0, dim_1, name, ignore_implicit_batch=False
+):
+ """Adds a transpose layer to the TensorRT network
+ Args:
+ network: TensorRT Network object
+ input_val: tensorrt.ITensor
+ dim_0, dim_1: dimensions for transpose, e.g. dim_0=1, dim_1=0 means transpose
+ the first two dimensions
+ name: Name of the layer
+ ignore_implicit_batch: activations might have implicit batch, but weights do
+ not, when this is True, we'll ignore the implicit batch and use the dimension
+ argument as is
+ Returns:
+ output TensorRT ITensor from the transpose layer
+ """
+ if not ignore_implicit_batch and network.has_implicit_batch_dimension:
+ assert (
+ dim_0 != 0 and dim_1 != 0
+ ), "It's not allowed to call transpose on non-constant when batch dim is implicit!"
+ dim_0 -= 1
+ dim_1 -= 1
+
+ permutation = list(range(len(input_val.shape)))
+ permutation[dim_0] = dim_1
+ permutation[dim_1] = dim_0
+
+ layer = network.add_shuffle(input_val)
+ layer.second_transpose = tuple(permutation)
+ layer.name = name
+ return layer.get_output(0)
+
+
+def process_attr(val, num_elem):
+ if not isinstance(val, tuple):
+ val = (val,) * num_elem
+ return val
+
+
+@tensorrt_converter(acc_ops.conv2d)
+def acc_ops_conv2d(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"Conv2d received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ if has_dynamic_shape(input_val.shape):
+ assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
+
+ kernel = to_numpy(kwargs["weight"])
+ bias = to_numpy(kwargs["bias"])
+
+ layer = network.add_convolution(
+ input=input_val,
+ num_output_maps=kernel.shape[0],
+ kernel_shape=kernel.shape[2:],
+ kernel=kernel,
+ bias=bias,
+ )
+
+ layer.name = name
+ layer.stride = kwargs["stride"]
+ layer.padding = kwargs["padding"]
+ layer.dilation = kwargs["dilation"]
+ if kwargs["groups"] is not None:
+ layer.num_groups = kwargs["groups"]
+
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.flatten)
+def acc_ops_flatten(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"flatten received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
+ start_dim = (kwargs["start_dim"] if "start_dim" in kwargs else 0) % num_dims
+ end_dim = (kwargs["end_dim"] if "end_dim" in kwargs else -1) % num_dims
+
+ if network.has_implicit_batch_dimension:
+ assert start_dim != 0, "Can't flatten batch dimension when it's implicit."
+ start_dim -= 1
+ end_dim -= 1
+
+ layer = network.add_shuffle(input_val)
+ layer.name = name
+
+ # If there're dynamic shapes then we need to use shape layers
+ # to figure out the final shape after flatten. We first slice
+ # the input shape to three parts:
+ # 1. dimensions before start_dim
+ # 2. dimensions between start_dim and end_dim
+ # 3. dimensions after end_dim
+ # Part 1 and 3 might not exist if start_dim is 0 or end_dim is
+ # last dim. Then we do a reduced multiplication over part 2 to
+ # get flattened dim. Finally, we concatenate the three parts to
+ # get the final shape.
+ if has_dynamic_shape(input_val.shape):
+ input_shape_layer = network.add_shape(input_val)
+ input_shape_layer.name = f"{name}_orig_shape"
+
+ final_shapes = []
+
+ # Shapes before start_dim
+ if start_dim > 0:
+ prefix_shape_layer = network.add_slice(
+ input_shape_layer.get_output(0),
+ start=(0,),
+ shape=(start_dim,),
+ stride=(1,),
+ )
+ prefix_shape_layer.name = f"{name}_pre_shape"
+ final_shapes.append(prefix_shape_layer.get_output(0))
+
+ flatten_shape_layer = network.add_slice(
+ input_shape_layer.get_output(0),
+ start=(start_dim,),
+ shape=(end_dim - start_dim + 1,),
+ stride=(1,),
+ )
+ flatten_shape_layer.name = f"{name}_need_flatten"
+ flatten_shape_layer = network.add_reduce(
+ flatten_shape_layer.get_output(0),
+ trt.ReduceOperation.PROD,
+ axes=get_axes_for_reduce_op(0, False),
+ keep_dims=True,
+ )
+ flatten_shape_layer.name = f"{name}_flatten_dim"
+ final_shapes.append(flatten_shape_layer.get_output(0))
+
+ # Shapes after start_dim
+ if end_dim < len(input_val.shape) - 1:
+ suffix_shape_layer = network.add_slice(
+ input_shape_layer.get_output(0),
+ start=(end_dim + 1,),
+ shape=(len(input_val.shape) - end_dim - 1,),
+ stride=(1,),
+ )
+ suffix_shape_layer.name = f"{name}_suffix_shape"
+ final_shapes.append(suffix_shape_layer.get_output(0))
+
+ final_shape_layer = network.add_concatenation(final_shapes)
+ final_shape_layer.axis = 0
+ final_shape_layer.name = f"{name}_final_shape"
+ layer.set_input(1, final_shape_layer.get_output(0))
+ else:
+ final_shape = []
+ flatten_dim = 1
+ for i, s in enumerate(input_val.shape):
+ if i >= start_dim and i <= end_dim:
+ flatten_dim *= s
+ elif i == end_dim + 1:
+ final_shape.append(flatten_dim)
+ final_shape.append(s)
+ else:
+ final_shape.append(s)
+ if end_dim == len(input_val.shape) - 1:
+ final_shape.append(flatten_dim)
+
+ layer.reshape_dims = tuple(final_shape)
+
+ return layer.get_output(0)
+
+
+# For implicit batch dim mode, we use this to represent batch dim if we
+# ever trying to retrieve it via size() and we hope it will fail hard if
+# it's used somewhere else.
+IMPLICIT_BATCH_DIM = -999
+
+
+@tensorrt_converter(acc_ops.size)
+def acc_ops_size(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"size received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ if not has_dynamic_shape(input_val.shape):
+ if network.has_implicit_batch_dimension:
+ return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_val.shape))
+ return torch.Size(input_val.shape)
+
+ layer = network.add_shape(input_val)
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.batch_norm)
+def acc_ops_batch_norm(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"BatchNorm2d received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ if has_dynamic_shape(input_val.shape):
+ assert input_val.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
+
+ scale = to_numpy(kwargs["weight"]) / np.sqrt(
+ to_numpy(kwargs["running_var"]) + kwargs["eps"]
+ )
+ bias = (
+ to_numpy(kwargs["bias"])
+ - to_numpy(kwargs["running_mean"]) * scale
+ )
+ power = np.ones_like(scale)
+
+ layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power)
+ layer.name = name
+
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.softmax)
+def acc_ops_softmax(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ dim = kwargs["dim"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"softmax received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ # Used to get dim when dim is None. Copied from PyTorch softmax implementation.
+ def get_softmax_dim(ndim):
+ if ndim == 0 or ndim == 1 or ndim == 3:
+ ret = 0
+ else:
+ ret = 1
+ return ret
+
+ if dim is None:
+ dim = get_softmax_dim(
+ len(input_val.shape)
+ if not network.has_implicit_batch_dimension
+ else len(input_val.shape) + 1
+ )
+
+ if network.has_implicit_batch_dimension:
+ assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
+ dim = (dim % (len(input_val.shape) + 1)) - 1
+
+ layer = network.add_softmax(input_val)
+ layer.axes = 1 << dim
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.relu)
+def acc_ops_relu(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.ActivationType.RELU
+ return add_activation_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.sin)
+def acc_ops_sin(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.SIN
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.cos)
+def acc_ops_cos(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.COS
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.tan)
+def acc_ops_tan(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.TAN
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.sinh)
+def acc_ops_sinh(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.SINH
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.cosh)
+def acc_ops_cosh(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.COSH
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.tanh)
+def acc_ops_tanh(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.ActivationType.TANH
+ return add_activation_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.asin)
+def acc_ops_asin(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.ASIN
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.acos)
+def acc_ops_acos(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.ACOS
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.atan)
+def acc_ops_atan(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.ATAN
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.exp)
+def acc_ops_exp(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.EXP
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.log)
+def acc_ops_log(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.LOG
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.sqrt)
+def acc_ops_sqrt(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.SQRT
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.reciprocal)
+def acc_ops_reciprocal(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.RECIP
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.abs)
+def acc_ops_abs(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.ABS
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.neg)
+def acc_ops_neg(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.NEG
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.floor)
+def acc_ops_floor(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.FLOOR
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.ceil)
+def acc_ops_ceil(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ operation_type = trt.UnaryOperation.CEIL
+ return add_unary_layer(network, input_val, operation_type, name)
+
+
+@tensorrt_converter(acc_ops.sum)
+def acc_ops_sum(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"sum received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ # If dim is specified, then we are computing reduced sum over certain dimensions.
+ # Otherwise, we are dong summation over all elements, which is only supported in
+ # explicit batch dimension.
+ if "dim" not in kwargs:
+ assert (
+ not network.has_implicit_batch_dimension
+ ), "Do not support sum all the elements for implicit batch."
+ dim = range(0, len(input_val.shape))
+ else:
+ dim = kwargs["dim"]
+
+ keepdim = False if "keepdim" not in kwargs else kwargs["keepdim"]
+ layer = network.add_reduce(
+ input_val,
+ trt.ReduceOperation.SUM,
+ get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension),
+ keepdim,
+ )
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.max_pool2d)
+def acc_ops_max_pool2d(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"MaxPool2d received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ kernel_size = process_attr(kwargs["kernel_size"], 2)
+ stride = process_attr(kwargs["stride"], 2)
+ padding = process_attr(kwargs["padding"], 2)
+ dilation = process_attr(kwargs["dilation"], 2)
+ ceil_mode = kwargs["ceil_mode"]
+
+ if dilation != (1, 1):
+ raise RuntimeError(
+ f"Only support dilation=(1, 1) for maxpool, but got {dilation}"
+ )
+
+ layer = network.add_pooling(
+ input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size
+ )
+ layer.stride = stride
+ layer.padding = padding
+ layer.name = name
+
+ if ceil_mode:
+ layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
+
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.squeeze)
+def acc_ops_squeeze(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"squeeze received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ dim = kwargs["dim"] if "dim" in kwargs else None
+ # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
+ # dim, which is a very rare case. For now we just claim not supporting dim=None.
+ assert dim is not None, "We don't support dim=None right now."
+
+ if network.has_implicit_batch_dimension:
+ assert dim != 0, "We don't support squeeze batch dim when it's implicit."
+ dim -= 1
+
+ assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
+ assert (
+ len(get_dynamic_dims(input_val.shape)) <= 1
+ ), "Currently more than one dynamic dim for input to squeeze is not supported."
+
+ output_shape = []
+ for i, s in enumerate(input_val.shape):
+ if i == dim and s == 1:
+ continue
+ output_shape.append(s)
+ layer = network.add_shuffle(input_val)
+ layer.reshape_dims = tuple(output_shape)
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.add)
+def acc_ops_add(network, target, args, kwargs, name):
+ return add_binary_elementwise_layer(
+ network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.SUM, name
+ )
+
+
+@tensorrt_converter(acc_ops.sub)
+def acc_ops_sub(network, target, args, kwargs, name):
+ return add_binary_elementwise_layer(
+ network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.SUB, name
+ )
+
+
+@tensorrt_converter(acc_ops.div)
+def acc_ops_div(network, target, args, kwargs, name):
+ return add_binary_elementwise_layer(
+ network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, name
+ )
+
+
+@tensorrt_converter(acc_ops.mul)
+def acc_ops_mul(network, target, args, kwargs, name):
+ return add_binary_elementwise_layer(
+ network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.PROD, name
+ )
+
+
+@tensorrt_converter(acc_ops.min_two_tensors_input)
+def acc_ops_min_two_tensors_input(network, target, args, kwargs, name):
+ return add_binary_elementwise_layer(
+ network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.MIN, name
+ )
+
+
+@tensorrt_converter(acc_ops.adaptive_avg_pool2d)
+def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"AdaptiveAvgPool2d received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ assert (
+ input_val.shape[-1] != -1 and input_val.shape[-1] != -1
+ ), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims."
+ output_size = kwargs["output_size"]
+
+ for input_dim, output_dim in zip(input_val.shape[-2:], output_size):
+ if input_dim % output_dim != 0:
+ raise RuntimeError(
+ "For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
+ f"Got input dim {input_dim}, output dim {output_dim}"
+ )
+
+ stride = (
+ input_val.shape[-2] // output_size[0],
+ input_val.shape[-1] // output_size[1],
+ )
+ kernel_size = (
+ input_val.shape[-2] - (output_size[0] - 1) * stride[0],
+ input_val.shape[-1] - (output_size[1] - 1) * stride[1],
+ )
+ layer = network.add_pooling(
+ input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
+ )
+ layer.stride = stride
+ layer.name = name
+
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.avg_pool2d)
+def acc_ops_avg_pool2d(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"AvgPool2d received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ kernel_size = process_attr(kwargs["kernel_size"], 2)
+ stride = process_attr(kwargs["stride"], 2)
+ padding = process_attr(kwargs["padding"], 2)
+ ceil_mode = kwargs["ceil_mode"]
+ count_include_pad = kwargs["count_include_pad"]
+ divisor_override = kwargs["divisor_override"]
+
+ if divisor_override:
+ raise RuntimeError("TensorRT does not support divisor_override.")
+
+ layer = network.add_pooling(
+ input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
+ )
+ layer.stride = stride
+ layer.padding = padding
+ layer.average_count_excludes_padding = False if count_include_pad else True
+
+ if ceil_mode:
+ layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
+
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.reshape)
+def acc_ops_reshape(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"Reshape received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ shape = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "shape")
+ if network.has_implicit_batch_dimension:
+ shape = shape[1:]
+
+ layer = network.add_shuffle(input_val)
+
+ if all(isinstance(s, int) for s in shape):
+ layer.reshape_dims = tuple(shape)
+ else:
+ # Convert all the dimensions to trt Tensors.
+ trt_shape = []
+
+ for i, s in enumerate(shape):
+ if isinstance(s, trt.tensorrt.ITensor):
+ if len(s.shape) == 0:
+ s = append_ones(network, s, f"{name}_{i}", 1)
+ trt_shape.append(s)
+ else:
+ trt_shape.append(get_trt_tensor(network, s, f"{name}_{i}"))
+
+ shape_layer = network.add_concatenation(inputs=trt_shape)
+ shape_layer.axis = 0
+ shape_layer.name = f"{name}_output_shape"
+ layer.set_input(1, shape_layer.get_output(0))
+
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.linear)
+def acc_ops_linear(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"Linear received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ dynamic_dims = get_dynamic_dims(input_val.shape)
+ assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, (
+ "Currently we only support one dynmaic "
+ "dim for linear and it can't be the last dim."
+ )
+
+ layer = network.add_shuffle(input_val)
+ layer.reshape_dims = tuple(input_val.shape) + (1, 1)
+ layer.name = f"{name}_pre_shuffle"
+
+ # add fully connected
+ layer = network.add_fully_connected(
+ input=layer.get_output(0),
+ num_outputs=kwargs["weight"].shape[0],
+ kernel=to_numpy(kwargs["weight"]),
+ bias=to_numpy(kwargs["bias"]),
+ )
+ layer.name = f"{name}_linear"
+
+ # reshape back
+ layer = network.add_shuffle(layer.get_output(0))
+ layer.reshape_dims = tuple(input_val.shape[:-1]) + (kwargs["weight"].shape[0],)
+ layer.name = f"{name}_post_shuffle"
+
+ return layer.get_output(0)
+
+
+def add_clamp(network, input, val, op):
+ acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
+ acc_ops_clamp_tensor = (
+ val
+ * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
+ .cpu()
+ .numpy()
+ )
+ acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
+ layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)
+
+ return layer
+
+
+@tensorrt_converter(acc_ops.clamp)
+def acc_ops_clamp(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ min_val = kwargs["min"]
+ max_val = kwargs["max"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"Clamp received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ if min_val is not None:
+ clamp_min_layer = add_clamp(
+ network, input_val, min_val, trt.ElementWiseOperation.MAX
+ )
+ clamp_min_layer.name = f"{name}_clamp_min"
+ input_val = clamp_min_layer.get_output(0)
+ if max_val is not None:
+ clamp_max_layer = add_clamp(
+ network, input_val, max_val, trt.ElementWiseOperation.MIN
+ )
+ clamp_max_layer.name = f"{name}_clamp_max"
+ input_val = clamp_max_layer.get_output(0)
+
+ return input_val
+
+
+@tensorrt_converter(acc_ops.getitem)
+def acc_ops_getitem(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ slices = kwargs["idx"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ return operator.getitem(input_val, slices)
+
+ assert not has_dynamic_shape(
+ input_val.shape
+ ), "Currently we don't support slicing tensor if it has dynamic shape."
+
+ def num_slice_types(slices):
+ """
+ Gather the number of slice in getitem slices.
+ """
+ num_slice = 0
+ for s in slices:
+ if isinstance(s, slice) or isinstance(s, int):
+ num_slice += 1
+ return num_slice
+
+ def slice_to_trt_params(py_slice, dim_size):
+ """
+ Convert python slice to TensorRT slice layer parameters.
+ """
+ start = (py_slice.start % dim_size) if py_slice.start else 0
+ stride = py_slice.step if py_slice.step else 1
+ stop = (py_slice.stop % dim_size) if py_slice.stop else dim_size
+ size = math.ceil((stop - start) * 1.0 / stride)
+ return start, size, stride
+
+ if not isinstance(slices, tuple):
+ slices = (slices,)
+
+ if network.has_implicit_batch_dimension:
+ # Raise an error if it's trying to subscript batch dimension unless it's
+ # slice(None, None, None).
+ batch_subscript = slices[0]
+ if batch_subscript != slice(None, None, None):
+ raise RuntimeError(
+ f"Can't subscript batch dimension when it's implicit. Got {slices}"
+ )
+
+ # Remove batch_dim subscript
+ slices = slices[1:]
+
+ # Replace ellipsis with expanded slices.
+ # Compute the number of dim ellipsis represent.
+ num_ellipsis = len(input_val.shape) - num_slice_types(slices)
+ new_slices = []
+ for s in slices:
+ if s == Ellipsis:
+ while num_ellipsis > 0:
+ new_slices.append(slice(None, None, None))
+ num_ellipsis -= 1
+ else:
+ new_slices.append(s)
+ slices = new_slices
+
+ # Build trt slice layer params
+ start = []
+ size = []
+ stride = []
+
+ i = 0
+ for s in slices:
+ if s is None:
+ continue
+
+ if isinstance(s, slice):
+ params = slice_to_trt_params(s, input_val.shape[i])
+ start.append(params[0])
+ size.append(params[1])
+ stride.append(params[2])
+ else:
+ start.append(s % input_val.shape[i])
+ size.append(1)
+ stride.append(1)
+ i += 1
+
+ while i < len(input_val.shape):
+ start.append(0)
+ size.append(input_val.shape[i])
+ stride.append(1)
+ i += 1
+
+ layer = network.add_slice(
+ input=input_val,
+ start=start,
+ shape=size,
+ stride=stride,
+ )
+ layer.name = name
+
+ # Add shuffle layer to insert dimensions for 'None' and remove dimensions for 'int'.
+ if any(not isinstance(s, slice) for s in slices):
+ slice_out = layer.get_output(0)
+ layer = network.add_shuffle(slice_out)
+ final_shape = []
+ original_idx = 0
+ for s in slices:
+ # If it's a slice, keep the dim.
+ if isinstance(s, slice):
+ final_shape.append(slice_out.shape[original_idx])
+ original_idx += 1
+ # If it's None, extend the dim.
+ elif s is None:
+ final_shape.append(1)
+ # If it's a int, remove the dim.
+ else:
+ original_idx += 1
+ layer.reshape_dims = tuple(final_shape) + tuple(slice_out.shape)[original_idx:]
+
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.cat)
+def acc_ops_cat(network, target, args, kwargs, name):
+ tensors = kwargs["tensors"]
+
+ if any(not isinstance(t, trt.tensorrt.ITensor) for t in tensors):
+ raise RuntimeError(
+ f"cat received inputs {tensors} that is not part " "of the TensorRT region!"
+ )
+
+ layer = network.add_concatenation(inputs=tensors)
+ layer.axis = kwargs["dim"] - (1 if network.has_implicit_batch_dimension else 0)
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.transpose)
+def acc_ops_transpose(network, target, args, kwargs, name):
+ input_val, dim_0, dim_1 = kwargs["input"], kwargs["dim0"], kwargs["dim1"]
+
+ # TODO: Remove this after enabling const folding in fx_acc
+ if isinstance(input_val, torch.Tensor):
+ return input_val.transpose(dim_0, dim_1).contiguous()
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"transpose received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ return add_transpose_layer(network, input_val, dim_0, dim_1, name)
+
+
+@tensorrt_converter(acc_ops.matmul)
+def acc_ops_matmul(network, target, args, kwargs, name):
+ input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
+ other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other")
+
+ for i in [input_val, other_val]:
+ if not isinstance(i, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"matmul received input {i} that is not part " "of the TensorRT region!"
+ )
+
+ input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
+ preset_diff = 0
+
+ if len(input_val.shape) == 1:
+ preset_diff -= 1
+ input_matrix_op = trt.MatrixOperation.VECTOR
+
+ if len(other_val.shape) == 1:
+ preset_diff += 1
+ other_matrix_op = trt.MatrixOperation.VECTOR
+
+ input_val, other_val = broadcast(
+ network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
+ )
+ layer = network.add_matrix_multiply(
+ input_val, input_matrix_op, other_val, other_matrix_op
+ )
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.sigmoid)
+def acc_ops_sigmoid(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"Sigmoid received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID)
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(acc_ops.permute)
+def acc_ops_permute(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+ permutation = kwargs["permutation"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(
+ f"permute received input {input_val} that is not part "
+ "of the TensorRT region!"
+ )
+
+ if network.has_implicit_batch_dimension:
+ assert permutation[0] == 0, "Can't permute batch dimension when it's implicit."
+ permutation = [i - 1 for i in permutation[1:]]
+
+ layer = network.add_shuffle(input_val)
+ layer.second_transpose = tuple(permutation)
+ layer.name = name
+ return layer.get_output(0)
@tensorrt_converter(operator.add)
@tensorrt_converter(torch.add)
def add(network, target, args, kwargs, layer_name):
- if len(kwargs) != 0:
- raise RuntimeError(f"Add receives unsupported kwargs: {kwargs}!")
-
- assert len(args) == 2
- if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in args):
+ # operator.add
+ if len(kwargs) == 0:
+ lhs_val, rhs_val = args
+ else:
+ # torch.add
+ lhs_val, rhs_val = kwargs["input"], kwargs["other"]
+ assert kwargs["alpha"] == 1
+
+ if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in [lhs_val, rhs_val]):
raise RuntimeError("add() received an input that is not part of the TensorRT region!")
- lhs_val, rhs_val = args
layer = network.add_elementwise(lhs_val, rhs_val, trt.ElementWiseOperation.SUM)
layer.name = layer_name
@tensorrt_converter(torch.mul)
@tensorrt_converter(operator.mul)
def mul(network, target, args, kwargs, layer_name):
- assert len(args) == 2
- if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in args):
+ # operator.mul
+ if len(kwargs) == 0:
+ lhs_val, rhs_val = args
+ else:
+ # torch.mul
+ lhs_val, rhs_val = kwargs["input"], kwargs["other"]
+
+ if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in [lhs_val, rhs_val]):
raise RuntimeError('mul() received an input that is not part of the TensorRT region!')
- lhs_val, rhs_val = args
layer = network.add_elementwise(lhs_val, rhs_val, trt.ElementWiseOperation.PROD)
layer.name = layer_name
-import copy
import warnings
from typing import List, NamedTuple, Iterable, Any, Optional, Tuple
+import tensorrt as trt
import torch
import torch.fx
-import tensorrt as trt
-from torch.fx.experimental.normalize import NormalizeArgs
# Borrowed from torch2trt
def torch_dtype_to_trt(dtype):
- if trt.__version__ >= '7.0' and dtype == torch.bool:
+ if trt.__version__ >= "7.0" and dtype == torch.bool:
return trt.bool
elif dtype == torch.int8:
return trt.int8
def torch_dtype_from_trt(dtype):
if dtype == trt.int8:
return torch.int8
- elif trt.__version__ >= '7.0' and dtype == trt.bool:
+ elif trt.__version__ >= "7.0" and dtype == trt.bool:
return torch.bool
elif dtype == trt.int32:
return torch.int32
class TRTModule(torch.nn.Module):
- def __init__(self, engine=None, input_names=None, output_names=None, fp16_output=False):
+ def __init__(
+ self, engine=None, input_names=None, output_names=None, fp16_output=False
+ ):
super(TRTModule, self).__init__()
self._register_state_dict_hook(TRTModule._on_state_dict)
self.engine = engine
self.output_names = state_dict[prefix + "output_names"]
def forward(self, *inputs):
- assert len(inputs) == len(self.input_names), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
+ assert len(inputs) == len(
+ self.input_names
+ ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
batch_size = inputs[0].shape[0]
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
bindings: List[Any] = [None] * (len(self.input_names) + len(self.output_names))
return tuple(outputs)
def enable_profiling(self):
- raise RuntimeError("Profiling is not supported right now because it requires calling"
- " execute() instead of execute_async().")
+ raise RuntimeError(
+ "Profiling is not supported right now because it requires calling"
+ " execute() instead of execute_async()."
+ )
if not self.context.profiler:
self.context.profiler = trt.Profiler()
def register_converter(converter):
CONVERTERS[key] = converter
return converter
+
return register_converter
has_batch_dim: Whether the shape includes batch dimension. Batch dimension has to be provided
if the engine want to run with dynamic shape.
"""
- shape : torch.Size
- dtype : torch.dtype
- device : torch.device = torch.device("cpu")
- shape_ranges : List[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]] = []
- has_batch_dim : bool = True
+
+ shape: torch.Size
+ dtype: torch.dtype
+ device: torch.device = torch.device("cpu")
+ shape_ranges: List[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]] = []
+ has_batch_dim: bool = True
@classmethod
def from_tensor(cls, tensor: torch.Tensor):
return dynamic_dims
-class BaseTRTInterpreter(torch.fx.Interpreter):
+def create_inputs_from_specs(input_specs):
+ inputs = []
+
+ for shape, dtype, device, shape_ranges, has_batch_dim in input_specs:
+ if len(get_dynamic_dims(shape)):
+ shape = shape_ranges[0][1]
+ elif not has_batch_dim:
+ shape = (1,) + tuple(shape)
+
+ inputs.append(torch.empty(shape, dtype=dtype, device=device))
+
+ return inputs
+
+
+class TRTInterpreter(torch.fx.Interpreter):
def __init__(
self,
- module : torch.fx.GraphModule,
- input_specs : List[InputTensorSpec],
- explicit_batch_dimension : bool = False,
- logger_level=trt.Logger.WARNING
+ module: torch.fx.GraphModule,
+ input_specs: List[InputTensorSpec],
+ explicit_batch_dimension: bool = False,
+ logger_level=trt.Logger.WARNING,
):
super().__init__(module)
self.builder = trt.Builder(self.logger)
if explicit_batch_dimension:
- EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+ EXPLICIT_BATCH = 1 << (int)(
+ trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
+ )
self.network = self.builder.create_network(EXPLICIT_BATCH)
else:
self.network = self.builder.create_network()
- self.optimization_profiles : Optional[List] = None
+ self.optimization_profiles: Optional[List] = None
self.input_specs = input_specs
self.input_specs_iter = 0
self.validate_input_specs()
def validate_input_specs(self):
for shape, dtpe, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
- assert has_batch_dim, "It's required to specify batch dimension when it's explicit in TensorRT network."
+ assert (
+ has_batch_dim
+ ), "It's required to specify batch dimension when it's explicit in TensorRT network."
dynamic_dims = get_dynamic_dims(shape)
if len(dynamic_dims):
- assert not self.network.has_implicit_batch_dimension, "Can't have dynamic dim when " \
+ assert not self.network.has_implicit_batch_dimension, (
+ "Can't have dynamic dim when "
f"batch dim is implicit, got {shape}."
- assert len(shape_ranges), "shape_ranges must be provided when shape has dynamic dim."
+ )
+ assert len(
+ shape_ranges
+ ), "shape_ranges must be provided when shape has dynamic dim."
if self.optimization_profiles:
- assert len(shape_ranges) == len(self.optimization_profiles), "Number of optimization " \
- f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range" \
+ assert len(shape_ranges) == len(self.optimization_profiles), (
+ "Number of optimization "
+ f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range"
f" {len(shape_ranges)} provided."
+ )
else:
- self.optimization_profiles = [self.builder.create_optimization_profile() for _ in range(len(shape_ranges))]
+ self.optimization_profiles = [
+ self.builder.create_optimization_profile()
+ for _ in range(len(shape_ranges))
+ ]
for shape_range in shape_ranges:
- assert len(shape_range) == 3, f"Expect three elements in shape_range, got {len(shape_range)}"
- assert all(len(s) == len(shape) for s in shape_range), "Expect elements in shape_range" \
+ assert (
+ len(shape_range) == 3
+ ), f"Expect three elements in shape_range, got {len(shape_range)}"
+ assert all(len(s) == len(shape) for s in shape_range), (
+ "Expect elements in shape_range"
f" {shape_range} have the same number of dimension as the provided shape {len(shape)}"
+ )
for i in range(len(shape)):
if i in dynamic_dims:
- assert all(shape_range[j][i] <= shape_range[j + 1][i] for j in range(2)), "Expect dynamic dim" \
+ assert all(
+ shape_range[j][i] <= shape_range[j + 1][i]
+ for j in range(2)
+ ), (
+ "Expect dynamic dim"
f" {i} to have incremental value for shapes in shape_range {shape_range}."
+ )
else:
- assert all(s[i] == shape[i] for s in shape_range), f"Expect non dynamic dim {i} to be the same" \
+ assert all(s[i] == shape[i] for s in shape_range), (
+ f"Expect non dynamic dim {i} to be the same"
f" for all shapes in shape_range {shape_range}."
+ )
else:
- assert len(shape_ranges) == 0, "shape_ranges are provided for input that doesn't have dynamic dim."
+ assert (
+ len(shape_ranges) == 0
+ ), "shape_ranges are provided for input that doesn't have dynamic dim."
def run(
self,
max_workspace_size=1 << 25,
fp16_mode=True,
int8_mode=False,
- strict_type_constraints=True
+ strict_type_constraints=True,
):
# TODO hack, should check contents of args and remove fp16_mode probably
self.fp16_mode = fp16_mode
builder_config.add_optimization_profile(optimization_profile)
engine = self.builder.build_engine(self.network, builder_config)
- assert(engine)
+ assert engine
return engine, self._input_names, self._output_names
def run_node(self, n):
def placeholder(self, target, args, kwargs):
self._input_names.append(target)
- shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[self.input_specs_iter]
+ shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[
+ self.input_specs_iter
+ ]
self.input_specs_iter += 1
if self.network.has_implicit_batch_dimension:
assert self.optimization_profiles
self.optimization_profiles[i].set_shape(target, *shape_range)
- return self.network.add_input(name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype))
+ return self.network.add_input(
+ name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
+ )
def call_module(self, target, args, kwargs):
assert isinstance(target, str)
converter = CONVERTERS.get(type(submod))
if not converter:
- raise RuntimeError(f'Conversion of module of type {type(submod)} not currently supported!')
+ raise RuntimeError(
+ f"Conversion of module of type {type(submod)} not currently supported!"
+ )
return converter(self.network, submod, args, kwargs, self._cur_node_name)
converter = CONVERTERS.get(target)
if not converter:
- raise RuntimeError(f'Conversion of function {torch.typename(target)} not currently supported!')
+ raise RuntimeError(
+ f"Conversion of function {torch.typename(target)} not currently supported!"
+ )
return converter(self.network, target, args, kwargs, self._cur_node_name)
converter = CONVERTERS.get(target)
if not converter:
- raise RuntimeError(f'Conversion of method {target} not currently supported!')
+ raise RuntimeError(
+ f"Conversion of method {target} not currently supported!"
+ )
return converter(self.network, target, args, kwargs, self._cur_node_name)
outputs = args[0] if isinstance(args[0], tuple) else (args[0],)
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
- raise RuntimeError('TensorRT requires all outputs to be Tensor!')
+ raise RuntimeError("TensorRT requires all outputs to be Tensor!")
for i, output in enumerate(outputs):
- name = f'output{i}'
+ name = f"output{i}"
output.name = name
self.network.mark_output(output)
if self.fp16_mode:
else:
output.dtype = trt.float32
self._output_names.append(name)
-
-
-class TRTInterpreter(BaseTRTInterpreter):
- """
- Use this for general case where there're PyTorch vanilla ops in the FX mdoule.
- """
- def __init__(self, module : torch.nn.Module, input_specs : List[InputTensorSpec], logger_level=trt.Logger.WARNING):
- # Preprocess the model
- if not isinstance(module, torch.fx.GraphModule):
- module = torch.fx.symbolic_trace(module)
- else:
- module = copy.deepcopy(module)
- module = module.cpu().float()
- module = NormalizeArgs(module).transform()
- super().__init__(module, input_specs, logger_level=logger_level)
--- /dev/null
+# type: ignore[]
+import inspect
+import re
+from typing import NamedTuple, Optional, Callable, Dict, List, Tuple, Union, Any, Set
+
+import torch.fx.experimental.fx_acc.acc_utils as acc_utils
+import torch
+import torch.fx
+from torch.fx.node import _get_qualified_name
+
+# Need to keep up-to-date with https://fburl.com/codesearch/7r2hhh53
+ALIAS_MAP = {
+ "input": ("input", "x", "a", "x1"),
+ "dim": ("dim", "axis"),
+ "keepdim": ("keepdim", "keepdims"),
+ "other": ("other", "x2"),
+}
+
+# Type used for arg replacement tuples. The list represents the argument signature of
+# some callable. Each item in the list is a tuple, where for each member of a tuple:
+# - The first member is union of either:
+# - A tuple of all potential alias kwarg str names of the source signature, or
+# - A tuple of a single str representing the single kwarg name allowed.
+# - The second member is the str name of the kwarg to map it to. This is either from the
+# signature of the acc_op, or for custom mapped nodes from the original unnormalized op.
+# - The third member is a bool representing whether this arg is optional, i.e. whether it
+# is allowed to not be present in the original input args.
+ArgReplacementTuplesType = List[Tuple[Tuple[str, ...], str, bool]]
+
+
+class NormalizationInfo(NamedTuple):
+ """
+ Holds normalization info for some FX node, where the FX node will be mapped either
+ via new_fn_target and arg_replacement_tuples, or via custom_mapping_fn.
+
+ If via new_fn_target and arg_replacement_tuples:
+ - new_fn_target is the target function to replace the original node with
+ (generally some function from acc_ops).
+
+ - arg_replacement_tuples describes how to map the original FX node's args/kwargs to
+ the new FX node. If set to None, then the kwargs are copied directly from the
+ original FX node. Else, this is list of three-member tuples, where each tuple
+ represents a mapping from either an arg or kwarg in the original FX node to the
+ kwarg it should be mapped to. If for ops registered with `register_acc_op` then
+ this is a mapping to the the new FX node for the acc_op. Otherwise it is for some
+ op registered with `register_custom_acc_mapper_fn`, in which case this is a
+ mapping for the original input node so its args are normalized to kwargs before
+ being custom normalized to acc_ops. The third member of the tuple is a bool
+ representing whether this argument is optional; if False and the arg is not
+ present then an assertion will be thrown. The index of the tuple indicates where
+ the original arg is in node.args and the string name indicates which original
+ kwarg it is.
+
+ If via custom_mapping_fn, then custom_mapping_fn is some function that takes the
+ original FX node as input and returns the FX node that should replace it. This means
+ it was registered via `register_custom_acc_mapper_fn`.
+ """
+
+ new_fn_target: Callable
+ arg_replacement_tuples: Optional[ArgReplacementTuplesType]
+ custom_mapping_fn: Optional[Callable]
+ kwargs_to_move_to_acc_out_ty: Optional[Optional[List[Tuple[str, str]]]]
+ needs_shapes_for_normalization: bool
+
+
+# Dict from (op, target) to NormalizationInfo for that op.
+_normalization_dict: Dict[Tuple[str, Union[str, Callable]], NormalizationInfo] = {}
+
+# Set of all the acc ops.
+_acc_ops: Set[Callable] = set()
+
+
+def _insert_fun(
+ op_and_target: Tuple[str, Union[str, Callable]],
+ arg_replacement_tuples: List[Tuple],
+ new_fn_target: Optional[Callable] = None,
+ custom_mapping_fn: Optional[Callable] = None,
+ kwargs_to_move_to_acc_out_ty: Optional[Optional[List[Tuple[str, str]]]] = None,
+ needs_shapes_for_normalization=False,
+ allow_normalize_from_torch_package=False,
+):
+ if op_and_target[0] == "call_function":
+ assert callable(op_and_target[1])
+ elif op_and_target[0] == "call_method":
+ assert isinstance(op_and_target[1], str)
+ elif op_and_target[0] == "call_module":
+ assert isinstance(op_and_target[1], type)
+
+ # Finalize arg replacement tuples.
+ # 1. Check to see if they have the `is_optional` bool, and if not defaulting it to
+ # False.
+ # 2. Some kwargs might have aliases. e.g. "a", "x" and "x1" are aliases of "input".
+ # Here we replace `orig_kwarg` with a tuple of all aliases if it has aliases.
+ final_arg_replacement_tuples = []
+ for arg_replacement_tuple in arg_replacement_tuples:
+ if len(arg_replacement_tuple) == 2:
+ orig_kwarg, new_kwarg, is_optional = *arg_replacement_tuple, False
+ else:
+ assert len(arg_replacement_tuple) == 3
+ orig_kwarg, new_kwarg, is_optional = arg_replacement_tuple
+
+ if not isinstance(orig_kwarg, tuple):
+ orig_kwarg = (orig_kwarg,)
+
+ # Use set to avoid duplicates.
+ orig_kwarg_set = set(orig_kwarg)
+
+ for k in orig_kwarg:
+ if k in ALIAS_MAP:
+ orig_kwarg_set.update(ALIAS_MAP[k])
+ final_arg_replacement_tuples.append(
+ (tuple(orig_kwarg_set), new_kwarg, is_optional)
+ )
+
+ assert op_and_target not in _normalization_dict.keys()
+ norm_info = NormalizationInfo(
+ new_fn_target=new_fn_target,
+ arg_replacement_tuples=final_arg_replacement_tuples,
+ custom_mapping_fn=custom_mapping_fn,
+ kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty,
+ needs_shapes_for_normalization=needs_shapes_for_normalization,
+ )
+ _normalization_dict[op_and_target] = norm_info
+
+ # If allow_normalize_from_torch_package then add another entry to
+ # _normalization_dict where we look for the qualified name of the target with the
+ # torch_package module prefix. Note that we leave off any integer at the end of
+ # "<torch_package_>" in order to allow for whatever mangling index is used.
+ if allow_normalize_from_torch_package:
+ torch_package_op_and_target = (
+ op_and_target[0],
+ f"<torch_package_>.{_get_qualified_name(op_and_target[1])}",
+ )
+ _normalization_dict[torch_package_op_and_target] = norm_info
+
+
+def _get_dup_signature_tuples(fn: Callable) -> List[Tuple[str, str]]:
+ """
+ Helper that inspects the arg signature of `fn` and returns a list of tuples, where
+ each tuple is a pair of duplicated names which is used for arg_replacement_tuples.
+ """
+ sig_tuples: List[Tuple[str, str]] = []
+ for param in inspect.signature(inspect.unwrap(fn)).parameters:
+ sig_tuples.append((param, param))
+ return sig_tuples
+
+
+def register_acc_op(acc_op: Callable):
+ """
+ For a new acc op, add this as decorator to register it.
+ """
+ _acc_ops.add(acc_op)
+ return acc_op
+
+
+def register_acc_op_mapping(
+ op_and_target: Tuple[str, Union[str, Callable]],
+ arg_replacement_tuples: Optional[
+ List[Tuple[Union[str, Tuple[str, ...]], str]]
+ ] = None,
+ kwargs_to_move_to_acc_out_ty: Optional[List[Tuple[str, str]]] = None,
+):
+ """
+ Use this decorator to map a non-acc operator to an acc operator.
+
+ Args:
+ op_and_target: A tuple that contains op and target of the node that represents the non-acc operator.
+ arg_replacement_tuples: Please refer to the comment on above for `ArgReplacementTuplesType`.
+ kwargs_to_move_to_acc_out_ty: The kwargs we want to move out from the non-acc op kwargs to acc_out_ty.
+ """
+
+ def insert(new_fn_target: Callable):
+ # If arg_replacement_tuples is None then assume we use the same signature for
+ # the acc_op and the original op.
+ if arg_replacement_tuples is None:
+ final_arg_replacement_tuples = _get_dup_signature_tuples(new_fn_target)
+ else:
+ final_arg_replacement_tuples = arg_replacement_tuples
+
+ _insert_fun(
+ op_and_target=op_and_target,
+ new_fn_target=new_fn_target,
+ arg_replacement_tuples=final_arg_replacement_tuples,
+ kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty,
+ )
+ return new_fn_target
+
+ return insert
+
+
+def register_custom_acc_mapper_fn(
+ op_and_target: Tuple[str, Union[str, Callable]],
+ arg_replacement_tuples: List[Tuple[Union[str, Tuple[str, ...]], str]],
+ needs_shapes_for_normalization=False,
+ allow_normalize_from_torch_package=False,
+):
+ def insert(custom_mapping_fn: Callable):
+ _insert_fun(
+ op_and_target=op_and_target,
+ custom_mapping_fn=custom_mapping_fn,
+ arg_replacement_tuples=arg_replacement_tuples,
+ needs_shapes_for_normalization=needs_shapes_for_normalization,
+ allow_normalize_from_torch_package=allow_normalize_from_torch_package,
+ )
+ return custom_mapping_fn
+
+ return insert
+
+
+def move_kwargs_to_acc_out_ty(
+ node_or_normalization_info: Union[NormalizationInfo, torch.fx.Node],
+ new_kwargs: Dict[str, Any],
+):
+ """
+ Given `node_or_normalization_info` which is either NormalizationInfo for a node, or
+ a node to fetch NormalizationInfo for, check if kwargs_to_move_to_acc_out_ty exists
+ in the NormalizationInfo, and if so perform the move of kwargs to acc_out_ty.
+ """
+ if isinstance(node_or_normalization_info, torch.fx.Node):
+ node = node_or_normalization_info
+ normalization_info = _normalization_dict.get((node.op, node.target))
+ else:
+ normalization_info = node_or_normalization_info
+
+ if normalization_info.kwargs_to_move_to_acc_out_ty is None:
+ return
+
+ assert acc_utils.is_acc_op_with_kwarg(
+ normalization_info.new_fn_target, "acc_out_ty"
+ )
+
+ # Build a dict representing the new TensorMetadata to use for acc_out_ty,
+ # and then remove the kwarg from the new_kwargs since it's passed in via
+ # acc_out_ty instead.
+ tmd_dict: Dict[str, Any] = {}
+ for (
+ orig_kwarg_name,
+ tmd_field_name,
+ ) in normalization_info.kwargs_to_move_to_acc_out_ty:
+ tmd_dict[tmd_field_name] = new_kwargs[orig_kwarg_name]
+ del new_kwargs[orig_kwarg_name]
+ # Note: allow_partial_spec here because we are only using the tensor metadata tuple
+ # here to pass specific values into the function. For example, for quantization we
+ # only need to provide dtype/q_scale/q_zero_point, but is_quantized and qscheme are
+ # not passed in.
+ new_kwargs["acc_out_ty"] = acc_utils.build_raw_tensor_meta(**tmd_dict)
+
+
+def get_normalized_kwargs(
+ node: torch.fx.Node, arg_replacement_tuples: ArgReplacementTuplesType
+):
+ new_kwargs = {}
+ final_arg_is_varg = False
+ for i, replacement_tuple in enumerate(arg_replacement_tuples):
+ orig_kwargs_names, new_kwarg_name, is_optional = replacement_tuple
+
+ # Check if this is a varg and if so break/process the rest outside the loop.
+ if len(orig_kwargs_names) == 1 and orig_kwargs_names[0] == "*":
+ assert i == len(arg_replacement_tuples) - 1
+ final_arg_is_varg = True
+ break
+
+ # If nothing is found in node.kwargs it means the kwarg is in node.arg
+ # or it's optional. In this case, we set orig_kwargs_name to None.
+ assert isinstance(orig_kwargs_names, tuple)
+ orig_kwargs_name = next(
+ (key for key in orig_kwargs_names if key in node.kwargs),
+ None,
+ )
+
+ # If can't find in node.kwargs then it should be in the i index
+ # of node.args.
+ if orig_kwargs_name is None:
+ if i < len(node.args):
+ new_kwargs[new_kwarg_name] = node.args[i]
+ else:
+ # Verify the arg we're trying to normalize was optional.
+ assert is_optional
+ else:
+ new_kwargs[new_kwarg_name] = node.kwargs[orig_kwargs_name]
+
+ # If using var args then process the rest of the args now.
+ if final_arg_is_varg:
+ var_arg_idx = len(arg_replacement_tuples) - 1
+ new_kwarg_name = arg_replacement_tuples[var_arg_idx][1]
+ rest_of_args = []
+ for i in range(var_arg_idx, len(node.args)):
+ rest_of_args.append(node.args[i])
+ new_kwargs[new_kwarg_name] = rest_of_args
+
+ return new_kwargs
+
+
+def normalize(mod: torch.fx.GraphModule, expect_nodes_have_shapes: bool = False):
+ assert len(_normalization_dict) > 0
+ graph = mod.graph
+
+ # For "call_module" node we return _base_class_origin if it's a
+ # RewrittenModule, otherwise, return its type. For other nodes,
+ # we return node.target.
+ def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node):
+ if node.op != "call_module":
+ return node.target
+
+ # Find the module that node.target points to
+ m = dict(mod.named_modules())[node.target]
+ return getattr(m, "_base_class_origin", type(m))
+
+ def normalize_to_acc_op(
+ node: torch.fx.Node,
+ normalization_info: NormalizationInfo,
+ normalized_args: Tuple[Any, ...],
+ normalized_kwargs: Dict[str, Any],
+ ):
+ # If there's a custom mapping function then use it.
+ if normalization_info.custom_mapping_fn is not None:
+ # For custom mapping, the normalized_kwargs are used for the original op,
+ # i.e. *before* custom acc_ops normalization. Do that now.
+ node.args = normalized_args
+ node.kwargs = normalized_kwargs
+ new_node = normalization_info.custom_mapping_fn(node, mod)
+ # If a new node is returned then use it to replace the old node. Otherwise
+ # the custom mapping function did its own replacement, so return early.
+ if new_node is None:
+ return
+ else:
+ # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in
+ # normalized_kwargs, and remove the kwarg from normalized_kwargs.
+ move_kwargs_to_acc_out_ty(normalization_info, normalized_kwargs)
+
+ # All acc ops are functions. Create a call to the correct acc_ops target using
+ # the normalized kwargs provided.
+ with graph.inserting_before(node):
+ new_node = graph.create_node(
+ "call_function",
+ normalization_info.new_fn_target,
+ args=normalized_args,
+ kwargs=normalized_kwargs,
+ name=node.name,
+ )
+ new_node.meta = node.meta.copy()
+
+ # Finally replace the original node with the normalized node.
+ node.replace_all_uses_with(new_node)
+ graph.erase_node(node)
+
+ for node in graph.nodes:
+ if node.op in {"placeholder", "get_attr", "output"}:
+ continue
+
+ normalization_info = _normalization_dict.get((node.op, get_target(mod, node)))
+
+ # Also check if the torch_packaged version of the op was specified to be normalized.
+ if normalization_info is None and node.op == "call_function":
+ # Strip off the mangle_index suffix here before checking the map.
+ target = re.sub(
+ r"\A<torch_package_\d+>",
+ "<torch_package_>",
+ _get_qualified_name(node.target),
+ )
+ torch_package_op_and_target = (node.op, target)
+ normalization_info = _normalization_dict.get(torch_package_op_and_target)
+
+ if normalization_info is None:
+ continue
+
+ # Get the normalized kwargs to be used by normalize_to_acc_op below. If
+ # normalization_info.arg_replacement_tuples is empty then assume the function
+ # signature must be left as is.
+ if len(normalization_info.arg_replacement_tuples) == 0:
+ normalized_args = node.args
+ normalized_kwargs = node.kwargs
+ else:
+ normalized_args = ()
+ normalized_kwargs = get_normalized_kwargs(
+ node, normalization_info.arg_replacement_tuples
+ )
+
+ if (
+ normalization_info.needs_shapes_for_normalization
+ and not expect_nodes_have_shapes
+ ):
+ # All nodes needing shapes for normalization should be custom mapped.
+ assert normalization_info.custom_mapping_fn is not None
+ # For custom mapping, the normalized_kwargs are used for the original op,
+ # i.e. *before* custom acc_ops normalization. Do that now so that whoever
+ # consumes the graph next (e.g. shape inference) can use kwargs safely.
+ node.args = normalized_args
+ node.kwargs = normalized_kwargs
+ continue
+
+ try:
+ normalize_to_acc_op(
+ node, normalization_info, normalized_args, normalized_kwargs
+ )
+ except Exception:
+ print(f"Error during normalization for node: {node.format_node()}")
+ raise
+
+ # If there are any dead nodes left after normalization, eliminate them now.
+ mod.graph.eliminate_dead_code()
--- /dev/null
+# encoding: utf-8
+# type: ignore[]
+import operator
+
+import torch # isort:skip
+import torch.fx.experimental.fx_acc.acc_utils as acc_utils
+import torch.nn as nn
+from torch.fx.experimental.fx_acc.acc_normalizer import (
+ register_acc_op,
+ register_acc_op_mapping,
+ register_custom_acc_mapper_fn,
+)
+from torch.fx.passes.shape_prop import extract_tensor_metadata
+
+this_arg_is_optional = True
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.linear))
+@register_acc_op
+def linear(*, input, weight, bias):
+ return nn.functional.linear(**locals())
+
+
+@register_acc_op
+def quantized_linear(*, input, weight, bias, acc_out_ty=None):
+ assert acc_out_ty is not None
+ return nn.quantized.functional.linear(
+ input,
+ weight,
+ bias,
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+ )
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_method", "flatten"),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("start_dim", "start_dim", this_arg_is_optional),
+ ("end_dim", "end_dim", this_arg_is_optional),
+ ],
+)
+@register_acc_op_mapping(op_and_target=("call_function", torch.flatten))
+@register_acc_op
+def flatten(*, input, start_dim=0, end_dim=-1):
+ return torch.flatten(**locals())
+
+
+@register_acc_op_mapping(
+ op_and_target=(
+ "call_method",
+ "squeeze",
+ ),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("dim", "dim", this_arg_is_optional),
+ ],
+)
+@register_acc_op
+def squeeze(*, input, dim=None):
+ if dim is None:
+ return input.squeeze()
+ return input.squeeze(dim=dim)
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.max_pool2d))
+@register_acc_op
+def max_pool2d(
+ *, input, kernel_size, stride, padding, dilation, ceil_mode, return_indices
+):
+ return nn.functional.max_pool2d(**locals())
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_function", nn.functional.adaptive_avg_pool2d)
+)
+@register_acc_op
+def adaptive_avg_pool2d(*, input, output_size):
+ return nn.functional.adaptive_avg_pool2d(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.avg_pool2d))
+@register_acc_op
+def avg_pool2d(
+ *,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad,
+ divisor_override,
+):
+ return nn.functional.avg_pool2d(**locals())
+
+
+@register_acc_op
+def size(*, input):
+ return input.size()
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", getattr),
+ arg_replacement_tuples=[],
+)
+def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+ """
+ Custom function for mapping a call_function getattr to other ops. Currently only
+ supports loading a getattr called on a torch.Tensor with attr name "shape", which is
+ supported by mapping it to acc_ops.size().
+ """
+ # Have to use args here since getattr forces positional args.
+ input_obj = node.args[0]
+ attr_name = node.args[1]
+ assert (
+ input_obj.meta["type"] == torch.Tensor
+ ), f"Expected torch.Tensor type for {input_obj.meta['type']}"
+ assert (
+ attr_name == "shape"
+ ), f"Only supporting shape getattr for now, not {attr_name}"
+ with node.graph.inserting_before(node):
+ size_node = node.graph.call_function(size, kwargs={"input": input_obj})
+ size_node.meta = node.meta.copy()
+ return size_node
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_method", "size"),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("dim", "dim", this_arg_is_optional),
+ ],
+)
+def tensor_size_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+ """
+ Mapping from Tensor.size() to acc_ops.size. We map size() to acc_ops.size directly
+ and map size(dim) to acc_ops.size + acc_ops.getitem.
+ """
+
+ with node.graph.inserting_before(node):
+ size_node = node.graph.call_function(
+ size, kwargs={"input": node.kwargs["input"]}
+ )
+
+ if "dim" not in node.kwargs:
+ size_node.meta = node.meta.copy()
+ return size_node
+
+ size_node.meta["type"] = torch.Size
+ getitem_node = node.graph.call_function(
+ getitem, kwargs={"input": size_node, "idx": node.kwargs["dim"]}
+ )
+ getitem_node.meta = node.meta.copy()
+ return getitem_node
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.add))
+@register_acc_op_mapping(op_and_target=("call_method", "add"))
+@register_acc_op
+def add(*, input, other):
+ return input + other
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.unsqueeze))
+@register_acc_op
+def unsqueeze(*, input, dim):
+ return torch.unsqueeze(**locals())
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.stack),
+ arg_replacement_tuples=[
+ ("tensors", "tensors"),
+ ("dim", "dim"),
+ ],
+)
+def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+ """
+ Map torch.stack to unsqueeze + cat.
+ """
+ with node.graph.inserting_before(node):
+ inputs = node.kwargs["tensors"]
+ unsqueeze_nodes = []
+ for i, t in enumerate(inputs):
+ new_node = node.graph.create_node(
+ "call_function",
+ unsqueeze,
+ kwargs={"input": t, "dim": node.kwargs["dim"]},
+ name=f"{node.name}_unsqueeze_{i}",
+ )
+ new_node.meta["type"] = torch.Tensor
+ unsqueeze_nodes.append(new_node)
+ cat_node = node.graph.create_node(
+ "call_function",
+ cat,
+ kwargs={"tensors": unsqueeze_nodes, "dim": node.kwargs["dim"]},
+ )
+ cat_node.meta = node.meta.copy()
+ return cat_node
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.clamp))
+@register_acc_op_mapping(op_and_target=("call_method", "clamp"))
+@register_acc_op
+def clamp(*, input, min, max):
+ return torch.clamp(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.cat))
+@register_acc_op
+def cat(*, tensors, dim):
+ return torch.cat(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.transpose))
+@register_acc_op_mapping(op_and_target=("call_method", "transpose"))
+@register_acc_op
+def transpose(*, input, dim0, dim1):
+ if input.dim() < 2:
+ return input
+ return torch.transpose(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.softmax))
+@register_acc_op
+def softmax(*, input, dim, dtype):
+ """
+ _stacklevel are ignored here.
+ """
+ return torch.nn.functional.softmax(**locals())
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.addmm),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("mat1", "mat1"),
+ ("mat2", "mat2"),
+ ("beta", "beta"),
+ ("alpha", "alpha"),
+ ],
+)
+def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+ """
+ Mapping from torch.addmm to acc_ops.mm -> acc_ops.add, if alpha or beta is not 1
+ then we also insert acc_ops.mul to the right place.
+ """
+ with node.graph.inserting_before(node):
+ mm_kwargs = {"input": node.kwargs["mat1"], "other": node.kwargs["mat2"]}
+ mm_node = node.graph.create_node(
+ "call_function", matmul, kwargs=mm_kwargs, name=f"{node.name}_mm"
+ )
+ mm_node.meta = node.meta.copy()
+
+ if node.kwargs["alpha"] != 1:
+ mul_kwargs = {"input": mm_node, "other": node.kwargs["alpha"]}
+ mm_node = node.graph.create_node(
+ "call_function", mul, kwargs=mul_kwargs, name=f"{mm_node.name}_mul"
+ )
+ mm_node.meta = node.meta.copy()
+
+ input_node = node.kwargs["input"]
+ if node.kwargs["beta"] != 1:
+ mul_kwargs = {"input": input_node, "other": node.kwargs["beta"]}
+ new_input_node = node.graph.create_node(
+ "call_function", mul, kwargs=mul_kwargs, name=f"{node.name}_input_mul"
+ )
+ new_input_node.meta = input_node.meta.copy()
+ input_node = new_input_node
+
+ add_kwargs = {"input": mm_node, "other": input_node}
+ add_node = node.graph.create_node(
+ "call_function", add, kwargs=add_kwargs, name=f"{node.name}_add"
+ )
+ add_node.meta = node.meta.copy()
+ return add_node
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.t),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ],
+)
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_method", "t"),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ],
+)
+def t_mapper(node: torch.fx.Node, _: nn.Module):
+ with node.graph.inserting_before(node):
+ new_node = node.graph.create_node(
+ "call_function",
+ transpose,
+ kwargs={"input": node.kwargs["input"], "dim0": 0, "dim1": 1},
+ name=node.name,
+ )
+ new_node.meta = node.meta.copy()
+ return new_node
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_method", "permute"),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("*", "permutation"),
+ ],
+)
+@register_acc_op
+def permute(*, input, permutation):
+ return input.permute(*permutation)
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.square),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ],
+)
+def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+ input_node = node.kwargs["input"]
+ with node.graph.inserting_before(node):
+ new_node = node.graph.call_function(
+ mul, kwargs={"input": input_node, "other": input_node}
+ )
+ new_node.meta = node.meta.copy()
+ return new_node
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_function", torch.bmm),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("mat2", "other"),
+ ],
+)
+@register_acc_op_mapping(op_and_target=("call_function", torch.matmul))
+@register_acc_op
+def matmul(*, input, other):
+ return torch.matmul(**locals())
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.min),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("other", "other", this_arg_is_optional),
+ ("dim", "dim", this_arg_is_optional),
+ ("keepdim", "keepdim", this_arg_is_optional),
+ ],
+)
+def custom_torch_min_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
+ """
+ Add custom mapping for torch.min because torch.min has three input types, where each yields a different output type:
+ 1. torch.min(input); Output: tensor wih the minimum number
+ 2. torch.min(input, other); Output: tensor with coordinate-wise min value
+ 3[Not Supported] torch.min(input, dim, keepdim); Output:(min_values,and min_indices)
+ """
+ with node.graph.inserting_before(node):
+ # If dim is in kwargs, assert "Not Supported"
+ assert "dim" not in node.kwargs, "Currently not support dim in torch.min"
+
+ if "other" in node.kwargs and node.kwargs["other"] is not None:
+ # If kwargs[other] is a valid tensor, call min_two_tensors_input,
+ op_func = min_two_tensors_input
+ else:
+ # Otherwise, call min_single_tensor_input
+ op_func = min_single_tensor_input
+
+ new_node = node.graph.create_node(
+ "call_function",
+ op_func,
+ kwargs=node.kwargs,
+ name=node.name,
+ )
+ new_node.meta = node.meta
+ return new_node
+
+
+@register_acc_op
+def min_single_tensor_input(*, input):
+ return torch.min(input)
+
+
+@register_acc_op
+def min_two_tensors_input(*, input, other):
+ return torch.min(input, other)
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_function", torch.ops.quantized.add),
+ arg_replacement_tuples=[
+ ("qa", "input"),
+ ("qb", "other"),
+ ("scale", "scale"),
+ ("zero_point", "zero_point"),
+ ],
+ kwargs_to_move_to_acc_out_ty=[
+ ("scale", "q_scale"),
+ ("zero_point", "q_zero_point"),
+ ],
+)
+@register_acc_op
+def quantized_add(*, input, other, acc_out_ty=None):
+ assert acc_out_ty is not None
+ return torch.ops.quantized.add(
+ input,
+ other,
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+ )
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_function", torch.ops.quantized.mul),
+ arg_replacement_tuples=[
+ ("qa", "input"),
+ ("qb", "other"),
+ ("scale", "scale"),
+ ("zero_point", "zero_point"),
+ ],
+ kwargs_to_move_to_acc_out_ty=[
+ ("scale", "q_scale"),
+ ("zero_point", "q_zero_point"),
+ ],
+)
+@register_acc_op
+def quantized_mul(*, input, other, acc_out_ty=None):
+ assert acc_out_ty is not None
+ return torch.ops.quantized.mul(
+ input,
+ other,
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+ )
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_function", torch.quantize_per_tensor),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("scale", "scale"),
+ ("zero_point", "zero_point"),
+ ("dtype", "dtype"),
+ ],
+ kwargs_to_move_to_acc_out_ty=[
+ ("dtype", "dtype"),
+ ("scale", "q_scale"),
+ ("zero_point", "q_zero_point"),
+ ],
+)
+@register_acc_op
+def quantize_per_tensor(*, input, acc_out_ty=None):
+ assert acc_out_ty is not None
+ return torch.quantize_per_tensor(
+ input,
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"),
+ )
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.dequantize))
+@register_acc_op_mapping(op_and_target=("call_method", "dequantize"))
+@register_acc_op
+def dequantize(*, input):
+ return torch.dequantize(input)
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.sub))
+@register_acc_op
+def sub(*, input, other):
+ return input - other
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.mul))
+@register_acc_op
+def mul(*, input, other):
+ return input * other
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.truediv))
+@register_acc_op
+def div(*, input, other):
+ return input / other
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.relu))
+@register_acc_op_mapping(
+ op_and_target=("call_function", torch.relu),
+ arg_replacement_tuples=[("input", "input")],
+)
+@register_acc_op_mapping(
+ op_and_target=("call_method", "relu"),
+ arg_replacement_tuples=[("input", "input")],
+)
+@register_acc_op
+def relu(*, input, inplace=False):
+ return nn.functional.relu(**locals())
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_method", "sum"),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("dim", "dim", this_arg_is_optional),
+ ("keepdim", "keepdim", this_arg_is_optional),
+ ("dtype", "dtype", this_arg_is_optional),
+ ],
+)
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.sum),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("dim", "dim", this_arg_is_optional),
+ ("keepdim", "keepdim", this_arg_is_optional),
+ ("dtype", "dtype", this_arg_is_optional),
+ ],
+)
+def add_sum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node:
+ with node.graph.inserting_before(node):
+ sum_kwargs = dict(node.kwargs)
+ if "dim" in sum_kwargs and isinstance(sum_kwargs["dim"], int):
+ sum_kwargs["dim"] = (sum_kwargs["dim"],)
+ sum_node = node.graph.call_function(sum, kwargs=sum_kwargs)
+ sum_node.meta = node.meta.copy()
+ return sum_node
+
+
+@register_acc_op
+def sum(*, input, dim=None, keepdim=False, dtype=None):
+ if dim:
+ return torch.sum(**locals())
+ else:
+ return input.sum(dtype=dtype)
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.sigmoid))
+@register_acc_op_mapping(op_and_target=("call_method", "sigmoid"))
+@register_acc_op
+def sigmoid(*, input):
+ return torch.sigmoid(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.sinh))
+@register_acc_op
+def sinh(*, input):
+ return torch.sinh(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.cosh))
+@register_acc_op
+def cosh(*, input):
+ return torch.cosh(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.tanh))
+@register_acc_op
+def tanh(*, input):
+ return torch.tanh(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.asin))
+@register_acc_op
+def asin(*, input):
+ return torch.asin(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.acos))
+@register_acc_op
+def acos(*, input):
+ return torch.acos(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.atan))
+@register_acc_op
+def atan(*, input):
+ return torch.atan(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.exp))
+@register_acc_op
+def exp(*, input):
+ return torch.exp(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.log))
+@register_acc_op
+def log(*, input):
+ return torch.log(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.sqrt))
+@register_acc_op
+def sqrt(*, input):
+ return torch.sqrt(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.reciprocal))
+@register_acc_op
+def reciprocal(*, input):
+ return torch.reciprocal(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.abs))
+@register_acc_op
+def abs(*, input):
+ return torch.abs(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.neg))
+@register_acc_op
+def neg(*, input):
+ return torch.neg(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.floor))
+@register_acc_op
+def floor(*, input):
+ return torch.floor(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.ceil))
+@register_acc_op
+def ceil(*, input):
+ return torch.ceil(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.conv2d))
+@register_acc_op
+def conv2d(*, input, weight, bias, stride, padding, dilation, groups):
+ return nn.functional.conv2d(**locals())
+
+
+@register_acc_op
+def quantized_conv2d(
+ *,
+ input,
+ weight,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+ padding_mode,
+ acc_out_ty=None,
+):
+ assert acc_out_ty is not None
+ return torch.nn.quantized.functional.conv2d(
+ input,
+ weight,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+ padding_mode,
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+ )
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.batch_norm))
+@register_acc_op
+def batch_norm(
+ *, input, running_mean, running_var, weight, bias, training, momentum, eps
+):
+ return nn.functional.batch_norm(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.layer_norm))
+@register_acc_op
+def layer_norm(*, input, normalized_shape, weight, bias, eps):
+ return nn.functional.layer_norm(**locals())
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_method", "split"),
+ arg_replacement_tuples=[
+ ("tensor", "input"),
+ ("split_size_or_sections", "split_size_or_sections"),
+ ("dim", "dim"),
+ ],
+)
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.split),
+ arg_replacement_tuples=[
+ ("tensor", "input"),
+ ("split_size_or_sections", "split_size_or_sections"),
+ ("dim", "dim"),
+ ],
+)
+def torch_split_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
+ """
+ If split_size_or_sections is sections, map the node to slice_tensors
+ + tuple_construct. Otherwise, if split_size_or_sections is split_size,
+ map the node to acc_ops.split.
+ """
+ split_size_or_sections = node.kwargs["split_size_or_sections"]
+ with node.graph.inserting_before(node):
+ if isinstance(split_size_or_sections, int):
+ new_kwargs = {
+ "input": node.kwargs["input"],
+ "split_size": split_size_or_sections,
+ "dim": node.kwargs["dim"],
+ }
+ new_node = node.graph.call_function(split, kwargs=new_kwargs)
+ new_node.meta = node.meta.copy()
+ return new_node
+
+ start = 0
+ slice_nodes = []
+ for i in split_size_or_sections:
+ new_kwargs = {
+ "input": node.kwargs["input"],
+ "dims": (node.kwargs["dim"],),
+ "starts": (start,),
+ "stops": (start + i,),
+ "steps": (1,),
+ }
+ new_node = node.graph.call_function(slice_tensor, kwargs=new_kwargs)
+ new_node.meta["type"] = torch.Tensor
+ slice_nodes.append(new_node)
+ start += i
+
+ new_node = node.graph.call_function(
+ tuple_construct, kwargs={"tensors": tuple(slice_nodes)}
+ )
+ new_node.meta = node.meta.copy()
+ return new_node
+
+
+@register_acc_op
+def split(*, input, split_size, dim):
+ return torch.split(input, split_size, dim)
+
+
+@register_acc_op
+def tuple_construct(*, tensors):
+ return tuple(tensors)
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_function", torch.ops.quantized.batch_norm2d),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("weight", "weight"),
+ ("bias", "bias"),
+ ("running_mean", "running_mean"),
+ ("running_var", "running_var"),
+ ("eps", "eps"),
+ ("scale", "scale"),
+ ("zero_point", "zero_point"),
+ ],
+ kwargs_to_move_to_acc_out_ty=[
+ ("scale", "q_scale"),
+ ("zero_point", "q_zero_point"),
+ ],
+)
+@register_acc_op
+def quantized_batch_norm2d(
+ *, input, running_mean, running_var, weight, bias, eps, acc_out_ty
+):
+ return torch.ops.quantized.batch_norm2d(
+ input,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ eps,
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"),
+ acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"),
+ )
+
+
+@register_acc_op_mapping(op_and_target=("call_function", nn.functional.embedding_bag))
+@register_acc_op
+def embedding_bag(
+ *,
+ input,
+ weight,
+ offsets,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ mode,
+ sparse,
+ per_sample_weights,
+ include_last_offset,
+ padding_idx,
+):
+ return nn.functional.embedding_bag(**locals())
+
+
+@register_acc_op_mapping(
+ op_and_target=(
+ "call_function",
+ torch.ops.quantized.embedding_bag_byte_rowwise_offsets,
+ )
+)
+@register_acc_op
+def embedding_bag_byte_rowwise_offsets(
+ *,
+ weight,
+ input,
+ offsets,
+ scale_grad_by_freq,
+ mode,
+ pruned_weights,
+ per_sample_weights,
+ compressed_indices_mapping,
+ include_last_offset,
+):
+ return torch.ops.quantized.embedding_bag_byte_rowwise_offsets(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.sin))
+@register_acc_op
+def sin(*, input):
+ return torch.sin(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.cos))
+@register_acc_op
+def cos(*, input):
+ return torch.cos(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.tan))
+@register_acc_op
+def tan(*, input):
+ return torch.tan(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", torch.topk))
+@register_acc_op
+def topk(*, input, k, dim, largest, sorted):
+ return torch.topk(**locals())
+
+
+@register_acc_op_mapping(op_and_target=("call_function", operator.getitem))
+@register_acc_op
+def getitem(*, input, idx):
+ return input[idx]
+
+
+@register_acc_op
+def slice_tensor(*, input, dims, starts, stops, steps):
+ slices = [None for _ in range(input.dim())]
+
+ # For all provided dims, extract out a slice for starts/stops/steps.
+ for idx, dim in enumerate(dims):
+ slices[dim] = slice(starts[idx], stops[idx], steps[idx])
+
+ # For all unspecified dims, default to the full slice.
+ for idx, s in enumerate(slices):
+ if s is None:
+ slices[idx] = slice(None, None, None)
+
+ return input[slices]
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.narrow),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("dim", "dim"),
+ ("start", "start"),
+ ("length", "length"),
+ ],
+)
+def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
+ kwargs = {
+ "input": node.kwargs["input"],
+ "dims": (node.kwargs["dim"],),
+ "starts": (node.kwargs["start"],),
+ "stops": (node.kwargs["start"] + node.kwargs["length"],),
+ "steps": (1,),
+ }
+ with node.graph.inserting_before(node):
+ new_node = node.graph.call_function(slice_tensor, kwargs=kwargs)
+ new_node.meta = node.meta.copy()
+ return new_node
+
+
+@register_acc_op_mapping(
+ op_and_target=("call_function", torch.reshape),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("shape", "shape"),
+ ],
+ kwargs_to_move_to_acc_out_ty=[("shape", "shape")],
+)
+@register_acc_op_mapping(
+ op_and_target=("call_method", "view"),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("*", "shape"),
+ ],
+ kwargs_to_move_to_acc_out_ty=[("shape", "shape")],
+)
+@register_acc_op
+def reshape(*, input, acc_out_ty=None):
+ assert acc_out_ty is not None
+ return torch.reshape(
+ input, tuple(acc_utils.get_field_from_acc_out_ty(acc_out_ty, "shape"))
+ )
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_method", "reshape"),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("*", "shape"),
+ ],
+)
+def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
+ """
+ For Tensor.reshape node, args could be (input, 1, 2, 3) or (input, (1, 2, 3)).
+ Here we do some special handling with the `shape` arg in order to map it to
+ acc_ops.reshape. It also handles the case when `shape` is a list instead of
+ tuple.
+ """
+ input_node = node.kwargs["input"]
+ shape = node.kwargs["shape"]
+
+ if isinstance(shape[0], (tuple, list)):
+ shape = shape[0]
+
+ with node.graph.inserting_before(node):
+ new_node = node.graph.call_function(
+ reshape,
+ kwargs={
+ "input": input_node,
+ "acc_out_ty": acc_utils.build_raw_tensor_meta(shape=shape),
+ },
+ )
+ new_node.meta = node.meta.copy()
+ return new_node
+
+
+@register_acc_op
+def to_dtype(input, acc_out_ty=None):
+ assert acc_out_ty is not None, "valid acc_out_ty needed"
+ return input.to(dtype=acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"))
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_method", "to"),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("dtype", "dtype"),
+ ],
+)
+def custom_tensor_to_mapper(node: torch.fx.Node, _: nn.Module):
+ dest_dtype = node.kwargs["dtype"]
+ mem_format = node.kwargs.get("memory_format")
+ device = node.kwargs.get("device")
+ assert dest_dtype is not None
+ assert mem_format is None or mem_format == torch.preserve_format
+ assert device is None
+
+ new_kwargs = {
+ "input": node.kwargs["input"],
+ "acc_out_ty": acc_utils.build_raw_tensor_meta(dtype=dest_dtype),
+ }
+
+ with node.graph.inserting_before(node):
+ new_node = node.graph.create_node(
+ "call_function", to_dtype, kwargs=new_kwargs, name=node.name
+ )
+ new_node.meta = node.meta
+ return new_node
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.add),
+ # Note that we may have aliases for inputs here due to issues with deterministically
+ # knowing the correct target that will be resolved by pytorch.
+ arg_replacement_tuples=[
+ (("input", "a"), "input"),
+ (("other", "b"), "other"),
+ ("alpha", "alpha", this_arg_is_optional),
+ ],
+)
+def custom_torch_add_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
+ """
+ Add custom mapping for torch.add because it has an `alpha` parameter which scales
+ the `other` input, and we want to make that mul a separate node.
+ """
+ with node.graph.inserting_before(node):
+ # If alpha is in kwargs check if we need to add a mul, and use correct kwargs.
+ if "alpha" in node.kwargs:
+ # Add mul node only if it has a numerical impact, i.e. alpha != 1.0.
+ if node.kwargs["alpha"] != 1.0:
+ other_node = node.graph.create_node(
+ "call_function",
+ mul,
+ kwargs={
+ "input": node.kwargs["other"],
+ "other": node.kwargs["alpha"],
+ },
+ name=node.name + "_mul_alpha",
+ )
+ other_node.meta = node.meta
+ else:
+ other_node = node.kwargs["other"]
+ add_kwargs = {"input": node.kwargs["input"], "other": other_node}
+ else:
+ add_kwargs = node.kwargs
+
+ new_node = node.graph.create_node(
+ "call_function", add, kwargs=add_kwargs, name=node.name
+ )
+ new_node.meta = node.meta
+ return new_node
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_module", nn.quantized.Linear),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ],
+)
+def packed_quantized_linear_mapper(
+ node: torch.fx.Node, mod: nn.Module
+) -> torch.fx.Node:
+ """
+ Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias
+ in this mapper and pass them directly to linear node.
+ """
+ linear_module = dict(mod.named_modules())[node.target]
+ prefix = node.target.replace(".", "_")
+ weight_name = f"{prefix}_weight"
+ bias_name = f"{prefix}_bias"
+
+ # Store weight and bias in the main module
+ mod.register_buffer(weight_name, linear_module.weight())
+ if linear_module.bias() is not None:
+ mod.register_buffer(bias_name, linear_module.bias())
+
+ with node.graph.inserting_before(node):
+ # Insert get_attr nodes for weight and bias
+ get_weight = node.graph.get_attr(weight_name)
+ get_weight.meta["tensor_meta"] = extract_tensor_metadata(linear_module.weight())
+
+ get_bias = None
+ if linear_module.bias() is not None:
+ get_bias = node.graph.get_attr(bias_name)
+ get_bias.meta["tensor_meta"] = extract_tensor_metadata(linear_module.bias())
+
+ # Create kwargs for acc_op.quantized_linear
+ kwargs = {
+ "input": node.kwargs["input"],
+ "weight": get_weight,
+ "bias": get_bias,
+ "acc_out_ty": acc_utils.build_raw_tensor_meta(
+ q_scale=linear_module.scale, q_zero_point=linear_module.zero_point
+ ),
+ }
+
+ new_node = node.graph.call_function(quantized_linear, kwargs=kwargs)
+ new_node.meta = node.meta
+ return new_node
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_module", nn.quantized.Conv2d),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ],
+)
+def packed_quantized_conv2d_mapper(
+ node: torch.fx.Node, mod: nn.Module
+) -> torch.fx.Node:
+ """
+ Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters
+ in this mapper and pass them directly to conv2d node.
+ """
+ conv_module = dict(mod.named_modules())[node.target]
+ prefix = node.target.replace(".", "_")
+ weight_name = f"{prefix}_weight"
+ bias_name = f"{prefix}_bias"
+
+ # Store weight and bias in the main module
+ mod.register_buffer(weight_name, conv_module.weight())
+ if conv_module.bias() is not None:
+ mod.register_buffer(bias_name, conv_module.bias())
+
+ with node.graph.inserting_before(node):
+ # Insert get_attr nodes for weight and bias
+ get_weight = node.graph.get_attr(weight_name)
+ get_weight.meta["tensor_meta"] = extract_tensor_metadata(conv_module.weight())
+
+ get_bias = None
+ if conv_module.bias() is not None:
+ get_bias = node.graph.get_attr(bias_name)
+ get_bias.meta["tensor_meta"] = extract_tensor_metadata(conv_module.bias())
+
+ # Create kwargs for acc_op.conv
+ kwargs = {
+ "input": node.kwargs["input"],
+ "weight": get_weight,
+ "bias": get_bias,
+ "stride": conv_module.stride,
+ "padding": conv_module.padding,
+ "dilation": conv_module.dilation,
+ "groups": conv_module.groups,
+ "padding_mode": conv_module.padding_mode,
+ "acc_out_ty": acc_utils.build_raw_tensor_meta(
+ q_scale=conv_module.scale, q_zero_point=conv_module.zero_point
+ ),
+ }
+
+ new_node = node.graph.call_function(quantized_conv2d, kwargs=kwargs)
+ new_node.meta = node.meta
+ return new_node
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.ops.quantized.add_relu),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("other", "other"),
+ ("scale", "scale"),
+ ("zero_point", "zero_point"),
+ ],
+)
+def add_relu_unfuse_mapper(
+ node: torch.fx.Node, mod: torch.fx.GraphModule
+) -> torch.fx.Node:
+ with node.graph.inserting_before(node):
+ add_kwargs = {
+ "input": node.kwargs["input"],
+ "other": node.kwargs["other"],
+ "acc_out_ty": acc_utils.build_raw_tensor_meta(
+ q_scale=node.kwargs["scale"],
+ q_zero_point=node.kwargs["zero_point"],
+ ),
+ }
+ add_node = node.graph.call_function(quantized_add, kwargs=add_kwargs)
+ add_node.meta = node.meta.copy()
+
+ relu_node = node.graph.call_function(
+ relu, kwargs={"input": add_node, "inplace": False}
+ )
+ relu_node.meta = node.meta
+ return relu_node
+
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_module", nn.intrinsic.quantized.ConvReLU2d),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ],
+)
+def packed_quantized_convrelu2d_mapper(
+ node: torch.fx.Node, mod: nn.Module
+) -> torch.fx.Node:
+ """
+ Mapping from quantized ConvReLU2d module to acc_op.relu. We use packed_quantized_conv2d_mapper to unpack all the parameters
+ in this mapper and pass the returned conv2d node directly to relu node.
+ """
+
+ with node.graph.inserting_before(node):
+ # conv2d op
+ conv2d_node = packed_quantized_conv2d_mapper(node, mod)
+
+ # relu op
+ relu_node = node.graph.call_function(
+ relu, kwargs={"input": conv2d_node, "inplace": False}
+ )
+ relu_node.meta = node.meta
+ return relu_node
--- /dev/null
+# type: ignore[]
+import ast
+import builtins
+import copy
+import inspect
+import textwrap
+import warnings
+from types import FunctionType
+from typing import Dict, Optional, Any, Type, Tuple, Set, List
+
+import torch.fx.experimental.fx_acc.acc_normalizer as acc_normalizer
+import torch.fx.experimental.fx_acc.acc_ops # noqa: F401
+import torch
+import torch.nn as nn
+from torch._sources import normalize_source_lines
+from torch.fx import Graph, Tracer
+from torch.fx.experimental.normalize import NormalizeArgs
+from torch.fx.passes import shape_prop
+
+
+def _get_exception_wrapper_attr_name(exc_type: Type[Exception]) -> str:
+ return f"_conditional_exception_wrapper_{exc_type.__name__}"
+
+
+class Acc_Rewriter(ast.NodeTransformer):
+ """
+ Take a FunctionType object representing a `forward` method, then
+ perform an AST rewrite to swap out nodes that are not symbolically
+ traceable with a callsite to the FX alternative.
+
+ To support swapping out an AST node, define a new `visit` method on
+ that node. For more details, see:
+ https://docs.python.org/3/library/ast.html#ast.NodeTransformer
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.exceptions_rewritten: Set[Type[Exception]] = set()
+
+ def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]]]:
+
+ # Normalize the source lines
+ sourcelines, _ = inspect.getsourcelines(fn)
+ sourcelines = normalize_source_lines(sourcelines)
+ source = "".join(sourcelines)
+ normalized_str = textwrap.dedent(source)
+
+ # Rewrite the original AST
+ source_ast = ast.parse(normalized_str)
+ dest_ast = ast.fix_missing_locations(self.visit(source_ast))
+
+ # Pull out the compiled function from the newly-created Module
+ code = compile(dest_ast, "", "exec")
+ globals_dict = copy.copy(fn.__globals__)
+ keys_before = set(globals_dict.keys())
+ exec(code, globals_dict)
+ new_keys = list(set(globals_dict.keys()) - keys_before)
+ assert len(new_keys) <= 1
+ fn_compiled = globals_dict[fn.__name__]
+
+ # Return the correct FunctionType object and the Exceptions that were
+ # rewritten during visit_If.
+ return fn_compiled, self.exceptions_rewritten
+
+ def visit_Assert(self, node: ast.Assert):
+ """
+ Swap out the Assert node (Python's `assert`) with a callsite to the
+ symbolically-traceable torch._assert function
+ """
+ # Create the Call node
+ n = ast.parse("torch._assert()", mode="eval")
+ assert isinstance(n, ast.Expression)
+ call_node = n.body
+ assert isinstance(call_node, ast.Call)
+ msg = node.msg if node.msg else ast.Constant(value="", kind=None)
+ call_node.args = [node.test, msg]
+
+ # Ensure that the new node conforms to the Python AST grammar
+ expr_wrapper = ast.Expr(value=call_node)
+
+ # Return the new Call node to signify that we want to use it as
+ # a replacement for the original _assert node
+ return ast.copy_location(expr_wrapper, node)
+
+ def visit_If(self, if_node: ast.If):
+ """
+ Swap out the pattern `If(x): Raise(y)` with a ConditionalExceptionWrapper
+ specialized for the specific exception y. The specialized
+ ConditionalExceptionWrapper module will be added in the RewrittenModule.
+ Only works with builtin Exceptions, as we assume the signature of the
+ init for the Exception is a string.
+ """
+ raise_node = if_node.body[0]
+ if not isinstance(raise_node, ast.Raise):
+ return if_node
+
+ # Don't handle orelse for now.
+ # TODO: Move orelse to the body after calling ConditionalExceptionWrapper.
+ if len(if_node.orelse) != 0:
+ return if_node
+
+ def _reuse_loc(node):
+ return ast.copy_location(node, if_node)
+
+ # If the exception has a message then we expect the raise's exc to be a
+ # Call w/ a msg. Else if it's a exc Name then there's no msg to use.
+ node_for_exc = raise_node.exc
+ if isinstance(node_for_exc, ast.Name):
+ # E.g. `raise AssertionError`, i.e. without an exc_msg.
+ name_node_of_exc = node_for_exc
+ exc_msg = _reuse_loc(ast.Constant(None))
+ elif isinstance(node_for_exc, ast.Call):
+ # E.g. `raise AssertionError("error message")`
+ name_node_of_exc = node_for_exc.func
+ if not isinstance(name_node_of_exc, ast.Name):
+ return if_node
+ # Most assertions just take a single string arg, but some may not; skip
+ # handling such assertions for now.
+ if len(node_for_exc.args) != 1:
+ return if_node
+ exc_msg = node_for_exc.args[0]
+ else:
+ return if_node
+
+ # Convert what we expect is the name of the exception into its
+ # associated python class.
+ name_of_exc = name_node_of_exc.id
+ try:
+ exc_type = eval(name_of_exc)
+ except Exception:
+ return if_node
+
+ # Check that we actually have a builtin exception.
+ if (
+ not issubclass(exc_type, Exception)
+ or getattr(getattr(exc_type, "__class__", None), "__module__", None)
+ != "builtins"
+ ):
+ return if_node
+
+ # We need a ConditionalExceptionWrapper specialized for every kind of
+ # exception, so add it to exceptions_rewritten to remember for later to
+ # add a specialized attr with it.
+ self.exceptions_rewritten.add(exc_type)
+
+ # From here we definitely should be able to do the replacement. Create a
+ # Call node to the ConditionalExceptionWrapper module we're replacing
+ # the If with, with args set as the If's condition and the string of the
+ # exception. The call to the self._conditional_exception_wrapper_*Error
+ # module is safe because the RewrittenModule will add it as an attr
+ # based on the returned exceptions_rewritten, and we assume we are
+ # currently modifying the AST of a method from a RewrittenModule.
+ exc_wrapper_node = ast.parse(
+ f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval"
+ )
+ assert isinstance(exc_wrapper_node, ast.Expression)
+ exc_wrapper_call_node = exc_wrapper_node.body
+ assert isinstance(exc_wrapper_call_node, ast.Call)
+ exc_wrapper_call_node.args = [if_node.test, exc_msg]
+
+ # Ensure that the new node conforms to the Python AST grammar
+ expr_wrapper = _reuse_loc(ast.Expr(_reuse_loc(exc_wrapper_call_node)))
+
+ # Return the new node to signify that we want to use it as a replacement
+ # for the original `If x: Raise y` pattern.
+ return expr_wrapper
+
+
+class ConditionalExceptionWrapper(nn.Module):
+ """
+ This wrapper class is used to wrap conditional raising of exceptions during
+ rewriting. For example:
+
+ .. code-block:: python
+
+ if self.name != "x":
+ raise AssertionError(f"Name was not x: {self.name}")
+
+ Is rewritten into
+
+ .. code-block:: python
+
+ self._conditional_exception_wrapper_AssertionError(
+ self.name != "x", f"Name was not x: {self.name}"
+ )
+
+ Note that __init__ takes the Exception class that it is wrapping, while
+ forward takes the condition to check and the message for the exception.
+
+ """
+
+ # Mark as impure so that calls to it will not be removed during DCE.
+ _is_impure = True
+
+ def __init__(self, exc: Type[Exception]):
+ super().__init__()
+ self.exc = exc
+
+ def forward(self, cond: bool, msg: str):
+ if cond:
+ raise self.exc if msg is None else self.exc(msg)
+
+
+# Custom tracer that traces to the functional level and rewrites asserts and
+# exceptions.
+class AccRewritingTracer(Tracer):
+ # Note: Treat ConditionalExceptionWrapper as a leaf so that we don't
+ # trace into it, because it contains control flow and raises an exception.
+ DEFAULT_LEAF_MODULE_LIST = {
+ ConditionalExceptionWrapper,
+ torch.nn.quantized.Linear,
+ torch.nn.quantized.Conv2d,
+ torch.nn.intrinsic.quantized.ConvReLU2d,
+ }
+
+ def is_leaf_module(self, m: nn.Module, mod_qual_name: str) -> bool:
+ return getattr(m, "_base_class_origin", type(m)) in self.leaf_module_list
+
+ def trace(
+ self,
+ root: nn.Module,
+ concrete_args: Optional[Dict[str, Any]] = None,
+ ast_rewriter_allow_list: Optional[Set] = None,
+ leaf_module_list: Optional[Set] = None,
+ ) -> Tuple[Graph, nn.Module]:
+ rewritten = _rewrite(root, ast_rewriter_allow_list)
+ self.leaf_module_list = self.DEFAULT_LEAF_MODULE_LIST
+ if leaf_module_list:
+ self.leaf_module_list.update(leaf_module_list)
+ return super().trace(rewritten, concrete_args), rewritten
+
+
+# List of modules that need rewriting to be supported for tracing.
+DEFAULT_REWRITE_ALLOW_LIST = {
+ nn.BatchNorm1d,
+ nn.BatchNorm2d,
+ nn.BatchNorm3d,
+}
+
+
+def _rewrite(mod_to_rewrite: nn.Module, allow_list: Optional[Set] = None) -> nn.Module:
+ if allow_list is None:
+ allow_list = DEFAULT_REWRITE_ALLOW_LIST
+ else:
+ allow_list.union(DEFAULT_REWRITE_ALLOW_LIST)
+
+ # Rewrite this module's functions as well as all recursive modules'
+ # functions that are attrs of this moodule. Return the new, rewritten module
+ # hierarchy.
+ def rewrite_module(m: nn.Module):
+ base_class = type(m)
+
+ # Keep track of all the ConditionalExceptionWrappers that the
+ # Acc_Rewriter calls into in this module so we can add them in init
+ # below.
+ all_added_wrappers: Set[Type[Exception]] = set()
+
+ # Note: Make this a subclass of our base class.
+ class RewrittenModule(base_class):
+ # Keep track of the base_class so that symbolic tracing can
+ # determine what kind of module this originally was later on.
+ _base_class_origin = base_class
+ # Add suffix to qualname so it's easier to debug the origin of this module.
+ __qualname__ = f"{base_class.__qualname__}__AccRewrittenModule"
+
+ # Write all of the non-dunder or special methods from base_class
+ # into RewrittenModule.
+ for method_name in dir(base_class):
+ method = getattr(base_class, method_name)
+ if builtins.type(method) is not FunctionType:
+ continue
+
+ # Always skip rewriting dunder methods, as they haven't (yet) been
+ # problematic, and modifying them has caused issues previously.
+ if method_name.startswith("__") and method_name.endswith("__"):
+ continue
+
+ # Only rewrite those Modules explicitly in the allow_list.
+ if base_class not in allow_list:
+ vars()[method_name] = method
+ else:
+ vars()[method_name], added_wrappers = Acc_Rewriter().rewrite(method)
+ all_added_wrappers.update(added_wrappers)
+
+ def __init__(self, orig):
+ nn.Module.__init__(self)
+ # Iterate over all added exception wrappers and add
+ # ConditionalExceptionWrapper attrs for each.
+ for exc_type in all_added_wrappers:
+ wrapper_name = _get_exception_wrapper_attr_name(exc_type)
+ assert not hasattr(self, wrapper_name)
+ setattr(
+ self,
+ wrapper_name,
+ ConditionalExceptionWrapper(exc_type),
+ )
+ # Recursively rewrite and copy all module attrs of this module.
+ for k, v in orig.__dict__.items():
+ if k == "_modules":
+ for mod_k, mod_v in v.items():
+ self._modules[mod_k] = rewrite_module(mod_v)
+ else:
+ self.__dict__[k] = v
+
+ # Add suffix to name so it's easier to debug the origin of this module.
+ RewrittenModule.__name__ = f"{base_class.__name__}__AccRewrittenModule"
+ return RewrittenModule(m)
+
+ return rewrite_module(mod_to_rewrite)
+
+
+def _remove_assertions(gm: torch.fx.GraphModule) -> bool:
+ """
+ Unconditionally removes all assertions found in GraphModule gm.
+ Returns whether the graph is modified.
+ """
+ changed = False
+ for node in gm.graph.nodes:
+ if node.op == "call_function" and node.target == torch._assert:
+ gm.graph.erase_node(node)
+ changed = True
+ return changed
+
+
+def _remove_exceptions(gm: torch.fx.GraphModule) -> bool:
+ """
+ Unconditionally removes all call_modules to ConditionalExceptionWrappers
+ found in GraphModule gm. Returns whether the graph is modified.
+ """
+ changed = False
+ for node in gm.graph.nodes:
+ if node.op == "call_module" and isinstance(
+ gm.get_submodule(node.target), ConditionalExceptionWrapper
+ ):
+ gm.graph.erase_node(node)
+ changed = True
+ return changed
+
+
+def trace(
+ mod: nn.Module,
+ sample_inputs: List[torch.Tensor],
+ remove_assertions: bool = True,
+ remove_exceptions: bool = True,
+ use_acc_normalization: bool = True,
+ ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None,
+ leaf_module_list: Optional[Set[Type[nn.Module]]] = None,
+) -> torch.fx.GraphModule:
+ """
+ Performs tracing and arg normalization specialized for accelerator lowering.
+
+ It first rewrites the AST of the module's methods (and all attr methods
+ recursively) to transform un-tracable parts of the module to make them
+ traceable.
+
+ It then traces to the functional level so that optimizations and backend
+ accelerator importers have the ability to see and/or change inputs to each
+ op.
+
+ It then removes assertions and exception wrappers found during symbolic
+ tracing if requested based on remove_assertions and remove_exceptions
+
+ Dead code is then eliminated, which will e.g. remove any nodes that were
+ only used by assertions or exceptions if they were removed.
+
+ It then performs normalization on args/kwargs, aligning any arg that can be
+ moved to kwarg to be so, and then making default values explicit.
+
+ Args:
+
+ mod (Module): The module to transform and trace.
+
+ sample_inputs (Tuple[Union[torch.Tensor, List[torch.Tensor]]]):
+ Sample inputs with which to run shape prop.
+
+ remove_assertions (bool): Whether to remove assertion nodes from
+ the graph after symbolic tracing.
+
+ remove_exceptions (bool): Whether to remove exception wrapper nodes
+ from the graph after symbolic tracing.
+
+ use_acc_normalization (bool): Whether to use acc-specific
+ normalization to all acc_ops.
+
+ ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
+ modules that need AST rewriting.
+
+ leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
+ modules will not be traced into.
+
+ """
+ if mod.training:
+ warnings.warn(
+ "acc_tracer does not support currently support models for training."
+ " Calling eval on model before tracing."
+ )
+ mod.eval()
+
+ # Rewrite the module to make it symbolic traceable, and then trace it.
+ rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
+ mod,
+ ast_rewriter_allow_list=ast_rewriter_allow_list,
+ leaf_module_list=leaf_module_list,
+ )
+
+ assert isinstance(rewritten_mod, nn.Module)
+ # Note: use the rewritten_mod here as the root. This is necessary because
+ # RewrittenModule includes a new module for the ConditionalExceptionWrapper.
+ traced = torch.fx.GraphModule(rewritten_mod, rewritten_graph)
+
+ # Now remove all assertions and exceptions if requested.
+ if remove_assertions:
+ _remove_assertions(traced)
+ if remove_exceptions:
+ _remove_exceptions(traced)
+
+ # Cleanup any dead code from the original module as well as resulting dead
+ # nodes after removing assertions and exceptions.
+ traced.graph.eliminate_dead_code()
+
+ # Now normalize args/kwargs to make default values visible. Leave args/kwargs as
+ # they were, since all-kwarg normalization is broken, and we don't need it anyway.
+ shape_prop.ShapeProp(traced).propagate(*sample_inputs)
+ traced = NormalizeArgs(traced, normalize_to_only_use_kwargs=False).transform()
+
+ # Normalize to acc-specialized wrappers for consistency across op naming and
+ # ensuring all kwarg usage.
+ if use_acc_normalization:
+ acc_normalizer.normalize(traced)
+
+ traced.recompile()
+
+ return traced
--- /dev/null
+import inspect
+import json
+from typing import Any, Tuple, Callable, Union, Dict
+
+import torch
+import torch.fx
+from torch.fx.experimental.graph_manipulation import (
+ serialize_module,
+)
+from torch.fx.graph_module import GraphModule
+from torch.fx.passes import graph_drawer
+from torch.fx.passes.shape_prop import TensorMetadata
+
+
+def is_acc_op(node_or_target: Union[Callable, torch.fx.Node]) -> bool:
+ """
+ Returns whether `node_or_target` is an acc_op. If it's a node, then checks whether
+ it's a call_function target is from the acc_ops module. Otherwise it's already
+ the target, which is similarly checked to see if it's from the acc_ops module.
+ """
+ if isinstance(node_or_target, torch.fx.Node):
+ # All acc_ops are call_functions.
+ if node_or_target.op != "call_function":
+ return False
+ target = node_or_target.target
+ else:
+ target = node_or_target
+ return "acc_ops" in target.__module__
+
+
+def is_acc_op_with_kwarg(
+ node_or_target: Union[Callable, torch.fx.Node], kwarg: str
+) -> bool:
+ """
+ Helper that inspects `node_or_target` and returns whether it is an acc_op node
+ (or a target for an acc_op) that has an arg signature that includes `kwarg`.
+ """
+ if not is_acc_op(node_or_target):
+ return False
+
+ target = (
+ node_or_target.target
+ if isinstance(node_or_target, torch.fx.Node)
+ else node_or_target
+ )
+ assert not isinstance(target, str)
+ return kwarg in inspect.signature(inspect.unwrap(target)).parameters
+
+
+def get_field_from_acc_out_ty(
+ acc_out_ty_or_dict: Union[Tuple, Dict[str, Any]], field: str
+):
+ """
+ After tracing NamedTuple inputs are converted to standard tuples, so we cannot
+ access them by name directly. Use this helper instead.
+ """
+ if isinstance(acc_out_ty_or_dict, dict):
+ acc_out_ty = acc_out_ty_or_dict["acc_out_ty"]
+ else:
+ acc_out_ty = acc_out_ty_or_dict
+ return acc_out_ty[TensorMetadata._fields.index(field)]
+
+
+def serialize_module_json_to_file(fx_module: GraphModule, fname: str):
+ weights: Dict = {}
+ serialized_json = json.dumps(serialize_module(fx_module, weights), indent=2)
+ with open(fname, "w") as ofile:
+ ofile.write(serialized_json)
+
+
+def build_raw_tensor_meta(
+ shape=None,
+ dtype=None,
+ requires_grad=None,
+ stride=None,
+ memory_format=None,
+ is_quantized=None,
+ qscheme=None,
+ q_scale=None,
+ q_zero_point=None,
+):
+ return TensorMetadata(**locals())
+
+
+def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph"):
+ if not fname.endswith(".svg"):
+ fname = fname + ".svg"
+ print(f"Writing FX graph to file: {fname}")
+ g = graph_drawer.FxGraphDrawer(traced, figname)
+ x = g.get_main_dot_graph()
+ x.write_svg(fname)
user_targets = {
_get_qualified_name(
n.target
- ).replace("glow.fb.fx.oss_acc_tracer.", "").replace("glow.fb.fx.", ""): n
+ ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n
for n in node.users.keys()
}
if (