From cec44aa574e06e8aa1096b62a7c6d7c4dda8a3f5 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Thu, 26 Aug 2021 21:05:56 -0700 Subject: [PATCH] [quant] Add op support for linear_relu_dynamic_fp16 (#63824) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63824 Add a fused operator implementation that will work with the quantization fusion APIs. Once FBGEMM FP16 kernel supports relu fusion natively we can remove the addition from the PT operator. Test Plan: python test/test_quantization.py Imported from OSS Reviewed By: heitorschueroff Differential Revision: D30503514 fbshipit-source-id: 6bf3bd53f47ffaa3f1d178eaad8cc980a7f5258a --- .../ATen/native/quantized/cpu/qlinear_dynamic.cpp | 11 ++++++-- aten/src/ATen/native/quantized/library.cpp | 1 + test/quantization/core/test_quantized_op.py | 32 ++++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 23c6158..3331a03 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -451,8 +451,14 @@ class QLinearDynamicFp16 final { TORCH_CHECK( fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); - TORCH_INTERNAL_ASSERT(!ReluFused); - return packed_weight->apply_dynamic(std::move(input)); + auto output = packed_weight->apply_dynamic(std::move(input)); + + // Call the relu operator here until fp16 linear dynamic in FBGEMM + // supports it natively. + if (ReluFused) { + output.relu_(); + } + return output; } #else // USE_FBGEMM static at::Tensor run( @@ -471,6 +477,7 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16::run)); } TORCH_LIBRARY_IMPL(_quantized, CPU, m) { diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 8ead74f..3dcf75b 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -142,6 +142,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 86fe350..49b7c96 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -2782,6 +2782,38 @@ class TestDynamicQuantizedLinear(TestCase): self.assertEqual(Y_fp32, Y_fp32_ref, msg="torch.ops.quantized.fbgemm_linear_dynamic results are off") + @skipIfNoFBGEMM + def test_qlinear_dynamic_fp16(self): + + options = itertools.product( + (2, 4), # batch_size + (4, 5, 12), # input_channels + (4, 7, 8), # output_channels + (True, False), # use_bias + (True, False), # use_relu + ) + for batch_size, input_channels, output_channels, use_bias, use_relu in options: + qlinear_prepack = torch.ops.quantized.linear_prepack_fp16 + if use_relu: + qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic_fp16 + else: + qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16 + + x = torch.randn(batch_size, input_channels) + w = torch.randn(output_channels, input_channels) + bias = torch.randn(output_channels) if use_bias else None + + w_packed = qlinear_prepack(w, bias) + out = qlinear_dynamic(x, w_packed) + + # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors + # output is FP32 + w_fp16 = w.to(torch.float16).to(torch.float32) + ref = F.linear(x, w_fp16, bias) + if use_relu: + ref.relu_() + + self.assertEqual(out, ref) class TestDynamicQuantizedRNNOp(TestCase): """Tests the correctness of the dynamic quantized lstm/gru.""" -- 2.7.4