torch_dtype_from_trt,
get_dynamic_dims,
)
+from typing import Optional
-def to_numpy(tensor: torch.Tensor):
+def to_numpy(tensor: Optional[torch.Tensor]):
"""
Convert a PyTorch Tensor to a Numpy Array.
"""
if tensor.is_quantized:
tensor = tensor.dequantize()
+ assert isinstance(tensor, torch.Tensor), f"to_numpy can't be called on None or a torch.Tensor, got: {tensor}"
+
return tensor.cpu().detach().contiguous().numpy()
if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
- kernel = to_numpy(kwargs["weight"])
+ # for now we'll assume bias is constant Tensor or None,
+ # and bias being ITensor is not supported in TensorRT api
+ # right now
bias = to_numpy(kwargs["bias"])
- layer = network.add_convolution(
- input=input_val,
- num_output_maps=kernel.shape[0],
- kernel_shape=kernel.shape[2:],
- kernel=kernel,
- bias=bias,
- )
+ if network.has_explicit_precision:
+ weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
+ weight_shape = tuple(kwargs["weight"].shape)
+ # will need to use uninitialized weight and set it later to support
+ # ITensor weights
+ dummy_weight = trt.Weights()
+
+ layer = network.add_convolution(
+ input=input_val,
+ num_output_maps=weight.shape[0],
+ kernel_shape=weight.shape[2:],
+ kernel=dummy_weight,
+ bias=bias,
+ )
+
+ layer.set_input(1, weight)
+ else:
+ weight = to_numpy(kwargs["weight"])
+ layer = network.add_convolution(
+ input=input_val,
+ num_output_maps=weight.shape[0],
+ kernel_shape=weight.shape[2:],
+ kernel=weight,
+ bias=bias,
+ )
layer.name = name
layer.stride = kwargs["stride"]
"dim for linear and it can't be the last dim."
)
- weight = kwargs["weight"]
-
- # For quantization, weight here would be a trt tensor because it goes through
- # quant + dequant. In this case, we need to use matmul + add because fully_connected
- # can't take non-constant weight.
# TODO: Need to benchmark the performance of lowering linear as fully_connected versus
# lowering as matmul + add. TensorRT documentation suggests to always lower it as
# matmul + add but we found in some cases this results in performance regression compared
# with lowering to fully_connected layer.
- if isinstance(weight, torch.Tensor):
- layer = network.add_shuffle(input_val)
- layer.reshape_dims = tuple(input_val.shape) + (1, 1)
- layer.name = f"{name}_pre_shuffle"
+ layer = network.add_shuffle(input_val)
+ layer.reshape_dims = tuple(input_val.shape) + (1, 1)
+ layer.name = f"{name}_pre_shuffle"
+ bias = to_numpy(kwargs["bias"])
+
+ if network.has_explicit_precision:
+ weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
+ # will need to use uninitialized weight and set it later to support
+ # ITensor weights
+ dummy_weight = trt.Weights()
# add fully connected
layer = network.add_fully_connected(
input=layer.get_output(0),
- num_outputs=kwargs["weight"].shape[0],
- kernel=to_numpy(kwargs["weight"]),
- bias=to_numpy(kwargs["bias"]),
+ num_outputs=weight.shape[0],
+ kernel=dummy_weight,
+ bias=bias,
)
- layer.name = f"{name}_linear"
-
- # reshape back
- layer = network.add_shuffle(layer.get_output(0))
- layer.reshape_dims = tuple(input_val.shape[:-1]) + (kwargs["weight"].shape[0],)
- layer.name = f"{name}_post_shuffle"
-
- return layer.get_output(0)
+ layer.set_input(1, weight)
else:
- # add matrix multiply and add
- output = add_matrix_multiply_layer(network, input_val, weight, f"{name}_linear_mm", transpose_other=True)
- if kwargs["bias"] is not None:
- return add_binary_elementwise_layer(network, output, kwargs["bias"], trt.ElementWiseOperation.SUM, f"{name}_linear_add")
- else:
- return output
+ weight = to_numpy(kwargs["weight"])
+ layer = network.add_fully_connected(
+ input=layer.get_output(0),
+ num_outputs=weight.shape[0],
+ kernel=weight,
+ bias=bias,
+ )
+ layer.name = f"{name}_linear"
+ # reshape back
+ layer = network.add_shuffle(layer.get_output(0))
+ layer.reshape_dims = tuple(input_val.shape[:-1]) + (kwargs["weight"].shape[0],)
+ layer.name = f"{name}_post_shuffle"
+
+ return layer.get_output(0)
def add_clamp(network, input, val, op):
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
@tensorrt_converter(acc_ops.quantize_per_tensor)
def acc_ops_quantize_per_tensor(network, target, args, kwargs, name):
- input_val = kwargs["input"]
+ input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input")
+
if not isinstance(input_val, trt.tensorrt.ITensor):
raise RuntimeError(f"{name} received input {input_val} that is not part "
def build_fp16_trt(rn18):
rn18 = copy.deepcopy(rn18)
rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)]) # type: ignore[attr-defined]
- interp = TRTInterpreter(rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
+ interp = TRTInterpreter(
+ rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
engine, input_names, output_names = interp.run(fp16_mode=True)
return TRTModule(engine, input_names, output_names)
def build_int8_trt(rn18):
rn18 = copy.deepcopy(rn18)
data = torch.randn(1, 3, 224, 224)
+ # data = torch.randn(1, 32)
# data = torch.randn(1, 64, 10, 10)
# TensorRT only supports symmetric quantization
qconfig = torch.quantization.QConfig(
for _ in range(10):
prepared(data)
quantized_rn18 = convert_fx(prepared, is_reference=True)
+ ref_res = quantized_rn18(data)
print("quantized model:", quantized_rn18)
quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) # type: ignore[attr-defined]
- interp = TRTInterpreter(quantized_rn18, [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)])
+ interp = TRTInterpreter(
+ quantized_rn18,
+ [InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
+ shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
+ explicit_batch_dimension=True, explicit_precision=True)
engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
- return TRTModule(engine, input_names, output_names)
+ trt_mod = TRTModule(engine, input_names, output_names)
+ trt_res = trt_mod(data.cuda())
+ print("result diff max", torch.max(ref_res - trt_res.cpu()))
+ return trt_mod
@torch.no_grad()
def build_int8_trt_implicit_quant(rn18):
for _ in range(10):
prepared(data)
quantized_rn18 = convert_fx(prepared, is_reference=True)
+ ref_res = quantized_rn18(data)
# Build trt int8 model
traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]))
engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
trt_mod = TRTModule(engine, input_names, output_names)
+ trt_res = trt_mod(data.cuda())
+ print("result equal?", torch.equal(ref_res, trt_res))
return trt_mod
class M(torch.nn.Module):
def __init__(self):
super().__init__()
- self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
+ self.linear = torch.nn.Linear(32, 46)
+ # self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
def forward(self, x):
- out = self.conv(x)
+ # out = self.conv(x)
+ out = self.linear(x)
+ # out = torch.nn.functional.relu(out)
+ # out += x
+ # out += out
# out = torch.nn.functional.relu(out)
- out += x
- out += out
- out = torch.nn.functional.relu(out)
return out
# rn18 = M().eval()
# rn18 = rn18.layer1
int8_trt = build_int8_trt(rn18)
-implicit_int8_trt = build_int8_trt_implicit_quant(rn18)
+# implicit_int8_trt = build_int8_trt_implicit_quant(rn18)
fp16_trt = build_fp16_trt(rn18)
x = torch.randn(5, 3, 224, 224, device="cuda")
+# x = torch.randn(1, 32, device="cuda")
rn18 = rn18.cuda()
import time
torch.cuda.synchronize()
print('trt int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
-torch.cuda.synchronize()
-s = time.time()
-for _ in range(NITER):
- implicit_int8_trt(x)
- torch.cuda.synchronize()
-print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
+# torch.cuda.synchronize()
+# s = time.time()
+# for _ in range(NITER):
+# implicit_int8_trt(x)
+# torch.cuda.synchronize()
+# print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000)
torch.cuda.synchronize()
s = time.time()
module: torch.fx.GraphModule,
input_specs: List[InputTensorSpec],
explicit_batch_dimension: bool = False,
+ explicit_precision: bool = False,
logger_level=trt.Logger.WARNING,
):
super().__init__(module)
self.logger = trt.Logger(logger_level)
self.builder = trt.Builder(self.logger)
+ flag = 0
if explicit_batch_dimension:
EXPLICIT_BATCH = 1 << (int)(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
)
- self.network = self.builder.create_network(EXPLICIT_BATCH)
- else:
- self.network = self.builder.create_network()
+ flag |= EXPLICIT_BATCH
+
+ if explicit_precision:
+ EXPLICIT_PRECISION = 1 << (int)(
+ trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION
+ )
+ flag |= EXPLICIT_PRECISION
+ self.network = self.builder.create_network(flag)
missing_ops = self.validate_conversion()
if missing_ops:
model
"""
# supress mypy warning
- assert isinstance(self.weight, torch.Tensor)
assert isinstance(self.weight_scale, torch.Tensor)
assert isinstance(self.weight_zero_point, torch.Tensor)
assert isinstance(self.weight_axis, torch.Tensor)
model
"""
# supress mypy warning
- assert isinstance(self.weight, torch.Tensor)
assert isinstance(self.weight_scale, torch.Tensor)
assert isinstance(self.weight_zero_point, torch.Tensor)
assert isinstance(self.weight_axis, torch.Tensor)
# tuple (activation_dtype, weight_dtype, compute_dtype)
# these are supported types for common binary ops like add/mul etc.
all_dtypes = [
+ (torch.qint8, torch.qint8, None),
(torch.quint8, torch.qint8, None),
(torch.float16, torch.float16, None),
]
(torch.float16, torch.float16, None)
]
int8_dtypes = [
+ (torch.qint8, torch.qint8, None),
(torch.quint8, torch.qint8, None),
]
binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {