[quant][graphmode][fx] Make maxpool and flatten produce the reference pattern (#63501)
authorJerry Zhang <jerryzh@fb.com>
Wed, 25 Aug 2021 04:28:40 +0000 (21:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 04:31:01 +0000 (21:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63501

Currently some of the ops are considered as working with both float and quantized input,
so we may have things like "quant - some_op - dequant" this might not work well with the backend,
we may consider change everything to produce "quant - dequant - some_op - quant - dequant" instead
in the future, this PR fixes it for maxpool and flatten only to unblock resnet benchmarking on TensorRT

Test Plan:
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: mruberry

Differential Revision: D30402788

fbshipit-source-id: 892c5ff6552775070e2c1453f65846590fb12735

torch/quantization/fx/quantization_patterns.py

index 1a7d714..09ca190 100644 (file)
@@ -1447,6 +1447,9 @@ class FixedQParamsOpQuantizeHandler(QuantizeHandler):
 @register_quant_pattern(torch.nn.AvgPool3d)
 @register_quant_pattern(torch.nn.Dropout)
 @register_quant_pattern(torch.nn.Hardtanh)
+@register_quant_pattern(torch.nn.MaxPool1d)
+@register_quant_pattern(torch.nn.MaxPool2d)
+@register_quant_pattern(torch.nn.MaxPool3d)
 @register_quant_pattern(torch.nn.ReLU)
 @register_quant_pattern(torch.nn.ReLU6)
 @register_quant_pattern(torch.adaptive_avg_pool1d)
@@ -1456,12 +1459,16 @@ class FixedQParamsOpQuantizeHandler(QuantizeHandler):
 @register_quant_pattern(torch.nn.functional.hardtanh)
 @register_quant_pattern(torch.nn.functional.hardtanh_)
 @register_quant_pattern(torch.nn.functional.interpolate)
+@register_quant_pattern(torch.nn.functional.max_pool1d)
+@register_quant_pattern(torch.nn.functional.max_pool2d)
+@register_quant_pattern(torch.nn.functional.max_pool3d)
 @register_quant_pattern(torch.nn.functional.relu)
 @register_quant_pattern(torch.nn.functional.relu6)
 @register_quant_pattern(torch.avg_pool1d)
 @register_quant_pattern(torch._C._nn.avg_pool2d)
 @register_quant_pattern(torch._C._nn.avg_pool3d)
 @register_quant_pattern(torch.clamp)
+@register_quant_pattern(torch.flatten)
 @register_quant_pattern(torch.max)
 @register_quant_pattern(torch.mean)
 @register_quant_pattern(torch.min)
@@ -1556,15 +1563,8 @@ class CustomModuleQuantizeHandler(QuantizeHandler):
         # module attribute like module._QUANTIZED_INPUT_INDEXES
         return quantized_graph.node_copy(node, load_arg(quantized=None))
 
-@register_quant_pattern(torch.nn.MaxPool1d)
-@register_quant_pattern(torch.nn.MaxPool2d)
-@register_quant_pattern(torch.nn.MaxPool3d)
 @register_quant_pattern(torch.nn.Identity)
-@register_quant_pattern(torch.nn.functional.max_pool1d)
-@register_quant_pattern(torch.nn.functional.max_pool2d)
-@register_quant_pattern(torch.nn.functional.max_pool3d)
 @register_quant_pattern(torch.chunk)
-@register_quant_pattern(torch.flatten)
 @register_quant_pattern(torch.transpose)
 @register_quant_pattern(torch.repeat_interleave)
 @register_quant_pattern(torch.sort)