if q_zero_point != 0:
raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
+ # temporarily set q_scale to 1 to make sure the q_scale is different
+ # for quantize and dequantize to avoid the error
+ # TODO: follow up with nvidia TensorRT team to repro and fix the problem
+ q_scale = 1
scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)))
scale_layer.name = input_val.name + ".quant.scale"
scale = scale_layer.get_output(0)
- # assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
- # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
+ assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
+ "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
layer = network.add_quantize(input=input_val, scale=scale)
layer.axis = 0
layer.name = input_val.name + ".quant"
@tensorrt_converter(acc_ops.dequantize)
def acc_ops_dequantize(network, target, args, kwargs, name):
+ """
+ Currently just a no-op.
+ """
input_val = kwargs["input"]
if not isinstance(input_val, trt.tensorrt.ITensor):
scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([q_scale], dtype=np.float32)))
scale_layer.name = input_val.name + ".dequant.scale"
scale = scale_layer.get_output(0)
- # assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
- # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
+ assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
+ "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
layer = network.add_dequantize(input=input_val, scale=scale)
layer.name = input_val.name + ".dequant"
layer.axis = 0
+++ /dev/null
-import torch.fx
-import torchvision.models as models
-from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule
-from torch.quantization.quantize_fx import prepare_fx, convert_fx
-import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer
-import copy
-from torch.fx.passes import shape_prop
-from torch.fx.experimental.normalize import NormalizeArgs
-
-rn18 = models.resnet18().eval()
-
-def build_fp16_trt(rn18):
- rn18 = copy.deepcopy(rn18)
- rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)])
- interp = TRTInterpreter(rn18, [InputTensorSpec([3, 224, 224], torch.float, has_batch_dim=False)])
- engine, input_names, output_names = interp.run(fp16_mode=True)
- return TRTModule(engine, input_names, output_names)
-
-@torch.no_grad()
-def build_int8_trt(rn18):
- rn18 = copy.deepcopy(rn18)
- data = torch.randn(1, 3, 224, 224)
- # data = torch.randn(1, 64, 10, 10)
- # TensorRT only supports symmetric quantization
- qconfig = torch.quantization.QConfig(
- activation=torch.quantization.observer.HistogramObserver.with_args(
- qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
- ),
- weight=torch.quantization.default_weight_observer
- )
- prepared = prepare_fx(rn18, {"": qconfig})
- for _ in range(10):
- prepared(data)
- quantized_rn18 = convert_fx(prepared, is_reference=True)
- print("quantized model:", quantized_rn18)
-
- quantized_rn18 = acc_tracer.trace(quantized_rn18, [data])
- interp = TRTInterpreter(quantized_rn18, [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)])
- engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
- return TRTModule(engine, input_names, output_names)
-
-@torch.no_grad()
-def build_int8_trt_implicit_quant(rn18):
- rn18 = copy.deepcopy(rn18)
- data = torch.randn(1, 3, 224, 224)
- # Quantization
- qconfig = torch.quantization.QConfig(
- activation=torch.quantization.observer.HistogramObserver.with_args(
- qscheme=torch.per_tensor_symmetric, reduce_range=True
- ),
- weight=torch.quantization.default_per_channel_weight_observer
- )
- prepared = prepare_fx(rn18, {"": qconfig})
- for _ in range(10):
- prepared(data)
- quantized_rn18 = convert_fx(prepared, is_reference=True)
-
- # Build trt int8 model
- traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
- shape_prop.ShapeProp(traced_rn18).propagate(data)
- traced_rn18 = NormalizeArgs(traced_rn18).transform()
- interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]))
- engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
- trt_mod = TRTModule(engine, input_names, output_names)
- return trt_mod
-
-class M(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
-
- def forward(self, x):
- out = self.conv(x)
- # out = torch.nn.functional.relu(out)
- out += x
- out += out
- out = torch.nn.functional.relu(out)
- return out
-
-# rn18 = M().eval()
-# rn18 = rn18.layer1
-int8_trt = build_int8_trt(rn18)
-implicit_int8_trt = build_int8_trt_implicit_quant(rn18)
-fp16_trt = build_fp16_trt(rn18)
-x = torch.randn(5, 3, 224, 224, device="cuda")
-rn18 = rn18.cuda()
-
-import time
-NITER = 100
-
-torch.cuda.synchronize()
-s = time.time()
-for _ in range(NITER):
- fp16_trt(x)
- torch.cuda.synchronize()
-print('trt fp16 time (ms/iter)', (time.time() - s) / NITER * 1000)
-
-torch.cuda.synchronize()
-s = time.time()
-for _ in range(NITER):
- int8_trt(x)
- torch.cuda.synchronize()
-print('trt int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
-
-torch.cuda.synchronize()
-s = time.time()
-for _ in range(NITER):
- implicit_int8_trt(x)
- torch.cuda.synchronize()
-print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
-
-torch.cuda.synchronize()
-s = time.time()
-for _ in range(NITER):
- rn18(x)
- torch.cuda.synchronize()
-print('PyTorch time (ms/iter)', (time.time() - s) / NITER * 1000)