[quant][graphmode][fx] Add fbgemm backend_config_dict (#64288)
authorJerry Zhang <jerryzh@fb.com>
Wed, 1 Sep 2021 22:48:54 +0000 (15:48 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 23:32:43 +0000 (16:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64288

This is just to setup the file structure and unblock experimentation.
The format for backend_config_dict will change in the future

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: zou3519

Differential Revision: D30699457

fbshipit-source-id: 28211a4def05d34757850c045a36e311f54760fe

torch/quantization/fx/backend_config_dict/__init__.py [new file with mode: 0644]
torch/quantization/fx/backend_config_dict/fbgemm.py [new file with mode: 0644]
torch/quantization/fx/prepare.py

diff --git a/torch/quantization/fx/backend_config_dict/__init__.py b/torch/quantization/fx/backend_config_dict/__init__.py
new file mode 100644 (file)
index 0000000..edb2b95
--- /dev/null
@@ -0,0 +1,4 @@
+from .fbgemm import get_fbgemm_backend_config_dict
+
+def validate_backend_config_dict(backend_config_dict):
+    return "quant_patterns" in backend_config_dict
diff --git a/torch/quantization/fx/backend_config_dict/fbgemm.py b/torch/quantization/fx/backend_config_dict/fbgemm.py
new file mode 100644 (file)
index 0000000..4f40b10
--- /dev/null
@@ -0,0 +1,11 @@
+from ..pattern_utils import get_default_quant_patterns
+
+def get_fbgemm_backend_config_dict():
+    """ Get the backend config dictionary for fbgemm backend
+    NOTE: Current api will change in the future, it's just to unblock experimentation for
+    new backends, please don't use it right now.
+    """
+    # TODO: add output_activation_post_process_map
+    return {
+        "quant_patterns": get_default_quant_patterns()
+    }
index fb526d0..0b65e33 100644 (file)
@@ -42,7 +42,6 @@ from .graph_module import (
 
 from .pattern_utils import (
     MatchResult,
-    get_default_quant_patterns,
     get_default_output_activation_post_process_map,
 )
 
@@ -84,6 +83,9 @@ from ..utils import (
     weight_dtype,
 )
 
+from .backend_config_dict import get_fbgemm_backend_config_dict
+from .backend_config_dict import validate_backend_config_dict
+
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
@@ -1140,6 +1142,10 @@ def prepare(
         prepare_custom_config_dict = {}
     if equalization_qconfig_dict is None:
         equalization_qconfig_dict = {}
+    if backend_config_dict is None:
+        backend_config_dict = get_fbgemm_backend_config_dict()
+
+    validate_backend_config_dict(backend_config_dict)
 
     additional_quant_patterns = \
         prepare_custom_config_dict.get("additional_quant_pattern", {})
@@ -1153,8 +1159,9 @@ def prepare(
     #   ((<function relu at 0x7f766a7360d0>, <built-in function add>):
     #     <class 'torch.quantization.fx.quantize.Add'>),
     # }
+    quant_patterns = backend_config_dict["quant_patterns"]
     patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict(
-        get_default_quant_patterns(), additional_quant_patterns)
+        quant_patterns, additional_quant_patterns)
 
     convert_dict_to_ordered_dict(qconfig_dict)
     convert_dict_to_ordered_dict(equalization_qconfig_dict)