fix copy.deepcopy on LinearPackedParams (#64367)
authorVasiliy Kuznetsov <vasiliy@fb.com>
Thu, 2 Sep 2021 13:12:07 +0000 (06:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 13:30:42 +0000 (06:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64367

This is the same thing as https://github.com/pytorch/pytorch/pull/56154
but for quantized linear. It fixes the behavior of `copy.deepcopy` on
these modules. Before this PR, copied instances of `LinearPackedParams`
were not properly initialized, and inspecting them raised errors of
missing `_modules`. After this PR, inspecting and using the copies
works.

Test Plan:
```
python test/test_quantization.py TestStaticQuantizedModule.test_linear_api
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D30702667

fbshipit-source-id: 38c26d1e72663416eeb989985b77ffc2052c12b9

test/quantization/core/test_quantized_module.py
torch/nn/quantized/modules/linear.py

index b0bc782..51e6217 100644 (file)
@@ -103,8 +103,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         zero_point = 3
         qlinear = class_map[use_fused](in_features, out_features)
 
-        qlinear_copy = qlinear  # deepcopy does not work right now
-        # qlinear_copy = copy.deepcopy(qlinear)
+        qlinear_copy = copy.deepcopy(qlinear)
         self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True)
         # Run module with default-initialized parameters.
         # This tests that the constructor is correct.
@@ -155,15 +154,16 @@ class TestStaticQuantizedModule(QuantizationTestCase):
                          linear_unpack(loaded_qlinear._packed_params._packed_params))
         self.assertEqual(qlinear.scale, loaded_qlinear.scale)
         self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
-        # make sure loaded_qlinear has the same dir as qlinear since
-        # scripting the module will add __overloads__ to __dict__
-        self.checkScriptable(loaded_qlinear, [[X_q]], check_save_load=True)
+        # scripting will add __overloads__ to __dict__, which is why we script a copy
+        # to be able to do the check in the next line
+        self.checkScriptable(copy.deepcopy(loaded_qlinear), [[X_q]], check_save_load=True)
         self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
         self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
         self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
         Z_q2 = loaded_qlinear(X_q)
         self.assertEqual(Z_q, Z_q2)
 
+        # Test serialization
         b = io.BytesIO()
         torch.save(qlinear, b)
         b.seek(0)
@@ -172,6 +172,25 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         self.assertEqual(qlinear.scale, loaded.scale)
         self.assertEqual(qlinear.zero_point, loaded.zero_point)
 
+        # Test copy and deepcopy
+        copied_linear = copy.copy(qlinear)
+        self.assertEqual(copied_linear.bias(), qlinear.bias())
+        self.assertEqual(copied_linear.scale, qlinear.scale)
+        self.assertEqual(copied_linear.zero_point,
+                         qlinear.zero_point)
+        Y_copied = copied_linear(X_q)
+        np.testing.assert_array_almost_equal(
+            Z_q.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
+
+        deepcopied_linear = copy.deepcopy(qlinear)
+        self.assertEqual(deepcopied_linear.bias(), qlinear.bias())
+        self.assertEqual(deepcopied_linear.scale, qlinear.scale)
+        self.assertEqual(deepcopied_linear.zero_point,
+                         qlinear.zero_point)
+        Y_deepcopied = copied_linear(X_q)
+        np.testing.assert_array_almost_equal(
+            Z_q.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
+
         # Test JIT
         self.checkScriptable(qlinear, [[X_q]], check_save_load=True)
 
index 4abd211..4df7751 100644 (file)
@@ -94,6 +94,16 @@ class LinearPackedParams(torch.nn.Module):
         self.set_weight_bias(state[0], state[1])
         self.training = state[2]
 
+    def __deepcopy__(self, memo):
+        new_instance = type(self).__new__(type(self))
+        torch.nn.Module.__init__(new_instance)
+        state = self.__getstate__()
+        new_instance.__setstate__(state)
+        return new_instance
+
+    def __copy__(self):
+        return self.__deepcopy__({})
+
     def __repr__(self):
         return self._weight_bias().__repr__()