[quant][fx2trt] Add lowering support for reference linear/conv modules (#64368)
authorJerry Zhang <jerryzh@fb.com>
Sat, 11 Sep 2021 05:24:10 +0000 (22:24 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 11 Sep 2021 05:25:27 +0000 (22:25 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64368

Test Plan:
python torch/fx/experimental/fx2trt/example/quantized_resnet_test.py

Imported from OSS

Reviewed By: 842974287

Differential Revision: D30708738

fbshipit-source-id: 88142b7ce43ed96093597112dab03a2d277de993

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
torch/fx/experimental/fx2trt/example/quantized_resnet_test.py
torch/fx/experimental/fx2trt/fx2trt.py
torch/nn/quantized/_reference/modules/conv.py
torch/nn/quantized/_reference/modules/linear.py
torch/quantization/fx/quantization_patterns.py

index e946a92..8f93493 100644 (file)
@@ -12,9 +12,10 @@ from torch.fx.experimental.fx2trt.fx2trt import (
     torch_dtype_from_trt,
     get_dynamic_dims,
 )
+from typing import Optional
 
 
-def to_numpy(tensor: torch.Tensor):
+def to_numpy(tensor: Optional[torch.Tensor]):
     """
     Convert a PyTorch Tensor to a Numpy Array.
     """
@@ -24,6 +25,8 @@ def to_numpy(tensor: torch.Tensor):
     if tensor.is_quantized:
         tensor = tensor.dequantize()
 
+    assert isinstance(tensor, torch.Tensor), f"to_numpy can't be called on None or a torch.Tensor, got: {tensor}"
+
     return tensor.cpu().detach().contiguous().numpy()
 
 
@@ -241,16 +244,36 @@ def acc_ops_conv2d(network, target, args, kwargs, name):
     if has_dynamic_shape(input_val.shape):
         assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
 
-    kernel = to_numpy(kwargs["weight"])
+    # for now we'll assume bias is constant Tensor or None,
+    # and bias being ITensor is not supported in TensorRT api
+    # right now
     bias = to_numpy(kwargs["bias"])
 
-    layer = network.add_convolution(
-        input=input_val,
-        num_output_maps=kernel.shape[0],
-        kernel_shape=kernel.shape[2:],
-        kernel=kernel,
-        bias=bias,
-    )
+    if network.has_explicit_precision:
+        weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
+        weight_shape = tuple(kwargs["weight"].shape)
+        # will need to use uninitialized weight and set it later to support
+        # ITensor weights
+        dummy_weight = trt.Weights()
+
+        layer = network.add_convolution(
+            input=input_val,
+            num_output_maps=weight.shape[0],
+            kernel_shape=weight.shape[2:],
+            kernel=dummy_weight,
+            bias=bias,
+        )
+
+        layer.set_input(1, weight)
+    else:
+        weight = to_numpy(kwargs["weight"])
+        layer = network.add_convolution(
+            input=input_val,
+            num_output_maps=weight.shape[0],
+            kernel_shape=weight.shape[2:],
+            kernel=weight,
+            bias=bias,
+        )
 
     layer.name = name
     layer.stride = kwargs["stride"]
@@ -1019,43 +1042,45 @@ def acc_ops_linear(network, target, args, kwargs, name):
         "dim for linear and it can't be the last dim."
     )
 
-    weight = kwargs["weight"]
-
-    # For quantization, weight here would be a trt tensor because it goes through
-    # quant + dequant. In this case, we need to use matmul + add because fully_connected
-    # can't take non-constant weight.
     # TODO: Need to benchmark the performance of lowering linear as fully_connected versus
     # lowering as matmul + add. TensorRT documentation suggests to always lower it as
     # matmul + add but we found in some cases this results in performance regression compared
     # with lowering to fully_connected layer.
-    if isinstance(weight, torch.Tensor):
-        layer = network.add_shuffle(input_val)
-        layer.reshape_dims = tuple(input_val.shape) + (1, 1)
-        layer.name = f"{name}_pre_shuffle"
+    layer = network.add_shuffle(input_val)
+    layer.reshape_dims = tuple(input_val.shape) + (1, 1)
+    layer.name = f"{name}_pre_shuffle"
+    bias = to_numpy(kwargs["bias"])
+
+    if network.has_explicit_precision:
+        weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
+        # will need to use uninitialized weight and set it later to support
+        # ITensor weights
+        dummy_weight = trt.Weights()
 
         # add fully connected
         layer = network.add_fully_connected(
             input=layer.get_output(0),
-            num_outputs=kwargs["weight"].shape[0],
-            kernel=to_numpy(kwargs["weight"]),
-            bias=to_numpy(kwargs["bias"]),
+            num_outputs=weight.shape[0],
+            kernel=dummy_weight,
+            bias=bias,
         )
-        layer.name = f"{name}_linear"
-
-        # reshape back
-        layer = network.add_shuffle(layer.get_output(0))
-        layer.reshape_dims = tuple(input_val.shape[:-1]) + (kwargs["weight"].shape[0],)
-        layer.name = f"{name}_post_shuffle"
-
-        return layer.get_output(0)
+        layer.set_input(1, weight)
     else:
-        # add matrix multiply and add
-        output = add_matrix_multiply_layer(network, input_val, weight, f"{name}_linear_mm", transpose_other=True)
-        if kwargs["bias"] is not None:
-            return add_binary_elementwise_layer(network, output, kwargs["bias"], trt.ElementWiseOperation.SUM, f"{name}_linear_add")
-        else:
-            return output
+        weight = to_numpy(kwargs["weight"])
+        layer = network.add_fully_connected(
+            input=layer.get_output(0),
+            num_outputs=weight.shape[0],
+            kernel=weight,
+            bias=bias,
+        )
+    layer.name = f"{name}_linear"
 
+    # reshape back
+    layer = network.add_shuffle(layer.get_output(0))
+    layer.reshape_dims = tuple(input_val.shape[:-1]) + (kwargs["weight"].shape[0],)
+    layer.name = f"{name}_post_shuffle"
+
+    return layer.get_output(0)
 
 def add_clamp(network, input, val, op):
     acc_ops_clamp_shape = (1,) * len(input.shape)  # broadcast all dimensions
@@ -1307,7 +1332,8 @@ def acc_ops_permute(network, target, args, kwargs, name):
 
 @tensorrt_converter(acc_ops.quantize_per_tensor)
 def acc_ops_quantize_per_tensor(network, target, args, kwargs, name):
-    input_val = kwargs["input"]
+    input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
+
 
     if not isinstance(input_val, trt.tensorrt.ITensor):
         raise RuntimeError(f"{name} received input {input_val} that is not part "
index 140f4fb..415ce2d 100644 (file)
@@ -12,7 +12,8 @@ rn18 = models.resnet18().eval()
 def build_fp16_trt(rn18):
     rn18 = copy.deepcopy(rn18)
     rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)])  # type: ignore[attr-defined]
