Back out "Revert D30384746: [fx2trt] Add a test for quantized resnet18" (#63973)
authorJerry Zhang <jerryzh@fb.com>
Thu, 26 Aug 2021 00:50:48 +0000 (17:50 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 00:52:22 +0000 (17:52 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63973

Original commit changeset: b93235323e22

Test Plan: buck run mode/opt -c python.package_style=inplace caffe2:fx2trt_quantized_resnet_test

Reviewed By: 842974287

Differential Revision: D30546036

fbshipit-source-id: 2c8302456f072d04da00cf9ad97aa8304bc5e43e

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
torch/fx/experimental/fx2trt/example/quantized_resnet_test.py [new file with mode: 0644]

index 566359b..33a817d 100644 (file)
@@ -1300,15 +1300,11 @@ def acc_ops_quantize_per_tensor(network, target, args, kwargs, name):
     if q_zero_point != 0:
         raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
 
-    # temporarily set q_scale to 1 to make sure the q_scale is different
-    # for quantize and dequantize to avoid the error
-    # TODO: follow up with nvidia TensorRT team to repro and fix the problem
-    q_scale = 1
     scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)))
     scale_layer.name = input_val.name + ".quant.scale"
     scale = scale_layer.get_output(0)
-    assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
-    "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
+    assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
+    "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
     layer = network.add_quantize(input=input_val, scale=scale)
     layer.axis = 0
     layer.name = input_val.name + ".quant"
@@ -1316,9 +1312,6 @@ def acc_ops_quantize_per_tensor(network, target, args, kwargs, name):
 
 @tensorrt_converter(acc_ops.dequantize)
 def acc_ops_dequantize(network, target, args, kwargs, name):
-    """
-    Currently just a no-op.
-    """
     input_val = kwargs["input"]
 
     if not isinstance(input_val, trt.tensorrt.ITensor):
@@ -1339,8 +1332,8 @@ def acc_ops_dequantize(network, target, args, kwargs, name):
     scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([q_scale], dtype=np.float32)))
     scale_layer.name = input_val.name + ".dequant.scale"
     scale = scale_layer.get_output(0)
-    assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
-    "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
+    assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
+    "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
     layer = network.add_dequantize(input=input_val, scale=scale)
     layer.name = input_val.name + ".dequant"
     layer.axis = 0
diff --git a/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py b/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py
new file mode 100644 (file)
index 0000000..140f4fb
--- /dev/null
@@ -0,0 +1,117 @@
+import torch.fx
+import torchvision.models as models
+from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule
+from torch.quantization.quantize_fx import prepare_fx, convert_fx
+import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer
+import copy
+from torch.fx.passes import shape_prop
+from torch.fx.experimental.normalize import NormalizeArgs
+
+rn18 = models.resnet18().eval()
+
+def build_fp16_trt(rn18):
+    rn18 = copy.deepcopy(rn18)
+    rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)])  # type: ignore[attr-defined]
+    interp = TRTInterpreter(rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
+    engine, input_names, output_names = interp.run(fp16_mode=True)
+    return TRTModule(engine, input_names, output_names)
+
+@torch.no_grad()
+def build_int8_trt(rn18):
+    rn18 = copy.deepcopy(rn18)
+    data = torch.randn(1, 3, 224, 224)
+    # data = torch.randn(1, 64, 10, 10)
+    # TensorRT only supports symmetric quantization
+    qconfig = torch.quantization.QConfig(
+        activation=torch.quantization.observer.HistogramObserver.with_args(
+            qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
+        ),
+        weight=torch.quantization.default_weight_observer
+    )
+    prepared = prepare_fx(rn18, {"": qconfig})
+    for _ in range(10):
+        prepared(data)
+    quantized_rn18 = convert_fx(prepared, is_reference=True)
+    print("quantized model:", quantized_rn18)
+
+    quantized_rn18 = acc_tracer.trace(quantized_rn18, [data])  # type: ignore[attr-defined]
+    interp = TRTInterpreter(quantized_rn18, [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)])
+    engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
+    return TRTModule(engine, input_names, output_names)
+
+@torch.no_grad()
+def build_int8_trt_implicit_quant(rn18):
+    rn18 = copy.deepcopy(rn18)
+    data = torch.randn(1, 3, 224, 224)
+    # Quantization
+    qconfig = torch.quantization.QConfig(
+        activation=torch.quantization.observer.HistogramObserver.with_args(
+            qscheme=torch.per_tensor_symmetric, reduce_range=True
+        ),
+        weight=torch.quantization.default_per_channel_weight_observer
+    )
+    prepared = prepare_fx(rn18, {"": qconfig})
+    for _ in range(10):
+        prepared(data)
+    quantized_rn18 = convert_fx(prepared, is_reference=True)
+
+    # Build trt int8 model
+    traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
+    shape_prop.ShapeProp(traced_rn18).propagate(data)
+    traced_rn18 = NormalizeArgs(traced_rn18).transform()
+    interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]))
+    engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
+    trt_mod = TRTModule(engine, input_names, output_names)
+    return trt_mod
+
+class M(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
+
+    def forward(self, x):
+        out = self.conv(x)
+        # out = torch.nn.functional.relu(out)
+        out += x
+        out += out
+        out = torch.nn.functional.relu(out)
+        return out
+
+# rn18 = M().eval()
+# rn18 = rn18.layer1
+int8_trt = build_int8_trt(rn18)
+implicit_int8_trt = build_int8_trt_implicit_quant(rn18)
+fp16_trt = build_fp16_trt(rn18)
+x = torch.randn(5, 3, 224, 224, device="cuda")
+rn18 = rn18.cuda()
+
+import time
+NITER = 100
+
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+    fp16_trt(x)
+    torch.cuda.synchronize()
+print('trt fp16 time (ms/iter)', (time.time() - s) / NITER * 1000)
+
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+    int8_trt(x)
+    torch.cuda.synchronize()
+print('trt int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
+
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+    implicit_int8_trt(x)
+    torch.cuda.synchronize()
+print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
+
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+    rn18(x)
+    torch.cuda.synchronize()
+print('PyTorch time (ms/iter)', (time.time() - s) / NITER * 1000)