From: Shiyan Deng Date: Sun, 15 Aug 2021 18:52:20 +0000 (-0700) Subject: Move fx2trt and oss_acc_tracer to oss (#63101) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~1011 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8e0998ca70ed9b878d0f11d435bbe51bb95aa1ca;p=platform%2Fupstream%2Fpytorch.git Move fx2trt and oss_acc_tracer to oss (#63101) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63101 Move internal fx2trt to torch/fx/experimental/fx2trt and merge the two TRT interpreter we have right now. cc: mortzur as this might affect uru exporting script. Move oss_acc_tracer to torch/fx/experimental/fx_acc. Test Plan: CI Reviewed By: jerryzh168 Differential Revision: D30257909 fbshipit-source-id: 4e374965fbf88d72e91844d9e9b6ff9b98f467d1 --- diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 87bf7ec..00f3201 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -568,7 +568,7 @@ class TestFXExperimental(JitTestCase): node_to_partition_id = {} partition_to_logical_devices = {} count = 0 - GraphManipulation.get_size_of_all_nodes(traced, [a]) + graph_manipulation.get_size_of_all_nodes(traced, [a]) for node in traced.graph.nodes: if node.op not in {"placeholder", "get_attr", "output"}: node_to_partition_id[node] = count diff --git a/torch/fx/experimental/fx2trt/converters/__init__.py b/torch/fx/experimental/fx2trt/converters/__init__.py index a73d459..4c62156 100644 --- a/torch/fx/experimental/fx2trt/converters/__init__.py +++ b/torch/fx/experimental/fx2trt/converters/__init__.py @@ -8,3 +8,4 @@ from .maxpool import * # noqa: F403 from .mul import * # noqa: F403 from .transformation import * # noqa: F403 from .quantization import * # noqa: F403 +from .acc_ops_converters import * # noqa: F403 diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py new file mode 100644 index 0000000..b85ab40 --- /dev/null +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -0,0 +1,1080 @@ +# type: ignore[attr-defined] +import math +import operator + +import torch.fx.experimental.fx_acc.acc_ops as acc_ops +import torch.fx.experimental.fx_acc.acc_utils as acc_utils +import numpy as np +import tensorrt as trt +import torch +from torch.fx.experimental.fx2trt.fx2trt import ( + tensorrt_converter, + torch_dtype_from_trt, + get_dynamic_dims, +) + + +def to_numpy(tensor: torch.Tensor): + """ + Convert a PyTorch Tensor to a Numpy Array. + """ + if tensor is None: + return tensor + + if tensor.is_quantized: + tensor = tensor.dequantize() + + return tensor.cpu().detach().contiguous().numpy() + + +def has_dynamic_shape(shape): + return any(s == -1 for s in shape) + + +def get_axes_for_reduce_op(dim, has_implicit_batch_dimension): + if isinstance(dim, int): + dim = (dim,) + + if has_implicit_batch_dimension: + assert 0 not in dim, "Can't reduce over batch dimension when it's implicit." + + axes = 0 + for d in dim: + axes |= 1 << (d - (1 if has_implicit_batch_dimension else 0)) + + return axes + + +def create_constant(network, tensor, name): + if isinstance(tensor, int): + tensor = torch.IntTensor([tensor]) + + if isinstance(tensor, float): + tensor = torch.Tensor([tensor]) + + shape = tuple(tensor.shape) + + # Remove all preceding 1s as they can be re-inserted later during broadcasting. + num_preceding_ones = 0 + for j in range(len(shape)): + if int(shape[j]) == 1: + num_preceding_ones += 1 + else: + break + + # If shape is all 1s, we want last digit. + shape = shape[num_preceding_ones:] if num_preceding_ones < len(shape) else (1,) + constant = network.add_constant(shape, to_numpy(tensor)) + constant.name = name + return constant.get_output(0) + + +def get_trt_tensor(network, input_val, name): + if isinstance(input_val, (torch.Tensor, int, float)): + return create_constant(network, input_val, name) + elif not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"Received input {input_val} of name {name} that " + "is not part of the TensorRT region!" + ) + else: + return input_val + + +def append_ones(network, input, name, num_prepend_ones): + layer = network.add_shuffle(input) + + if has_dynamic_shape(input.shape): + input_shape_layer = network.add_shape(input) + input_shape_layer.name = f"{name}_broadcast_orig_shape" + prepend_shape_layer = network.add_constant( + (num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32) + ) + prepend_shape_layer.name = f"{name}_broadcast_prepend_ones" + reshape_dim_layer = network.add_concatenation( + [prepend_shape_layer.get_output(0), input_shape_layer.get_output(0)] + ) + reshape_dim_layer.axis = 0 + reshape_dim_layer.name = f"{name}_broadcast_final_shape" + layer.set_input(1, reshape_dim_layer.get_output(0)) + else: + layer.reshape_dims = (1,) * num_prepend_ones + tuple(input.shape) + + layer.name = name + return layer.get_output(0) + + +def broadcast(network, a, b, a_name, b_name, preset_diff=0): + a_shape = tuple(a.shape) + b_shape = tuple(b.shape) + + diff = len(a_shape) - len(b_shape) - preset_diff + if diff > 0: + b = append_ones(network, b, f"{b_name}_broadcast", diff) + elif diff < 0: + a = append_ones(network, a, f"{a_name}_broadcast", -diff) + + return a, b + + +def add_binary_elementwise_layer(network, lhs_val, rhs_val, op_type, name): + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs") + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs") + lhs_val, rhs_val = broadcast( + network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + layer = network.add_elementwise(lhs_val, rhs_val, op_type) + layer.name = name + return layer.get_output(0) + + +def add_unary_layer(network, input_val, operation_type, name): + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_unary(input_val, operation_type) + layer.name = name + return layer.get_output(0) + + +def add_activation_layer(network, input_val, operation_type, name): + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_activation(input_val, operation_type) + layer.name = name + return layer.get_output(0) + + +def add_transpose_layer( + network, input_val, dim_0, dim_1, name, ignore_implicit_batch=False +): + """Adds a transpose layer to the TensorRT network + Args: + network: TensorRT Network object + input_val: tensorrt.ITensor + dim_0, dim_1: dimensions for transpose, e.g. dim_0=1, dim_1=0 means transpose + the first two dimensions + name: Name of the layer + ignore_implicit_batch: activations might have implicit batch, but weights do + not, when this is True, we'll ignore the implicit batch and use the dimension + argument as is + Returns: + output TensorRT ITensor from the transpose layer + """ + if not ignore_implicit_batch and network.has_implicit_batch_dimension: + assert ( + dim_0 != 0 and dim_1 != 0 + ), "It's not allowed to call transpose on non-constant when batch dim is implicit!" + dim_0 -= 1 + dim_1 -= 1 + + permutation = list(range(len(input_val.shape))) + permutation[dim_0] = dim_1 + permutation[dim_1] = dim_0 + + layer = network.add_shuffle(input_val) + layer.second_transpose = tuple(permutation) + layer.name = name + return layer.get_output(0) + + +def process_attr(val, num_elem): + if not isinstance(val, tuple): + val = (val,) * num_elem + return val + + +@tensorrt_converter(acc_ops.conv2d) +def acc_ops_conv2d(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"Conv2d received input {input_val} that is not part " + "of the TensorRT region!" + ) + + if has_dynamic_shape(input_val.shape): + assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution." + + kernel = to_numpy(kwargs["weight"]) + bias = to_numpy(kwargs["bias"]) + + layer = network.add_convolution( + input=input_val, + num_output_maps=kernel.shape[0], + kernel_shape=kernel.shape[2:], + kernel=kernel, + bias=bias, + ) + + layer.name = name + layer.stride = kwargs["stride"] + layer.padding = kwargs["padding"] + layer.dilation = kwargs["dilation"] + if kwargs["groups"] is not None: + layer.num_groups = kwargs["groups"] + + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.flatten) +def acc_ops_flatten(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"flatten received input {input_val} that is not part " + "of the TensorRT region!" + ) + + num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + start_dim = (kwargs["start_dim"] if "start_dim" in kwargs else 0) % num_dims + end_dim = (kwargs["end_dim"] if "end_dim" in kwargs else -1) % num_dims + + if network.has_implicit_batch_dimension: + assert start_dim != 0, "Can't flatten batch dimension when it's implicit." + start_dim -= 1 + end_dim -= 1 + + layer = network.add_shuffle(input_val) + layer.name = name + + # If there're dynamic shapes then we need to use shape layers + # to figure out the final shape after flatten. We first slice + # the input shape to three parts: + # 1. dimensions before start_dim + # 2. dimensions between start_dim and end_dim + # 3. dimensions after end_dim + # Part 1 and 3 might not exist if start_dim is 0 or end_dim is + # last dim. Then we do a reduced multiplication over part 2 to + # get flattened dim. Finally, we concatenate the three parts to + # get the final shape. + if has_dynamic_shape(input_val.shape): + input_shape_layer = network.add_shape(input_val) + input_shape_layer.name = f"{name}_orig_shape" + + final_shapes = [] + + # Shapes before start_dim + if start_dim > 0: + prefix_shape_layer = network.add_slice( + input_shape_layer.get_output(0), + start=(0,), + shape=(start_dim,), + stride=(1,), + ) + prefix_shape_layer.name = f"{name}_pre_shape" + final_shapes.append(prefix_shape_layer.get_output(0)) + + flatten_shape_layer = network.add_slice( + input_shape_layer.get_output(0), + start=(start_dim,), + shape=(end_dim - start_dim + 1,), + stride=(1,), + ) + flatten_shape_layer.name = f"{name}_need_flatten" + flatten_shape_layer = network.add_reduce( + flatten_shape_layer.get_output(0), + trt.ReduceOperation.PROD, + axes=get_axes_for_reduce_op(0, False), + keep_dims=True, + ) + flatten_shape_layer.name = f"{name}_flatten_dim" + final_shapes.append(flatten_shape_layer.get_output(0)) + + # Shapes after start_dim + if end_dim < len(input_val.shape) - 1: + suffix_shape_layer = network.add_slice( + input_shape_layer.get_output(0), + start=(end_dim + 1,), + shape=(len(input_val.shape) - end_dim - 1,), + stride=(1,), + ) + suffix_shape_layer.name = f"{name}_suffix_shape" + final_shapes.append(suffix_shape_layer.get_output(0)) + + final_shape_layer = network.add_concatenation(final_shapes) + final_shape_layer.axis = 0 + final_shape_layer.name = f"{name}_final_shape" + layer.set_input(1, final_shape_layer.get_output(0)) + else: + final_shape = [] + flatten_dim = 1 + for i, s in enumerate(input_val.shape): + if i >= start_dim and i <= end_dim: + flatten_dim *= s + elif i == end_dim + 1: + final_shape.append(flatten_dim) + final_shape.append(s) + else: + final_shape.append(s) + if end_dim == len(input_val.shape) - 1: + final_shape.append(flatten_dim) + + layer.reshape_dims = tuple(final_shape) + + return layer.get_output(0) + + +# For implicit batch dim mode, we use this to represent batch dim if we +# ever trying to retrieve it via size() and we hope it will fail hard if +# it's used somewhere else. +IMPLICIT_BATCH_DIM = -999 + + +@tensorrt_converter(acc_ops.size) +def acc_ops_size(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"size received input {input_val} that is not part " + "of the TensorRT region!" + ) + + if not has_dynamic_shape(input_val.shape): + if network.has_implicit_batch_dimension: + return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_val.shape)) + return torch.Size(input_val.shape) + + layer = network.add_shape(input_val) + layer.name = name + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.batch_norm) +def acc_ops_batch_norm(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"BatchNorm2d received input {input_val} that is not part " + "of the TensorRT region!" + ) + + if has_dynamic_shape(input_val.shape): + assert input_val.shape[1] != -1, "Channel dim can't be dynamic for batch norm." + + scale = to_numpy(kwargs["weight"]) / np.sqrt( + to_numpy(kwargs["running_var"]) + kwargs["eps"] + ) + bias = ( + to_numpy(kwargs["bias"]) + - to_numpy(kwargs["running_mean"]) * scale + ) + power = np.ones_like(scale) + + layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power) + layer.name = name + + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.softmax) +def acc_ops_softmax(network, target, args, kwargs, name): + input_val = kwargs["input"] + dim = kwargs["dim"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"softmax received input {input_val} that is not part " + "of the TensorRT region!" + ) + + # Used to get dim when dim is None. Copied from PyTorch softmax implementation. + def get_softmax_dim(ndim): + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + if dim is None: + dim = get_softmax_dim( + len(input_val.shape) + if not network.has_implicit_batch_dimension + else len(input_val.shape) + 1 + ) + + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." + dim = (dim % (len(input_val.shape) + 1)) - 1 + + layer = network.add_softmax(input_val) + layer.axes = 1 << dim + layer.name = name + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.relu) +def acc_ops_relu(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.RELU + return add_activation_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.sin) +def acc_ops_sin(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.SIN + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.cos) +def acc_ops_cos(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.COS + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.tan) +def acc_ops_tan(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.TAN + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.sinh) +def acc_ops_sinh(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.SINH + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.cosh) +def acc_ops_cosh(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.COSH + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.tanh) +def acc_ops_tanh(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.TANH + return add_activation_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.asin) +def acc_ops_asin(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.ASIN + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.acos) +def acc_ops_acos(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.ACOS + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.atan) +def acc_ops_atan(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.ATAN + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.exp) +def acc_ops_exp(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.EXP + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.log) +def acc_ops_log(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.LOG + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.sqrt) +def acc_ops_sqrt(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.SQRT + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.reciprocal) +def acc_ops_reciprocal(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.RECIP + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.abs) +def acc_ops_abs(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.ABS + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.neg) +def acc_ops_neg(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.NEG + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.floor) +def acc_ops_floor(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.FLOOR + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.ceil) +def acc_ops_ceil(network, target, args, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.CEIL + return add_unary_layer(network, input_val, operation_type, name) + + +@tensorrt_converter(acc_ops.sum) +def acc_ops_sum(network, target, args, kwargs, name): + input_val = kwargs["input"] + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"sum received input {input_val} that is not part " + "of the TensorRT region!" + ) + + # If dim is specified, then we are computing reduced sum over certain dimensions. + # Otherwise, we are dong summation over all elements, which is only supported in + # explicit batch dimension. + if "dim" not in kwargs: + assert ( + not network.has_implicit_batch_dimension + ), "Do not support sum all the elements for implicit batch." + dim = range(0, len(input_val.shape)) + else: + dim = kwargs["dim"] + + keepdim = False if "keepdim" not in kwargs else kwargs["keepdim"] + layer = network.add_reduce( + input_val, + trt.ReduceOperation.SUM, + get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension), + keepdim, + ) + layer.name = name + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.max_pool2d) +def acc_ops_max_pool2d(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"MaxPool2d received input {input_val} that is not part " + "of the TensorRT region!" + ) + + kernel_size = process_attr(kwargs["kernel_size"], 2) + stride = process_attr(kwargs["stride"], 2) + padding = process_attr(kwargs["padding"], 2) + dilation = process_attr(kwargs["dilation"], 2) + ceil_mode = kwargs["ceil_mode"] + + if dilation != (1, 1): + raise RuntimeError( + f"Only support dilation=(1, 1) for maxpool, but got {dilation}" + ) + + layer = network.add_pooling( + input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size + ) + layer.stride = stride + layer.padding = padding + layer.name = name + + if ceil_mode: + layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP + + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.squeeze) +def acc_ops_squeeze(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"squeeze received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dim = kwargs["dim"] if "dim" in kwargs else None + # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic + # dim, which is a very rare case. For now we just claim not supporting dim=None. + assert dim is not None, "We don't support dim=None right now." + + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 + + assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." + + output_shape = [] + for i, s in enumerate(input_val.shape): + if i == dim and s == 1: + continue + output_shape.append(s) + layer = network.add_shuffle(input_val) + layer.reshape_dims = tuple(output_shape) + layer.name = name + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.add) +def acc_ops_add(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.SUM, name + ) + + +@tensorrt_converter(acc_ops.sub) +def acc_ops_sub(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.SUB, name + ) + + +@tensorrt_converter(acc_ops.div) +def acc_ops_div(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, name + ) + + +@tensorrt_converter(acc_ops.mul) +def acc_ops_mul(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.PROD, name + ) + + +@tensorrt_converter(acc_ops.min_two_tensors_input) +def acc_ops_min_two_tensors_input(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.MIN, name + ) + + +@tensorrt_converter(acc_ops.adaptive_avg_pool2d) +def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"AdaptiveAvgPool2d received input {input_val} that is not part " + "of the TensorRT region!" + ) + + assert ( + input_val.shape[-1] != -1 and input_val.shape[-1] != -1 + ), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims." + output_size = kwargs["output_size"] + + for input_dim, output_dim in zip(input_val.shape[-2:], output_size): + if input_dim % output_dim != 0: + raise RuntimeError( + "For AdaptiveAvgPool, input dim has to be integer multiple of output dim." + f"Got input dim {input_dim}, output dim {output_dim}" + ) + + stride = ( + input_val.shape[-2] // output_size[0], + input_val.shape[-1] // output_size[1], + ) + kernel_size = ( + input_val.shape[-2] - (output_size[0] - 1) * stride[0], + input_val.shape[-1] - (output_size[1] - 1) * stride[1], + ) + layer = network.add_pooling( + input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size + ) + layer.stride = stride + layer.name = name + + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.avg_pool2d) +def acc_ops_avg_pool2d(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"AvgPool2d received input {input_val} that is not part " + "of the TensorRT region!" + ) + + kernel_size = process_attr(kwargs["kernel_size"], 2) + stride = process_attr(kwargs["stride"], 2) + padding = process_attr(kwargs["padding"], 2) + ceil_mode = kwargs["ceil_mode"] + count_include_pad = kwargs["count_include_pad"] + divisor_override = kwargs["divisor_override"] + + if divisor_override: + raise RuntimeError("TensorRT does not support divisor_override.") + + layer = network.add_pooling( + input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size + ) + layer.stride = stride + layer.padding = padding + layer.average_count_excludes_padding = False if count_include_pad else True + + if ceil_mode: + layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP + + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.reshape) +def acc_ops_reshape(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"Reshape received input {input_val} that is not part " + "of the TensorRT region!" + ) + + shape = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "shape") + if network.has_implicit_batch_dimension: + shape = shape[1:] + + layer = network.add_shuffle(input_val) + + if all(isinstance(s, int) for s in shape): + layer.reshape_dims = tuple(shape) + else: + # Convert all the dimensions to trt Tensors. + trt_shape = [] + + for i, s in enumerate(shape): + if isinstance(s, trt.tensorrt.ITensor): + if len(s.shape) == 0: + s = append_ones(network, s, f"{name}_{i}", 1) + trt_shape.append(s) + else: + trt_shape.append(get_trt_tensor(network, s, f"{name}_{i}")) + + shape_layer = network.add_concatenation(inputs=trt_shape) + shape_layer.axis = 0 + shape_layer.name = f"{name}_output_shape" + layer.set_input(1, shape_layer.get_output(0)) + + layer.name = name + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.linear) +def acc_ops_linear(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"Linear received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dynamic_dims = get_dynamic_dims(input_val.shape) + assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, ( + "Currently we only support one dynmaic " + "dim for linear and it can't be the last dim." + ) + + layer = network.add_shuffle(input_val) + layer.reshape_dims = tuple(input_val.shape) + (1, 1) + layer.name = f"{name}_pre_shuffle" + + # add fully connected + layer = network.add_fully_connected( + input=layer.get_output(0), + num_outputs=kwargs["weight"].shape[0], + kernel=to_numpy(kwargs["weight"]), + bias=to_numpy(kwargs["bias"]), + ) + layer.name = f"{name}_linear" + + # reshape back + layer = network.add_shuffle(layer.get_output(0)) + layer.reshape_dims = tuple(input_val.shape[:-1]) + (kwargs["weight"].shape[0],) + layer.name = f"{name}_post_shuffle" + + return layer.get_output(0) + + +def add_clamp(network, input, val, op): + acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions + acc_ops_clamp_tensor = ( + val + * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)) + .cpu() + .numpy() + ) + acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor) + layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op) + + return layer + + +@tensorrt_converter(acc_ops.clamp) +def acc_ops_clamp(network, target, args, kwargs, name): + input_val = kwargs["input"] + min_val = kwargs["min"] + max_val = kwargs["max"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"Clamp received input {input_val} that is not part " + "of the TensorRT region!" + ) + + if min_val is not None: + clamp_min_layer = add_clamp( + network, input_val, min_val, trt.ElementWiseOperation.MAX + ) + clamp_min_layer.name = f"{name}_clamp_min" + input_val = clamp_min_layer.get_output(0) + if max_val is not None: + clamp_max_layer = add_clamp( + network, input_val, max_val, trt.ElementWiseOperation.MIN + ) + clamp_max_layer.name = f"{name}_clamp_max" + input_val = clamp_max_layer.get_output(0) + + return input_val + + +@tensorrt_converter(acc_ops.getitem) +def acc_ops_getitem(network, target, args, kwargs, name): + input_val = kwargs["input"] + slices = kwargs["idx"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + return operator.getitem(input_val, slices) + + assert not has_dynamic_shape( + input_val.shape + ), "Currently we don't support slicing tensor if it has dynamic shape." + + def num_slice_types(slices): + """ + Gather the number of slice in getitem slices. + """ + num_slice = 0 + for s in slices: + if isinstance(s, slice) or isinstance(s, int): + num_slice += 1 + return num_slice + + def slice_to_trt_params(py_slice, dim_size): + """ + Convert python slice to TensorRT slice layer parameters. + """ + start = (py_slice.start % dim_size) if py_slice.start else 0 + stride = py_slice.step if py_slice.step else 1 + stop = (py_slice.stop % dim_size) if py_slice.stop else dim_size + size = math.ceil((stop - start) * 1.0 / stride) + return start, size, stride + + if not isinstance(slices, tuple): + slices = (slices,) + + if network.has_implicit_batch_dimension: + # Raise an error if it's trying to subscript batch dimension unless it's + # slice(None, None, None). + batch_subscript = slices[0] + if batch_subscript != slice(None, None, None): + raise RuntimeError( + f"Can't subscript batch dimension when it's implicit. Got {slices}" + ) + + # Remove batch_dim subscript + slices = slices[1:] + + # Replace ellipsis with expanded slices. + # Compute the number of dim ellipsis represent. + num_ellipsis = len(input_val.shape) - num_slice_types(slices) + new_slices = [] + for s in slices: + if s == Ellipsis: + while num_ellipsis > 0: + new_slices.append(slice(None, None, None)) + num_ellipsis -= 1 + else: + new_slices.append(s) + slices = new_slices + + # Build trt slice layer params + start = [] + size = [] + stride = [] + + i = 0 + for s in slices: + if s is None: + continue + + if isinstance(s, slice): + params = slice_to_trt_params(s, input_val.shape[i]) + start.append(params[0]) + size.append(params[1]) + stride.append(params[2]) + else: + start.append(s % input_val.shape[i]) + size.append(1) + stride.append(1) + i += 1 + + while i < len(input_val.shape): + start.append(0) + size.append(input_val.shape[i]) + stride.append(1) + i += 1 + + layer = network.add_slice( + input=input_val, + start=start, + shape=size, + stride=stride, + ) + layer.name = name + + # Add shuffle layer to insert dimensions for 'None' and remove dimensions for 'int'. + if any(not isinstance(s, slice) for s in slices): + slice_out = layer.get_output(0) + layer = network.add_shuffle(slice_out) + final_shape = [] + original_idx = 0 + for s in slices: + # If it's a slice, keep the dim. + if isinstance(s, slice): + final_shape.append(slice_out.shape[original_idx]) + original_idx += 1 + # If it's None, extend the dim. + elif s is None: + final_shape.append(1) + # If it's a int, remove the dim. + else: + original_idx += 1 + layer.reshape_dims = tuple(final_shape) + tuple(slice_out.shape)[original_idx:] + + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.cat) +def acc_ops_cat(network, target, args, kwargs, name): + tensors = kwargs["tensors"] + + if any(not isinstance(t, trt.tensorrt.ITensor) for t in tensors): + raise RuntimeError( + f"cat received inputs {tensors} that is not part " "of the TensorRT region!" + ) + + layer = network.add_concatenation(inputs=tensors) + layer.axis = kwargs["dim"] - (1 if network.has_implicit_batch_dimension else 0) + layer.name = name + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.transpose) +def acc_ops_transpose(network, target, args, kwargs, name): + input_val, dim_0, dim_1 = kwargs["input"], kwargs["dim0"], kwargs["dim1"] + + # TODO: Remove this after enabling const folding in fx_acc + if isinstance(input_val, torch.Tensor): + return input_val.transpose(dim_0, dim_1).contiguous() + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"transpose received input {input_val} that is not part " + "of the TensorRT region!" + ) + + return add_transpose_layer(network, input_val, dim_0, dim_1, name) + + +@tensorrt_converter(acc_ops.matmul) +def acc_ops_matmul(network, target, args, kwargs, name): + input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") + other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other") + + for i in [input_val, other_val]: + if not isinstance(i, trt.tensorrt.ITensor): + raise RuntimeError( + f"matmul received input {i} that is not part " "of the TensorRT region!" + ) + + input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE + preset_diff = 0 + + if len(input_val.shape) == 1: + preset_diff -= 1 + input_matrix_op = trt.MatrixOperation.VECTOR + + if len(other_val.shape) == 1: + preset_diff += 1 + other_matrix_op = trt.MatrixOperation.VECTOR + + input_val, other_val = broadcast( + network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff + ) + layer = network.add_matrix_multiply( + input_val, input_matrix_op, other_val, other_matrix_op + ) + layer.name = name + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.sigmoid) +def acc_ops_sigmoid(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"Sigmoid received input {input_val} that is not part " + "of the TensorRT region!" + ) + + layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID) + layer.name = name + return layer.get_output(0) + + +@tensorrt_converter(acc_ops.permute) +def acc_ops_permute(network, target, args, kwargs, name): + input_val = kwargs["input"] + permutation = kwargs["permutation"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"permute received input {input_val} that is not part " + "of the TensorRT region!" + ) + + if network.has_implicit_batch_dimension: + assert permutation[0] == 0, "Can't permute batch dimension when it's implicit." + permutation = [i - 1 for i in permutation[1:]] + + layer = network.add_shuffle(input_val) + layer.second_transpose = tuple(permutation) + layer.name = name + return layer.get_output(0) diff --git a/torch/fx/experimental/fx2trt/converters/add.py b/torch/fx/experimental/fx2trt/converters/add.py index 22ad9dc..f8a201b 100644 --- a/torch/fx/experimental/fx2trt/converters/add.py +++ b/torch/fx/experimental/fx2trt/converters/add.py @@ -8,14 +8,17 @@ from .helper_functions import get_dyn_range, mark_as_int8_layer @tensorrt_converter(operator.add) @tensorrt_converter(torch.add) def add(network, target, args, kwargs, layer_name): - if len(kwargs) != 0: - raise RuntimeError(f"Add receives unsupported kwargs: {kwargs}!") - - assert len(args) == 2 - if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in args): + # operator.add + if len(kwargs) == 0: + lhs_val, rhs_val = args + else: + # torch.add + lhs_val, rhs_val = kwargs["input"], kwargs["other"] + assert kwargs["alpha"] == 1 + + if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in [lhs_val, rhs_val]): raise RuntimeError("add() received an input that is not part of the TensorRT region!") - lhs_val, rhs_val = args layer = network.add_elementwise(lhs_val, rhs_val, trt.ElementWiseOperation.SUM) layer.name = layer_name diff --git a/torch/fx/experimental/fx2trt/converters/mul.py b/torch/fx/experimental/fx2trt/converters/mul.py index 169bcf2..bb1838a 100644 --- a/torch/fx/experimental/fx2trt/converters/mul.py +++ b/torch/fx/experimental/fx2trt/converters/mul.py @@ -8,11 +8,16 @@ from .helper_functions import get_dyn_range, mark_as_int8_layer @tensorrt_converter(torch.mul) @tensorrt_converter(operator.mul) def mul(network, target, args, kwargs, layer_name): - assert len(args) == 2 - if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in args): + # operator.mul + if len(kwargs) == 0: + lhs_val, rhs_val = args + else: + # torch.mul + lhs_val, rhs_val = kwargs["input"], kwargs["other"] + + if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in [lhs_val, rhs_val]): raise RuntimeError('mul() received an input that is not part of the TensorRT region!') - lhs_val, rhs_val = args layer = network.add_elementwise(lhs_val, rhs_val, trt.ElementWiseOperation.PROD) layer.name = layer_name diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py index 38059cd..6054586 100644 --- a/torch/fx/experimental/fx2trt/fx2trt.py +++ b/torch/fx/experimental/fx2trt/fx2trt.py @@ -1,16 +1,14 @@ -import copy import warnings from typing import List, NamedTuple, Iterable, Any, Optional, Tuple +import tensorrt as trt import torch import torch.fx -import tensorrt as trt -from torch.fx.experimental.normalize import NormalizeArgs # Borrowed from torch2trt def torch_dtype_to_trt(dtype): - if trt.__version__ >= '7.0' and dtype == torch.bool: + if trt.__version__ >= "7.0" and dtype == torch.bool: return trt.bool elif dtype == torch.int8: return trt.int8 @@ -27,7 +25,7 @@ def torch_dtype_to_trt(dtype): def torch_dtype_from_trt(dtype): if dtype == trt.int8: return torch.int8 - elif trt.__version__ >= '7.0' and dtype == trt.bool: + elif trt.__version__ >= "7.0" and dtype == trt.bool: return torch.bool elif dtype == trt.int32: return torch.int32 @@ -40,7 +38,9 @@ def torch_dtype_from_trt(dtype): class TRTModule(torch.nn.Module): - def __init__(self, engine=None, input_names=None, output_names=None, fp16_output=False): + def __init__( + self, engine=None, input_names=None, output_names=None, fp16_output=False + ): super(TRTModule, self).__init__() self._register_state_dict_hook(TRTModule._on_state_dict) self.engine = engine @@ -78,7 +78,9 @@ class TRTModule(torch.nn.Module): self.output_names = state_dict[prefix + "output_names"] def forward(self, *inputs): - assert len(inputs) == len(self.input_names), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." + assert len(inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." batch_size = inputs[0].shape[0] contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] bindings: List[Any] = [None] * (len(self.input_names) + len(self.output_names)) @@ -119,8 +121,10 @@ class TRTModule(torch.nn.Module): return tuple(outputs) def enable_profiling(self): - raise RuntimeError("Profiling is not supported right now because it requires calling" - " execute() instead of execute_async().") + raise RuntimeError( + "Profiling is not supported right now because it requires calling" + " execute() instead of execute_async()." + ) if not self.context.profiler: self.context.profiler = trt.Profiler() @@ -132,6 +136,7 @@ def tensorrt_converter(key): def register_converter(converter): CONVERTERS[key] = converter return converter + return register_converter @@ -157,11 +162,12 @@ class InputTensorSpec(NamedTuple): has_batch_dim: Whether the shape includes batch dimension. Batch dimension has to be provided if the engine want to run with dynamic shape. """ - shape : torch.Size - dtype : torch.dtype - device : torch.device = torch.device("cpu") - shape_ranges : List[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]] = [] - has_batch_dim : bool = True + + shape: torch.Size + dtype: torch.dtype + device: torch.device = torch.device("cpu") + shape_ranges: List[Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]] = [] + has_batch_dim: bool = True @classmethod def from_tensor(cls, tensor: torch.Tensor): @@ -182,13 +188,27 @@ def get_dynamic_dims(shape): return dynamic_dims -class BaseTRTInterpreter(torch.fx.Interpreter): +def create_inputs_from_specs(input_specs): + inputs = [] + + for shape, dtype, device, shape_ranges, has_batch_dim in input_specs: + if len(get_dynamic_dims(shape)): + shape = shape_ranges[0][1] + elif not has_batch_dim: + shape = (1,) + tuple(shape) + + inputs.append(torch.empty(shape, dtype=dtype, device=device)) + + return inputs + + +class TRTInterpreter(torch.fx.Interpreter): def __init__( self, - module : torch.fx.GraphModule, - input_specs : List[InputTensorSpec], - explicit_batch_dimension : bool = False, - logger_level=trt.Logger.WARNING + module: torch.fx.GraphModule, + input_specs: List[InputTensorSpec], + explicit_batch_dimension: bool = False, + logger_level=trt.Logger.WARNING, ): super().__init__(module) @@ -196,12 +216,14 @@ class BaseTRTInterpreter(torch.fx.Interpreter): self.builder = trt.Builder(self.logger) if explicit_batch_dimension: - EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + EXPLICIT_BATCH = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH + ) self.network = self.builder.create_network(EXPLICIT_BATCH) else: self.network = self.builder.create_network() - self.optimization_profiles : Optional[List] = None + self.optimization_profiles: Optional[List] = None self.input_specs = input_specs self.input_specs_iter = 0 self.validate_input_specs() @@ -212,35 +234,59 @@ class BaseTRTInterpreter(torch.fx.Interpreter): def validate_input_specs(self): for shape, dtpe, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: - assert has_batch_dim, "It's required to specify batch dimension when it's explicit in TensorRT network." + assert ( + has_batch_dim + ), "It's required to specify batch dimension when it's explicit in TensorRT network." dynamic_dims = get_dynamic_dims(shape) if len(dynamic_dims): - assert not self.network.has_implicit_batch_dimension, "Can't have dynamic dim when " \ + assert not self.network.has_implicit_batch_dimension, ( + "Can't have dynamic dim when " f"batch dim is implicit, got {shape}." - assert len(shape_ranges), "shape_ranges must be provided when shape has dynamic dim." + ) + assert len( + shape_ranges + ), "shape_ranges must be provided when shape has dynamic dim." if self.optimization_profiles: - assert len(shape_ranges) == len(self.optimization_profiles), "Number of optimization " \ - f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range" \ + assert len(shape_ranges) == len(self.optimization_profiles), ( + "Number of optimization " + f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range" f" {len(shape_ranges)} provided." + ) else: - self.optimization_profiles = [self.builder.create_optimization_profile() for _ in range(len(shape_ranges))] + self.optimization_profiles = [ + self.builder.create_optimization_profile() + for _ in range(len(shape_ranges)) + ] for shape_range in shape_ranges: - assert len(shape_range) == 3, f"Expect three elements in shape_range, got {len(shape_range)}" - assert all(len(s) == len(shape) for s in shape_range), "Expect elements in shape_range" \ + assert ( + len(shape_range) == 3 + ), f"Expect three elements in shape_range, got {len(shape_range)}" + assert all(len(s) == len(shape) for s in shape_range), ( + "Expect elements in shape_range" f" {shape_range} have the same number of dimension as the provided shape {len(shape)}" + ) for i in range(len(shape)): if i in dynamic_dims: - assert all(shape_range[j][i] <= shape_range[j + 1][i] for j in range(2)), "Expect dynamic dim" \ + assert all( + shape_range[j][i] <= shape_range[j + 1][i] + for j in range(2) + ), ( + "Expect dynamic dim" f" {i} to have incremental value for shapes in shape_range {shape_range}." + ) else: - assert all(s[i] == shape[i] for s in shape_range), f"Expect non dynamic dim {i} to be the same" \ + assert all(s[i] == shape[i] for s in shape_range), ( + f"Expect non dynamic dim {i} to be the same" f" for all shapes in shape_range {shape_range}." + ) else: - assert len(shape_ranges) == 0, "shape_ranges are provided for input that doesn't have dynamic dim." + assert ( + len(shape_ranges) == 0 + ), "shape_ranges are provided for input that doesn't have dynamic dim." def run( self, @@ -248,7 +294,7 @@ class BaseTRTInterpreter(torch.fx.Interpreter): max_workspace_size=1 << 25, fp16_mode=True, int8_mode=False, - strict_type_constraints=True + strict_type_constraints=True, ): # TODO hack, should check contents of args and remove fp16_mode probably self.fp16_mode = fp16_mode @@ -279,7 +325,7 @@ class BaseTRTInterpreter(torch.fx.Interpreter): builder_config.add_optimization_profile(optimization_profile) engine = self.builder.build_engine(self.network, builder_config) - assert(engine) + assert engine return engine, self._input_names, self._output_names def run_node(self, n): @@ -288,7 +334,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter): def placeholder(self, target, args, kwargs): self._input_names.append(target) - shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[self.input_specs_iter] + shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[ + self.input_specs_iter + ] self.input_specs_iter += 1 if self.network.has_implicit_batch_dimension: @@ -299,7 +347,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter): assert self.optimization_profiles self.optimization_profiles[i].set_shape(target, *shape_range) - return self.network.add_input(name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)) + return self.network.add_input( + name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype) + ) def call_module(self, target, args, kwargs): assert isinstance(target, str) @@ -307,7 +357,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter): converter = CONVERTERS.get(type(submod)) if not converter: - raise RuntimeError(f'Conversion of module of type {type(submod)} not currently supported!') + raise RuntimeError( + f"Conversion of module of type {type(submod)} not currently supported!" + ) return converter(self.network, submod, args, kwargs, self._cur_node_name) @@ -315,7 +367,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter): converter = CONVERTERS.get(target) if not converter: - raise RuntimeError(f'Conversion of function {torch.typename(target)} not currently supported!') + raise RuntimeError( + f"Conversion of function {torch.typename(target)} not currently supported!" + ) return converter(self.network, target, args, kwargs, self._cur_node_name) @@ -324,7 +378,9 @@ class BaseTRTInterpreter(torch.fx.Interpreter): converter = CONVERTERS.get(target) if not converter: - raise RuntimeError(f'Conversion of method {target} not currently supported!') + raise RuntimeError( + f"Conversion of method {target} not currently supported!" + ) return converter(self.network, target, args, kwargs, self._cur_node_name) @@ -333,10 +389,10 @@ class BaseTRTInterpreter(torch.fx.Interpreter): outputs = args[0] if isinstance(args[0], tuple) else (args[0],) if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): - raise RuntimeError('TensorRT requires all outputs to be Tensor!') + raise RuntimeError("TensorRT requires all outputs to be Tensor!") for i, output in enumerate(outputs): - name = f'output{i}' + name = f"output{i}" output.name = name self.network.mark_output(output) if self.fp16_mode: @@ -344,18 +400,3 @@ class BaseTRTInterpreter(torch.fx.Interpreter): else: output.dtype = trt.float32 self._output_names.append(name) - - -class TRTInterpreter(BaseTRTInterpreter): - """ - Use this for general case where there're PyTorch vanilla ops in the FX mdoule. - """ - def __init__(self, module : torch.nn.Module, input_specs : List[InputTensorSpec], logger_level=trt.Logger.WARNING): - # Preprocess the model - if not isinstance(module, torch.fx.GraphModule): - module = torch.fx.symbolic_trace(module) - else: - module = copy.deepcopy(module) - module = module.cpu().float() - module = NormalizeArgs(module).transform() - super().__init__(module, input_specs, logger_level=logger_level) diff --git a/torch/fx/experimental/fx_acc/__init__.py b/torch/fx/experimental/fx_acc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch/fx/experimental/fx_acc/acc_normalizer.py b/torch/fx/experimental/fx_acc/acc_normalizer.py new file mode 100644 index 0000000..a5d116a --- /dev/null +++ b/torch/fx/experimental/fx_acc/acc_normalizer.py @@ -0,0 +1,401 @@ +# type: ignore[] +import inspect +import re +from typing import NamedTuple, Optional, Callable, Dict, List, Tuple, Union, Any, Set + +import torch.fx.experimental.fx_acc.acc_utils as acc_utils +import torch +import torch.fx +from torch.fx.node import _get_qualified_name + +# Need to keep up-to-date with https://fburl.com/codesearch/7r2hhh53 +ALIAS_MAP = { + "input": ("input", "x", "a", "x1"), + "dim": ("dim", "axis"), + "keepdim": ("keepdim", "keepdims"), + "other": ("other", "x2"), +} + +# Type used for arg replacement tuples. The list represents the argument signature of +# some callable. Each item in the list is a tuple, where for each member of a tuple: +# - The first member is union of either: +# - A tuple of all potential alias kwarg str names of the source signature, or +# - A tuple of a single str representing the single kwarg name allowed. +# - The second member is the str name of the kwarg to map it to. This is either from the +# signature of the acc_op, or for custom mapped nodes from the original unnormalized op. +# - The third member is a bool representing whether this arg is optional, i.e. whether it +# is allowed to not be present in the original input args. +ArgReplacementTuplesType = List[Tuple[Tuple[str, ...], str, bool]] + + +class NormalizationInfo(NamedTuple): + """ + Holds normalization info for some FX node, where the FX node will be mapped either + via new_fn_target and arg_replacement_tuples, or via custom_mapping_fn. + + If via new_fn_target and arg_replacement_tuples: + - new_fn_target is the target function to replace the original node with + (generally some function from acc_ops). + + - arg_replacement_tuples describes how to map the original FX node's args/kwargs to + the new FX node. If set to None, then the kwargs are copied directly from the + original FX node. Else, this is list of three-member tuples, where each tuple + represents a mapping from either an arg or kwarg in the original FX node to the + kwarg it should be mapped to. If for ops registered with `register_acc_op` then + this is a mapping to the the new FX node for the acc_op. Otherwise it is for some + op registered with `register_custom_acc_mapper_fn`, in which case this is a + mapping for the original input node so its args are normalized to kwargs before + being custom normalized to acc_ops. The third member of the tuple is a bool + representing whether this argument is optional; if False and the arg is not + present then an assertion will be thrown. The index of the tuple indicates where + the original arg is in node.args and the string name indicates which original + kwarg it is. + + If via custom_mapping_fn, then custom_mapping_fn is some function that takes the + original FX node as input and returns the FX node that should replace it. This means + it was registered via `register_custom_acc_mapper_fn`. + """ + + new_fn_target: Callable + arg_replacement_tuples: Optional[ArgReplacementTuplesType] + custom_mapping_fn: Optional[Callable] + kwargs_to_move_to_acc_out_ty: Optional[Optional[List[Tuple[str, str]]]] + needs_shapes_for_normalization: bool + + +# Dict from (op, target) to NormalizationInfo for that op. +_normalization_dict: Dict[Tuple[str, Union[str, Callable]], NormalizationInfo] = {} + +# Set of all the acc ops. +_acc_ops: Set[Callable] = set() + + +def _insert_fun( + op_and_target: Tuple[str, Union[str, Callable]], + arg_replacement_tuples: List[Tuple], + new_fn_target: Optional[Callable] = None, + custom_mapping_fn: Optional[Callable] = None, + kwargs_to_move_to_acc_out_ty: Optional[Optional[List[Tuple[str, str]]]] = None, + needs_shapes_for_normalization=False, + allow_normalize_from_torch_package=False, +): + if op_and_target[0] == "call_function": + assert callable(op_and_target[1]) + elif op_and_target[0] == "call_method": + assert isinstance(op_and_target[1], str) + elif op_and_target[0] == "call_module": + assert isinstance(op_and_target[1], type) + + # Finalize arg replacement tuples. + # 1. Check to see if they have the `is_optional` bool, and if not defaulting it to + # False. + # 2. Some kwargs might have aliases. e.g. "a", "x" and "x1" are aliases of "input". + # Here we replace `orig_kwarg` with a tuple of all aliases if it has aliases. + final_arg_replacement_tuples = [] + for arg_replacement_tuple in arg_replacement_tuples: + if len(arg_replacement_tuple) == 2: + orig_kwarg, new_kwarg, is_optional = *arg_replacement_tuple, False + else: + assert len(arg_replacement_tuple) == 3 + orig_kwarg, new_kwarg, is_optional = arg_replacement_tuple + + if not isinstance(orig_kwarg, tuple): + orig_kwarg = (orig_kwarg,) + + # Use set to avoid duplicates. + orig_kwarg_set = set(orig_kwarg) + + for k in orig_kwarg: + if k in ALIAS_MAP: + orig_kwarg_set.update(ALIAS_MAP[k]) + final_arg_replacement_tuples.append( + (tuple(orig_kwarg_set), new_kwarg, is_optional) + ) + + assert op_and_target not in _normalization_dict.keys() + norm_info = NormalizationInfo( + new_fn_target=new_fn_target, + arg_replacement_tuples=final_arg_replacement_tuples, + custom_mapping_fn=custom_mapping_fn, + kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty, + needs_shapes_for_normalization=needs_shapes_for_normalization, + ) + _normalization_dict[op_and_target] = norm_info + + # If allow_normalize_from_torch_package then add another entry to + # _normalization_dict where we look for the qualified name of the target with the + # torch_package module prefix. Note that we leave off any integer at the end of + # "" in order to allow for whatever mangling index is used. + if allow_normalize_from_torch_package: + torch_package_op_and_target = ( + op_and_target[0], + f".{_get_qualified_name(op_and_target[1])}", + ) + _normalization_dict[torch_package_op_and_target] = norm_info + + +def _get_dup_signature_tuples(fn: Callable) -> List[Tuple[str, str]]: + """ + Helper that inspects the arg signature of `fn` and returns a list of tuples, where + each tuple is a pair of duplicated names which is used for arg_replacement_tuples. + """ + sig_tuples: List[Tuple[str, str]] = [] + for param in inspect.signature(inspect.unwrap(fn)).parameters: + sig_tuples.append((param, param)) + return sig_tuples + + +def register_acc_op(acc_op: Callable): + """ + For a new acc op, add this as decorator to register it. + """ + _acc_ops.add(acc_op) + return acc_op + + +def register_acc_op_mapping( + op_and_target: Tuple[str, Union[str, Callable]], + arg_replacement_tuples: Optional[ + List[Tuple[Union[str, Tuple[str, ...]], str]] + ] = None, + kwargs_to_move_to_acc_out_ty: Optional[List[Tuple[str, str]]] = None, +): + """ + Use this decorator to map a non-acc operator to an acc operator. + + Args: + op_and_target: A tuple that contains op and target of the node that represents the non-acc operator. + arg_replacement_tuples: Please refer to the comment on above for `ArgReplacementTuplesType`. + kwargs_to_move_to_acc_out_ty: The kwargs we want to move out from the non-acc op kwargs to acc_out_ty. + """ + + def insert(new_fn_target: Callable): + # If arg_replacement_tuples is None then assume we use the same signature for + # the acc_op and the original op. + if arg_replacement_tuples is None: + final_arg_replacement_tuples = _get_dup_signature_tuples(new_fn_target) + else: + final_arg_replacement_tuples = arg_replacement_tuples + + _insert_fun( + op_and_target=op_and_target, + new_fn_target=new_fn_target, + arg_replacement_tuples=final_arg_replacement_tuples, + kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty, + ) + return new_fn_target + + return insert + + +def register_custom_acc_mapper_fn( + op_and_target: Tuple[str, Union[str, Callable]], + arg_replacement_tuples: List[Tuple[Union[str, Tuple[str, ...]], str]], + needs_shapes_for_normalization=False, + allow_normalize_from_torch_package=False, +): + def insert(custom_mapping_fn: Callable): + _insert_fun( + op_and_target=op_and_target, + custom_mapping_fn=custom_mapping_fn, + arg_replacement_tuples=arg_replacement_tuples, + needs_shapes_for_normalization=needs_shapes_for_normalization, + allow_normalize_from_torch_package=allow_normalize_from_torch_package, + ) + return custom_mapping_fn + + return insert + + +def move_kwargs_to_acc_out_ty( + node_or_normalization_info: Union[NormalizationInfo, torch.fx.Node], + new_kwargs: Dict[str, Any], +): + """ + Given `node_or_normalization_info` which is either NormalizationInfo for a node, or + a node to fetch NormalizationInfo for, check if kwargs_to_move_to_acc_out_ty exists + in the NormalizationInfo, and if so perform the move of kwargs to acc_out_ty. + """ + if isinstance(node_or_normalization_info, torch.fx.Node): + node = node_or_normalization_info + normalization_info = _normalization_dict.get((node.op, node.target)) + else: + normalization_info = node_or_normalization_info + + if normalization_info.kwargs_to_move_to_acc_out_ty is None: + return + + assert acc_utils.is_acc_op_with_kwarg( + normalization_info.new_fn_target, "acc_out_ty" + ) + + # Build a dict representing the new TensorMetadata to use for acc_out_ty, + # and then remove the kwarg from the new_kwargs since it's passed in via + # acc_out_ty instead. + tmd_dict: Dict[str, Any] = {} + for ( + orig_kwarg_name, + tmd_field_name, + ) in normalization_info.kwargs_to_move_to_acc_out_ty: + tmd_dict[tmd_field_name] = new_kwargs[orig_kwarg_name] + del new_kwargs[orig_kwarg_name] + # Note: allow_partial_spec here because we are only using the tensor metadata tuple + # here to pass specific values into the function. For example, for quantization we + # only need to provide dtype/q_scale/q_zero_point, but is_quantized and qscheme are + # not passed in. + new_kwargs["acc_out_ty"] = acc_utils.build_raw_tensor_meta(**tmd_dict) + + +def get_normalized_kwargs( + node: torch.fx.Node, arg_replacement_tuples: ArgReplacementTuplesType +): + new_kwargs = {} + final_arg_is_varg = False + for i, replacement_tuple in enumerate(arg_replacement_tuples): + orig_kwargs_names, new_kwarg_name, is_optional = replacement_tuple + + # Check if this is a varg and if so break/process the rest outside the loop. + if len(orig_kwargs_names) == 1 and orig_kwargs_names[0] == "*": + assert i == len(arg_replacement_tuples) - 1 + final_arg_is_varg = True + break + + # If nothing is found in node.kwargs it means the kwarg is in node.arg + # or it's optional. In this case, we set orig_kwargs_name to None. + assert isinstance(orig_kwargs_names, tuple) + orig_kwargs_name = next( + (key for key in orig_kwargs_names if key in node.kwargs), + None, + ) + + # If can't find in node.kwargs then it should be in the i index + # of node.args. + if orig_kwargs_name is None: + if i < len(node.args): + new_kwargs[new_kwarg_name] = node.args[i] + else: + # Verify the arg we're trying to normalize was optional. + assert is_optional + else: + new_kwargs[new_kwarg_name] = node.kwargs[orig_kwargs_name] + + # If using var args then process the rest of the args now. + if final_arg_is_varg: + var_arg_idx = len(arg_replacement_tuples) - 1 + new_kwarg_name = arg_replacement_tuples[var_arg_idx][1] + rest_of_args = [] + for i in range(var_arg_idx, len(node.args)): + rest_of_args.append(node.args[i]) + new_kwargs[new_kwarg_name] = rest_of_args + + return new_kwargs + + +def normalize(mod: torch.fx.GraphModule, expect_nodes_have_shapes: bool = False): + assert len(_normalization_dict) > 0 + graph = mod.graph + + # For "call_module" node we return _base_class_origin if it's a + # RewrittenModule, otherwise, return its type. For other nodes, + # we return node.target. + def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node): + if node.op != "call_module": + return node.target + + # Find the module that node.target points to + m = dict(mod.named_modules())[node.target] + return getattr(m, "_base_class_origin", type(m)) + + def normalize_to_acc_op( + node: torch.fx.Node, + normalization_info: NormalizationInfo, + normalized_args: Tuple[Any, ...], + normalized_kwargs: Dict[str, Any], + ): + # If there's a custom mapping function then use it. + if normalization_info.custom_mapping_fn is not None: + # For custom mapping, the normalized_kwargs are used for the original op, + # i.e. *before* custom acc_ops normalization. Do that now. + node.args = normalized_args + node.kwargs = normalized_kwargs + new_node = normalization_info.custom_mapping_fn(node, mod) + # If a new node is returned then use it to replace the old node. Otherwise + # the custom mapping function did its own replacement, so return early. + if new_node is None: + return + else: + # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in + # normalized_kwargs, and remove the kwarg from normalized_kwargs. + move_kwargs_to_acc_out_ty(normalization_info, normalized_kwargs) + + # All acc ops are functions. Create a call to the correct acc_ops target using + # the normalized kwargs provided. + with graph.inserting_before(node): + new_node = graph.create_node( + "call_function", + normalization_info.new_fn_target, + args=normalized_args, + kwargs=normalized_kwargs, + name=node.name, + ) + new_node.meta = node.meta.copy() + + # Finally replace the original node with the normalized node. + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + for node in graph.nodes: + if node.op in {"placeholder", "get_attr", "output"}: + continue + + normalization_info = _normalization_dict.get((node.op, get_target(mod, node))) + + # Also check if the torch_packaged version of the op was specified to be normalized. + if normalization_info is None and node.op == "call_function": + # Strip off the mangle_index suffix here before checking the map. + target = re.sub( + r"\A", + "", + _get_qualified_name(node.target), + ) + torch_package_op_and_target = (node.op, target) + normalization_info = _normalization_dict.get(torch_package_op_and_target) + + if normalization_info is None: + continue + + # Get the normalized kwargs to be used by normalize_to_acc_op below. If + # normalization_info.arg_replacement_tuples is empty then assume the function + # signature must be left as is. + if len(normalization_info.arg_replacement_tuples) == 0: + normalized_args = node.args + normalized_kwargs = node.kwargs + else: + normalized_args = () + normalized_kwargs = get_normalized_kwargs( + node, normalization_info.arg_replacement_tuples + ) + + if ( + normalization_info.needs_shapes_for_normalization + and not expect_nodes_have_shapes + ): + # All nodes needing shapes for normalization should be custom mapped. + assert normalization_info.custom_mapping_fn is not None + # For custom mapping, the normalized_kwargs are used for the original op, + # i.e. *before* custom acc_ops normalization. Do that now so that whoever + # consumes the graph next (e.g. shape inference) can use kwargs safely. + node.args = normalized_args + node.kwargs = normalized_kwargs + continue + + try: + normalize_to_acc_op( + node, normalization_info, normalized_args, normalized_kwargs + ) + except Exception: + print(f"Error during normalization for node: {node.format_node()}") + raise + + # If there are any dead nodes left after normalization, eliminate them now. + mod.graph.eliminate_dead_code() diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py new file mode 100644 index 0000000..bc4dfb3 --- /dev/null +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -0,0 +1,1176 @@ +# encoding: utf-8 +# type: ignore[] +import operator + +import torch # isort:skip +import torch.fx.experimental.fx_acc.acc_utils as acc_utils +import torch.nn as nn +from torch.fx.experimental.fx_acc.acc_normalizer import ( + register_acc_op, + register_acc_op_mapping, + register_custom_acc_mapper_fn, +) +from torch.fx.passes.shape_prop import extract_tensor_metadata + +this_arg_is_optional = True + + +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.linear)) +@register_acc_op +def linear(*, input, weight, bias): + return nn.functional.linear(**locals()) + + +@register_acc_op +def quantized_linear(*, input, weight, bias, acc_out_ty=None): + assert acc_out_ty is not None + return nn.quantized.functional.linear( + input, + weight, + bias, + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), + ) + + +@register_acc_op_mapping( + op_and_target=("call_method", "flatten"), + arg_replacement_tuples=[ + ("input", "input"), + ("start_dim", "start_dim", this_arg_is_optional), + ("end_dim", "end_dim", this_arg_is_optional), + ], +) +@register_acc_op_mapping(op_and_target=("call_function", torch.flatten)) +@register_acc_op +def flatten(*, input, start_dim=0, end_dim=-1): + return torch.flatten(**locals()) + + +@register_acc_op_mapping( + op_and_target=( + "call_method", + "squeeze", + ), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ], +) +@register_acc_op +def squeeze(*, input, dim=None): + if dim is None: + return input.squeeze() + return input.squeeze(dim=dim) + + +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.max_pool2d)) +@register_acc_op +def max_pool2d( + *, input, kernel_size, stride, padding, dilation, ceil_mode, return_indices +): + return nn.functional.max_pool2d(**locals()) + + +@register_acc_op_mapping( + op_and_target=("call_function", nn.functional.adaptive_avg_pool2d) +) +@register_acc_op +def adaptive_avg_pool2d(*, input, output_size): + return nn.functional.adaptive_avg_pool2d(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.avg_pool2d)) +@register_acc_op +def avg_pool2d( + *, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, +): + return nn.functional.avg_pool2d(**locals()) + + +@register_acc_op +def size(*, input): + return input.size() + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", getattr), + arg_replacement_tuples=[], +) +def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: + """ + Custom function for mapping a call_function getattr to other ops. Currently only + supports loading a getattr called on a torch.Tensor with attr name "shape", which is + supported by mapping it to acc_ops.size(). + """ + # Have to use args here since getattr forces positional args. + input_obj = node.args[0] + attr_name = node.args[1] + assert ( + input_obj.meta["type"] == torch.Tensor + ), f"Expected torch.Tensor type for {input_obj.meta['type']}" + assert ( + attr_name == "shape" + ), f"Only supporting shape getattr for now, not {attr_name}" + with node.graph.inserting_before(node): + size_node = node.graph.call_function(size, kwargs={"input": input_obj}) + size_node.meta = node.meta.copy() + return size_node + + +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "size"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ], +) +def tensor_size_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: + """ + Mapping from Tensor.size() to acc_ops.size. We map size() to acc_ops.size directly + and map size(dim) to acc_ops.size + acc_ops.getitem. + """ + + with node.graph.inserting_before(node): + size_node = node.graph.call_function( + size, kwargs={"input": node.kwargs["input"]} + ) + + if "dim" not in node.kwargs: + size_node.meta = node.meta.copy() + return size_node + + size_node.meta["type"] = torch.Size + getitem_node = node.graph.call_function( + getitem, kwargs={"input": size_node, "idx": node.kwargs["dim"]} + ) + getitem_node.meta = node.meta.copy() + return getitem_node + + +@register_acc_op_mapping(op_and_target=("call_function", operator.add)) +@register_acc_op_mapping(op_and_target=("call_method", "add")) +@register_acc_op +def add(*, input, other): + return input + other + + +@register_acc_op_mapping(op_and_target=("call_function", torch.unsqueeze)) +@register_acc_op +def unsqueeze(*, input, dim): + return torch.unsqueeze(**locals()) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.stack), + arg_replacement_tuples=[ + ("tensors", "tensors"), + ("dim", "dim"), + ], +) +def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: + """ + Map torch.stack to unsqueeze + cat. + """ + with node.graph.inserting_before(node): + inputs = node.kwargs["tensors"] + unsqueeze_nodes = [] + for i, t in enumerate(inputs): + new_node = node.graph.create_node( + "call_function", + unsqueeze, + kwargs={"input": t, "dim": node.kwargs["dim"]}, + name=f"{node.name}_unsqueeze_{i}", + ) + new_node.meta["type"] = torch.Tensor + unsqueeze_nodes.append(new_node) + cat_node = node.graph.create_node( + "call_function", + cat, + kwargs={"tensors": unsqueeze_nodes, "dim": node.kwargs["dim"]}, + ) + cat_node.meta = node.meta.copy() + return cat_node + + +@register_acc_op_mapping(op_and_target=("call_function", torch.clamp)) +@register_acc_op_mapping(op_and_target=("call_method", "clamp")) +@register_acc_op +def clamp(*, input, min, max): + return torch.clamp(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.cat)) +@register_acc_op +def cat(*, tensors, dim): + return torch.cat(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.transpose)) +@register_acc_op_mapping(op_and_target=("call_method", "transpose")) +@register_acc_op +def transpose(*, input, dim0, dim1): + if input.dim() < 2: + return input + return torch.transpose(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.softmax)) +@register_acc_op +def softmax(*, input, dim, dtype): + """ + _stacklevel are ignored here. + """ + return torch.nn.functional.softmax(**locals()) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.addmm), + arg_replacement_tuples=[ + ("input", "input"), + ("mat1", "mat1"), + ("mat2", "mat2"), + ("beta", "beta"), + ("alpha", "alpha"), + ], +) +def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: + """ + Mapping from torch.addmm to acc_ops.mm -> acc_ops.add, if alpha or beta is not 1 + then we also insert acc_ops.mul to the right place. + """ + with node.graph.inserting_before(node): + mm_kwargs = {"input": node.kwargs["mat1"], "other": node.kwargs["mat2"]} + mm_node = node.graph.create_node( + "call_function", matmul, kwargs=mm_kwargs, name=f"{node.name}_mm" + ) + mm_node.meta = node.meta.copy() + + if node.kwargs["alpha"] != 1: + mul_kwargs = {"input": mm_node, "other": node.kwargs["alpha"]} + mm_node = node.graph.create_node( + "call_function", mul, kwargs=mul_kwargs, name=f"{mm_node.name}_mul" + ) + mm_node.meta = node.meta.copy() + + input_node = node.kwargs["input"] + if node.kwargs["beta"] != 1: + mul_kwargs = {"input": input_node, "other": node.kwargs["beta"]} + new_input_node = node.graph.create_node( + "call_function", mul, kwargs=mul_kwargs, name=f"{node.name}_input_mul" + ) + new_input_node.meta = input_node.meta.copy() + input_node = new_input_node + + add_kwargs = {"input": mm_node, "other": input_node} + add_node = node.graph.create_node( + "call_function", add, kwargs=add_kwargs, name=f"{node.name}_add" + ) + add_node.meta = node.meta.copy() + return add_node + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.t), + arg_replacement_tuples=[ + ("input", "input"), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "t"), + arg_replacement_tuples=[ + ("input", "input"), + ], +) +def t_mapper(node: torch.fx.Node, _: nn.Module): + with node.graph.inserting_before(node): + new_node = node.graph.create_node( + "call_function", + transpose, + kwargs={"input": node.kwargs["input"], "dim0": 0, "dim1": 1}, + name=node.name, + ) + new_node.meta = node.meta.copy() + return new_node + + +@register_acc_op_mapping( + op_and_target=("call_method", "permute"), + arg_replacement_tuples=[ + ("input", "input"), + ("*", "permutation"), + ], +) +@register_acc_op +def permute(*, input, permutation): + return input.permute(*permutation) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.square), + arg_replacement_tuples=[ + ("input", "input"), + ], +) +def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: + input_node = node.kwargs["input"] + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + mul, kwargs={"input": input_node, "other": input_node} + ) + new_node.meta = node.meta.copy() + return new_node + + +@register_acc_op_mapping( + op_and_target=("call_function", torch.bmm), + arg_replacement_tuples=[ + ("input", "input"), + ("mat2", "other"), + ], +) +@register_acc_op_mapping(op_and_target=("call_function", torch.matmul)) +@register_acc_op +def matmul(*, input, other): + return torch.matmul(**locals()) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.min), + arg_replacement_tuples=[ + ("input", "input"), + ("other", "other", this_arg_is_optional), + ("dim", "dim", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) +def custom_torch_min_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + """ + Add custom mapping for torch.min because torch.min has three input types, where each yields a different output type: + 1. torch.min(input); Output: tensor wih the minimum number + 2. torch.min(input, other); Output: tensor with coordinate-wise min value + 3[Not Supported] torch.min(input, dim, keepdim); Output:(min_values,and min_indices) + """ + with node.graph.inserting_before(node): + # If dim is in kwargs, assert "Not Supported" + assert "dim" not in node.kwargs, "Currently not support dim in torch.min" + + if "other" in node.kwargs and node.kwargs["other"] is not None: + # If kwargs[other] is a valid tensor, call min_two_tensors_input, + op_func = min_two_tensors_input + else: + # Otherwise, call min_single_tensor_input + op_func = min_single_tensor_input + + new_node = node.graph.create_node( + "call_function", + op_func, + kwargs=node.kwargs, + name=node.name, + ) + new_node.meta = node.meta + return new_node + + +@register_acc_op +def min_single_tensor_input(*, input): + return torch.min(input) + + +@register_acc_op +def min_two_tensors_input(*, input, other): + return torch.min(input, other) + + +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.quantized.add), + arg_replacement_tuples=[ + ("qa", "input"), + ("qb", "other"), + ("scale", "scale"), + ("zero_point", "zero_point"), + ], + kwargs_to_move_to_acc_out_ty=[ + ("scale", "q_scale"), + ("zero_point", "q_zero_point"), + ], +) +@register_acc_op +def quantized_add(*, input, other, acc_out_ty=None): + assert acc_out_ty is not None + return torch.ops.quantized.add( + input, + other, + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), + ) + + +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.quantized.mul), + arg_replacement_tuples=[ + ("qa", "input"), + ("qb", "other"), + ("scale", "scale"), + ("zero_point", "zero_point"), + ], + kwargs_to_move_to_acc_out_ty=[ + ("scale", "q_scale"), + ("zero_point", "q_zero_point"), + ], +) +@register_acc_op +def quantized_mul(*, input, other, acc_out_ty=None): + assert acc_out_ty is not None + return torch.ops.quantized.mul( + input, + other, + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), + ) + + +@register_acc_op_mapping( + op_and_target=("call_function", torch.quantize_per_tensor), + arg_replacement_tuples=[ + ("input", "input"), + ("scale", "scale"), + ("zero_point", "zero_point"), + ("dtype", "dtype"), + ], + kwargs_to_move_to_acc_out_ty=[ + ("dtype", "dtype"), + ("scale", "q_scale"), + ("zero_point", "q_zero_point"), + ], +) +@register_acc_op +def quantize_per_tensor(*, input, acc_out_ty=None): + assert acc_out_ty is not None + return torch.quantize_per_tensor( + input, + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"), + ) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.dequantize)) +@register_acc_op_mapping(op_and_target=("call_method", "dequantize")) +@register_acc_op +def dequantize(*, input): + return torch.dequantize(input) + + +@register_acc_op_mapping(op_and_target=("call_function", operator.sub)) +@register_acc_op +def sub(*, input, other): + return input - other + + +@register_acc_op_mapping(op_and_target=("call_function", operator.mul)) +@register_acc_op +def mul(*, input, other): + return input * other + + +@register_acc_op_mapping(op_and_target=("call_function", operator.truediv)) +@register_acc_op +def div(*, input, other): + return input / other + + +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.relu)) +@register_acc_op_mapping( + op_and_target=("call_function", torch.relu), + arg_replacement_tuples=[("input", "input")], +) +@register_acc_op_mapping( + op_and_target=("call_method", "relu"), + arg_replacement_tuples=[("input", "input")], +) +@register_acc_op +def relu(*, input, inplace=False): + return nn.functional.relu(**locals()) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "sum"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ("dtype", "dtype", this_arg_is_optional), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.sum), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ("dtype", "dtype", this_arg_is_optional), + ], +) +def add_sum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node: + with node.graph.inserting_before(node): + sum_kwargs = dict(node.kwargs) + if "dim" in sum_kwargs and isinstance(sum_kwargs["dim"], int): + sum_kwargs["dim"] = (sum_kwargs["dim"],) + sum_node = node.graph.call_function(sum, kwargs=sum_kwargs) + sum_node.meta = node.meta.copy() + return sum_node + + +@register_acc_op +def sum(*, input, dim=None, keepdim=False, dtype=None): + if dim: + return torch.sum(**locals()) + else: + return input.sum(dtype=dtype) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.sigmoid)) +@register_acc_op_mapping(op_and_target=("call_method", "sigmoid")) +@register_acc_op +def sigmoid(*, input): + return torch.sigmoid(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.sinh)) +@register_acc_op +def sinh(*, input): + return torch.sinh(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.cosh)) +@register_acc_op +def cosh(*, input): + return torch.cosh(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.tanh)) +@register_acc_op +def tanh(*, input): + return torch.tanh(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.asin)) +@register_acc_op +def asin(*, input): + return torch.asin(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.acos)) +@register_acc_op +def acos(*, input): + return torch.acos(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.atan)) +@register_acc_op +def atan(*, input): + return torch.atan(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.exp)) +@register_acc_op +def exp(*, input): + return torch.exp(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.log)) +@register_acc_op +def log(*, input): + return torch.log(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.sqrt)) +@register_acc_op +def sqrt(*, input): + return torch.sqrt(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.reciprocal)) +@register_acc_op +def reciprocal(*, input): + return torch.reciprocal(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.abs)) +@register_acc_op +def abs(*, input): + return torch.abs(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.neg)) +@register_acc_op +def neg(*, input): + return torch.neg(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.floor)) +@register_acc_op +def floor(*, input): + return torch.floor(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.ceil)) +@register_acc_op +def ceil(*, input): + return torch.ceil(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.conv2d)) +@register_acc_op +def conv2d(*, input, weight, bias, stride, padding, dilation, groups): + return nn.functional.conv2d(**locals()) + + +@register_acc_op +def quantized_conv2d( + *, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + padding_mode, + acc_out_ty=None, +): + assert acc_out_ty is not None + return torch.nn.quantized.functional.conv2d( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + padding_mode, + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), + ) + + +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.batch_norm)) +@register_acc_op +def batch_norm( + *, input, running_mean, running_var, weight, bias, training, momentum, eps +): + return nn.functional.batch_norm(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.layer_norm)) +@register_acc_op +def layer_norm(*, input, normalized_shape, weight, bias, eps): + return nn.functional.layer_norm(**locals()) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "split"), + arg_replacement_tuples=[ + ("tensor", "input"), + ("split_size_or_sections", "split_size_or_sections"), + ("dim", "dim"), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.split), + arg_replacement_tuples=[ + ("tensor", "input"), + ("split_size_or_sections", "split_size_or_sections"), + ("dim", "dim"), + ], +) +def torch_split_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + """ + If split_size_or_sections is sections, map the node to slice_tensors + + tuple_construct. Otherwise, if split_size_or_sections is split_size, + map the node to acc_ops.split. + """ + split_size_or_sections = node.kwargs["split_size_or_sections"] + with node.graph.inserting_before(node): + if isinstance(split_size_or_sections, int): + new_kwargs = { + "input": node.kwargs["input"], + "split_size": split_size_or_sections, + "dim": node.kwargs["dim"], + } + new_node = node.graph.call_function(split, kwargs=new_kwargs) + new_node.meta = node.meta.copy() + return new_node + + start = 0 + slice_nodes = [] + for i in split_size_or_sections: + new_kwargs = { + "input": node.kwargs["input"], + "dims": (node.kwargs["dim"],), + "starts": (start,), + "stops": (start + i,), + "steps": (1,), + } + new_node = node.graph.call_function(slice_tensor, kwargs=new_kwargs) + new_node.meta["type"] = torch.Tensor + slice_nodes.append(new_node) + start += i + + new_node = node.graph.call_function( + tuple_construct, kwargs={"tensors": tuple(slice_nodes)} + ) + new_node.meta = node.meta.copy() + return new_node + + +@register_acc_op +def split(*, input, split_size, dim): + return torch.split(input, split_size, dim) + + +@register_acc_op +def tuple_construct(*, tensors): + return tuple(tensors) + + +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.quantized.batch_norm2d), + arg_replacement_tuples=[ + ("input", "input"), + ("weight", "weight"), + ("bias", "bias"), + ("running_mean", "running_mean"), + ("running_var", "running_var"), + ("eps", "eps"), + ("scale", "scale"), + ("zero_point", "zero_point"), + ], + kwargs_to_move_to_acc_out_ty=[ + ("scale", "q_scale"), + ("zero_point", "q_zero_point"), + ], +) +@register_acc_op +def quantized_batch_norm2d( + *, input, running_mean, running_var, weight, bias, eps, acc_out_ty +): + return torch.ops.quantized.batch_norm2d( + input, + weight, + bias, + running_mean, + running_var, + eps, + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), + acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), + ) + + +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.embedding_bag)) +@register_acc_op +def embedding_bag( + *, + input, + weight, + offsets, + max_norm, + norm_type, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return nn.functional.embedding_bag(**locals()) + + +@register_acc_op_mapping( + op_and_target=( + "call_function", + torch.ops.quantized.embedding_bag_byte_rowwise_offsets, + ) +) +@register_acc_op +def embedding_bag_byte_rowwise_offsets( + *, + weight, + input, + offsets, + scale_grad_by_freq, + mode, + pruned_weights, + per_sample_weights, + compressed_indices_mapping, + include_last_offset, +): + return torch.ops.quantized.embedding_bag_byte_rowwise_offsets(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.sin)) +@register_acc_op +def sin(*, input): + return torch.sin(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.cos)) +@register_acc_op +def cos(*, input): + return torch.cos(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.tan)) +@register_acc_op +def tan(*, input): + return torch.tan(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.topk)) +@register_acc_op +def topk(*, input, k, dim, largest, sorted): + return torch.topk(**locals()) + + +@register_acc_op_mapping(op_and_target=("call_function", operator.getitem)) +@register_acc_op +def getitem(*, input, idx): + return input[idx] + + +@register_acc_op +def slice_tensor(*, input, dims, starts, stops, steps): + slices = [None for _ in range(input.dim())] + + # For all provided dims, extract out a slice for starts/stops/steps. + for idx, dim in enumerate(dims): + slices[dim] = slice(starts[idx], stops[idx], steps[idx]) + + # For all unspecified dims, default to the full slice. + for idx, s in enumerate(slices): + if s is None: + slices[idx] = slice(None, None, None) + + return input[slices] + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.narrow), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("start", "start"), + ("length", "length"), + ], +) +def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + kwargs = { + "input": node.kwargs["input"], + "dims": (node.kwargs["dim"],), + "starts": (node.kwargs["start"],), + "stops": (node.kwargs["start"] + node.kwargs["length"],), + "steps": (1,), + } + with node.graph.inserting_before(node): + new_node = node.graph.call_function(slice_tensor, kwargs=kwargs) + new_node.meta = node.meta.copy() + return new_node + + +@register_acc_op_mapping( + op_and_target=("call_function", torch.reshape), + arg_replacement_tuples=[ + ("input", "input"), + ("shape", "shape"), + ], + kwargs_to_move_to_acc_out_ty=[("shape", "shape")], +) +@register_acc_op_mapping( + op_and_target=("call_method", "view"), + arg_replacement_tuples=[ + ("input", "input"), + ("*", "shape"), + ], + kwargs_to_move_to_acc_out_ty=[("shape", "shape")], +) +@register_acc_op +def reshape(*, input, acc_out_ty=None): + assert acc_out_ty is not None + return torch.reshape( + input, tuple(acc_utils.get_field_from_acc_out_ty(acc_out_ty, "shape")) + ) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "reshape"), + arg_replacement_tuples=[ + ("input", "input"), + ("*", "shape"), + ], +) +def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: + """ + For Tensor.reshape node, args could be (input, 1, 2, 3) or (input, (1, 2, 3)). + Here we do some special handling with the `shape` arg in order to map it to + acc_ops.reshape. It also handles the case when `shape` is a list instead of + tuple. + """ + input_node = node.kwargs["input"] + shape = node.kwargs["shape"] + + if isinstance(shape[0], (tuple, list)): + shape = shape[0] + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + reshape, + kwargs={ + "input": input_node, + "acc_out_ty": acc_utils.build_raw_tensor_meta(shape=shape), + }, + ) + new_node.meta = node.meta.copy() + return new_node + + +@register_acc_op +def to_dtype(input, acc_out_ty=None): + assert acc_out_ty is not None, "valid acc_out_ty needed" + return input.to(dtype=acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype")) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "to"), + arg_replacement_tuples=[ + ("input", "input"), + ("dtype", "dtype"), + ], +) +def custom_tensor_to_mapper(node: torch.fx.Node, _: nn.Module): + dest_dtype = node.kwargs["dtype"] + mem_format = node.kwargs.get("memory_format") + device = node.kwargs.get("device") + assert dest_dtype is not None + assert mem_format is None or mem_format == torch.preserve_format + assert device is None + + new_kwargs = { + "input": node.kwargs["input"], + "acc_out_ty": acc_utils.build_raw_tensor_meta(dtype=dest_dtype), + } + + with node.graph.inserting_before(node): + new_node = node.graph.create_node( + "call_function", to_dtype, kwargs=new_kwargs, name=node.name + ) + new_node.meta = node.meta + return new_node + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.add), + # Note that we may have aliases for inputs here due to issues with deterministically + # knowing the correct target that will be resolved by pytorch. + arg_replacement_tuples=[ + (("input", "a"), "input"), + (("other", "b"), "other"), + ("alpha", "alpha", this_arg_is_optional), + ], +) +def custom_torch_add_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + """ + Add custom mapping for torch.add because it has an `alpha` parameter which scales + the `other` input, and we want to make that mul a separate node. + """ + with node.graph.inserting_before(node): + # If alpha is in kwargs check if we need to add a mul, and use correct kwargs. + if "alpha" in node.kwargs: + # Add mul node only if it has a numerical impact, i.e. alpha != 1.0. + if node.kwargs["alpha"] != 1.0: + other_node = node.graph.create_node( + "call_function", + mul, + kwargs={ + "input": node.kwargs["other"], + "other": node.kwargs["alpha"], + }, + name=node.name + "_mul_alpha", + ) + other_node.meta = node.meta + else: + other_node = node.kwargs["other"] + add_kwargs = {"input": node.kwargs["input"], "other": other_node} + else: + add_kwargs = node.kwargs + + new_node = node.graph.create_node( + "call_function", add, kwargs=add_kwargs, name=node.name + ) + new_node.meta = node.meta + return new_node + + +@register_custom_acc_mapper_fn( + op_and_target=("call_module", nn.quantized.Linear), + arg_replacement_tuples=[ + ("input", "input"), + ], +) +def packed_quantized_linear_mapper( + node: torch.fx.Node, mod: nn.Module +) -> torch.fx.Node: + """ + Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias + in this mapper and pass them directly to linear node. + """ + linear_module = dict(mod.named_modules())[node.target] + prefix = node.target.replace(".", "_") + weight_name = f"{prefix}_weight" + bias_name = f"{prefix}_bias" + + # Store weight and bias in the main module + mod.register_buffer(weight_name, linear_module.weight()) + if linear_module.bias() is not None: + mod.register_buffer(bias_name, linear_module.bias()) + + with node.graph.inserting_before(node): + # Insert get_attr nodes for weight and bias + get_weight = node.graph.get_attr(weight_name) + get_weight.meta["tensor_meta"] = extract_tensor_metadata(linear_module.weight()) + + get_bias = None + if linear_module.bias() is not None: + get_bias = node.graph.get_attr(bias_name) + get_bias.meta["tensor_meta"] = extract_tensor_metadata(linear_module.bias()) + + # Create kwargs for acc_op.quantized_linear + kwargs = { + "input": node.kwargs["input"], + "weight": get_weight, + "bias": get_bias, + "acc_out_ty": acc_utils.build_raw_tensor_meta( + q_scale=linear_module.scale, q_zero_point=linear_module.zero_point + ), + } + + new_node = node.graph.call_function(quantized_linear, kwargs=kwargs) + new_node.meta = node.meta + return new_node + + +@register_custom_acc_mapper_fn( + op_and_target=("call_module", nn.quantized.Conv2d), + arg_replacement_tuples=[ + ("input", "input"), + ], +) +def packed_quantized_conv2d_mapper( + node: torch.fx.Node, mod: nn.Module +) -> torch.fx.Node: + """ + Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters + in this mapper and pass them directly to conv2d node. + """ + conv_module = dict(mod.named_modules())[node.target] + prefix = node.target.replace(".", "_") + weight_name = f"{prefix}_weight" + bias_name = f"{prefix}_bias" + + # Store weight and bias in the main module + mod.register_buffer(weight_name, conv_module.weight()) + if conv_module.bias() is not None: + mod.register_buffer(bias_name, conv_module.bias()) + + with node.graph.inserting_before(node): + # Insert get_attr nodes for weight and bias + get_weight = node.graph.get_attr(weight_name) + get_weight.meta["tensor_meta"] = extract_tensor_metadata(conv_module.weight()) + + get_bias = None + if conv_module.bias() is not None: + get_bias = node.graph.get_attr(bias_name) + get_bias.meta["tensor_meta"] = extract_tensor_metadata(conv_module.bias()) + + # Create kwargs for acc_op.conv + kwargs = { + "input": node.kwargs["input"], + "weight": get_weight, + "bias": get_bias, + "stride": conv_module.stride, + "padding": conv_module.padding, + "dilation": conv_module.dilation, + "groups": conv_module.groups, + "padding_mode": conv_module.padding_mode, + "acc_out_ty": acc_utils.build_raw_tensor_meta( + q_scale=conv_module.scale, q_zero_point=conv_module.zero_point + ), + } + + new_node = node.graph.call_function(quantized_conv2d, kwargs=kwargs) + new_node.meta = node.meta + return new_node + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.ops.quantized.add_relu), + arg_replacement_tuples=[ + ("input", "input"), + ("other", "other"), + ("scale", "scale"), + ("zero_point", "zero_point"), + ], +) +def add_relu_unfuse_mapper( + node: torch.fx.Node, mod: torch.fx.GraphModule +) -> torch.fx.Node: + with node.graph.inserting_before(node): + add_kwargs = { + "input": node.kwargs["input"], + "other": node.kwargs["other"], + "acc_out_ty": acc_utils.build_raw_tensor_meta( + q_scale=node.kwargs["scale"], + q_zero_point=node.kwargs["zero_point"], + ), + } + add_node = node.graph.call_function(quantized_add, kwargs=add_kwargs) + add_node.meta = node.meta.copy() + + relu_node = node.graph.call_function( + relu, kwargs={"input": add_node, "inplace": False} + ) + relu_node.meta = node.meta + return relu_node + + +@register_custom_acc_mapper_fn( + op_and_target=("call_module", nn.intrinsic.quantized.ConvReLU2d), + arg_replacement_tuples=[ + ("input", "input"), + ], +) +def packed_quantized_convrelu2d_mapper( + node: torch.fx.Node, mod: nn.Module +) -> torch.fx.Node: + """ + Mapping from quantized ConvReLU2d module to acc_op.relu. We use packed_quantized_conv2d_mapper to unpack all the parameters + in this mapper and pass the returned conv2d node directly to relu node. + """ + + with node.graph.inserting_before(node): + # conv2d op + conv2d_node = packed_quantized_conv2d_mapper(node, mod) + + # relu op + relu_node = node.graph.call_function( + relu, kwargs={"input": conv2d_node, "inplace": False} + ) + relu_node.meta = node.meta + return relu_node diff --git a/torch/fx/experimental/fx_acc/acc_tracer.py b/torch/fx/experimental/fx_acc/acc_tracer.py new file mode 100644 index 0000000..1f6a365 --- /dev/null +++ b/torch/fx/experimental/fx_acc/acc_tracer.py @@ -0,0 +1,433 @@ +# type: ignore[] +import ast +import builtins +import copy +import inspect +import textwrap +import warnings +from types import FunctionType +from typing import Dict, Optional, Any, Type, Tuple, Set, List + +import torch.fx.experimental.fx_acc.acc_normalizer as acc_normalizer +import torch.fx.experimental.fx_acc.acc_ops # noqa: F401 +import torch +import torch.nn as nn +from torch._sources import normalize_source_lines +from torch.fx import Graph, Tracer +from torch.fx.experimental.normalize import NormalizeArgs +from torch.fx.passes import shape_prop + + +def _get_exception_wrapper_attr_name(exc_type: Type[Exception]) -> str: + return f"_conditional_exception_wrapper_{exc_type.__name__}" + + +class Acc_Rewriter(ast.NodeTransformer): + """ + Take a FunctionType object representing a `forward` method, then + perform an AST rewrite to swap out nodes that are not symbolically + traceable with a callsite to the FX alternative. + + To support swapping out an AST node, define a new `visit` method on + that node. For more details, see: + https://docs.python.org/3/library/ast.html#ast.NodeTransformer + """ + + def __init__(self): + super().__init__() + self.exceptions_rewritten: Set[Type[Exception]] = set() + + def rewrite(self, fn: FunctionType) -> Tuple[FunctionType, Set[Type[Exception]]]: + + # Normalize the source lines + sourcelines, _ = inspect.getsourcelines(fn) + sourcelines = normalize_source_lines(sourcelines) + source = "".join(sourcelines) + normalized_str = textwrap.dedent(source) + + # Rewrite the original AST + source_ast = ast.parse(normalized_str) + dest_ast = ast.fix_missing_locations(self.visit(source_ast)) + + # Pull out the compiled function from the newly-created Module + code = compile(dest_ast, "", "exec") + globals_dict = copy.copy(fn.__globals__) + keys_before = set(globals_dict.keys()) + exec(code, globals_dict) + new_keys = list(set(globals_dict.keys()) - keys_before) + assert len(new_keys) <= 1 + fn_compiled = globals_dict[fn.__name__] + + # Return the correct FunctionType object and the Exceptions that were + # rewritten during visit_If. + return fn_compiled, self.exceptions_rewritten + + def visit_Assert(self, node: ast.Assert): + """ + Swap out the Assert node (Python's `assert`) with a callsite to the + symbolically-traceable torch._assert function + """ + # Create the Call node + n = ast.parse("torch._assert()", mode="eval") + assert isinstance(n, ast.Expression) + call_node = n.body + assert isinstance(call_node, ast.Call) + msg = node.msg if node.msg else ast.Constant(value="", kind=None) + call_node.args = [node.test, msg] + + # Ensure that the new node conforms to the Python AST grammar + expr_wrapper = ast.Expr(value=call_node) + + # Return the new Call node to signify that we want to use it as + # a replacement for the original _assert node + return ast.copy_location(expr_wrapper, node) + + def visit_If(self, if_node: ast.If): + """ + Swap out the pattern `If(x): Raise(y)` with a ConditionalExceptionWrapper + specialized for the specific exception y. The specialized + ConditionalExceptionWrapper module will be added in the RewrittenModule. + Only works with builtin Exceptions, as we assume the signature of the + init for the Exception is a string. + """ + raise_node = if_node.body[0] + if not isinstance(raise_node, ast.Raise): + return if_node + + # Don't handle orelse for now. + # TODO: Move orelse to the body after calling ConditionalExceptionWrapper. + if len(if_node.orelse) != 0: + return if_node + + def _reuse_loc(node): + return ast.copy_location(node, if_node) + + # If the exception has a message then we expect the raise's exc to be a + # Call w/ a msg. Else if it's a exc Name then there's no msg to use. + node_for_exc = raise_node.exc + if isinstance(node_for_exc, ast.Name): + # E.g. `raise AssertionError`, i.e. without an exc_msg. + name_node_of_exc = node_for_exc + exc_msg = _reuse_loc(ast.Constant(None)) + elif isinstance(node_for_exc, ast.Call): + # E.g. `raise AssertionError("error message")` + name_node_of_exc = node_for_exc.func + if not isinstance(name_node_of_exc, ast.Name): + return if_node + # Most assertions just take a single string arg, but some may not; skip + # handling such assertions for now. + if len(node_for_exc.args) != 1: + return if_node + exc_msg = node_for_exc.args[0] + else: + return if_node + + # Convert what we expect is the name of the exception into its + # associated python class. + name_of_exc = name_node_of_exc.id + try: + exc_type = eval(name_of_exc) + except Exception: + return if_node + + # Check that we actually have a builtin exception. + if ( + not issubclass(exc_type, Exception) + or getattr(getattr(exc_type, "__class__", None), "__module__", None) + != "builtins" + ): + return if_node + + # We need a ConditionalExceptionWrapper specialized for every kind of + # exception, so add it to exceptions_rewritten to remember for later to + # add a specialized attr with it. + self.exceptions_rewritten.add(exc_type) + + # From here we definitely should be able to do the replacement. Create a + # Call node to the ConditionalExceptionWrapper module we're replacing + # the If with, with args set as the If's condition and the string of the + # exception. The call to the self._conditional_exception_wrapper_*Error + # module is safe because the RewrittenModule will add it as an attr + # based on the returned exceptions_rewritten, and we assume we are + # currently modifying the AST of a method from a RewrittenModule. + exc_wrapper_node = ast.parse( + f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval" + ) + assert isinstance(exc_wrapper_node, ast.Expression) + exc_wrapper_call_node = exc_wrapper_node.body + assert isinstance(exc_wrapper_call_node, ast.Call) + exc_wrapper_call_node.args = [if_node.test, exc_msg] + + # Ensure that the new node conforms to the Python AST grammar + expr_wrapper = _reuse_loc(ast.Expr(_reuse_loc(exc_wrapper_call_node))) + + # Return the new node to signify that we want to use it as a replacement + # for the original `If x: Raise y` pattern. + return expr_wrapper + + +class ConditionalExceptionWrapper(nn.Module): + """ + This wrapper class is used to wrap conditional raising of exceptions during + rewriting. For example: + + .. code-block:: python + + if self.name != "x": + raise AssertionError(f"Name was not x: {self.name}") + + Is rewritten into + + .. code-block:: python + + self._conditional_exception_wrapper_AssertionError( + self.name != "x", f"Name was not x: {self.name}" + ) + + Note that __init__ takes the Exception class that it is wrapping, while + forward takes the condition to check and the message for the exception. + + """ + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__(self, exc: Type[Exception]): + super().__init__() + self.exc = exc + + def forward(self, cond: bool, msg: str): + if cond: + raise self.exc if msg is None else self.exc(msg) + + +# Custom tracer that traces to the functional level and rewrites asserts and +# exceptions. +class AccRewritingTracer(Tracer): + # Note: Treat ConditionalExceptionWrapper as a leaf so that we don't + # trace into it, because it contains control flow and raises an exception. + DEFAULT_LEAF_MODULE_LIST = { + ConditionalExceptionWrapper, + torch.nn.quantized.Linear, + torch.nn.quantized.Conv2d, + torch.nn.intrinsic.quantized.ConvReLU2d, + } + + def is_leaf_module(self, m: nn.Module, mod_qual_name: str) -> bool: + return getattr(m, "_base_class_origin", type(m)) in self.leaf_module_list + + def trace( + self, + root: nn.Module, + concrete_args: Optional[Dict[str, Any]] = None, + ast_rewriter_allow_list: Optional[Set] = None, + leaf_module_list: Optional[Set] = None, + ) -> Tuple[Graph, nn.Module]: + rewritten = _rewrite(root, ast_rewriter_allow_list) + self.leaf_module_list = self.DEFAULT_LEAF_MODULE_LIST + if leaf_module_list: + self.leaf_module_list.update(leaf_module_list) + return super().trace(rewritten, concrete_args), rewritten + + +# List of modules that need rewriting to be supported for tracing. +DEFAULT_REWRITE_ALLOW_LIST = { + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, +} + + +def _rewrite(mod_to_rewrite: nn.Module, allow_list: Optional[Set] = None) -> nn.Module: + if allow_list is None: + allow_list = DEFAULT_REWRITE_ALLOW_LIST + else: + allow_list.union(DEFAULT_REWRITE_ALLOW_LIST) + + # Rewrite this module's functions as well as all recursive modules' + # functions that are attrs of this moodule. Return the new, rewritten module + # hierarchy. + def rewrite_module(m: nn.Module): + base_class = type(m) + + # Keep track of all the ConditionalExceptionWrappers that the + # Acc_Rewriter calls into in this module so we can add them in init + # below. + all_added_wrappers: Set[Type[Exception]] = set() + + # Note: Make this a subclass of our base class. + class RewrittenModule(base_class): + # Keep track of the base_class so that symbolic tracing can + # determine what kind of module this originally was later on. + _base_class_origin = base_class + # Add suffix to qualname so it's easier to debug the origin of this module. + __qualname__ = f"{base_class.__qualname__}__AccRewrittenModule" + + # Write all of the non-dunder or special methods from base_class + # into RewrittenModule. + for method_name in dir(base_class): + method = getattr(base_class, method_name) + if builtins.type(method) is not FunctionType: + continue + + # Always skip rewriting dunder methods, as they haven't (yet) been + # problematic, and modifying them has caused issues previously. + if method_name.startswith("__") and method_name.endswith("__"): + continue + + # Only rewrite those Modules explicitly in the allow_list. + if base_class not in allow_list: + vars()[method_name] = method + else: + vars()[method_name], added_wrappers = Acc_Rewriter().rewrite(method) + all_added_wrappers.update(added_wrappers) + + def __init__(self, orig): + nn.Module.__init__(self) + # Iterate over all added exception wrappers and add + # ConditionalExceptionWrapper attrs for each. + for exc_type in all_added_wrappers: + wrapper_name = _get_exception_wrapper_attr_name(exc_type) + assert not hasattr(self, wrapper_name) + setattr( + self, + wrapper_name, + ConditionalExceptionWrapper(exc_type), + ) + # Recursively rewrite and copy all module attrs of this module. + for k, v in orig.__dict__.items(): + if k == "_modules": + for mod_k, mod_v in v.items(): + self._modules[mod_k] = rewrite_module(mod_v) + else: + self.__dict__[k] = v + + # Add suffix to name so it's easier to debug the origin of this module. + RewrittenModule.__name__ = f"{base_class.__name__}__AccRewrittenModule" + return RewrittenModule(m) + + return rewrite_module(mod_to_rewrite) + + +def _remove_assertions(gm: torch.fx.GraphModule) -> bool: + """ + Unconditionally removes all assertions found in GraphModule gm. + Returns whether the graph is modified. + """ + changed = False + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._assert: + gm.graph.erase_node(node) + changed = True + return changed + + +def _remove_exceptions(gm: torch.fx.GraphModule) -> bool: + """ + Unconditionally removes all call_modules to ConditionalExceptionWrappers + found in GraphModule gm. Returns whether the graph is modified. + """ + changed = False + for node in gm.graph.nodes: + if node.op == "call_module" and isinstance( + gm.get_submodule(node.target), ConditionalExceptionWrapper + ): + gm.graph.erase_node(node) + changed = True + return changed + + +def trace( + mod: nn.Module, + sample_inputs: List[torch.Tensor], + remove_assertions: bool = True, + remove_exceptions: bool = True, + use_acc_normalization: bool = True, + ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None, + leaf_module_list: Optional[Set[Type[nn.Module]]] = None, +) -> torch.fx.GraphModule: + """ + Performs tracing and arg normalization specialized for accelerator lowering. + + It first rewrites the AST of the module's methods (and all attr methods + recursively) to transform un-tracable parts of the module to make them + traceable. + + It then traces to the functional level so that optimizations and backend + accelerator importers have the ability to see and/or change inputs to each + op. + + It then removes assertions and exception wrappers found during symbolic + tracing if requested based on remove_assertions and remove_exceptions + + Dead code is then eliminated, which will e.g. remove any nodes that were + only used by assertions or exceptions if they were removed. + + It then performs normalization on args/kwargs, aligning any arg that can be + moved to kwarg to be so, and then making default values explicit. + + Args: + + mod (Module): The module to transform and trace. + + sample_inputs (Tuple[Union[torch.Tensor, List[torch.Tensor]]]): + Sample inputs with which to run shape prop. + + remove_assertions (bool): Whether to remove assertion nodes from + the graph after symbolic tracing. + + remove_exceptions (bool): Whether to remove exception wrapper nodes + from the graph after symbolic tracing. + + use_acc_normalization (bool): Whether to use acc-specific + normalization to all acc_ops. + + ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of + modules that need AST rewriting. + + leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where + modules will not be traced into. + + """ + if mod.training: + warnings.warn( + "acc_tracer does not support currently support models for training." + " Calling eval on model before tracing." + ) + mod.eval() + + # Rewrite the module to make it symbolic traceable, and then trace it. + rewritten_graph, rewritten_mod = AccRewritingTracer().trace( + mod, + ast_rewriter_allow_list=ast_rewriter_allow_list, + leaf_module_list=leaf_module_list, + ) + + assert isinstance(rewritten_mod, nn.Module) + # Note: use the rewritten_mod here as the root. This is necessary because + # RewrittenModule includes a new module for the ConditionalExceptionWrapper. + traced = torch.fx.GraphModule(rewritten_mod, rewritten_graph) + + # Now remove all assertions and exceptions if requested. + if remove_assertions: + _remove_assertions(traced) + if remove_exceptions: + _remove_exceptions(traced) + + # Cleanup any dead code from the original module as well as resulting dead + # nodes after removing assertions and exceptions. + traced.graph.eliminate_dead_code() + + # Now normalize args/kwargs to make default values visible. Leave args/kwargs as + # they were, since all-kwarg normalization is broken, and we don't need it anyway. + shape_prop.ShapeProp(traced).propagate(*sample_inputs) + traced = NormalizeArgs(traced, normalize_to_only_use_kwargs=False).transform() + + # Normalize to acc-specialized wrappers for consistency across op naming and + # ensuring all kwarg usage. + if use_acc_normalization: + acc_normalizer.normalize(traced) + + traced.recompile() + + return traced diff --git a/torch/fx/experimental/fx_acc/acc_utils.py b/torch/fx/experimental/fx_acc/acc_utils.py new file mode 100644 index 0000000..a77f3fa --- /dev/null +++ b/torch/fx/experimental/fx_acc/acc_utils.py @@ -0,0 +1,91 @@ +import inspect +import json +from typing import Any, Tuple, Callable, Union, Dict + +import torch +import torch.fx +from torch.fx.experimental.graph_manipulation import ( + serialize_module, +) +from torch.fx.graph_module import GraphModule +from torch.fx.passes import graph_drawer +from torch.fx.passes.shape_prop import TensorMetadata + + +def is_acc_op(node_or_target: Union[Callable, torch.fx.Node]) -> bool: + """ + Returns whether `node_or_target` is an acc_op. If it's a node, then checks whether + it's a call_function target is from the acc_ops module. Otherwise it's already + the target, which is similarly checked to see if it's from the acc_ops module. + """ + if isinstance(node_or_target, torch.fx.Node): + # All acc_ops are call_functions. + if node_or_target.op != "call_function": + return False + target = node_or_target.target + else: + target = node_or_target + return "acc_ops" in target.__module__ + + +def is_acc_op_with_kwarg( + node_or_target: Union[Callable, torch.fx.Node], kwarg: str +) -> bool: + """ + Helper that inspects `node_or_target` and returns whether it is an acc_op node + (or a target for an acc_op) that has an arg signature that includes `kwarg`. + """ + if not is_acc_op(node_or_target): + return False + + target = ( + node_or_target.target + if isinstance(node_or_target, torch.fx.Node) + else node_or_target + ) + assert not isinstance(target, str) + return kwarg in inspect.signature(inspect.unwrap(target)).parameters + + +def get_field_from_acc_out_ty( + acc_out_ty_or_dict: Union[Tuple, Dict[str, Any]], field: str +): + """ + After tracing NamedTuple inputs are converted to standard tuples, so we cannot + access them by name directly. Use this helper instead. + """ + if isinstance(acc_out_ty_or_dict, dict): + acc_out_ty = acc_out_ty_or_dict["acc_out_ty"] + else: + acc_out_ty = acc_out_ty_or_dict + return acc_out_ty[TensorMetadata._fields.index(field)] + + +def serialize_module_json_to_file(fx_module: GraphModule, fname: str): + weights: Dict = {} + serialized_json = json.dumps(serialize_module(fx_module, weights), indent=2) + with open(fname, "w") as ofile: + ofile.write(serialized_json) + + +def build_raw_tensor_meta( + shape=None, + dtype=None, + requires_grad=None, + stride=None, + memory_format=None, + is_quantized=None, + qscheme=None, + q_scale=None, + q_zero_point=None, +): + return TensorMetadata(**locals()) + + +def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph"): + if not fname.endswith(".svg"): + fname = fname + ".svg" + print(f"Writing FX graph to file: {fname}") + g = graph_drawer.FxGraphDrawer(traced, figname) + x = g.get_main_dot_graph() + x.write_svg(fname) diff --git a/torch/fx/experimental/graph_manipulation.py b/torch/fx/experimental/graph_manipulation.py index 7f96fcf..9d0af53 100644 --- a/torch/fx/experimental/graph_manipulation.py +++ b/torch/fx/experimental/graph_manipulation.py @@ -372,7 +372,7 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D user_targets = { _get_qualified_name( n.target - ).replace("glow.fb.fx.oss_acc_tracer.", "").replace("glow.fb.fx.", ""): n + ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n for n in node.users.keys() } if (