Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62861
This PR adds a lower_to_native_backend function to lower a quantized reference model
to a model that uses fbgemm/qnnpack ops. We'll gradually add support and remove
the fbgemm/qnnpack specific handling in quantization_patterns.py
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
Imported from OSS
Reviewed By: vkuzo
Differential Revision:
D30165828
fbshipit-source-id:
de1149cd7e7c1840c17c251cd4d35004afd015b7
if n.target == "lstm":
self.assertEqual(type(n.args[1]), tuple)
+ def test_lowering(self):
+ class M(torch.nn.Module):
+ def forward(self, x):
+ return torch.nn.functional.relu(x)
+
+ m = M().eval()
+ m = prepare_fx(m, {"": default_qconfig})
+ m_copy = copy.deepcopy(m)
+ m = convert_fx(m)
+ m_ref = convert_fx(m_copy, is_reference=True)
+ node_occurrence = {
+ ns.call_function(torch.quantize_per_tensor): 1,
+ ns.call_method("dequantize"): 1
+ }
+ node_occurrence_ref = {
+ ns.call_function(torch.quantize_per_tensor): 2,
+ ns.call_method("dequantize"): 2
+ }
+
+ self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
+ self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
+
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
--- /dev/null
+from torch.fx import subgraph_rewriter
+from .graph_module import QuantizedGraphModule
+from .quantized_fusion_patterns_and_replacements import get_fbgemm_patterns_and_replacements
+
+def _lower_to_native_backend(model: QuantizedGraphModule) -> QuantizedGraphModule:
+ """ Lower a quantized reference model (with reference quantized operator patterns)
+ to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same
+ operator signature so they can be lowered with the same function
+ """
+ module_dict = dict(model.named_modules())
+ for pattern, replacement in get_fbgemm_patterns_and_replacements():
+ subgraph_rewriter.replace_pattern(model, pattern, replacement)
+ model.graph.lint()
+ return model
activation_dtype,
)
+from .lower_to_fbgemm import lower_to_fbgemm
+
# weight prepacking ops
WEIGHT_PREPACK_OPS = {
torch._ops.ops.quantized.linear_prepack,
model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes)
if not is_reference:
model = fold_weight(model, node_name_to_scope)
+ model = lower_to_fbgemm(model)
return model
--- /dev/null
+from ._lower_to_native_backend import _lower_to_native_backend
+from .graph_module import QuantizedGraphModule
+
+def lower_to_fbgemm(model: QuantizedGraphModule) -> QuantizedGraphModule:
+ """ Lower a quantized reference model (with reference quantized operator patterns)
+ to fbgemm
+ """
+ return _lower_to_native_backend(model)
--- /dev/null
+from ._lower_to_native_backend import _lower_to_native_backend
+from .graph_module import QuantizedGraphModule
+
+def lower_to_qnnpack(model: QuantizedGraphModule) -> QuantizedGraphModule:
+ """ Lower a quantized reference model (with reference quantized operator patterns)
+ to qnnpack
+ """
+ return _lower_to_native_backend(model)
load_arg: Callable,
is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
- if is_reference:
+ # always produce reference pattern for relu
+ is_relu = node.op == "call_function" and node.target == torch.nn.functional.relu
+ if is_reference or is_relu:
# when activation dtype is torch.float, the node does not require
# observation
# e.g. dynamic quantization or weight_only quantization
--- /dev/null
+import torch
+
+def relu_inplace_pattern(x, scale, zero_point):
+ x = x.dequantize()
+ x = torch.nn.functional.relu(x, inplace=True)
+ x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
+ return x
+
+def relu_non_inplace_pattern(x, scale, zero_point):
+ x = x.dequantize()
+ x = torch.nn.functional.relu(x, inplace=False)
+ x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
+ return x
+
+def relu_replacement(x, scale, zero_point):
+ x = torch.nn.functional.relu(x)
+ return x
+
+
+def _get_all_patterns_and_replacements():
+ return [
+ (relu_inplace_pattern, relu_replacement),
+ (relu_non_inplace_pattern, relu_replacement)
+ ]
+
+
+def get_fbgemm_patterns_and_replacements():
+ return _get_all_patterns_and_replacements()
+
+def get_qnnpack_patterns_and_replacements():
+ return _get_all_patterns_and_replacements()