From: Jerry Zhang Date: Wed, 1 Sep 2021 22:48:54 +0000 (-0700) Subject: [quant][graphmode][fx] Add fbgemm backend_config_dict (#64288) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~503 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ed89937d2cbda8f4c5b67439b8b7b138cff42552;p=platform%2Fupstream%2Fpytorch.git [quant][graphmode][fx] Add fbgemm backend_config_dict (#64288) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64288 This is just to setup the file structure and unblock experimentation. The format for backend_config_dict will change in the future Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Imported from OSS Reviewed By: zou3519 Differential Revision: D30699457 fbshipit-source-id: 28211a4def05d34757850c045a36e311f54760fe --- diff --git a/torch/quantization/fx/backend_config_dict/__init__.py b/torch/quantization/fx/backend_config_dict/__init__.py new file mode 100644 index 0000000..edb2b95 --- /dev/null +++ b/torch/quantization/fx/backend_config_dict/__init__.py @@ -0,0 +1,4 @@ +from .fbgemm import get_fbgemm_backend_config_dict + +def validate_backend_config_dict(backend_config_dict): + return "quant_patterns" in backend_config_dict diff --git a/torch/quantization/fx/backend_config_dict/fbgemm.py b/torch/quantization/fx/backend_config_dict/fbgemm.py new file mode 100644 index 0000000..4f40b10 --- /dev/null +++ b/torch/quantization/fx/backend_config_dict/fbgemm.py @@ -0,0 +1,11 @@ +from ..pattern_utils import get_default_quant_patterns + +def get_fbgemm_backend_config_dict(): + """ Get the backend config dictionary for fbgemm backend + NOTE: Current api will change in the future, it's just to unblock experimentation for + new backends, please don't use it right now. + """ + # TODO: add output_activation_post_process_map + return { + "quant_patterns": get_default_quant_patterns() + } diff --git a/torch/quantization/fx/prepare.py b/torch/quantization/fx/prepare.py index fb526d0..0b65e33 100644 --- a/torch/quantization/fx/prepare.py +++ b/torch/quantization/fx/prepare.py @@ -42,7 +42,6 @@ from .graph_module import ( from .pattern_utils import ( MatchResult, - get_default_quant_patterns, get_default_output_activation_post_process_map, ) @@ -84,6 +83,9 @@ from ..utils import ( weight_dtype, ) +from .backend_config_dict import get_fbgemm_backend_config_dict +from .backend_config_dict import validate_backend_config_dict + from typing import Any, Callable, Dict, List, Optional, Tuple, Union def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool: @@ -1140,6 +1142,10 @@ def prepare( prepare_custom_config_dict = {} if equalization_qconfig_dict is None: equalization_qconfig_dict = {} + if backend_config_dict is None: + backend_config_dict = get_fbgemm_backend_config_dict() + + validate_backend_config_dict(backend_config_dict) additional_quant_patterns = \ prepare_custom_config_dict.get("additional_quant_pattern", {}) @@ -1153,8 +1159,9 @@ def prepare( # ((, ): # ), # } + quant_patterns = backend_config_dict["quant_patterns"] patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict( - get_default_quant_patterns(), additional_quant_patterns) + quant_patterns, additional_quant_patterns) convert_dict_to_ordered_dict(qconfig_dict) convert_dict_to_ordered_dict(equalization_qconfig_dict)