[quant][fx] Ensure qconfig works for QAT with multiple modules (#63343)
authorSupriya Rao <supriyar@fb.com>
Tue, 17 Aug 2021 18:39:16 +0000 (11:39 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 18:40:51 +0000 (11:40 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63343

The previous implementation had a bug where we were trying to modify an ordered dict value while iterating through it.
This fixes it by creating a copy before modifying it.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_qconfig_qat_module_type

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D30346116

fbshipit-source-id: 0e33dad1163e8bff3fd363bfd04de8f7114d7a3a

test/quantization/fx/test_quantize_fx.py
torch/quantization/fx/prepare.py

index bcebb6e..2f93f49 100644 (file)
@@ -1076,25 +1076,22 @@ class TestQuantizeFx(QuantizationTestCase):
         self.checkGraphModuleNodes(m, expected_node_list=node_list)
 
     def test_qconfig_qat_module_type(self):
-        class Linear(torch.nn.Module):
+        class LinearRelu(nn.Sequential):
             def __init__(self):
-                super().__init__()
-                self.w = torch.ones(5, 5)
-                self.b = torch.zeros(5)
-
-            def forward(self, x):
-                return torch.nn.functional.linear(x, self.w, self.b)
-
+                super().__init__(
+                    nn.Linear(5, 5),
+                    nn.ReLU(),
+                )
 
         class M(torch.nn.Module):
             def __init__(self):
                 super().__init__()
-                self.mods1 = torch.nn.Sequential(
-                    torch.nn.Linear(5, 5),
-                )
+                self.lin_relu = LinearRelu()
+                self.linear = nn.Linear(5, 5)
 
             def forward(self, x):
-                x = self.mods1(x)
+                x = self.lin_relu(x)
+                x = self.linear(x)
                 return x
 
         model = M().train()
@@ -1103,6 +1100,7 @@ class TestQuantizeFx(QuantizationTestCase):
             "": None,
             "object_type": [
                 (torch.nn.Linear, default_qat_qconfig),
+                (torch.nn.ReLU, default_qat_qconfig),
             ],
         }
         m = prepare_qat_fx(model, qconfig_dict)
@@ -1111,6 +1109,7 @@ class TestQuantizeFx(QuantizationTestCase):
         m(torch.rand(5, 5))
         node_list = [
             ns.call_function(torch.quantize_per_tensor),
+            ns.call_module(nniq.LinearReLU),
             ns.call_module(nnq.Linear),
             ns.call_method("dequantize"),
         ]
index 055b984..ab13748 100644 (file)
@@ -163,7 +163,8 @@ def update_qconfig_for_qat(
     all_qat_mappings = get_combined_dict(
         get_default_qat_module_mappings(), additional_qat_module_mapping)
     object_type_dict = qconfig_dict.get("object_type", None)
-    for k, v in object_type_dict.items():
+    new_object_type_dict = object_type_dict.copy()
+    for k, v in new_object_type_dict.items():
         if k in all_qat_mappings:
             object_type_dict[all_qat_mappings[k]] = v
     return qconfig_dict