from torch.quantization import default_qconfig
from torch.quantization import QuantWrapper
-import torch.quantization._numeric_suite as ns
+import torch.ao.ns._numeric_suite as ns
from torch.quantization._correct_bias import (
_supported_modules,
quantize,
quantize_dynamic,
)
-from torch.quantization._numeric_suite import (
+from torch.ao.ns._numeric_suite import (
OutputLogger,
Shadow,
ShadowLogger,
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
import torch.quantization.fx.quantization_patterns as qp
-from torch.quantization.ns.pattern_utils import (
+from torch.ao.ns.fx.pattern_utils import (
get_type_a_related_to_b,
)
-from torch.quantization.ns.graph_matcher import (
+from torch.ao.ns.fx.graph_matcher import (
get_matching_subgraph_pairs,
GraphMatchingException,
)
-from torch.quantization.ns.utils import (
+from torch.ao.ns.fx.utils import (
compute_sqnr,
compute_normalized_l2_error,
compute_cosine_similarity,
)
-from torch.quantization.ns.mappings import (
+from torch.ao.ns.fx.mappings import (
get_node_type_to_io_type_map,
get_unmatchable_types_map,
get_base_name_to_sets_of_related_ops,
get_base_name_for_op,
add_op_to_sets_of_related_ops,
)
-from torch.quantization.ns.weight_utils import (
+from torch.ao.ns.fx.weight_utils import (
get_op_to_type_to_weight_extraction_fn,
)
-from torch.quantization._numeric_suite_fx import (
+from torch.ao.ns._numeric_suite_fx import (
extract_weights,
_extract_weights_impl,
add_loggers,
op_to_type_to_weight_extraction_fn = \
get_op_to_type_to_weight_extraction_fn()
op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
- torch.quantization.ns.weight_utils.get_linear_fun_weight
+ torch.ao.ns.fx.weight_utils.get_linear_fun_weight
# test compare weights
results = extract_weights(
--- /dev/null
+import torch
+import torch.nn as nn
+import torch.nn.quantized as nnq
+import torch.nn.quantized.dynamic as nnqd
+from torch.quantization import prepare
+from typing import Dict, List, Optional, Any, Union, Callable, Set
+
+from torch.quantization.quantization_mappings import (
+ get_default_compare_output_module_list,
+)
+
+NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
+ nnqd.Linear,
+ nnq.Linear,
+ nnqd.LSTM,
+ nn.LSTM,
+}
+
+
+def _find_match(
+ str_list: Union[Dict[str, Any], List[str]], key_str: str,
+ postfix: str,
+) -> Optional[str]:
+ split_str = key_str.split(".")
+ if split_str[-1] == postfix:
+ match_string = "".join(key_str.split(".")[0:-1])
+ for s2 in str_list:
+ pattern1 = "".join(s2.split(".")[0:-1])
+ pattern2 = "".join(s2.split(".")[0:-2])
+ if match_string == pattern1:
+ return s2
+ if match_string == pattern2:
+ return s2
+
+ # For matching "fc.weight" and "fc._packed_params._packed_params"
+ if postfix == "_packed_params":
+ match_string = "".join(key_str.split(".")[0:-2])
+ if len(match_string) == 0:
+ return None
+ for s2 in str_list:
+ pattern1 = "".join(s2.split(".")[0:-1])
+ pattern2 = "".join(s2.split(".")[0:-2])
+ if match_string == pattern1:
+ return s2
+ if match_string == pattern2:
+ return s2
+ return None
+ else:
+ return None
+
+
+def compare_weights(
+ float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
+) -> Dict[str, Dict[str, torch.Tensor]]:
+ r"""Compare the weights of the float module with its corresponding quantized
+ module. Return a dict with key corresponding to module names and each entry being
+ a dictionary with two keys 'float' and 'quantized', containing the float and
+ quantized weights. This dict can be used to compare and compute the quantization
+ error of the weights of float and quantized models.
+
+ Example usage:
+ wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
+ for key in wt_compare_dict:
+ print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
+
+ Args:
+ float_dict: state dict of the float model
+ quantized_dict: state dict of the quantized model
+
+ Return:
+ weight_dict: dict with key corresponding to module names and each entry being
+ a dictionary with two keys 'float' and 'quantized', containing the float and
+ quantized weights
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
+ weight_dict: Dict[str, Dict] = {}
+ for key in quantized_dict:
+ match_key = _find_match(float_dict, key, "weight")
+ if match_key is not None:
+ weight_dict[key] = {}
+ weight_dict[key]["float"] = float_dict[match_key]
+ weight_dict[key]["quantized"] = quantized_dict[key]
+ continue
+
+ # For matching "fc.weight" and "fc._packed_params._packed_params"
+ match_key = _find_match(float_dict, key, "_packed_params")
+ if match_key is not None:
+ weight_dict[key] = {}
+ weight_dict[key]["float"] = float_dict[match_key]
+ weight_dict[key]["quantized"] = quantized_dict[key][0]
+
+ # For LSTM
+ split_str = key.split(".")
+ if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
+ layer = split_str[-2]
+ module_name = ".".join(split_str[:-3])
+ float_weight_ih_key = module_name + ".weight_ih_l" + layer
+ float_weight_hh_key = module_name + ".weight_hh_l" + layer
+ if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
+ weight_dict[key] = {}
+ weight_dict[key]["float"] = float_dict[float_weight_ih_key]
+ weight_dict[key]["quantized"] = (
+ quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
+ )
+ weight_dict[key]["float"] = float_dict[float_weight_hh_key]
+ weight_dict[key]["quantized"] = (
+ quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
+ )
+
+ return weight_dict
+
+
+def _get_logger_dict_helper(
+ mod: nn.Module, target_dict: Dict[str, Any],
+ prefix: str = "",
+) -> None:
+ r"""This is the helper function for get_logger_dict
+
+ Args:
+ mod: module we want to save all logger stats
+ prefix: prefix for the current module
+ target_dict: the dictionary used to save all logger stats
+ """
+
+ def get_prefix(prefix):
+ return prefix if prefix == "" else prefix + "."
+
+ for name, child in mod.named_children():
+ if isinstance(child, Logger):
+ target_dict[get_prefix(prefix) + "stats"] = child.stats
+ break
+
+ for name, child in mod.named_children():
+ module_prefix = get_prefix(prefix) + name if prefix else name
+ _get_logger_dict_helper(child, target_dict, module_prefix)
+
+
+def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
+ r"""Traverse the modules and save all logger stats into target dict.
+ This is mainly used for quantization accuracy debug.
+
+ Type of loggers supported:
+ ShadowLogger: used to log the outputs of the quantized module and its
+ matching float shadow module,
+ OutputLogger: used to log the outputs of the modules
+
+ Args:
+ mod: module we want to save all logger stats
+ prefix: prefix for the current module
+
+ Return:
+ target_dict: the dictionary used to save all logger stats
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
+
+ target_dict: Dict[str, Dict] = {}
+ _get_logger_dict_helper(mod, target_dict, prefix)
+ return target_dict
+
+
+class Logger(nn.Module):
+ r"""Base class for stats logging
+ """
+
+ def __init__(self):
+ super(Logger, self).__init__()
+ self.stats = {}
+ # We only insert observer if the op is quantized with static quantization,
+ # which is identified by activation_observer.dtype == quint8. This is needed
+ # when attaching Logger as observer for FX mode
+ self.dtype = torch.quint8
+
+ def forward(self, x):
+ pass
+
+
+class ShadowLogger(Logger):
+ r"""Class used in Shadow module to record the outputs of the original and
+ shadow modules.
+ """
+
+ def __init__(self):
+ super(ShadowLogger, self).__init__()
+ self.stats["float"] = []
+ self.stats["quantized"] = []
+
+ def forward(self, x, y):
+ if len(x) > 1:
+ x = x[0]
+ if len(y) > 1:
+ y = y[0]
+ self.stats["quantized"].append(x.detach())
+ self.stats["float"].append(y.detach())
+
+
+class OutputLogger(Logger):
+ r"""Class used to log the outputs of the module
+ """
+
+ def __init__(self):
+ super(OutputLogger, self).__init__()
+ self.stats["tensor_val"] = []
+
+
+ def forward(self, x):
+ self.stats["tensor_val"].append(x)
+ return x
+
+
+def _convert_tuple_to_list(t: Any) -> Any:
+ return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
+
+
+def _dequantize_tensor_list(t: Any) -> Any:
+ return (
+ list(_dequantize_tensor_list(x) for x in t)
+ if type(t) is list
+ else t.dequantize()
+ if t.is_quantized
+ else t
+ )
+
+
+class Shadow(nn.Module):
+ r"""Shadow module attaches the float module to its matching quantized module
+ as the shadow. Then it uses Logger module to process the outputs of both
+ modules.
+
+ Args:
+ q_module: module quantized from float_module that we want to shadow
+ float_module: float module used to shadow q_module
+ logger_cls: type of logger used to process the outputs of q_module and
+ float_module. ShadowLogger or custom loggers can be used.
+ """
+
+ def __init__(self, q_module, float_module, logger_cls):
+ super(Shadow, self).__init__()
+ self.orig_module = q_module
+ self.shadow_module = float_module
+ self.dequant = nnq.DeQuantize()
+ self.logger = logger_cls()
+
+ def forward(self, *x) -> torch.Tensor:
+ xl = _convert_tuple_to_list(x)
+ output = self.orig_module(*xl)
+ xl_float = _dequantize_tensor_list(xl)
+ shadow_output = self.shadow_module(*xl_float)
+ self.logger(output, shadow_output)
+ return output
+
+ def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ output = self.orig_module.add(x, y)
+ x = x.dequantize()
+ y = y.dequantize()
+ shadow_output = self.shadow_module.add(x, y)
+ self.logger(output, shadow_output)
+ return output
+
+ def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
+ output = self.orig_module.add_scalar(x, y)
+ x = x.dequantize()
+ shadow_output = self.shadow_module.add_scalar(x, y)
+ self.logger(output, shadow_output)
+ return output
+
+ def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ output = self.orig_module.mul(x, y)
+ x = x.dequantize()
+ y = y.dequantize()
+ shadow_output = self.shadow_module.mul(x, y)
+ self.logger(output, shadow_output)
+ return output
+
+ def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
+ output = self.orig_module.mul_scalar(x, y)
+ x = x.dequantize()
+ shadow_output = self.shadow_module.mul_scalar(x, y)
+ self.logger(output, shadow_output)
+ return output
+
+ def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
+ output = self.orig_module.cat(x, dim)
+ x = [y.dequantize() for y in x]
+ shadow_output = self.shadow_module.cat(x, dim)
+ self.logger(output, shadow_output)
+ return output
+
+ def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ output = self.orig_module.add_relu(x, y)
+ x = x.dequantize()
+ y = y.dequantize()
+ shadow_output = self.shadow_module.add_relu(x, y)
+ self.logger(output, shadow_output)
+ return output
+
+
+def prepare_model_with_stubs(
+ float_module: nn.Module, q_module: nn.Module,
+ module_swap_list: Set[type], logger_cls: Callable,
+) -> None:
+ r"""Prepare the model by attaching the float module to its matching quantized
+ module as the shadow if the float module type is in module_swap_list.
+
+ Example usage:
+ prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
+ q_model(data)
+ ob_dict = get_logger_dict(q_model)
+
+ Args:
+ float_module: float module used to generate the q_module
+ q_module: module quantized from float_module
+ module_swap_list: list of float module types to attach the shadow
+ logger_cls: type of logger to be used in shadow module to process the outputs of
+ quantized module and its float shadow module
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
+
+ float_module_children = {}
+ for name, mod in float_module.named_children():
+ float_module_children[name] = mod
+
+ reassign = {}
+ for name, mod in q_module.named_children():
+
+ if name not in float_module_children:
+ continue
+
+ float_mod = float_module_children[name]
+
+ if type(float_mod) not in module_swap_list:
+ prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
+
+ # Insert shadow module only if the module is not of the same type as
+ # the floating point module
+ if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
+ reassign[name] = Shadow(mod, float_mod, logger_cls)
+
+ for key, value in reassign.items():
+ q_module._modules[key] = value
+
+def _is_identical_module_type(mod1, mod2):
+ # Compare if two modules have the same dtype
+ mod1_module_types = [type(mod) for mod in mod1.modules()]
+ mod2_module_types = [type(mod) for mod in mod2.modules()]
+ return mod1_module_types == mod2_module_types
+
+
+
+def compare_model_stub(
+ float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
+ *data, logger_cls=ShadowLogger
+) -> Dict[str, Dict]:
+ r"""Compare quantized module in a model with its floating point counterpart,
+ feeding both of them the same input. Return a dict with key corresponding to
+ module names and each entry being a dictionary with two keys 'float' and
+ 'quantized', containing the output tensors of quantized and its matching
+ float shadow module. This dict can be used to compare and compute the module
+ level quantization error.
+
+ This function first call prepare_model_with_stubs() to swap the quantized
+ module that we want to compare with the Shadow module, which takes quantized
+ module, corresponding float module and logger as input, and creates a forward
+ path inside to make the float module to shadow quantized module sharing the
+ same input. The logger can be customizable, default logger is ShadowLogger
+ and it will save the outputs of the quantized module and float module that
+ can be used to compute the module level quantization error.
+
+ Example usage:
+ module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
+ ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
+ for key in ob_dict:
+ print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
+
+ Args:
+ float_model: float model used to generate the q_model
+ q_model: model quantized from float_model
+ module_swap_list: list of float module types at which shadow modules will
+ be attached.
+ data: input data used to run the prepared q_model
+ logger_cls: type of logger to be used in shadow module to process the outputs of
+ quantized module and its float shadow module
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
+ prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
+ q_model(*data)
+ ob_dict = get_logger_dict(q_model)
+ return ob_dict
+
+
+def get_matching_activations(
+ float_module: nn.Module, q_module: nn.Module,
+) -> Dict[str, Dict[str, torch.Tensor]]:
+ r"""Find the matching activation between float and quantized modules.
+
+ Args:
+ float_module: float module used to generate the q_module
+ q_module: module quantized from float_module
+
+ Return:
+ act_dict: dict with key corresponding to quantized module names and each
+ entry being a dictionary with two keys 'float' and 'quantized', containing
+ the matching float and quantized activations
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
+ float_dict = get_logger_dict(float_module)
+ quantized_dict = get_logger_dict(q_module)
+ act_dict: Dict[str, Dict] = {}
+ for key in quantized_dict:
+ match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
+ if match_key is not None:
+ act_dict[key] = {}
+ act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
+ act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
+ return act_dict
+
+
+def prepare_model_outputs(
+ float_module: nn.Module,
+ q_module: nn.Module,
+ logger_cls=OutputLogger,
+ allow_list=None
+) -> None:
+ r"""Prepare the model by attaching the logger to both float module
+ and quantized module if they are in the allow_list.
+
+ Args:
+ float_module: float module used to generate the q_module
+ q_module: module quantized from float_module
+ logger_cls: type of logger to be attached to float_module and q_module
+ allow_list: list of module types to attach logger
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
+ if allow_list is None:
+ allow_list = get_default_compare_output_module_list()
+
+ qconfig_debug = torch.quantization.QConfig(activation=logger_cls, weight=None)
+ float_module.qconfig = qconfig_debug # type: ignore[assignment]
+ prepare(float_module, inplace=True, allow_list=allow_list)
+ q_module.qconfig = qconfig_debug # type: ignore[assignment]
+ prepare(
+ q_module,
+ inplace=True,
+ allow_list=allow_list,
+ observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
+ )
+
+
+def compare_model_outputs(
+ float_model: nn.Module,
+ q_model: nn.Module,
+ *data,
+ logger_cls=OutputLogger,
+ allow_list=None
+) -> Dict[str, Dict[str, torch.Tensor]]:
+ r"""Compare output activations between float and quantized models at
+ corresponding locations for the same input. Return a dict with key corresponding
+ to quantized module names and each entry being a dictionary with two keys
+ 'float' and 'quantized', containing the activations of quantized model and
+ float model at matching locations. This dict can be used to compare and
+ compute the propagation quantization error.
+
+ Example usage:
+ act_compare_dict = compare_model_outputs(float_model, qmodel, data)
+ for key in act_compare_dict:
+ print(key, compute_error(act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize()))
+
+ Args:
+ float_model: float model used to generate the q_model
+ q_model: model quantized from float_model
+ data: input data used to run the prepared float_model and q_model
+ logger_cls: type of logger to be attached to float_module and q_module
+ allow_list: list of module types to attach logger
+
+ Return:
+ act_compare_dict: dict with key corresponding to quantized module names
+ and each entry being a dictionary with two keys 'float' and 'quantized',
+ containing the matching float and quantized activations
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
+ if allow_list is None:
+ allow_list = get_default_compare_output_module_list()
+ prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
+ float_model(*data)
+ q_model(*data)
+ act_compare_dict = get_matching_activations(float_model, q_model)
+ return act_compare_dict
--- /dev/null
+import collections
+
+import torch
+import torch.nn as nn
+import torch.quantization.quantize_fx as quantize_fx
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+from torch.ao.ns.fx.mappings import (
+ get_base_name_to_sets_of_related_ops,
+)
+from torch.ao.ns.fx.graph_matcher import (
+ get_matching_subgraph_pairs,
+ get_type_a_related_to_b,
+)
+
+from .fx.weight_utils import (
+ extract_weight_from_node,
+)
+
+from .fx.graph_passes import (
+ add_loggers_to_model,
+ create_a_shadows_b,
+)
+
+from .fx.utils import (
+ rekey_logger_info_on_node_name_of_model,
+ maybe_add_missing_fqns,
+ get_target_type_str,
+)
+
+from .fx.ns_types import (
+ NSSingleResultValuesType,
+ NSResultsType,
+ NSNodeTargetType,
+)
+
+from typing import Dict, Tuple, Callable, List, Optional, Set
+
+RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+
+class OutputLogger(nn.Module):
+ stats: List[torch.Tensor]
+ stats_rnn: List[RNNReturnType]
+
+ def __init__(
+ self,
+ ref_node_name: str,
+ prev_node_name: str,
+ model_name: str,
+ ref_name: str,
+ prev_node_target_type: str,
+ ref_node_target_type: str,
+ results_type: str,
+ index_within_arg: int,
+ index_of_arg: int,
+ fqn: Optional[str],
+ ):
+ super().__init__()
+ self.stats: List[torch.Tensor] = []
+ self.stats_rnn: List[RNNReturnType] = []
+
+ # name of the node which was responsible for adding this logger
+ # Note:
+ # - if we are logging node outputs, this is the same as prev_node_name
+ # - if we are logging node inputs, this is the name of the node
+ # whose input this logger is logging.
+ #
+ # example, where logger1 is logging input of op1 and logger2 is logging
+ # the output of op1:
+ #
+ # x1 -> logger1 -> op1 -> logger2 -> x2
+ #
+ # in this example,
+ # - logger1's prev_node_name is x1 and ref_node_name is op1
+ # - logger2's prev_node_name is op1 and ref_node_name is op1
+ self.ref_node_name = ref_node_name
+ # name of the node whose output this Logger is capturing
+ self.prev_node_name = prev_node_name
+
+ # name of the model from which the node originated from
+ self.model_name = model_name
+ # reference name, used to match loggers from separate models
+ # to each other
+ self.ref_name = ref_name
+ # type of the target of the node whose output this logger is logging
+ self.prev_node_target_type = prev_node_target_type
+ # type of the target of the node which was respondible for adding this
+ # logger
+ self.ref_node_target_type = ref_node_target_type
+ # what kind of values are inside of stats
+ self.results_type = results_type
+ # index of this node within the arg of the input/output node
+ # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
+ self.index_within_arg = index_within_arg
+ # index of this node within the args of the input/output node
+ # for example, in add(x1, x2), x2 would have index_of_arg == 1
+ self.index_of_arg = index_of_arg
+ # fully qualified name
+ self.fqn = fqn
+
+ # Note: cannot annotate the type of x because TorchScript does not support
+ # the Union type.
+ def forward(self, x):
+ if isinstance(x, torch.Tensor):
+ self.stats.append(x.detach())
+ elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
+ new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
+ self.stats_rnn.append(new_res)
+ return x
+
+ def __repr__(self):
+ return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
+prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
+ref_node_target_type={self.ref_node_target_type}
+results_type={self.results_type}, index_within_arg={self.index_within_arg},
+index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
+
+
+class NSTracer(quantize_fx.QuantizationTracer):
+ """
+ Just like a regular tracer, but treats observers and fake_quantize
+ modules as leaf modules.
+ """
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
+ if isinstance(m, torch.quantization.ObserverBase):
+ return True
+ elif isinstance(m, torch.quantization.FakeQuantizeBase):
+ return True
+ return super().is_leaf_module(m, module_qualified_name)
+
+
+def _extract_weights_one_model(
+ model_name: str,
+ model: GraphModule,
+ nodes_and_names_to_instrument: List[Tuple[Node, str]],
+ results: NSResultsType,
+ op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> None:
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
+ for node, ref_name in nodes_and_names_to_instrument:
+ res_type = NSSingleResultValuesType.WEIGHT.value
+ extracted_weight = extract_weight_from_node(
+ node, model, op_to_type_to_weight_extraction_fn)
+ if extracted_weight:
+ if ref_name not in results:
+ results[ref_name] = {res_type: {}}
+ results[ref_name][res_type][model_name] = [extracted_weight]
+
+
+def _extract_weights_impl(
+ model_name_a: str,
+ gm_a: GraphModule,
+ model_name_b: str,
+ gm_b: GraphModule,
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> NSResultsType:
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
+ gm_a, gm_b, base_name_to_sets_of_related_ops,
+ unmatchable_types_map)
+
+ # split the subgraph pairs into one data structure for each model
+ nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
+ nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
+ for match_name, match in matched_subgraph_pairs.items():
+ subgraph_a, subgraph_b = match
+ nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
+ nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
+
+ # populate the results, one model at a time
+ results: NSResultsType = {}
+ _extract_weights_one_model(
+ model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
+ op_to_type_to_weight_extraction_fn)
+ _extract_weights_one_model(
+ model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
+ op_to_type_to_weight_extraction_fn)
+
+ # fill in missing fqn entries
+ maybe_add_missing_fqns(results)
+
+ # rekey on names of nodes in gm_b
+ results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
+
+ return results
+
+
+def extract_weights(
+ model_name_a: str,
+ model_a: nn.Module,
+ model_name_b: str,
+ model_b: nn.Module,
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> NSResultsType:
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
+ if base_name_to_sets_of_related_ops is None:
+ base_name_to_sets_of_related_ops = \
+ get_base_name_to_sets_of_related_ops()
+ type_a_related_to_b = \
+ get_type_a_related_to_b(base_name_to_sets_of_related_ops)
+
+ # TODO(future PR): expose these
+ skipped_module_names: List[str] = []
+ skipped_module_classes: List[Callable] = []
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+ if hasattr(model_a, '_node_name_to_scope'):
+ gm_a._node_name_to_scope = model_a._node_name_to_scope
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+ if hasattr(model_b, '_node_name_to_scope'):
+ gm_b._node_name_to_scope = model_b._node_name_to_scope
+ return _extract_weights_impl(
+ model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
+ unmatchable_types_map, op_to_type_to_weight_extraction_fn)
+
+
+def _add_loggers_one_model(
+ model_name: str,
+ model: GraphModule,
+ nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
+ nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
+ logger_cls: Callable,
+) -> nn.Module:
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
+
+ # TODO(future PR): do not observe nodes we do not care
+ # about (both fp32, denylist, etc)
+ node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
+ node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
+ for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
+ node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
+ for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
+ node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
+
+ model = add_loggers_to_model(
+ model, node_to_instrument_inputs_to_ref_name,
+ node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
+ return model
+
+
+def _add_loggers_impl(
+ name_a: str,
+ gm_a: GraphModule,
+ name_b: str,
+ gm_b: GraphModule,
+ logger_cls: Callable,
+ should_log_inputs: bool,
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> Tuple[nn.Module, nn.Module]:
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
+ gm_a, gm_b,
+ base_name_to_sets_of_related_ops, unmatchable_types_map)
+ nodes_and_names_to_instrument_inputs_a = []
+ nodes_and_names_to_instrument_inputs_b = []
+ nodes_and_names_to_instrument_outputs_a = []
+ nodes_and_names_to_instrument_outputs_b = []
+ for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
+ ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
+ ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
+ # Note: for matching inputs we use start_node, such as observing
+ # the input of linear in linear-relu
+ if should_log_inputs:
+ nodes_and_names_to_instrument_inputs_a.append(
+ (subgraph_a.start_node, match_name, ref_node_type_a))
+ nodes_and_names_to_instrument_inputs_b.append(
+ (subgraph_b.start_node, match_name, ref_node_type_b))
+ # Note: for matching activations we always use end_node,
+ # such as observing the output of relu in linear-relu
+ nodes_and_names_to_instrument_outputs_a.append(
+ (subgraph_a.end_node, match_name, ref_node_type_a))
+ nodes_and_names_to_instrument_outputs_b.append(
+ (subgraph_b.end_node, match_name, ref_node_type_b))
+
+ new_model_a = _add_loggers_one_model(
+ name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
+ nodes_and_names_to_instrument_outputs_a, logger_cls)
+ new_model_b = _add_loggers_one_model(
+ name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
+ nodes_and_names_to_instrument_outputs_b, logger_cls)
+ return (new_model_a, new_model_b)
+
+
+def add_loggers(
+ name_a: str,
+ model_a: nn.Module,
+ name_b: str,
+ model_b: nn.Module,
+ logger_cls: Callable,
+ should_log_inputs : bool = False,
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> Tuple[nn.Module, nn.Module]:
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
+ # TODO(future PR): expose these
+ skipped_module_names: List[str] = []
+ skipped_module_classes: List[Callable] = []
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+ if hasattr(model_a, '_node_name_to_scope'):
+ gm_a._node_name_to_scope = model_a._node_name_to_scope
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+ if hasattr(model_b, '_node_name_to_scope'):
+ gm_b._node_name_to_scope = model_b._node_name_to_scope
+ return _add_loggers_impl(
+ name_a, gm_a, name_b, gm_b, logger_cls,
+ should_log_inputs=should_log_inputs,
+ base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
+ unmatchable_types_map=unmatchable_types_map)
+
+
+def _extract_logger_info_one_model(
+ model: nn.Module,
+ results: NSResultsType,
+ logger_cls: Callable,
+) -> None:
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
+ for gm_name, mod in model.named_modules():
+ # TODO(future PR): better check when scripted
+ is_logger = (
+ isinstance(mod, logger_cls) # type: ignore[arg-type]
+ or (
+ isinstance(mod, torch.jit.RecursiveScriptModule)
+ and mod.original_name == 'OutputLogger'
+ )
+ )
+ if is_logger:
+ key = mod.ref_name
+ if key not in results:
+ results[key] = {}
+ assert mod.model_name not in results[key], \
+ f"{mod.model_name} is already present in results"
+ if mod.results_type not in results[key]:
+ results[key][mod.results_type] = {}
+ if mod.model_name not in results[key][mod.results_type]:
+ results[key][mod.results_type][mod.model_name] = []
+ stats_to_use = mod.stats
+ if len(mod.stats_rnn) > 0:
+ stats_to_use = mod.stats_rnn
+ results[key][mod.results_type][mod.model_name].append({
+ 'type': mod.results_type,
+ 'values': stats_to_use,
+ 'ref_node_name': mod.ref_node_name,
+ 'ref_node_target_type': mod.ref_node_target_type,
+ 'prev_node_name': mod.prev_node_name,
+ 'prev_node_target_type': mod.prev_node_target_type,
+ 'index_within_arg': mod.index_within_arg,
+ 'index_of_arg': mod.index_of_arg,
+ 'fqn': mod.fqn,
+ })
+ # ensure the list stays sorted
+ results[key][mod.results_type][mod.model_name].sort(
+ key=lambda res:
+ f"{res['index_of_arg']}:{res['index_within_arg']}"
+ )
+
+
+# TODO(future PR): align on naming
+# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
+def extract_logger_info(
+ model_a: nn.Module,
+ model_b: nn.Module,
+ logger_cls: Callable,
+ model_name_to_use_for_layer_names: str,
+) -> NSResultsType:
+ """
+ Same thing as ns.extract_logger_info, but for models prepared with
+ this module.
+
+ TODO(future PR): real docblock
+
+ Output format: NSResultsType
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
+ results: NSResultsType = {}
+ for model in (model_a, model_b):
+ _extract_logger_info_one_model(model, results, logger_cls)
+ # fill in missing fqn entries
+ maybe_add_missing_fqns(results)
+ # rekey on the name of model b
+ results = rekey_logger_info_on_node_name_of_model(
+ results, model_name_to_use_for_layer_names)
+ return results
+
+
+def _add_shadow_loggers_impl(
+ name_a: str,
+ gm_a: GraphModule,
+ name_b: str,
+ gm_b: GraphModule,
+ logger_cls: Callable,
+ should_log_inputs: bool,
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> nn.Module:
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
+ matched_subgraph_pairs = get_matching_subgraph_pairs(
+ gm_a, gm_b, base_name_to_sets_of_related_ops,
+ unmatchable_types_map)
+ gm_a_shadows_b = create_a_shadows_b(
+ name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
+ should_log_inputs=should_log_inputs,
+ node_type_to_io_type_map=node_type_to_io_type_map)
+ return gm_a_shadows_b
+
+
+def add_shadow_loggers(
+ name_a: str,
+ model_a: nn.Module,
+ name_b: str,
+ model_b: nn.Module,
+ logger_cls: Callable,
+ should_log_inputs: bool = False,
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> nn.Module:
+ """
+ Same thing as add_loggers, but for an `a_shadows_b` model.
+ TODO(future PR): real docblock
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
+ # TODO(future PR): expose these
+ skipped_module_names: List[str] = []
+ skipped_module_classes: List[Callable] = []
+ tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+ tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+ gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+ if hasattr(model_a, '_node_name_to_scope'):
+ gm_a._node_name_to_scope = model_a._node_name_to_scope
+ gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+ if hasattr(model_b, '_node_name_to_scope'):
+ gm_b._node_name_to_scope = model_b._node_name_to_scope
+ return _add_shadow_loggers_impl(
+ name_a, gm_a, name_b, gm_b, logger_cls,
+ should_log_inputs=should_log_inputs,
+ base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
+ node_type_to_io_type_map=node_type_to_io_type_map,
+ unmatchable_types_map=unmatchable_types_map)
+
+
+def extract_shadow_logger_info(
+ model_a_shadows_b: nn.Module,
+ logger_cls: Callable,
+ model_name_to_use_for_layer_names: str,
+) -> NSResultsType:
+ """
+ Same thing as extract_logger_info, but for an `a_shadows_b` model.
+ TODO(future PR): real docblock
+ """
+ torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
+ results: NSResultsType = collections.defaultdict(dict)
+ _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
+ # fill in missing fqn entries
+ maybe_add_missing_fqns(results)
+ # rekey on the name of model b
+ results = rekey_logger_info_on_node_name_of_model(
+ results, model_name_to_use_for_layer_names)
+ return dict(results)
+
+
+def extend_logger_results_with_comparison(
+ results: NSResultsType,
+ model_name_1: str,
+ model_name_2: str,
+ comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ comparison_name: str,
+) -> None:
+ """
+ Compares the logged values from `model_name_2` against the corresponding
+ values in `model_name_1`, using `comparison_fn`. Records the result
+ in `model_name_2`'s results under `comparison_name`.
+ """
+ for _, results_type_to_results in results.items():
+ for _, model_name_to_results in results_type_to_results.items():
+ assert model_name_1 in model_name_to_results, \
+ f"{model_name_1} not found in results"
+ assert model_name_2 in model_name_to_results, \
+ f"{model_name_2} not found in results"
+
+ results_1 = model_name_to_results[model_name_1]
+ results_2 = model_name_to_results[model_name_2]
+
+ for result_2 in results_2:
+ index_within_arg_2 = result_2['index_within_arg']
+ index_of_arg_2 = result_2['index_of_arg']
+ # find corresponding result_1
+ result_1 = None
+ for cur_result_1 in results_1:
+ index_within_arg_1 = cur_result_1['index_within_arg']
+ index_of_arg_1 = cur_result_1['index_of_arg']
+ if (
+ (index_within_arg_1 == index_within_arg_2) and
+ (index_of_arg_1 == index_of_arg_2)
+ ):
+ result_1 = cur_result_1
+ break
+ assert result_1 is not None
+
+ values_1 = result_1['values']
+ values_2 = result_2['values']
+ result_2[comparison_name] = []
+ for value_1, value_2 in zip(values_1, values_2):
+ comparison_result = comparison_fn(value_1, value_2)
+ result_2[comparison_name].append(comparison_result)
NSSubgraph,
NSNodeTargetType,
)
-from torch.quantization.ns.mappings import (
+from torch.ao.ns.fx.mappings import (
get_node_type_to_io_type_map,
)
from torch.quantization.quantize import is_activation_post_process
import torch.nn.quantized as nnq
import torch.quantization
-import torch.quantization._numeric_suite as ns
+import torch.ao.ns._numeric_suite as ns
_supported_modules = {nn.Linear, nn.Conv2d}
_supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
-import torch
-import torch.nn as nn
-import torch.nn.quantized as nnq
-import torch.nn.quantized.dynamic as nnqd
-from torch.quantization import prepare
-from typing import Dict, List, Optional, Any, Union, Callable, Set
-
-from .quantization_mappings import (
- get_default_compare_output_module_list,
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/ns/_numeric_suite.py`, while adding an import statement
+here.
+"""
+
+from torch.ao.ns._numeric_suite import (
+ NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
+ _find_match,
+ compare_weights,
+ _get_logger_dict_helper,
+ get_logger_dict,
+ Logger,
+ ShadowLogger,
+ OutputLogger,
+ _convert_tuple_to_list,
+ _dequantize_tensor_list,
+ Shadow,
+ prepare_model_with_stubs,
+ _is_identical_module_type,
+ compare_model_stub,
+ get_matching_activations,
+ prepare_model_outputs,
+ compare_model_outputs,
)
-
-NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
- nnqd.Linear,
- nnq.Linear,
- nnqd.LSTM,
- nn.LSTM,
-}
-
-
-def _find_match(
- str_list: Union[Dict[str, Any], List[str]], key_str: str,
- postfix: str,
-) -> Optional[str]:
- split_str = key_str.split(".")
- if split_str[-1] == postfix:
- match_string = "".join(key_str.split(".")[0:-1])
- for s2 in str_list:
- pattern1 = "".join(s2.split(".")[0:-1])
- pattern2 = "".join(s2.split(".")[0:-2])
- if match_string == pattern1:
- return s2
- if match_string == pattern2:
- return s2
-
- # For matching "fc.weight" and "fc._packed_params._packed_params"
- if postfix == "_packed_params":
- match_string = "".join(key_str.split(".")[0:-2])
- if len(match_string) == 0:
- return None
- for s2 in str_list:
- pattern1 = "".join(s2.split(".")[0:-1])
- pattern2 = "".join(s2.split(".")[0:-2])
- if match_string == pattern1:
- return s2
- if match_string == pattern2:
- return s2
- return None
- else:
- return None
-
-
-def compare_weights(
- float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
-) -> Dict[str, Dict[str, torch.Tensor]]:
- r"""Compare the weights of the float module with its corresponding quantized
- module. Return a dict with key corresponding to module names and each entry being
- a dictionary with two keys 'float' and 'quantized', containing the float and
- quantized weights. This dict can be used to compare and compute the quantization
- error of the weights of float and quantized models.
-
- Example usage:
- wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
- for key in wt_compare_dict:
- print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
-
- Args:
- float_dict: state dict of the float model
- quantized_dict: state dict of the quantized model
-
- Return:
- weight_dict: dict with key corresponding to module names and each entry being
- a dictionary with two keys 'float' and 'quantized', containing the float and
- quantized weights
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
- weight_dict: Dict[str, Dict] = {}
- for key in quantized_dict:
- match_key = _find_match(float_dict, key, "weight")
- if match_key is not None:
- weight_dict[key] = {}
- weight_dict[key]["float"] = float_dict[match_key]
- weight_dict[key]["quantized"] = quantized_dict[key]
- continue
-
- # For matching "fc.weight" and "fc._packed_params._packed_params"
- match_key = _find_match(float_dict, key, "_packed_params")
- if match_key is not None:
- weight_dict[key] = {}
- weight_dict[key]["float"] = float_dict[match_key]
- weight_dict[key]["quantized"] = quantized_dict[key][0]
-
- # For LSTM
- split_str = key.split(".")
- if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
- layer = split_str[-2]
- module_name = ".".join(split_str[:-3])
- float_weight_ih_key = module_name + ".weight_ih_l" + layer
- float_weight_hh_key = module_name + ".weight_hh_l" + layer
- if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
- weight_dict[key] = {}
- weight_dict[key]["float"] = float_dict[float_weight_ih_key]
- weight_dict[key]["quantized"] = (
- quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
- )
- weight_dict[key]["float"] = float_dict[float_weight_hh_key]
- weight_dict[key]["quantized"] = (
- quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
- )
-
- return weight_dict
-
-
-def _get_logger_dict_helper(
- mod: nn.Module, target_dict: Dict[str, Any],
- prefix: str = "",
-) -> None:
- r"""This is the helper function for get_logger_dict
-
- Args:
- mod: module we want to save all logger stats
- prefix: prefix for the current module
- target_dict: the dictionary used to save all logger stats
- """
-
- def get_prefix(prefix):
- return prefix if prefix == "" else prefix + "."
-
- for name, child in mod.named_children():
- if isinstance(child, Logger):
- target_dict[get_prefix(prefix) + "stats"] = child.stats
- break
-
- for name, child in mod.named_children():
- module_prefix = get_prefix(prefix) + name if prefix else name
- _get_logger_dict_helper(child, target_dict, module_prefix)
-
-
-def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
- r"""Traverse the modules and save all logger stats into target dict.
- This is mainly used for quantization accuracy debug.
-
- Type of loggers supported:
- ShadowLogger: used to log the outputs of the quantized module and its
- matching float shadow module,
- OutputLogger: used to log the outputs of the modules
-
- Args:
- mod: module we want to save all logger stats
- prefix: prefix for the current module
-
- Return:
- target_dict: the dictionary used to save all logger stats
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
-
- target_dict: Dict[str, Dict] = {}
- _get_logger_dict_helper(mod, target_dict, prefix)
- return target_dict
-
-
-class Logger(nn.Module):
- r"""Base class for stats logging
- """
-
- def __init__(self):
- super(Logger, self).__init__()
- self.stats = {}
- # We only insert observer if the op is quantized with static quantization,
- # which is identified by activation_observer.dtype == quint8. This is needed
- # when attaching Logger as observer for FX mode
- self.dtype = torch.quint8
-
- def forward(self, x):
- pass
-
-
-class ShadowLogger(Logger):
- r"""Class used in Shadow module to record the outputs of the original and
- shadow modules.
- """
-
- def __init__(self):
- super(ShadowLogger, self).__init__()
- self.stats["float"] = []
- self.stats["quantized"] = []
-
- def forward(self, x, y):
- if len(x) > 1:
- x = x[0]
- if len(y) > 1:
- y = y[0]
- self.stats["quantized"].append(x.detach())
- self.stats["float"].append(y.detach())
-
-
-class OutputLogger(Logger):
- r"""Class used to log the outputs of the module
- """
-
- def __init__(self):
- super(OutputLogger, self).__init__()
- self.stats["tensor_val"] = []
-
-
- def forward(self, x):
- self.stats["tensor_val"].append(x)
- return x
-
-
-def _convert_tuple_to_list(t: Any) -> Any:
- return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
-
-
-def _dequantize_tensor_list(t: Any) -> Any:
- return (
- list(_dequantize_tensor_list(x) for x in t)
- if type(t) is list
- else t.dequantize()
- if t.is_quantized
- else t
- )
-
-
-class Shadow(nn.Module):
- r"""Shadow module attaches the float module to its matching quantized module
- as the shadow. Then it uses Logger module to process the outputs of both
- modules.
-
- Args:
- q_module: module quantized from float_module that we want to shadow
- float_module: float module used to shadow q_module
- logger_cls: type of logger used to process the outputs of q_module and
- float_module. ShadowLogger or custom loggers can be used.
- """
-
- def __init__(self, q_module, float_module, logger_cls):
- super(Shadow, self).__init__()
- self.orig_module = q_module
- self.shadow_module = float_module
- self.dequant = nnq.DeQuantize()
- self.logger = logger_cls()
-
- def forward(self, *x) -> torch.Tensor:
- xl = _convert_tuple_to_list(x)
- output = self.orig_module(*xl)
- xl_float = _dequantize_tensor_list(xl)
- shadow_output = self.shadow_module(*xl_float)
- self.logger(output, shadow_output)
- return output
-
- def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- output = self.orig_module.add(x, y)
- x = x.dequantize()
- y = y.dequantize()
- shadow_output = self.shadow_module.add(x, y)
- self.logger(output, shadow_output)
- return output
-
- def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
- output = self.orig_module.add_scalar(x, y)
- x = x.dequantize()
- shadow_output = self.shadow_module.add_scalar(x, y)
- self.logger(output, shadow_output)
- return output
-
- def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- output = self.orig_module.mul(x, y)
- x = x.dequantize()
- y = y.dequantize()
- shadow_output = self.shadow_module.mul(x, y)
- self.logger(output, shadow_output)
- return output
-
- def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
- output = self.orig_module.mul_scalar(x, y)
- x = x.dequantize()
- shadow_output = self.shadow_module.mul_scalar(x, y)
- self.logger(output, shadow_output)
- return output
-
- def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
- output = self.orig_module.cat(x, dim)
- x = [y.dequantize() for y in x]
- shadow_output = self.shadow_module.cat(x, dim)
- self.logger(output, shadow_output)
- return output
-
- def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- output = self.orig_module.add_relu(x, y)
- x = x.dequantize()
- y = y.dequantize()
- shadow_output = self.shadow_module.add_relu(x, y)
- self.logger(output, shadow_output)
- return output
-
-
-def prepare_model_with_stubs(
- float_module: nn.Module, q_module: nn.Module,
- module_swap_list: Set[type], logger_cls: Callable,
-) -> None:
- r"""Prepare the model by attaching the float module to its matching quantized
- module as the shadow if the float module type is in module_swap_list.
-
- Example usage:
- prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
- q_model(data)
- ob_dict = get_logger_dict(q_model)
-
- Args:
- float_module: float module used to generate the q_module
- q_module: module quantized from float_module
- module_swap_list: list of float module types to attach the shadow
- logger_cls: type of logger to be used in shadow module to process the outputs of
- quantized module and its float shadow module
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
-
- float_module_children = {}
- for name, mod in float_module.named_children():
- float_module_children[name] = mod
-
- reassign = {}
- for name, mod in q_module.named_children():
-
- if name not in float_module_children:
- continue
-
- float_mod = float_module_children[name]
-
- if type(float_mod) not in module_swap_list:
- prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
-
- # Insert shadow module only if the module is not of the same type as
- # the floating point module
- if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
- reassign[name] = Shadow(mod, float_mod, logger_cls)
-
- for key, value in reassign.items():
- q_module._modules[key] = value
-
-def _is_identical_module_type(mod1, mod2):
- # Compare if two modules have the same dtype
- mod1_module_types = [type(mod) for mod in mod1.modules()]
- mod2_module_types = [type(mod) for mod in mod2.modules()]
- return mod1_module_types == mod2_module_types
-
-
-
-def compare_model_stub(
- float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
- *data, logger_cls=ShadowLogger
-) -> Dict[str, Dict]:
- r"""Compare quantized module in a model with its floating point counterpart,
- feeding both of them the same input. Return a dict with key corresponding to
- module names and each entry being a dictionary with two keys 'float' and
- 'quantized', containing the output tensors of quantized and its matching
- float shadow module. This dict can be used to compare and compute the module
- level quantization error.
-
- This function first call prepare_model_with_stubs() to swap the quantized
- module that we want to compare with the Shadow module, which takes quantized
- module, corresponding float module and logger as input, and creates a forward
- path inside to make the float module to shadow quantized module sharing the
- same input. The logger can be customizable, default logger is ShadowLogger
- and it will save the outputs of the quantized module and float module that
- can be used to compute the module level quantization error.
-
- Example usage:
- module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
- ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
- for key in ob_dict:
- print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
-
- Args:
- float_model: float model used to generate the q_model
- q_model: model quantized from float_model
- module_swap_list: list of float module types at which shadow modules will
- be attached.
- data: input data used to run the prepared q_model
- logger_cls: type of logger to be used in shadow module to process the outputs of
- quantized module and its float shadow module
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
- prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
- q_model(*data)
- ob_dict = get_logger_dict(q_model)
- return ob_dict
-
-
-def get_matching_activations(
- float_module: nn.Module, q_module: nn.Module,
-) -> Dict[str, Dict[str, torch.Tensor]]:
- r"""Find the matching activation between float and quantized modules.
-
- Args:
- float_module: float module used to generate the q_module
- q_module: module quantized from float_module
-
- Return:
- act_dict: dict with key corresponding to quantized module names and each
- entry being a dictionary with two keys 'float' and 'quantized', containing
- the matching float and quantized activations
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
- float_dict = get_logger_dict(float_module)
- quantized_dict = get_logger_dict(q_module)
- act_dict: Dict[str, Dict] = {}
- for key in quantized_dict:
- match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
- if match_key is not None:
- act_dict[key] = {}
- act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
- act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
- return act_dict
-
-
-def prepare_model_outputs(
- float_module: nn.Module,
- q_module: nn.Module,
- logger_cls=OutputLogger,
- allow_list=None
-) -> None:
- r"""Prepare the model by attaching the logger to both float module
- and quantized module if they are in the allow_list.
-
- Args:
- float_module: float module used to generate the q_module
- q_module: module quantized from float_module
- logger_cls: type of logger to be attached to float_module and q_module
- allow_list: list of module types to attach logger
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
- if allow_list is None:
- allow_list = get_default_compare_output_module_list()
-
- qconfig_debug = torch.quantization.QConfig(activation=logger_cls, weight=None)
- float_module.qconfig = qconfig_debug # type: ignore[assignment]
- prepare(float_module, inplace=True, allow_list=allow_list)
- q_module.qconfig = qconfig_debug # type: ignore[assignment]
- prepare(
- q_module,
- inplace=True,
- allow_list=allow_list,
- observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
- )
-
-
-def compare_model_outputs(
- float_model: nn.Module,
- q_model: nn.Module,
- *data,
- logger_cls=OutputLogger,
- allow_list=None
-) -> Dict[str, Dict[str, torch.Tensor]]:
- r"""Compare output activations between float and quantized models at
- corresponding locations for the same input. Return a dict with key corresponding
- to quantized module names and each entry being a dictionary with two keys
- 'float' and 'quantized', containing the activations of quantized model and
- float model at matching locations. This dict can be used to compare and
- compute the propagation quantization error.
-
- Example usage:
- act_compare_dict = compare_model_outputs(float_model, qmodel, data)
- for key in act_compare_dict:
- print(key, compute_error(act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize()))
-
- Args:
- float_model: float model used to generate the q_model
- q_model: model quantized from float_model
- data: input data used to run the prepared float_model and q_model
- logger_cls: type of logger to be attached to float_module and q_module
- allow_list: list of module types to attach logger
-
- Return:
- act_compare_dict: dict with key corresponding to quantized module names
- and each entry being a dictionary with two keys 'float' and 'quantized',
- containing the matching float and quantized activations
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
- if allow_list is None:
- allow_list = get_default_compare_output_module_list()
- prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
- float_model(*data)
- q_model(*data)
- act_compare_dict = get_matching_activations(float_model, q_model)
- return act_compare_dict
-import collections
-
-import torch
-import torch.nn as nn
-import torch.quantization.quantize_fx as quantize_fx
-from torch.fx import GraphModule
-from torch.fx.graph import Node
-from torch.quantization.ns.mappings import (
- get_base_name_to_sets_of_related_ops,
-)
-from torch.quantization.ns.graph_matcher import (
- get_matching_subgraph_pairs,
- get_type_a_related_to_b,
-)
-
-from .ns.weight_utils import (
- extract_weight_from_node,
-)
-
-from .ns.graph_passes import (
- add_loggers_to_model,
- create_a_shadows_b,
-)
-
-from .ns.utils import (
- rekey_logger_info_on_node_name_of_model,
- maybe_add_missing_fqns,
- get_target_type_str,
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/ns/_numeric_suite_fx.py`, while adding an import statement
+here.
+"""
+
+from torch.ao.ns._numeric_suite_fx import (
+ RNNReturnType,
+ OutputLogger,
+ NSTracer,
+ _extract_weights_one_model,
+ _extract_weights_impl,
+ extract_weights,
+ _add_loggers_one_model,
+ _add_loggers_impl,
+ add_loggers,
+ _extract_logger_info_one_model,
+ extract_logger_info,
+ _add_shadow_loggers_impl,
+ add_shadow_loggers,
+ extract_shadow_logger_info,
+ extend_logger_results_with_comparison,
)
-
-from .ns.ns_types import (
- NSSingleResultValuesType,
- NSResultsType,
- NSNodeTargetType,
-)
-
-from typing import Dict, Tuple, Callable, List, Optional, Set
-
-RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-
-class OutputLogger(nn.Module):
- stats: List[torch.Tensor]
- stats_rnn: List[RNNReturnType]
-
- def __init__(
- self,
- ref_node_name: str,
- prev_node_name: str,
- model_name: str,
- ref_name: str,
- prev_node_target_type: str,
- ref_node_target_type: str,
- results_type: str,
- index_within_arg: int,
- index_of_arg: int,
- fqn: Optional[str],
- ):
- super().__init__()
- self.stats: List[torch.Tensor] = []
- self.stats_rnn: List[RNNReturnType] = []
-
- # name of the node which was responsible for adding this logger
- # Note:
- # - if we are logging node outputs, this is the same as prev_node_name
- # - if we are logging node inputs, this is the name of the node
- # whose input this logger is logging.
- #
- # example, where logger1 is logging input of op1 and logger2 is logging
- # the output of op1:
- #
- # x1 -> logger1 -> op1 -> logger2 -> x2
- #
- # in this example,
- # - logger1's prev_node_name is x1 and ref_node_name is op1
- # - logger2's prev_node_name is op1 and ref_node_name is op1
- self.ref_node_name = ref_node_name
- # name of the node whose output this Logger is capturing
- self.prev_node_name = prev_node_name
-
- # name of the model from which the node originated from
- self.model_name = model_name
- # reference name, used to match loggers from separate models
- # to each other
- self.ref_name = ref_name
- # type of the target of the node whose output this logger is logging
- self.prev_node_target_type = prev_node_target_type
- # type of the target of the node which was respondible for adding this
- # logger
- self.ref_node_target_type = ref_node_target_type
- # what kind of values are inside of stats
- self.results_type = results_type
- # index of this node within the arg of the input/output node
- # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
- self.index_within_arg = index_within_arg
- # index of this node within the args of the input/output node
- # for example, in add(x1, x2), x2 would have index_of_arg == 1
- self.index_of_arg = index_of_arg
- # fully qualified name
- self.fqn = fqn
-
- # Note: cannot annotate the type of x because TorchScript does not support
- # the Union type.
- def forward(self, x):
- if isinstance(x, torch.Tensor):
- self.stats.append(x.detach())
- elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
- new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
- self.stats_rnn.append(new_res)
- return x
-
- def __repr__(self):
- return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
-prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
-ref_node_target_type={self.ref_node_target_type}
-results_type={self.results_type}, index_within_arg={self.index_within_arg},
-index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
-
-
-class NSTracer(quantize_fx.QuantizationTracer):
- """
- Just like a regular tracer, but treats observers and fake_quantize
- modules as leaf modules.
- """
- def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
- if isinstance(m, torch.quantization.ObserverBase):
- return True
- elif isinstance(m, torch.quantization.FakeQuantizeBase):
- return True
- return super().is_leaf_module(m, module_qualified_name)
-
-
-def _extract_weights_one_model(
- model_name: str,
- model: GraphModule,
- nodes_and_names_to_instrument: List[Tuple[Node, str]],
- results: NSResultsType,
- op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
-) -> None:
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
- for node, ref_name in nodes_and_names_to_instrument:
- res_type = NSSingleResultValuesType.WEIGHT.value
- extracted_weight = extract_weight_from_node(
- node, model, op_to_type_to_weight_extraction_fn)
- if extracted_weight:
- if ref_name not in results:
- results[ref_name] = {res_type: {}}
- results[ref_name][res_type][model_name] = [extracted_weight]
-
-
-def _extract_weights_impl(
- model_name_a: str,
- gm_a: GraphModule,
- model_name_b: str,
- gm_b: GraphModule,
- base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
-) -> NSResultsType:
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
- matched_subgraph_pairs = get_matching_subgraph_pairs(
- gm_a, gm_b, base_name_to_sets_of_related_ops,
- unmatchable_types_map)
-
- # split the subgraph pairs into one data structure for each model
- nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
- nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
- for match_name, match in matched_subgraph_pairs.items():
- subgraph_a, subgraph_b = match
- nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
- nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
-
- # populate the results, one model at a time
- results: NSResultsType = {}
- _extract_weights_one_model(
- model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
- op_to_type_to_weight_extraction_fn)
- _extract_weights_one_model(
- model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
- op_to_type_to_weight_extraction_fn)
-
- # fill in missing fqn entries
- maybe_add_missing_fqns(results)
-
- # rekey on names of nodes in gm_b
- results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
-
- return results
-
-
-def extract_weights(
- model_name_a: str,
- model_a: nn.Module,
- model_name_b: str,
- model_b: nn.Module,
- base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
-) -> NSResultsType:
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
- if base_name_to_sets_of_related_ops is None:
- base_name_to_sets_of_related_ops = \
- get_base_name_to_sets_of_related_ops()
- type_a_related_to_b = \
- get_type_a_related_to_b(base_name_to_sets_of_related_ops)
-
- # TODO(future PR): expose these
- skipped_module_names: List[str] = []
- skipped_module_classes: List[Callable] = []
- tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
- tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
- gm_a = GraphModule(model_a, tracer_a.trace(model_a))
- if hasattr(model_a, '_node_name_to_scope'):
- gm_a._node_name_to_scope = model_a._node_name_to_scope
- gm_b = GraphModule(model_b, tracer_b.trace(model_b))
- if hasattr(model_b, '_node_name_to_scope'):
- gm_b._node_name_to_scope = model_b._node_name_to_scope
- return _extract_weights_impl(
- model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
- unmatchable_types_map, op_to_type_to_weight_extraction_fn)
-
-
-def _add_loggers_one_model(
- model_name: str,
- model: GraphModule,
- nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
- nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
- logger_cls: Callable,
-) -> nn.Module:
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
-
- # TODO(future PR): do not observe nodes we do not care
- # about (both fp32, denylist, etc)
- node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
- node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
- for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
- node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
- for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
- node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
-
- model = add_loggers_to_model(
- model, node_to_instrument_inputs_to_ref_name,
- node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
- return model
-
-
-def _add_loggers_impl(
- name_a: str,
- gm_a: GraphModule,
- name_b: str,
- gm_b: GraphModule,
- logger_cls: Callable,
- should_log_inputs: bool,
- base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-) -> Tuple[nn.Module, nn.Module]:
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
- matched_subgraph_pairs = get_matching_subgraph_pairs(
- gm_a, gm_b,
- base_name_to_sets_of_related_ops, unmatchable_types_map)
- nodes_and_names_to_instrument_inputs_a = []
- nodes_and_names_to_instrument_inputs_b = []
- nodes_and_names_to_instrument_outputs_a = []
- nodes_and_names_to_instrument_outputs_b = []
- for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
- ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
- ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
- # Note: for matching inputs we use start_node, such as observing
- # the input of linear in linear-relu
- if should_log_inputs:
- nodes_and_names_to_instrument_inputs_a.append(
- (subgraph_a.start_node, match_name, ref_node_type_a))
- nodes_and_names_to_instrument_inputs_b.append(
- (subgraph_b.start_node, match_name, ref_node_type_b))
- # Note: for matching activations we always use end_node,
- # such as observing the output of relu in linear-relu
- nodes_and_names_to_instrument_outputs_a.append(
- (subgraph_a.end_node, match_name, ref_node_type_a))
- nodes_and_names_to_instrument_outputs_b.append(
- (subgraph_b.end_node, match_name, ref_node_type_b))
-
- new_model_a = _add_loggers_one_model(
- name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
- nodes_and_names_to_instrument_outputs_a, logger_cls)
- new_model_b = _add_loggers_one_model(
- name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
- nodes_and_names_to_instrument_outputs_b, logger_cls)
- return (new_model_a, new_model_b)
-
-
-def add_loggers(
- name_a: str,
- model_a: nn.Module,
- name_b: str,
- model_b: nn.Module,
- logger_cls: Callable,
- should_log_inputs : bool = False,
- base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-) -> Tuple[nn.Module, nn.Module]:
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
- # TODO(future PR): expose these
- skipped_module_names: List[str] = []
- skipped_module_classes: List[Callable] = []
- tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
- tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
- gm_a = GraphModule(model_a, tracer_a.trace(model_a))
- if hasattr(model_a, '_node_name_to_scope'):
- gm_a._node_name_to_scope = model_a._node_name_to_scope
- gm_b = GraphModule(model_b, tracer_b.trace(model_b))
- if hasattr(model_b, '_node_name_to_scope'):
- gm_b._node_name_to_scope = model_b._node_name_to_scope
- return _add_loggers_impl(
- name_a, gm_a, name_b, gm_b, logger_cls,
- should_log_inputs=should_log_inputs,
- base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
- unmatchable_types_map=unmatchable_types_map)
-
-
-def _extract_logger_info_one_model(
- model: nn.Module,
- results: NSResultsType,
- logger_cls: Callable,
-) -> None:
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
- for gm_name, mod in model.named_modules():
- # TODO(future PR): better check when scripted
- is_logger = (
- isinstance(mod, logger_cls) # type: ignore[arg-type]
- or (
- isinstance(mod, torch.jit.RecursiveScriptModule)
- and mod.original_name == 'OutputLogger'
- )
- )
- if is_logger:
- key = mod.ref_name
- if key not in results:
- results[key] = {}
- assert mod.model_name not in results[key], \
- f"{mod.model_name} is already present in results"
- if mod.results_type not in results[key]:
- results[key][mod.results_type] = {}
- if mod.model_name not in results[key][mod.results_type]:
- results[key][mod.results_type][mod.model_name] = []
- stats_to_use = mod.stats
- if len(mod.stats_rnn) > 0:
- stats_to_use = mod.stats_rnn
- results[key][mod.results_type][mod.model_name].append({
- 'type': mod.results_type,
- 'values': stats_to_use,
- 'ref_node_name': mod.ref_node_name,
- 'ref_node_target_type': mod.ref_node_target_type,
- 'prev_node_name': mod.prev_node_name,
- 'prev_node_target_type': mod.prev_node_target_type,
- 'index_within_arg': mod.index_within_arg,
- 'index_of_arg': mod.index_of_arg,
- 'fqn': mod.fqn,
- })
- # ensure the list stays sorted
- results[key][mod.results_type][mod.model_name].sort(
- key=lambda res:
- f"{res['index_of_arg']}:{res['index_within_arg']}"
- )
-
-
-# TODO(future PR): align on naming
-# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
-def extract_logger_info(
- model_a: nn.Module,
- model_b: nn.Module,
- logger_cls: Callable,
- model_name_to_use_for_layer_names: str,
-) -> NSResultsType:
- """
- Same thing as ns.extract_logger_info, but for models prepared with
- this module.
-
- TODO(future PR): real docblock
-
- Output format: NSResultsType
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
- results: NSResultsType = {}
- for model in (model_a, model_b):
- _extract_logger_info_one_model(model, results, logger_cls)
- # fill in missing fqn entries
- maybe_add_missing_fqns(results)
- # rekey on the name of model b
- results = rekey_logger_info_on_node_name_of_model(
- results, model_name_to_use_for_layer_names)
- return results
-
-
-def _add_shadow_loggers_impl(
- name_a: str,
- gm_a: GraphModule,
- name_b: str,
- gm_b: GraphModule,
- logger_cls: Callable,
- should_log_inputs: bool,
- base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-) -> nn.Module:
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
- matched_subgraph_pairs = get_matching_subgraph_pairs(
- gm_a, gm_b, base_name_to_sets_of_related_ops,
- unmatchable_types_map)
- gm_a_shadows_b = create_a_shadows_b(
- name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
- should_log_inputs=should_log_inputs,
- node_type_to_io_type_map=node_type_to_io_type_map)
- return gm_a_shadows_b
-
-
-def add_shadow_loggers(
- name_a: str,
- model_a: nn.Module,
- name_b: str,
- model_b: nn.Module,
- logger_cls: Callable,
- should_log_inputs: bool = False,
- base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-) -> nn.Module:
- """
- Same thing as add_loggers, but for an `a_shadows_b` model.
- TODO(future PR): real docblock
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
- # TODO(future PR): expose these
- skipped_module_names: List[str] = []
- skipped_module_classes: List[Callable] = []
- tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
- tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
- gm_a = GraphModule(model_a, tracer_a.trace(model_a))
- if hasattr(model_a, '_node_name_to_scope'):
- gm_a._node_name_to_scope = model_a._node_name_to_scope
- gm_b = GraphModule(model_b, tracer_b.trace(model_b))
- if hasattr(model_b, '_node_name_to_scope'):
- gm_b._node_name_to_scope = model_b._node_name_to_scope
- return _add_shadow_loggers_impl(
- name_a, gm_a, name_b, gm_b, logger_cls,
- should_log_inputs=should_log_inputs,
- base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
- node_type_to_io_type_map=node_type_to_io_type_map,
- unmatchable_types_map=unmatchable_types_map)
-
-
-def extract_shadow_logger_info(
- model_a_shadows_b: nn.Module,
- logger_cls: Callable,
- model_name_to_use_for_layer_names: str,
-) -> NSResultsType:
- """
- Same thing as extract_logger_info, but for an `a_shadows_b` model.
- TODO(future PR): real docblock
- """
- torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
- results: NSResultsType = collections.defaultdict(dict)
- _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
- # fill in missing fqn entries
- maybe_add_missing_fqns(results)
- # rekey on the name of model b
- results = rekey_logger_info_on_node_name_of_model(
- results, model_name_to_use_for_layer_names)
- return dict(results)
-
-
-def extend_logger_results_with_comparison(
- results: NSResultsType,
- model_name_1: str,
- model_name_2: str,
- comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
- comparison_name: str,
-) -> None:
- """
- Compares the logged values from `model_name_2` against the corresponding
- values in `model_name_1`, using `comparison_fn`. Records the result
- in `model_name_2`'s results under `comparison_name`.
- """
- for _, results_type_to_results in results.items():
- for _, model_name_to_results in results_type_to_results.items():
- assert model_name_1 in model_name_to_results, \
- f"{model_name_1} not found in results"
- assert model_name_2 in model_name_to_results, \
- f"{model_name_2} not found in results"
-
- results_1 = model_name_to_results[model_name_1]
- results_2 = model_name_to_results[model_name_2]
-
- for result_2 in results_2:
- index_within_arg_2 = result_2['index_within_arg']
- index_of_arg_2 = result_2['index_of_arg']
- # find corresponding result_1
- result_1 = None
- for cur_result_1 in results_1:
- index_within_arg_1 = cur_result_1['index_within_arg']
- index_of_arg_1 = cur_result_1['index_of_arg']
- if (
- (index_within_arg_1 == index_within_arg_2) and
- (index_of_arg_1 == index_of_arg_2)
- ):
- result_1 = cur_result_1
- break
- assert result_1 is not None
-
- values_1 = result_1['values']
- values_2 = result_2['values']
- result_2[comparison_name] = []
- for value_1, value_2 in zip(values_1, values_2):
- comparison_result = comparison_fn(value_1, value_2)
- result_2[comparison_name].append(comparison_result)
model_b: A quantized model
x: Inputs to use during calibration
"""
- import torch.quantization._numeric_suite_fx as ns
- from torch.quantization.ns.mappings import get_unmatchable_types_map
+ import torch.ao.ns._numeric_suite_fx as ns
+ from torch.ao.ns.fx.mappings import get_unmatchable_types_map
unmatchable_types_map = get_unmatchable_types_map()
unmatchable_types_map["funs_unmatchable"].add(torch.mul)
ns.extend_logger_results_with_comparison(
activation_comparison_dict,
'fp32', 'int8',
- torch.quantization.ns.utils.compute_sqnr, 'sqnr'
+ torch.ao.ns.fx.utils.compute_sqnr, 'sqnr'
)
# Construct a dictionary mapping layer names to the SQNR values
prepare_qat_fx,
convert_fx,
)
- from torch.quantization.ns.ns_types import NSSingleResultValuesType, NSSubgraph
+ from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
from torch.fx.graph import Node
from torch.fx import GraphModule
HAS_FX = True