[quant][fix] Fix quantization for sub_scalar (#64603)
authorJerry Zhang <jerryzh@fb.com>
Fri, 10 Sep 2021 00:17:01 +0000 (17:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 00:18:31 +0000 (17:18 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64603

We'll insert observer only when both the operator and dtype is supported

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_sub_scalar

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D30797025

fbshipit-source-id: a77c21e2749405534fc245374cf33a0657a3d2c8

test/quantization/fx/test_quantize_fx.py
torch/quantization/fx/quantization_patterns.py

index 9682da1..f2f665d 100644 (file)
@@ -3023,6 +3023,22 @@ class TestQuantizeFx(QuantizationTestCase):
             result_ref = m_ref(data)
             self.assertTrue(torch.equal(result, result_ref))
 
+    def test_sub_scalar(self):
+        class M(torch.nn.Module):
+            def forward(self, x):
+                x = x + 1
+                x = x - 1
+                return x
+
+        m = M().eval()
+        m = prepare_fx(m, {"": default_qconfig})
+        m = convert_fx(m)
+        occurrence = {
+            ns.call_function(torch.quantize_per_tensor): 1,
+            ns.call_method("dequantize"): 1
+        }
+        self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence)
+
 @skipIfNoFBGEMM
 class TestQuantizeFxOps(QuantizationTestCase):
     """Unit tests for individual ops
@@ -4146,6 +4162,9 @@ class TestQuantizeFxOps(QuantizationTestCase):
             module, functional, qconfig, is_reference, node_list)
 
     def test_bmm_int_reference(self):
+        """ int8 is not supported for bmm so we won't produce reference
+            pattern for it
+        """
         class M(torch.nn.Module):
             def __init__(self):
                 super().__init__()
@@ -4163,10 +4182,7 @@ class TestQuantizeFxOps(QuantizationTestCase):
             ns.call_function(torch.quantize_per_tensor),
             ns.call_function(torch.quantize_per_tensor),
             ns.call_method('dequantize'),
-            ns.call_method('dequantize'),
             ns.call_function(torch.bmm),
-            ns.call_function(torch.quantize_per_tensor),
-            ns.call_method('dequantize'),
         ]
 
         m = M().eval()
index 418cae1..d90be90 100644 (file)
@@ -356,6 +356,9 @@ class BinaryOpQuantizeHandler(QuantizeHandler):
         the pattern matched to this QuantizeHandler instance during the
         prepare step.
         """
+        dtypes = get_qconfig_dtypes(qconfig)
+        if not (self.binary_op in binary_op_supported_dtypes and dtypes in binary_op_supported_dtypes[self.binary_op]):
+            return False
         if self.num_tensor_args == 1:
             return True
         elif self.all_node_args_are_tensors and self.input_output_observed():
@@ -396,7 +399,9 @@ class BinaryOpQuantizeHandler(QuantizeHandler):
 
         if is_reference:
             act_dtype = activation_dtype(qconfig)
-            if act_dtype == torch.float:
+            dtypes = get_qconfig_dtypes(qconfig)
+            if act_dtype == torch.float or \
+               not (self.binary_op in binary_op_supported_dtypes and dtypes in binary_op_supported_dtypes[self.binary_op]):
                 return quantized_graph.node_copy(node, load_arg(quantized=torch.float))
             else:
                 if self.num_tensor_args == 2: