if q_zero_point != 0:
raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")
- # temporarily set q_scale to 1 to make sure the q_scale is different
- # for quantize and dequantize to avoid the error
- # TODO: follow up with nvidia TensorRT team to repro and fix the problem
- q_scale = 1
scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)))
scale_layer.name = input_val.name + ".quant.scale"
scale = scale_layer.get_output(0)
- assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
- "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
+ # assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
+ # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
layer = network.add_quantize(input=input_val, scale=scale)
layer.axis = 0
layer.name = input_val.name + ".quant"
@tensorrt_converter(acc_ops.dequantize)
def acc_ops_dequantize(network, target, args, kwargs, name):
- """
- Currently just a no-op.
- """
input_val = kwargs["input"]
if not isinstance(input_val, trt.tensorrt.ITensor):
scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([q_scale], dtype=np.float32)))
scale_layer.name = input_val.name + ".dequant.scale"
scale = scale_layer.get_output(0)
- assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
- "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
+ # assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
+ # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
layer = network.add_dequantize(input=input_val, scale=scale)
layer.name = input_val.name + ".dequant"
layer.axis = 0
--- /dev/null
+import torch.fx
+import torchvision.models as models
+from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule
+from torch.quantization.quantize_fx import prepare_fx, convert_fx
+import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer
+import copy
+from torch.fx.passes import shape_prop
+from torch.fx.experimental.normalize import NormalizeArgs
+
+rn18 = models.resnet18().eval()
+
+def build_fp16_trt(rn18):
+ rn18 = copy.deepcopy(rn18)
+ rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)]) # type: ignore[attr-defined]
+ interp = TRTInterpreter(rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
+ engine, input_names, output_names = interp.run(fp16_mode=True)
+ return TRTModule(engine, input_names, output_names)
+
+@torch.no_grad()
+def build_int8_trt(rn18):
+ rn18 = copy.deepcopy(rn18)
+ data = torch.randn(1, 3, 224, 224)
+ # data = torch.randn(1, 64, 10, 10)
+ # TensorRT only supports symmetric quantization
+ qconfig = torch.quantization.QConfig(
+ activation=torch.quantization.observer.HistogramObserver.with_args(
+ qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
+ ),
+ weight=torch.quantization.default_weight_observer
+ )
+ prepared = prepare_fx(rn18, {"": qconfig})
+ for _ in range(10):
+ prepared(data)
+ quantized_rn18 = convert_fx(prepared, is_reference=True)
+ print("quantized model:", quantized_rn18)
+
+ quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) # type: ignore[attr-defined]
+ interp = TRTInterpreter(quantized_rn18, [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)])
+ engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
+ return TRTModule(engine, input_names, output_names)
+
+@torch.no_grad()
+def build_int8_trt_implicit_quant(rn18):
+ rn18 = copy.deepcopy(rn18)
+ data = torch.randn(1, 3, 224, 224)
+ # Quantization
+ qconfig = torch.quantization.QConfig(
+ activation=torch.quantization.observer.HistogramObserver.with_args(
+ qscheme=torch.per_tensor_symmetric, reduce_range=True
+ ),
+ weight=torch.quantization.default_per_channel_weight_observer
+ )
+ prepared = prepare_fx(rn18, {"": qconfig})
+ for _ in range(10):
+ prepared(data)
+ quantized_rn18 = convert_fx(prepared, is_reference=True)
+
+ # Build trt int8 model
+ traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
+ shape_prop.ShapeProp(traced_rn18).propagate(data)
+ traced_rn18 = NormalizeArgs(traced_rn18).transform()
+ interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]))
+ engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
+ trt_mod = TRTModule(engine, input_names, output_names)
+ return trt_mod
+
+class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
+
+ def forward(self, x):
+ out = self.conv(x)
+ # out = torch.nn.functional.relu(out)
+ out += x
+ out += out
+ out = torch.nn.functional.relu(out)
+ return out
+
+# rn18 = M().eval()
+# rn18 = rn18.layer1
+int8_trt = build_int8_trt(rn18)
+implicit_int8_trt = build_int8_trt_implicit_quant(rn18)
+fp16_trt = build_fp16_trt(rn18)
+x = torch.randn(5, 3, 224, 224, device="cuda")
+rn18 = rn18.cuda()
+
+import time
+NITER = 100
+
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+ fp16_trt(x)
+ torch.cuda.synchronize()
+print('trt fp16 time (ms/iter)', (time.time() - s) / NITER * 1000)
+
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+ int8_trt(x)
+ torch.cuda.synchronize()
+print('trt int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
+
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+ implicit_int8_trt(x)
+ torch.cuda.synchronize()
+print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
+
+torch.cuda.synchronize()
+s = time.time()
+for _ in range(NITER):
+ rn18(x)
+ torch.cuda.synchronize()
+print('PyTorch time (ms/iter)', (time.time() - s) / NITER * 1000)