From 294db0603fef315c8f6ac95e30f8ce6b5cce2b5a Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Thu, 26 Aug 2021 21:05:56 -0700 Subject: [PATCH] [quant] Add support for linear_relu fusion for FP16 dynamic quant (#63826) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63826 Support the conversion of the intrinsic linearRelu module to the quantized dynamic LinearReLU module Verify the support works for both linear module and functional linear fusion Test Plan: python test/test_quantization.py test_dynamic_with_fusion Imported from OSS Reviewed By: iramazanli Differential Revision: D30503513 fbshipit-source-id: 70446797e9670dfef7341cba2047183d6f88b70f --- test/quantization/fx/test_quantize_fx.py | 33 ++++++++++++++-------- .../quantized/dynamic/modules/linear_relu.py | 7 ++--- torch/quantization/fx/quantization_patterns.py | 9 +++--- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index cdf2e7b..762919e 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -2922,18 +2922,24 @@ class TestQuantizeFx(QuantizationTestCase): return x model = M().eval() - qconfig = { - "": default_dynamic_qconfig, + + dynamic_quantized_ops = { + float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16, + default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic } - m = prepare_fx(model, qconfig) - m = convert_fx(m) - m(torch.rand(5, 5)) - node_list = [ - ns.call_module(nniqd.LinearReLU), - ns.call_module(nniqd.LinearReLU), - ns.call_function(torch.ops.quantized.linear_relu_dynamic), - ] - self.checkGraphModuleNodes(m, expected_node_list=node_list) + for config in [float16_dynamic_qconfig, default_dynamic_qconfig]: + qconfig = { + "": config + } + m = prepare_fx(model, qconfig) + m = convert_fx(m) + m(torch.rand(5, 5)) + node_list = [ + ns.call_module(nniqd.LinearReLU), + ns.call_module(nniqd.LinearReLU), + ns.call_function(dynamic_quantized_ops[config]), + ] + self.checkGraphModuleNodes(m, expected_node_list=node_list) @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): @@ -3089,7 +3095,10 @@ class TestQuantizeFxOps(QuantizationTestCase): if is_reference: qlinear_fun = ns.call_function(torch.nn.functional.linear) else: - qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16) + if has_relu: + qlinear_fun = ns.call_function(torch.ops.quantized.linear_relu_dynamic_fp16) + else: + qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16) prepare_node_occurrence = { # weight ns.call_module(torch.quantization.PlaceholderObserver): 1 diff --git a/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py index 04c4c95..c30b3109 100644 --- a/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +++ b/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -31,10 +31,9 @@ class LinearReLU(nnqd.Linear): # TODO check if we should set reduce_rage = True by default here Y = torch.ops.quantized.linear_relu_dynamic( x, self._packed_params._packed_params, reduce_range=True) - # TODO Support this in a later PR - # elif self._packed_params.dtype == torch.float16: - # Y = torch.ops.quantized.linear_relu_dynamic_fp16( - # x, self._packed_params._packed_params) + elif self._packed_params.dtype == torch.float16: + Y = torch.ops.quantized.linear_relu_dynamic_fp16( + x, self._packed_params._packed_params) else: raise RuntimeError('Unsupported dtype on dynamic quantized linear relu!') return Y.to(x.dtype) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index b7c39ca..6362961 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -1027,8 +1027,11 @@ class LinearReLUQuantizeHandler(QuantizeHandler): qlinear_op = torch.ops.quantized.linear_relu_dynamic else: qlinear_op = torch.ops.quantized.linear_dynamic - else: # TODO add support for fp16 + relu fusion in a later PR - qlinear_op = torch.ops.quantized.linear_dynamic_fp16 + else: + if self.relu_node: + qlinear_op = torch.ops.quantized.linear_relu_dynamic_fp16 + else: + qlinear_op = torch.ops.quantized.linear_dynamic_fp16 linear_input = load_arg(quantized=torch.float)(self.linear_node.args[0]) qlinear_args = (linear_input, packed_weight) # type: ignore[assignment] @@ -1038,8 +1041,6 @@ class LinearReLUQuantizeHandler(QuantizeHandler): # TODO: may need to change the key to Node regenerate the map in each transformation, # since we might not be able to rely on the name node_name_to_scope[op_out.name] = node_name_to_scope[self.linear_node.name] - if self.relu_node and weight_dtype is not torch.qint8: - op_out = quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {}) return op_out else: assert dtypes == (torch.float16, torch.float16, None) -- 2.7.4