--- /dev/null
+import torch
+from torch.quantization.quantize_fx import (
+ prepare_fx,
+ convert_fx,
+ get_tensorrt_backend_config_dict
+)
+import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer
+from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule
+from torch.testing._internal.common_quantization import QuantizationTestCase
+from torch.testing._internal.common_cuda import TEST_CUDA
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.common_quantization import NodeSpec as ns
+
+import unittest
+
+def lower_to_trt(model, sample_input, shape_ranges):
+ model = acc_tracer.trace(model, [sample_input]) # type: ignore[attr-defined]
+ interp = TRTInterpreter(
+ model,
+ [InputTensorSpec(
+ torch.Size([-1, *sample_input.shape[1:]]), torch.float,
+ shape_ranges=shape_ranges, has_batch_dim=True)],
+ explicit_batch_dimension=True, explicit_precision=True)
+ engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
+ trt_mod = TRTModule(engine, input_names, output_names)
+ return trt_mod
+
+
+
+@unittest.skipIf(not TEST_CUDA, "gpu is not available.")
+class TestQuantizeFxTRT(QuantizationTestCase):
+ def test_conv(self):
+ class Conv2d(torch.nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(*args)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ conv2d_input = torch.rand(1, 3, 224, 224)
+ conv2d_module_args = (3, 3, 3)
+
+ m = Conv2d(*conv2d_module_args).eval()
+ qconfig = torch.quantization.QConfig(
+ activation=torch.quantization.observer.HistogramObserver.with_args(
+ qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
+ ),
+ weight=torch.quantization.default_weight_observer
+ )
+ prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict())
+ # calibration
+ prepared(conv2d_input)
+ quantized = convert_fx(prepared, is_reference=True)
+ node_occurrence = {
+ ns.call_function(torch.quantize_per_tensor): 1,
+ ns.call_method("dequantize"): 1
+ }
+ self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
+ # lower to trt
+ trt_mod = lower_to_trt(quantized, conv2d_input, [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))])
+ # make sure it runs
+ trt_mod(conv2d_input.cuda())
+
+ def test_linear(self):
+ class LinearModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(5, 10)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ linear_module_input = torch.rand(8, 5)
+
+ m = LinearModule().eval()
+ qconfig = torch.quantization.QConfig(
+ activation=torch.quantization.observer.HistogramObserver.with_args(
+ qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
+ ),
+ weight=torch.quantization.default_weight_observer
+ )
+ prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict())
+ # calibration
+ prepared(linear_module_input)
+ quantized = convert_fx(prepared, is_reference=True)
+ node_occurrence = {
+ ns.call_function(torch.quantize_per_tensor): 1,
+ ns.call_method("dequantize"): 1
+ }
+ self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
+ # lower to trt
+ trt_mod = lower_to_trt(
+ quantized,
+ linear_module_input,
+ [((1, *linear_module_input.shape[1:]),
+ (5, *linear_module_input.shape[1:]),
+ (10, *linear_module_input.shape[1:]))])
+ # make sure it runs
+ trt_mod(linear_module_input.cuda())
+
+if __name__ == '__main__':
+ run_tests()
return maybe_obs
return None
-
def input_output_observed(self) -> bool:
"""
Returns True if the pattern matched to this qhandler could be
setattr(modules[parent_name], name, quantized_standalone_module)
modules[str(node.target)] = quantized_standalone_module
return quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))
+
+
+class ConvReLUQuantizeHandlerNew(QuantizeHandler):
+ """ This is to unblock perf testing for TensorRT, this will be
+ changed in the future so don't depend on this.
+ """
+ def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
+ super().__init__(node, modules)
+ self.relu_node = None
+ if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
+ (node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)):
+ self.relu_node = node
+ node = node.args[0] # type: ignore[assignment]
+ self.conv_node = node
+ if node.op == "call_module":
+ self.conv = modules[str(self.conv_node.target)]
+ elif node.op == "call_function":
+ self.conv = node.target # type: ignore[assignment]
+
+ def should_insert_observer_for_output(
+ self,
+ qconfig: Any,
+ model_is_training: bool,
+ ) -> bool:
+ return False
+
+ def is_output_quantized(self, qconfig, is_reference):
+ return False
+
+ def convert(self,
+ node: Node,
+ qconfig: QConfigAny,
+ modules: Dict[str, torch.nn.Module],
+ quantized_graph: Graph,
+ node_name_to_scope: Dict[str, Tuple[str, type]],
+ load_arg: Callable,
+ is_reference: bool = False,
+ convert_custom_config_dict: Dict[str, Any] = None) -> Node:
+ assert is_reference, "ConvReLUQuantizeHandlerNew only works for the case when is_reference=True"
+ activation_int8_quantized = activation_is_int8_quantized(qconfig)
+ if self.conv_node.op == 'call_module':
+ output_activation_post_process = \
+ self._maybe_get_last_node_only_observer(modules)
+ # note that relu should already be fused into conv module in the fusion step
+ assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
+ 'please make sure to run fusion before prepare'
+ # produce dequant - float_op - quant pattern
+ dtype = torch.float
+ if activation_int8_quantized:
+ dtype = activation_dtype(qconfig)
+ activation = load_arg(quantized=dtype)(self.conv_node.args[0])
+ args = load_arg(quantized=torch.float)(self.conv_node.args)
+ # Get the float conv and attach quantization scheme and quantization
+ # parameters of weight to the module
+ # and qparam is a dictionary of
+ # {"qscheme": ..., "scale": ..., "zero_point": ...} for per tensor quantization or
+ # {"qscheme": ..., "scale": ..., "zero_point": ..., "axis": ...} for per channel quantization
+ float_conv = self.conv
+ fused_conv = None
+ if isinstance(
+ float_conv,
+ QAT_CONV_MODULE_CLASSES):
+ # case 1. converting qat conv module to
+ # a float conv module, we need to attch
+ # weight fake_quant to the conv module,
+ # weight fake_quant is assumed to be run during
+ # QAT so we don't need to run it again here
+ float_conv = self.conv.to_float() # type: ignore[operator]
+ # change qat conv to conv
+ parent_name, name = _parent_name(self.conv_node.target)
+ setattr(modules[parent_name], name, float_conv)
+ if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
+ fused_conv = float_conv
+ float_conv = float_conv[0]
+ weight_post_process = self.conv.weight_fake_quant
+ else:
+ # case 2. converting a conv module/fused conv module
+ # to float conv module, we need to attach
+ # weight observer to the conv module and run it
+ # with conv weight
+ if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
+ fused_conv = float_conv
+ float_conv = float_conv[0] # type: ignore[index]
+ assert qconfig is not None
+ weight_post_process = qconfig.weight()
+ # run weight observer
+ weight_post_process(float_conv.weight) # type: ignore[operator]
+ weight_qparams = get_qparam_dict(weight_post_process)
+ # hardcoded for now, TODO: expose the api to user,
+ # we can have a map from module to reference module
+ # and allow user to register new ones
+ qconv_cls = get_static_quant_module_class(
+ type(float_conv), is_reference=is_reference)
+ ref_conv = qconv_cls.from_float(float_conv, weight_qparams) # type: ignore[attr-defined]
+ # if the parent is a fused conv (Sequential), we can replace the first
+ # item to ref conv, otherwise we can update
+ # the conv instance in the module tree
+ if fused_conv is not None:
+ fused_conv[0] = ref_conv
+ else:
+ parent_name, name = _parent_name(self.conv_node.target)
+ setattr(modules[parent_name], name, ref_conv)
+ op_out = quantized_graph.create_node(
+ 'call_module',
+ self.conv_node.target,
+ args, {})
+ # disabling quantize node for output for now, this will be controlled by the
+ # backend_config_dict in the final design
+ if output_activation_post_process:
+ op_out = quantize_node(
+ op_out,
+ output_activation_post_process,
+ node,
+ modules,
+ quantized_graph,
+ node_name_to_scope,
+ is_input=False)
+ return op_out
+ else:
+ assert self.conv_node.op == "call_function"
+ # make sure the input and weight are quantized to torch.quint8, torch.qint8, respectively
+ load_arg(quantized={0: torch.quint8, 1: torch.qint8})(self.conv_node.args)
+ args = load_arg(quantized=torch.float)(self.conv_node.args)
+ kwargs = load_arg(quantized=torch.float)(self.conv_node.kwargs)
+ op_out = quantized_graph.create_node(
+ "call_function", self.conv, args, kwargs)
+ if self.relu_node:
+ relu_args = [op_out]
+ relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
+ relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
+ op_out = quantized_graph.create_node(
+ "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
+
+ # disabling quantize node for output for now, this will be controlled by the
+ # backend_config_dict in the final design
+ # if activation_int8_quantized:
+ # root_module = modules['']
+ # act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
+ # act_post_process_node = self.relu_node if self.relu_node else self.conv_node
+ # activation_post_process = \
+ # self._maybe_get_last_node_only_observer(modules)
+ # assert activation_post_process is not None
+ # return quantize_node(
+ # op_out,
+ # activation_post_process,
+ # act_post_process_node,
+ # modules,
+ # quantized_graph,
+ # node_name_to_scope,
+ # is_input=False)
+ # else:
+ # # output for dynamically quantized conv op is not quantized
+ # return op_out
+ return op_out
+
+class LinearReLUQuantizeHandlerNew(QuantizeHandler):
+ """ This is to unblock perf testing for TensorRT, this will be
+ changed in the future so don't depend on this.
+ """
+ def __init__(
+ self,
+ node: Node,
+ modules: Dict[str, torch.nn.Module]):
+ super().__init__(node, modules)
+ self.relu_node = None
+ if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
+ (node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)):
+ self.relu_node = node
+ node = node.args[0] # type: ignore[assignment]
+ self.linear_node = node
+ if node.op == 'call_module':
+ self.linear = modules[str(self.linear_node.target)]
+
+ def should_insert_observer_for_output(
+ self,
+ qconfig: Any,
+ model_is_training: bool,
+ ) -> bool:
+ return False
+
+ def is_output_quantized(self, qconfig, is_reference):
+ return False
+
+ def convert(self,
+ node: Node,
+ qconfig: QConfigAny,
+ modules: Dict[str, torch.nn.Module],
+ quantized_graph: Graph,
+ node_name_to_scope: Dict[str, Tuple[str, type]],
+ load_arg: Callable,
+ is_reference: bool = False,
+ convert_custom_config_dict: Dict[str, Any] = None) -> Node:
+ assert is_reference, "LinearReLUQuantizeHandlerNew only works for the case when is_reference=True"
+ if convert_custom_config_dict is None:
+ convert_custom_config_dict = {}
+ dtypes = get_qconfig_dtypes(qconfig)
+ activation_int8_quantized = activation_is_int8_quantized(qconfig)
+ activation_statically_quantized = activation_is_statically_quantized(qconfig)
+ weight_dtype = dtypes[1]
+ if self.linear_node.op == 'call_module':
+
+ output_activation_post_process = \
+ self._maybe_get_last_node_only_observer(modules)
+
+ # note that relu should already be fused into linear modul in the fusion step
+ assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
+ 'please make sure to run fusion before prepare'
+ # produce dequant - float_op - quant pattern
+ dtype = torch.float
+ if activation_int8_quantized:
+ dtype = activation_dtype(qconfig)
+ activation = load_arg(quantized=dtype)(self.linear_node.args[0])
+ args = load_arg(quantized=torch.float)(self.linear_node.args)
+ # Get the float linear and attach qscheme and qparams
+ # the the module
+ float_linear = self.linear
+ fused_linear = None
+ if isinstance(float_linear, (torch.nn.qat.Linear, torch.nn.intrinsic.qat.LinearReLU)):
+ float_linear = float_linear.to_float()
+ # change qat linear to linear
+ parent_name, name = _parent_name(self.linear_node.target)
+ setattr(modules[parent_name], name, float_linear)
+ # Attach weight fake quant to the linear module
+ if isinstance(float_linear, torch.nn.intrinsic.LinearReLU):
+ fused_linear = float_linear
+ float_linear = float_linear[0]
+ weight_post_process = self.linear.weight_fake_quant
+ else:
+ if isinstance(float_linear, torch.nn.intrinsic.LinearReLU):
+ fused_linear = float_linear
+ float_linear = self.linear[0] # type: ignore[index]
+ # Attach the weight observer to the module
+ weight_post_process = qconfig.weight() # type: ignore[union-attr]
+ # Run weight observer
+ weight_post_process(float_linear.weight) # type: ignore[operator]
+
+ weight_qparams = get_qparam_dict(weight_post_process)
+ # TODO: include the configuration in backend_config_dict
+ # we can have a map from module to reference module
+ # and allow user to register new ones
+ qlinear_cls = get_static_quant_module_class(
+ type(float_linear), is_reference=is_reference)
+ ref_linear = qlinear_cls.from_float(float_linear, weight_qparams)
+
+ # if the parent is a fused linear (Sequential), we can replace the first
+ # item to ref linear, otherwise we can update
+ # the linear instance in the module tree
+ if fused_linear is not None:
+ fused_linear[0] = ref_linear
+ else:
+ parent_name, name = _parent_name(self.linear_node.target)
+ setattr(modules[parent_name], name, ref_linear)
+ op_out = quantized_graph.create_node(
+ 'call_module',
+ self.linear_node.target,
+ args, {})
+ if output_activation_post_process:
+ op_out = quantize_node(
+ op_out,
+ output_activation_post_process,
+ node,
+ modules,
+ quantized_graph,
+ node_name_to_scope,
+ is_input=False)
+ return op_out
+ else: # call_function
+ assert self.linear_node.op == 'call_function'
+ quantized_input_dtypes = [torch.float, torch.float]
+ if activation_int8_quantized:
+ quantized_input_dtypes[0] = torch.quint8
+ if weight_is_statically_quantized(qconfig):
+ quantized_input_dtypes[1] = torch.qint8
+ args = load_arg(quantized=quantized_input_dtypes)(self.linear_node.args)
+ args = load_arg(quantized=torch.float)(self.linear_node.args)
+ kwargs = load_arg(quantized=torch.float)(self.linear_node.kwargs)
+ op_out = quantized_graph.create_node(
+ "call_function", torch.nn.functional.linear, args, kwargs)
+ if self.relu_node:
+ relu_args = [op_out]
+ relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
+ relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
+ op_out = quantized_graph.create_node(
+ "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
+
+ return op_out
+ # TODO: enable later
+ # if activation_statically_quantized:
+ # # quantize output for statically quantized linear op
+ # root_module = modules['']
+ # act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name
+ # act_post_process_node = self.relu_node if self.relu_node else self.linear_node
+ # activation_post_process = \
+ # self._maybe_get_last_node_only_observer(modules)
+ # assert activation_post_process is not None
+ # return quantize_node(
+ # op_out,
+ # activation_post_process,
+ # act_post_process_node,
+ # modules,
+ # quantized_graph,
+ # node_name_to_scope,
+ # is_input=False)
+ # else:
+ # # output for dynamically quantized linear op is not quantized
+ # return op_out