From 839eaa2e91556ecd4532596b4fef18a1c3f6e1c1 Mon Sep 17 00:00:00 2001 From: Linbin Yu Date: Wed, 25 Aug 2021 00:42:03 -0700 Subject: [PATCH] Revert D30384746: [fx2trt] Add a test for quantized resnet18 Test Plan: revert-hammer Differential Revision: D30384746 (https://github.com/pytorch/pytorch/commit/10dfa58eba055a1bbc1cc89df033cd2815cbb403) Original commit changeset: 1a8638777116 fbshipit-source-id: b93235323e229b391f5456f6e3543988062dd0d4 --- .../fx2trt/converters/acc_ops_converters.py | 15 ++- .../fx2trt/example/quantized_resnet_test.py | 117 ------------------ 2 files changed, 11 insertions(+), 121 deletions(-) delete mode 100644 torch/fx/experimental/fx2trt/example/quantized_resnet_test.py diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index 33a817d4cc..566359bf2a 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -1300,11 +1300,15 @@ def acc_ops_quantize_per_tensor(network, target, args, kwargs, name): if q_zero_point != 0: raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}") + # temporarily set q_scale to 1 to make sure the q_scale is different + # for quantize and dequantize to avoid the error + # TODO: follow up with nvidia TensorRT team to repro and fix the problem + q_scale = 1 scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32))) scale_layer.name = input_val.name + ".quant.scale" scale = scale_layer.get_output(0) - # assert trt.__version__ > "8.0", "Explicit quantize op is only supported in " - # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ + assert trt.__version__ > "8.0", "Explicit quantize op is only supported in " + "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ layer = network.add_quantize(input=input_val, scale=scale) layer.axis = 0 layer.name = input_val.name + ".quant" @@ -1312,6 +1316,9 @@ def acc_ops_quantize_per_tensor(network, target, args, kwargs, name): @tensorrt_converter(acc_ops.dequantize) def acc_ops_dequantize(network, target, args, kwargs, name): + """ + Currently just a no-op. + """ input_val = kwargs["input"] if not isinstance(input_val, trt.tensorrt.ITensor): @@ -1332,8 +1339,8 @@ def acc_ops_dequantize(network, target, args, kwargs, name): scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([q_scale], dtype=np.float32))) scale_layer.name = input_val.name + ".dequant.scale" scale = scale_layer.get_output(0) - # assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in " - # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ + assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in " + "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ layer = network.add_dequantize(input=input_val, scale=scale) layer.name = input_val.name + ".dequant" layer.axis = 0 diff --git a/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py b/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py deleted file mode 100644 index 39553dfd9d..0000000000 --- a/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch.fx -import torchvision.models as models -from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule -from torch.quantization.quantize_fx import prepare_fx, convert_fx -import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer -import copy -from torch.fx.passes import shape_prop -from torch.fx.experimental.normalize import NormalizeArgs - -rn18 = models.resnet18().eval() - -def build_fp16_trt(rn18): - rn18 = copy.deepcopy(rn18) - rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)]) - interp = TRTInterpreter(rn18, [InputTensorSpec([3, 224, 224], torch.float, has_batch_dim=False)]) - engine, input_names, output_names = interp.run(fp16_mode=True) - return TRTModule(engine, input_names, output_names) - -@torch.no_grad() -def build_int8_trt(rn18): - rn18 = copy.deepcopy(rn18) - data = torch.randn(1, 3, 224, 224) - # data = torch.randn(1, 64, 10, 10) - # TensorRT only supports symmetric quantization - qconfig = torch.quantization.QConfig( - activation=torch.quantization.observer.HistogramObserver.with_args( - qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 - ), - weight=torch.quantization.default_weight_observer - ) - prepared = prepare_fx(rn18, {"": qconfig}) - for _ in range(10): - prepared(data) - quantized_rn18 = convert_fx(prepared, is_reference=True) - print("quantized model:", quantized_rn18) - - quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) - interp = TRTInterpreter(quantized_rn18, [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)]) - engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True) - return TRTModule(engine, input_names, output_names) - -@torch.no_grad() -def build_int8_trt_implicit_quant(rn18): - rn18 = copy.deepcopy(rn18) - data = torch.randn(1, 3, 224, 224) - # Quantization - qconfig = torch.quantization.QConfig( - activation=torch.quantization.observer.HistogramObserver.with_args( - qscheme=torch.per_tensor_symmetric, reduce_range=True - ), - weight=torch.quantization.default_per_channel_weight_observer - ) - prepared = prepare_fx(rn18, {"": qconfig}) - for _ in range(10): - prepared(data) - quantized_rn18 = convert_fx(prepared, is_reference=True) - - # Build trt int8 model - traced_rn18 = torch.fx.symbolic_trace(quantized_rn18) - shape_prop.ShapeProp(traced_rn18).propagate(data) - traced_rn18 = NormalizeArgs(traced_rn18).transform() - interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data])) - engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True) - trt_mod = TRTModule(engine, input_names, output_names) - return trt_mod - -class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3, padding=1) - - def forward(self, x): - out = self.conv(x) - # out = torch.nn.functional.relu(out) - out += x - out += out - out = torch.nn.functional.relu(out) - return out - -# rn18 = M().eval() -# rn18 = rn18.layer1 -int8_trt = build_int8_trt(rn18) -implicit_int8_trt = build_int8_trt_implicit_quant(rn18) -fp16_trt = build_fp16_trt(rn18) -x = torch.randn(5, 3, 224, 224, device="cuda") -rn18 = rn18.cuda() - -import time -NITER = 100 - -torch.cuda.synchronize() -s = time.time() -for _ in range(NITER): - fp16_trt(x) - torch.cuda.synchronize() -print('trt fp16 time (ms/iter)', (time.time() - s) / NITER * 1000) - -torch.cuda.synchronize() -s = time.time() -for _ in range(NITER): - int8_trt(x) - torch.cuda.synchronize() -print('trt int8 time (ms/iter)', (time.time() - s) / NITER * 1000) - -torch.cuda.synchronize() -s = time.time() -for _ in range(NITER): - implicit_int8_trt(x) - torch.cuda.synchronize() -print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000) - -torch.cuda.synchronize() -s = time.time() -for _ in range(NITER): - rn18(x) - torch.cuda.synchronize() -print('PyTorch time (ms/iter)', (time.time() - s) / NITER * 1000) -- 2.34.1