From 670853295addc9250d5829a26bd1065ff9bff159 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 14 Sep 2021 15:26:03 -0700 Subject: [PATCH] [quant][tensorrt] Add tensorrt backend config (#64623) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64623 The config api will change, but we'll add configs gradually for TensorRT to unblock experimentation Test Plan: python torch/fx/experimental/fx2trt/example/unittests.py Imported from OSS Reviewed By: vkuzo Differential Revision: D30800474 fbshipit-source-id: 3c4640de1205a0f19b62943ab84f386d80394ec2 --- torch/fx/experimental/fx2trt/example/unittests.py | 103 +++++++ torch/quantization/fx/__init__.py | 2 + .../fx/backend_config_dict/__init__.py | 1 + .../fx/backend_config_dict/tensorrt.py | 15 + torch/quantization/fx/quantization_patterns.py | 307 ++++++++++++++++++++- torch/quantization/quantize_fx.py | 19 +- 6 files changed, 443 insertions(+), 4 deletions(-) create mode 100644 torch/fx/experimental/fx2trt/example/unittests.py create mode 100644 torch/quantization/fx/backend_config_dict/tensorrt.py diff --git a/torch/fx/experimental/fx2trt/example/unittests.py b/torch/fx/experimental/fx2trt/example/unittests.py new file mode 100644 index 0000000..d625675 --- /dev/null +++ b/torch/fx/experimental/fx2trt/example/unittests.py @@ -0,0 +1,103 @@ +import torch +from torch.quantization.quantize_fx import ( + prepare_fx, + convert_fx, + get_tensorrt_backend_config_dict +) +import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer +from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule +from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_quantization import NodeSpec as ns + +import unittest + +def lower_to_trt(model, sample_input, shape_ranges): + model = acc_tracer.trace(model, [sample_input]) # type: ignore[attr-defined] + interp = TRTInterpreter( + model, + [InputTensorSpec( + torch.Size([-1, *sample_input.shape[1:]]), torch.float, + shape_ranges=shape_ranges, has_batch_dim=True)], + explicit_batch_dimension=True, explicit_precision=True) + engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True) + trt_mod = TRTModule(engine, input_names, output_names) + return trt_mod + + + +@unittest.skipIf(not TEST_CUDA, "gpu is not available.") +class TestQuantizeFxTRT(QuantizationTestCase): + def test_conv(self): + class Conv2d(torch.nn.Module): + def __init__(self, *args): + super().__init__() + self.conv = torch.nn.Conv2d(*args) + + def forward(self, x): + return self.conv(x) + + conv2d_input = torch.rand(1, 3, 224, 224) + conv2d_module_args = (3, 3, 3) + + m = Conv2d(*conv2d_module_args).eval() + qconfig = torch.quantization.QConfig( + activation=torch.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 + ), + weight=torch.quantization.default_weight_observer + ) + prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) + # calibration + prepared(conv2d_input) + quantized = convert_fx(prepared, is_reference=True) + node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 1, + ns.call_method("dequantize"): 1 + } + self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) + # lower to trt + trt_mod = lower_to_trt(quantized, conv2d_input, [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))]) + # make sure it runs + trt_mod(conv2d_input.cuda()) + + def test_linear(self): + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + linear_module_input = torch.rand(8, 5) + + m = LinearModule().eval() + qconfig = torch.quantization.QConfig( + activation=torch.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 + ), + weight=torch.quantization.default_weight_observer + ) + prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) + # calibration + prepared(linear_module_input) + quantized = convert_fx(prepared, is_reference=True) + node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 1, + ns.call_method("dequantize"): 1 + } + self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) + # lower to trt + trt_mod = lower_to_trt( + quantized, + linear_module_input, + [((1, *linear_module_input.shape[1:]), + (5, *linear_module_input.shape[1:]), + (10, *linear_module_input.shape[1:]))]) + # make sure it runs + trt_mod(linear_module_input.cuda()) + +if __name__ == '__main__': + run_tests() diff --git a/torch/quantization/fx/__init__.py b/torch/quantization/fx/__init__.py index 8ba1b6d..01ee7e8 100644 --- a/torch/quantization/fx/__init__.py +++ b/torch/quantization/fx/__init__.py @@ -1,3 +1,5 @@ from .prepare import prepare from .convert import convert from .fuse import Fuser +from .backend_config_dict import get_fbgemm_backend_config_dict +from .backend_config_dict import get_tensorrt_backend_config_dict diff --git a/torch/quantization/fx/backend_config_dict/__init__.py b/torch/quantization/fx/backend_config_dict/__init__.py index edb2b95..c50e61a 100644 --- a/torch/quantization/fx/backend_config_dict/__init__.py +++ b/torch/quantization/fx/backend_config_dict/__init__.py @@ -1,4 +1,5 @@ from .fbgemm import get_fbgemm_backend_config_dict +from .tensorrt import get_tensorrt_backend_config_dict def validate_backend_config_dict(backend_config_dict): return "quant_patterns" in backend_config_dict diff --git a/torch/quantization/fx/backend_config_dict/tensorrt.py b/torch/quantization/fx/backend_config_dict/tensorrt.py new file mode 100644 index 0000000..932f491 --- /dev/null +++ b/torch/quantization/fx/backend_config_dict/tensorrt.py @@ -0,0 +1,15 @@ +import torch +from ..quantization_patterns import ConvReLUQuantizeHandlerNew, LinearReLUQuantizeHandlerNew + +def get_tensorrt_backend_config_dict(): + """ Get the backend config dictionary for tensorrt backend + NOTE: Current api will change in the future, it's just to unblock experimentation for + new backends, please don't use it right now. + """ + quant_patterns = { + torch.nn.Conv2d: ConvReLUQuantizeHandlerNew, + torch.nn.Linear: LinearReLUQuantizeHandlerNew + } + return { + "quant_patterns": quant_patterns + } diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 7ae07bf..df4c21a 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -90,7 +90,6 @@ class QuantizeHandler(ABC): return maybe_obs return None - def input_output_observed(self) -> bool: """ Returns True if the pattern matched to this qhandler could be @@ -1690,3 +1689,309 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler): setattr(modules[parent_name], name, quantized_standalone_module) modules[str(node.target)] = quantized_standalone_module return quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs)) + + +class ConvReLUQuantizeHandlerNew(QuantizeHandler): + """ This is to unblock perf testing for TensorRT, this will be + changed in the future so don't depend on this. + """ + def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]): + super().__init__(node, modules) + self.relu_node = None + if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ + (node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)): + self.relu_node = node + node = node.args[0] # type: ignore[assignment] + self.conv_node = node + if node.op == "call_module": + self.conv = modules[str(self.conv_node.target)] + elif node.op == "call_function": + self.conv = node.target # type: ignore[assignment] + + def should_insert_observer_for_output( + self, + qconfig: Any, + model_is_training: bool, + ) -> bool: + return False + + def is_output_quantized(self, qconfig, is_reference): + return False + + def convert(self, + node: Node, + qconfig: QConfigAny, + modules: Dict[str, torch.nn.Module], + quantized_graph: Graph, + node_name_to_scope: Dict[str, Tuple[str, type]], + load_arg: Callable, + is_reference: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + assert is_reference, "ConvReLUQuantizeHandlerNew only works for the case when is_reference=True" + activation_int8_quantized = activation_is_int8_quantized(qconfig) + if self.conv_node.op == 'call_module': + output_activation_post_process = \ + self._maybe_get_last_node_only_observer(modules) + # note that relu should already be fused into conv module in the fusion step + assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \ + 'please make sure to run fusion before prepare' + # produce dequant - float_op - quant pattern + dtype = torch.float + if activation_int8_quantized: + dtype = activation_dtype(qconfig) + activation = load_arg(quantized=dtype)(self.conv_node.args[0]) + args = load_arg(quantized=torch.float)(self.conv_node.args) + # Get the float conv and attach quantization scheme and quantization + # parameters of weight to the module + # and qparam is a dictionary of + # {"qscheme": ..., "scale": ..., "zero_point": ...} for per tensor quantization or + # {"qscheme": ..., "scale": ..., "zero_point": ..., "axis": ...} for per channel quantization + float_conv = self.conv + fused_conv = None + if isinstance( + float_conv, + QAT_CONV_MODULE_CLASSES): + # case 1. converting qat conv module to + # a float conv module, we need to attch + # weight fake_quant to the conv module, + # weight fake_quant is assumed to be run during + # QAT so we don't need to run it again here + float_conv = self.conv.to_float() # type: ignore[operator] + # change qat conv to conv + parent_name, name = _parent_name(self.conv_node.target) + setattr(modules[parent_name], name, float_conv) + if isinstance(float_conv, torch.nn.intrinsic._FusedModule): + fused_conv = float_conv + float_conv = float_conv[0] + weight_post_process = self.conv.weight_fake_quant + else: + # case 2. converting a conv module/fused conv module + # to float conv module, we need to attach + # weight observer to the conv module and run it + # with conv weight + if isinstance(float_conv, torch.nn.intrinsic._FusedModule): + fused_conv = float_conv + float_conv = float_conv[0] # type: ignore[index] + assert qconfig is not None + weight_post_process = qconfig.weight() + # run weight observer + weight_post_process(float_conv.weight) # type: ignore[operator] + weight_qparams = get_qparam_dict(weight_post_process) + # hardcoded for now, TODO: expose the api to user, + # we can have a map from module to reference module + # and allow user to register new ones + qconv_cls = get_static_quant_module_class( + type(float_conv), is_reference=is_reference) + ref_conv = qconv_cls.from_float(float_conv, weight_qparams) # type: ignore[attr-defined] + # if the parent is a fused conv (Sequential), we can replace the first + # item to ref conv, otherwise we can update + # the conv instance in the module tree + if fused_conv is not None: + fused_conv[0] = ref_conv + else: + parent_name, name = _parent_name(self.conv_node.target) + setattr(modules[parent_name], name, ref_conv) + op_out = quantized_graph.create_node( + 'call_module', + self.conv_node.target, + args, {}) + # disabling quantize node for output for now, this will be controlled by the + # backend_config_dict in the final design + if output_activation_post_process: + op_out = quantize_node( + op_out, + output_activation_post_process, + node, + modules, + quantized_graph, + node_name_to_scope, + is_input=False) + return op_out + else: + assert self.conv_node.op == "call_function" + # make sure the input and weight are quantized to torch.quint8, torch.qint8, respectively + load_arg(quantized={0: torch.quint8, 1: torch.qint8})(self.conv_node.args) + args = load_arg(quantized=torch.float)(self.conv_node.args) + kwargs = load_arg(quantized=torch.float)(self.conv_node.kwargs) + op_out = quantized_graph.create_node( + "call_function", self.conv, args, kwargs) + if self.relu_node: + relu_args = [op_out] + relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:])) + relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs) + op_out = quantized_graph.create_node( + "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs) + + # disabling quantize node for output for now, this will be controlled by the + # backend_config_dict in the final design + # if activation_int8_quantized: + # root_module = modules[''] + # act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name + # act_post_process_node = self.relu_node if self.relu_node else self.conv_node + # activation_post_process = \ + # self._maybe_get_last_node_only_observer(modules) + # assert activation_post_process is not None + # return quantize_node( + # op_out, + # activation_post_process, + # act_post_process_node, + # modules, + # quantized_graph, + # node_name_to_scope, + # is_input=False) + # else: + # # output for dynamically quantized conv op is not quantized + # return op_out + return op_out + +class LinearReLUQuantizeHandlerNew(QuantizeHandler): + """ This is to unblock perf testing for TensorRT, this will be + changed in the future so don't depend on this. + """ + def __init__( + self, + node: Node, + modules: Dict[str, torch.nn.Module]): + super().__init__(node, modules) + self.relu_node = None + if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ + (node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)): + self.relu_node = node + node = node.args[0] # type: ignore[assignment] + self.linear_node = node + if node.op == 'call_module': + self.linear = modules[str(self.linear_node.target)] + + def should_insert_observer_for_output( + self, + qconfig: Any, + model_is_training: bool, + ) -> bool: + return False + + def is_output_quantized(self, qconfig, is_reference): + return False + + def convert(self, + node: Node, + qconfig: QConfigAny, + modules: Dict[str, torch.nn.Module], + quantized_graph: Graph, + node_name_to_scope: Dict[str, Tuple[str, type]], + load_arg: Callable, + is_reference: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + assert is_reference, "LinearReLUQuantizeHandlerNew only works for the case when is_reference=True" + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + dtypes = get_qconfig_dtypes(qconfig) + activation_int8_quantized = activation_is_int8_quantized(qconfig) + activation_statically_quantized = activation_is_statically_quantized(qconfig) + weight_dtype = dtypes[1] + if self.linear_node.op == 'call_module': + + output_activation_post_process = \ + self._maybe_get_last_node_only_observer(modules) + + # note that relu should already be fused into linear modul in the fusion step + assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \ + 'please make sure to run fusion before prepare' + # produce dequant - float_op - quant pattern + dtype = torch.float + if activation_int8_quantized: + dtype = activation_dtype(qconfig) + activation = load_arg(quantized=dtype)(self.linear_node.args[0]) + args = load_arg(quantized=torch.float)(self.linear_node.args) + # Get the float linear and attach qscheme and qparams + # the the module + float_linear = self.linear + fused_linear = None + if isinstance(float_linear, (torch.nn.qat.Linear, torch.nn.intrinsic.qat.LinearReLU)): + float_linear = float_linear.to_float() + # change qat linear to linear + parent_name, name = _parent_name(self.linear_node.target) + setattr(modules[parent_name], name, float_linear) + # Attach weight fake quant to the linear module + if isinstance(float_linear, torch.nn.intrinsic.LinearReLU): + fused_linear = float_linear + float_linear = float_linear[0] + weight_post_process = self.linear.weight_fake_quant + else: + if isinstance(float_linear, torch.nn.intrinsic.LinearReLU): + fused_linear = float_linear + float_linear = self.linear[0] # type: ignore[index] + # Attach the weight observer to the module + weight_post_process = qconfig.weight() # type: ignore[union-attr] + # Run weight observer + weight_post_process(float_linear.weight) # type: ignore[operator] + + weight_qparams = get_qparam_dict(weight_post_process) + # TODO: include the configuration in backend_config_dict + # we can have a map from module to reference module + # and allow user to register new ones + qlinear_cls = get_static_quant_module_class( + type(float_linear), is_reference=is_reference) + ref_linear = qlinear_cls.from_float(float_linear, weight_qparams) + + # if the parent is a fused linear (Sequential), we can replace the first + # item to ref linear, otherwise we can update + # the linear instance in the module tree + if fused_linear is not None: + fused_linear[0] = ref_linear + else: + parent_name, name = _parent_name(self.linear_node.target) + setattr(modules[parent_name], name, ref_linear) + op_out = quantized_graph.create_node( + 'call_module', + self.linear_node.target, + args, {}) + if output_activation_post_process: + op_out = quantize_node( + op_out, + output_activation_post_process, + node, + modules, + quantized_graph, + node_name_to_scope, + is_input=False) + return op_out + else: # call_function + assert self.linear_node.op == 'call_function' + quantized_input_dtypes = [torch.float, torch.float] + if activation_int8_quantized: + quantized_input_dtypes[0] = torch.quint8 + if weight_is_statically_quantized(qconfig): + quantized_input_dtypes[1] = torch.qint8 + args = load_arg(quantized=quantized_input_dtypes)(self.linear_node.args) + args = load_arg(quantized=torch.float)(self.linear_node.args) + kwargs = load_arg(quantized=torch.float)(self.linear_node.kwargs) + op_out = quantized_graph.create_node( + "call_function", torch.nn.functional.linear, args, kwargs) + if self.relu_node: + relu_args = [op_out] + relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:])) + relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs) + op_out = quantized_graph.create_node( + "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs) + + return op_out + # TODO: enable later + # if activation_statically_quantized: + # # quantize output for statically quantized linear op + # root_module = modules[''] + # act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name + # act_post_process_node = self.relu_node if self.relu_node else self.linear_node + # activation_post_process = \ + # self._maybe_get_last_node_only_observer(modules) + # assert activation_post_process is not None + # return quantize_node( + # op_out, + # activation_post_process, + # act_post_process_node, + # modules, + # quantized_graph, + # node_name_to_scope, + # is_input=False) + # else: + # # output for dynamically quantized linear op is not quantized + # return op_out diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 2dd98ea..c6cb824 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -4,6 +4,8 @@ from torch.fx._symbolic_trace import Tracer from torch.fx.node import Target, Node, Argument from .fx import Fuser # noqa: F401 from .fx import prepare, convert # noqa: F401 +from .fx import get_fbgemm_backend_config_dict # noqa: F401 +from .fx import get_tensorrt_backend_config_dict # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 from .fx.graph_module import ObservedGraphModule, QuantizedGraphModule @@ -139,7 +141,8 @@ class QuantizationTracer(Tracer): self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type) return node -def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, +def _prepare_fx(model: torch.nn.Module, + qconfig_dict: Any, prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, @@ -195,6 +198,7 @@ forward graph of the parent module, tracer.node_name_to_scope, prepare_custom_config_dict=prepare_custom_config_dict, equalization_qconfig_dict=equalization_qconfig_dict, + backend_config_dict=backend_config_dict, is_standalone_module=is_standalone_module) for attr_name in preserved_attributes: @@ -428,7 +432,12 @@ def prepare_fx( torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") assert not model.training, 'prepare_fx only works for models in ' + \ 'eval mode' - return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict) + return _prepare_fx( + model, + qconfig_dict, + prepare_custom_config_dict, + equalization_qconfig_dict, + backend_config_dict) def prepare_qat_fx( model: torch.nn.Module, qconfig_dict: Any, @@ -467,7 +476,11 @@ def prepare_qat_fx( torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") assert model.training, 'prepare_qat_fx only works for models in ' + \ 'train mode' - return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, backend_config_dict) + return _prepare_fx( + model, + qconfig_dict, + prepare_custom_config_dict, + backend_config_dict=backend_config_dict) def _convert_fx( graph_module: GraphModule, is_reference: bool, -- 2.7.4