From 877e6f2be3e78258247fb969577cb86be392e90c Mon Sep 17 00:00:00 2001 From: Charles David Hernandez Date: Wed, 18 Aug 2021 13:30:35 -0700 Subject: [PATCH] Bugfix for fuse qconfig comparison (#63384) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63384 In some cases the changes to qconfig on module would cause the fusions to fail. This bugfix solves that problem by adding a qconfig_function_comparison that compares the functions within the qconfig rather than the modules the qconfigs are on. The comparison looks at the partial object within QConfig.activation/weight.p and compares args, keywords and func. This is necessary to do mannually because partial doesn't have __eq__ implemented and so == reverts to is. Test Plan: python test/test_quantization.py TestFuseFx.test_problematic_fuse_example Imported from OSS Reviewed By: supriyar, ejguan Differential Revision: D30386264 fbshipit-source-id: 51e358c021c39d6f48dc12ad2a82b2838677b9de --- test/quantization/fx/test_quantize_fx.py | 32 ++++++++++++++++++++++++++++++++ torch/quantization/fx/prepare.py | 4 ++-- torch/quantization/qconfig.py | 17 +++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 2f5f7c4..bf15a06 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -314,6 +314,38 @@ class TestFuseFx(QuantizationTestCase): self.checkGraphModuleNodes(quantized, expected_node_list=node_list) + def test_problematic_fuse_example(self): + class LinearRelu(nn.Sequential): + def __init__(self): + super().__init__( + nn.Linear(5, 5), + nn.ReLU(), + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin_relu = LinearRelu() + self.linear = nn.Linear(5, 5) + + def forward(self, x): + x = self.lin_relu(x) + x = self.linear(x) + return x + + model = M().eval() + # these qconfigs somehow fail equality where default_qconfig does not + qconfig_dict = { + "": None, + "object_type": [ + (torch.nn.Linear, get_default_qconfig('fbgemm')), + (torch.nn.ReLU, get_default_qconfig('fbgemm')), + ], + } + m = prepare_fx(model, qconfig_dict) + + self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.nn.intrinsic.modules.fused.LinearReLU)) + def test_fuse_custom_config_dict_validity(self): r""" Verifies that if a user passes an invalid key or makes a typo when diff --git a/torch/quantization/fx/prepare.py b/torch/quantization/fx/prepare.py index 873d11a..23d1d40 100644 --- a/torch/quantization/fx/prepare.py +++ b/torch/quantization/fx/prepare.py @@ -15,7 +15,7 @@ from torch.fx.graph import ( ) from torch.fx.node import Argument -from ..qconfig import QConfigAny +from ..qconfig import QConfigAny, qconfig_function_equality from .qconfig_utils import ( convert_dict_to_ordered_dict, generate_qconfig_map, @@ -195,7 +195,7 @@ def update_qconfig_for_fusion( # Raise an error if the modules in the fused module have # different qconfigs specified in the qconfig_dict for op in ops: - if object_type_dict.get(op, None) != fused_qconfig: + if not qconfig_function_equality(object_type_dict.get(op, None), fused_qconfig): raise LookupError("During fusion, we need to specify the same " + f"qconfigs for both modules in {module_type}.") diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 15eb174..01d67dd 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -209,3 +209,20 @@ def add_module_to_qconfig_obs_ctr( return QConfig(activation, weight) else: return QConfigDynamic(activation, weight) + + +def qconfig_function_equality(q1: QConfigAny, q2: QConfigAny): + # functools.partial has no __eq__ operator defined so '==' defaults to 'is' + def compare_partial(p1, p2): + same = p1.func == p2.func + same = same and p1.args == p2.args + return same and p1.keywords == p2.keywords + + if q1 is None or q2 is None: + return q1 == q2 + else: + assert q1 is not None and q2 is not None + try: + return compare_partial(q1.activation.p, q2.activation.p) and compare_partial(q1.weight.p, q2.weight.p) + except AttributeError: + return q1 == q2 -- 2.7.4