[quant] Add op support for linear_relu_dynamic_fp16 (#63824)
authorSupriya Rao <supriyar@fb.com>
Fri, 27 Aug 2021 04:05:56 +0000 (21:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 04:12:04 +0000 (21:12 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63824

Add a fused operator implementation that will work with the quantization fusion APIs.
Once FBGEMM FP16 kernel supports relu fusion natively we can remove the addition from the PT operator.

Test Plan:
python test/test_quantization.py

Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D30503514

fbshipit-source-id: 6bf3bd53f47ffaa3f1d178eaad8cc980a7f5258a

aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
aten/src/ATen/native/quantized/library.cpp
test/quantization/core/test_quantized_op.py

index 23c6158..3331a03 100644 (file)
@@ -451,8 +451,14 @@ class QLinearDynamicFp16 final {
     TORCH_CHECK(
         fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
 
-    TORCH_INTERNAL_ASSERT(!ReluFused);
-    return packed_weight->apply_dynamic(std::move(input));
+    auto output = packed_weight->apply_dynamic(std::move(input));
+
+    // Call the relu operator here until fp16 linear dynamic in FBGEMM
+    // supports it natively.
+    if (ReluFused) {
+      output.relu_();
+    }
+    return output;
   }
 #else // USE_FBGEMM
   static at::Tensor run(
@@ -471,6 +477,7 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic"), TORCH_FN(QLinearDynamicInt8<false>::run));
   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic"), TORCH_FN(QLinearDynamicInt8<true>::run));
   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16<false>::run));
+  m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16<true>::run));
 }
 
 TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
index 8ead74f..3dcf75b 100644 (file)
@@ -142,6 +142,7 @@ TORCH_LIBRARY(quantized, m) {
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"));
+  m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
index 86fe350..49b7c96 100644 (file)
@@ -2782,6 +2782,38 @@ class TestDynamicQuantizedLinear(TestCase):
         self.assertEqual(Y_fp32, Y_fp32_ref,
                          msg="torch.ops.quantized.fbgemm_linear_dynamic results are off")
 
+    @skipIfNoFBGEMM
+    def test_qlinear_dynamic_fp16(self):
+
+        options = itertools.product(
+            (2, 4),         # batch_size
+            (4, 5, 12),     # input_channels
+            (4, 7, 8),      # output_channels
+            (True, False),  # use_bias
+            (True, False),  # use_relu
+        )
+        for batch_size, input_channels, output_channels, use_bias, use_relu in options:
+            qlinear_prepack = torch.ops.quantized.linear_prepack_fp16
+            if use_relu:
+                qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic_fp16
+            else:
+                qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16
+
+            x = torch.randn(batch_size, input_channels)
+            w = torch.randn(output_channels, input_channels)
+            bias = torch.randn(output_channels) if use_bias else None
+
+            w_packed = qlinear_prepack(w, bias)
+            out = qlinear_dynamic(x, w_packed)
+
+            # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors
+            # output is FP32
+            w_fp16 = w.to(torch.float16).to(torch.float32)
+            ref = F.linear(x, w_fp16, bias)
+            if use_relu:
+                ref.relu_()
+
+            self.assertEqual(out, ref)
 
 class TestDynamicQuantizedRNNOp(TestCase):
     """Tests the correctness of the dynamic quantized lstm/gru."""