-    interp = TRTInterpreter(rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
+    interp = TRTInterpreter(
+        rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
     engine, input_names, output_names = interp.run(fp16_mode=True)
     return TRTModule(engine, input_names, output_names)
 
@@ -20,6 +21,7 @@ def build_fp16_trt(rn18):
 def build_int8_trt(rn18):
     rn18 = copy.deepcopy(rn18)
     data = torch.randn(1, 3, 224, 224)
+    # data = torch.randn(1, 32)
     # data = torch.randn(1, 64, 10, 10)
     # TensorRT only supports symmetric quantization
     qconfig = torch.quantization.QConfig(
@@ -32,12 +34,20 @@ def build_int8_trt(rn18):
     for _ in range(10):
         prepared(data)
     quantized_rn18 = convert_fx(prepared, is_reference=True)
+    ref_res = quantized_rn18(data)
     print("quantized model:", quantized_rn18)
 
     quantized_rn18 = acc_tracer.trace(quantized_rn18, [data])  # type: ignore[attr-defined]
-    interp = TRTInterpreter(quantized_rn18, [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)])
+    interp = TRTInterpreter(
+        quantized_rn18,
+        [InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
+                         shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
+        explicit_batch_dimension=True, explicit_precision=True)
     engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
-    return TRTModule(engine, input_names, output_names)
+    trt_mod = TRTModule(engine, input_names, output_names)
+    trt_res = trt_mod(data.cuda())
+    print("result diff max", torch.max(ref_res - trt_res.cpu()))
+    return trt_mod
 
 @torch.no_grad()
 def build_int8_trt_implicit_quant(rn18):
@@ -54,6 +64,7 @@ def build_int8_trt_implicit_quant(rn18):
     for _ in range(10):
         prepared(data)
     quantized_rn18 = convert_fx(prepared, is_reference=True)
+    ref_res = quantized_rn18(data)
 
     # Build trt int8 model
     traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
@@ -62,27 +73,32 @@ def build_int8_trt_implicit_quant(rn18):
     interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]))
     engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
     trt_mod = TRTModule(engine, input_names, output_names)
+    trt_res = trt_mod(data.cuda())
+    print("result equal?", torch.equal(ref_res, trt_res))
     return trt_mod
 
 class M(torch.nn.Module):
     def __init__(self):
         super().__init__()
-        self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
+        self.linear = torch.nn.Linear(32, 46)
+        # self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
 
     def forward(self, x):
-        out = self.conv(x)
+        # out = self.conv(x)
+        out = self.linear(x)
+        # out = torch.nn.functional.relu(out)
+        # out += x
+        # out += out
         # out = torch.nn.functional.relu(out)
-        out += x
-        out += out
-        out = torch.nn.functional.relu(out)
         return out
 
 # rn18 = M().eval()
 # rn18 = rn18.layer1
 int8_trt = build_int8_trt(rn18)
-implicit_int8_trt = build_int8_trt_implicit_quant(rn18)
+implicit_int8_trt = build_int8_trt_implicit_quant(rn18)
 fp16_trt = build_fp16_trt(rn18)
 x = torch.randn(5, 3, 224, 224, device="cuda")
+# x = torch.randn(1, 32, device="cuda")
 rn18 = rn18.cuda()
 
 import time
@@ -102,12 +118,12 @@ for _ in range(NITER):
     torch.cuda.synchronize()
 print('trt int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
 
-torch.cuda.synchronize()
-s = time.time()
-for _ in range(NITER):
-    implicit_int8_trt(x)
-    torch.cuda.synchronize()
-print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+    implicit_int8_trt(x)
+    torch.cuda.synchronize()
+print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
 
 torch.cuda.synchronize()
 s = time.time()
index 4c0b44c..af6daa8 100644 (file)
@@ -218,6 +218,7 @@ class TRTInterpreter(torch.fx.Interpreter):
         module: torch.fx.GraphModule,
         input_specs: List[InputTensorSpec],
         explicit_batch_dimension: bool = False,
+        explicit_precision: bool = False,
         logger_level=trt.Logger.WARNING,
     ):
         super().__init__(module)
@@ -225,13 +226,19 @@ class TRTInterpreter(torch.fx.Interpreter):
         self.logger = trt.Logger(logger_level)
         self.builder = trt.Builder(self.logger)
 
+        flag = 0
         if explicit_batch_dimension:
             EXPLICIT_BATCH = 1 << (int)(
                 trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
             )
-            self.network = self.builder.create_network(EXPLICIT_BATCH)
-        else:
-            self.network = self.builder.create_network()
+            flag |= EXPLICIT_BATCH
+
+        if explicit_precision:
+            EXPLICIT_PRECISION = 1 << (int)(
+                trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION
+            )
+            flag |= EXPLICIT_PRECISION
+        self.network = self.builder.create_network(flag)
 
         missing_ops = self.validate_conversion()
         if missing_ops:
index 6b03bb0..df9ae95 100644 (file)
@@ -67,7 +67,6 @@ class _ConvNd(torch.nn.modules.conv._ConvNd):
         model
         """
         # supress mypy warning
-        assert isinstance(self.weight, torch.Tensor)
         assert isinstance(self.weight_scale, torch.Tensor)
         assert isinstance(self.weight_zero_point, torch.Tensor)
         assert isinstance(self.weight_axis, torch.Tensor)
index 1df5499..a472a04 100644 (file)
@@ -64,7 +64,6 @@ class Linear(nn.Linear):
         model
         """
         # supress mypy warning
-        assert isinstance(self.weight, torch.Tensor)
         assert isinstance(self.weight_scale, torch.Tensor)
         assert isinstance(self.weight_zero_point, torch.Tensor)
         assert isinstance(self.weight_axis, torch.Tensor)
index d90be90..7ae07bf 100644 (file)
@@ -190,6 +190,7 @@ class QuantizeHandler(ABC):
 # tuple (activation_dtype, weight_dtype, compute_dtype)
 # these are supported types for common binary ops like add/mul etc.
 all_dtypes = [
+    (torch.qint8, torch.qint8, None),
     (torch.quint8, torch.qint8, None),
     (torch.float16, torch.float16, None),
 ]
@@ -197,6 +198,7 @@ fp16_dtypes = [
     (torch.float16, torch.float16, None)
 ]
 int8_dtypes = [
+    (torch.qint8, torch.qint8, None),
     (torch.quint8, torch.qint8, None),
 ]
 binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {