From 01b8162d00bfb0844a3f8a165d49907e51a16add Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Aug 2021 17:50:48 -0700 Subject: [PATCH] Back out "Revert D30384746: [fx2trt] Add a test for quantized resnet18" (#63973) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63973 Original commit changeset: b93235323e22 Test Plan: buck run mode/opt -c python.package_style=inplace caffe2:fx2trt_quantized_resnet_test Reviewed By: 842974287 Differential Revision: D30546036 fbshipit-source-id: 2c8302456f072d04da00cf9ad97aa8304bc5e43e --- .../fx2trt/converters/acc_ops_converters.py | 15 +-- .../fx2trt/example/quantized_resnet_test.py | 117 +++++++++++++++++++++ 2 files changed, 121 insertions(+), 11 deletions(-) create mode 100644 torch/fx/experimental/fx2trt/example/quantized_resnet_test.py diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index 566359b..33a817d 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -1300,15 +1300,11 @@ def acc_ops_quantize_per_tensor(network, target, args, kwargs, name): if q_zero_point != 0: raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}") - # temporarily set q_scale to 1 to make sure the q_scale is different - # for quantize and dequantize to avoid the error - # TODO: follow up with nvidia TensorRT team to repro and fix the problem - q_scale = 1 scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32))) scale_layer.name = input_val.name + ".quant.scale" scale = scale_layer.get_output(0) - assert trt.__version__ > "8.0", "Explicit quantize op is only supported in " - "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ + # assert trt.__version__ > "8.0", "Explicit quantize op is only supported in " + # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ layer = network.add_quantize(input=input_val, scale=scale) layer.axis = 0 layer.name = input_val.name + ".quant" @@ -1316,9 +1312,6 @@ def acc_ops_quantize_per_tensor(network, target, args, kwargs, name): @tensorrt_converter(acc_ops.dequantize) def acc_ops_dequantize(network, target, args, kwargs, name): - """ - Currently just a no-op. - """ input_val = kwargs["input"] if not isinstance(input_val, trt.tensorrt.ITensor): @@ -1339,8 +1332,8 @@ def acc_ops_dequantize(network, target, args, kwargs, name): scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([q_scale], dtype=np.float32))) scale_layer.name = input_val.name + ".dequant.scale" scale = scale_layer.get_output(0) - assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in " - "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ + # assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in " + # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ layer = network.add_dequantize(input=input_val, scale=scale) layer.name = input_val.name + ".dequant" layer.axis = 0 diff --git a/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py b/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py new file mode 100644 index 0000000..140f4fb --- /dev/null +++ b/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py @@ -0,0 +1,117 @@ +import torch.fx +import torchvision.models as models +from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule +from torch.quantization.quantize_fx import prepare_fx, convert_fx +import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer +import copy +from torch.fx.passes import shape_prop +from torch.fx.experimental.normalize import NormalizeArgs + +rn18 = models.resnet18().eval() + +def build_fp16_trt(rn18): + rn18 = copy.deepcopy(rn18) + rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)]) # type: ignore[attr-defined] + interp = TRTInterpreter(rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)]) + engine, input_names, output_names = interp.run(fp16_mode=True) + return TRTModule(engine, input_names, output_names) + +@torch.no_grad() +def build_int8_trt(rn18): + rn18 = copy.deepcopy(rn18) + data = torch.randn(1, 3, 224, 224) + # data = torch.randn(1, 64, 10, 10) + # TensorRT only supports symmetric quantization + qconfig = torch.quantization.QConfig( + activation=torch.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 + ), + weight=torch.quantization.default_weight_observer + ) + prepared = prepare_fx(rn18, {"": qconfig}) + for _ in range(10): + prepared(data) + quantized_rn18 = convert_fx(prepared, is_reference=True) + print("quantized model:", quantized_rn18) + + quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) # type: ignore[attr-defined] + interp = TRTInterpreter(quantized_rn18, [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)]) + engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True) + return TRTModule(engine, input_names, output_names) + +@torch.no_grad() +def build_int8_trt_implicit_quant(rn18): + rn18 = copy.deepcopy(rn18) + data = torch.randn(1, 3, 224, 224) + # Quantization + qconfig = torch.quantization.QConfig( + activation=torch.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, reduce_range=True + ), + weight=torch.quantization.default_per_channel_weight_observer + ) + prepared = prepare_fx(rn18, {"": qconfig}) + for _ in range(10): + prepared(data) + quantized_rn18 = convert_fx(prepared, is_reference=True) + + # Build trt int8 model + traced_rn18 = torch.fx.symbolic_trace(quantized_rn18) + shape_prop.ShapeProp(traced_rn18).propagate(data) + traced_rn18 = NormalizeArgs(traced_rn18).transform() + interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data])) + engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True) + trt_mod = TRTModule(engine, input_names, output_names) + return trt_mod + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3, padding=1) + + def forward(self, x): + out = self.conv(x) + # out = torch.nn.functional.relu(out) + out += x + out += out + out = torch.nn.functional.relu(out) + return out + +# rn18 = M().eval() +# rn18 = rn18.layer1 +int8_trt = build_int8_trt(rn18) +implicit_int8_trt = build_int8_trt_implicit_quant(rn18) +fp16_trt = build_fp16_trt(rn18) +x = torch.randn(5, 3, 224, 224, device="cuda") +rn18 = rn18.cuda() + +import time +NITER = 100 + +torch.cuda.synchronize() +s = time.time() +for _ in range(NITER): + fp16_trt(x) + torch.cuda.synchronize() +print('trt fp16 time (ms/iter)', (time.time() - s) / NITER * 1000) + +torch.cuda.synchronize() +s = time.time() +for _ in range(NITER): + int8_trt(x) + torch.cuda.synchronize() +print('trt int8 time (ms/iter)', (time.time() - s) / NITER * 1000) + +torch.cuda.synchronize() +s = time.time() +for _ in range(NITER): + implicit_int8_trt(x) + torch.cuda.synchronize() +print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000) + +torch.cuda.synchronize() +s = time.time() +for _ in range(NITER): + rn18(x) + torch.cuda.synchronize() +print('PyTorch time (ms/iter)', (time.time() - s) / NITER * 1000) -- 2.7.4