torch.ao migration: numeric suite, eager and fx (#64817)
authorVasiliy Kuznetsov <vasiliy@fb.com>
Sun, 12 Sep 2021 18:59:44 +0000 (11:59 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sun, 12 Sep 2021 19:00:45 +0000 (12:00 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64817

This migrates `torch.quantization._numeric_suite` to `torch.ao.ns._numeric_suite`, and `torch.quantization._numeric_suite_fx` to `torch.ao.ns._numeric_suite_fx`.

1. move the files
```
HG: move eager mode
hg mv caffe2/torch/quantization/_numeric_suite.py caffe2/torch/ao/ns/
HG: move fx
hg mv caffe2/torch/quantization/_numeric_suite_fx.py caffe2/torch/ao/ns/
hg mv caffe2/torch/quantization/ns/* caffe2/torch/ao/ns/fx/
```

2. create new versions of `_numeric_suite.py` and `_numeric_suite_fx.py` with
imports

3. update all FB callsites

Test Plan: buck test mode/dev //caffe2/test:quantization

Reviewed By: z-a-f

Differential Revision: D30867538

fbshipit-source-id: 120ee830434ca490c1183a187a518eebcbbaf22c

19 files changed:
test/quantization/eager/test_bias_correction_eager.py
test/quantization/eager/test_numeric_suite_eager.py
test/quantization/fx/test_numeric_suite_fx.py
torch/ao/ns/__init__.py [moved from torch/quantization/ns/__init__.py with 100% similarity]
torch/ao/ns/_numeric_suite.py [new file with mode: 0644]
torch/ao/ns/_numeric_suite_fx.py [new file with mode: 0644]
torch/ao/ns/fx/__init__.py [new file with mode: 0644]
torch/ao/ns/fx/graph_matcher.py [moved from torch/quantization/ns/graph_matcher.py with 100% similarity]
torch/ao/ns/fx/graph_passes.py [moved from torch/quantization/ns/graph_passes.py with 99% similarity]
torch/ao/ns/fx/mappings.py [moved from torch/quantization/ns/mappings.py with 100% similarity]
torch/ao/ns/fx/ns_types.py [moved from torch/quantization/ns/ns_types.py with 100% similarity]
torch/ao/ns/fx/pattern_utils.py [moved from torch/quantization/ns/pattern_utils.py with 100% similarity]
torch/ao/ns/fx/utils.py [moved from torch/quantization/ns/utils.py with 100% similarity]
torch/ao/ns/fx/weight_utils.py [moved from torch/quantization/ns/weight_utils.py with 100% similarity]
torch/quantization/_correct_bias.py
torch/quantization/_numeric_suite.py
torch/quantization/_numeric_suite_fx.py
torch/quantization/fx/_equalize.py
torch/testing/_internal/common_quantization.py

index aeb024a..a8c7289 100644 (file)
@@ -5,7 +5,7 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM
 
 from torch.quantization import default_qconfig
 from torch.quantization import QuantWrapper
-import torch.quantization._numeric_suite as ns
+import torch.ao.ns._numeric_suite as ns
 
 from torch.quantization._correct_bias import (
     _supported_modules,
index a2d00ca..a33592d 100644 (file)
@@ -10,7 +10,7 @@ from torch.quantization import (
     quantize,
     quantize_dynamic,
 )
-from torch.quantization._numeric_suite import (
+from torch.ao.ns._numeric_suite import (
     OutputLogger,
     Shadow,
     ShadowLogger,
index 3e627f5..e5024b0 100644 (file)
@@ -34,29 +34,29 @@ from torch.quantization.quantization_mappings import (
 from torch.testing._internal.common_quantization import NodeSpec as ns
 from torch.quantization.fx.pattern_utils import get_default_quant_patterns
 import torch.quantization.fx.quantization_patterns as qp
-from torch.quantization.ns.pattern_utils import (
+from torch.ao.ns.fx.pattern_utils import (
     get_type_a_related_to_b,
 )
-from torch.quantization.ns.graph_matcher import (
+from torch.ao.ns.fx.graph_matcher import (
     get_matching_subgraph_pairs,
     GraphMatchingException,
 )
-from torch.quantization.ns.utils import (
+from torch.ao.ns.fx.utils import (
     compute_sqnr,
     compute_normalized_l2_error,
     compute_cosine_similarity,
 )
-from torch.quantization.ns.mappings import (
+from torch.ao.ns.fx.mappings import (
     get_node_type_to_io_type_map,
     get_unmatchable_types_map,
     get_base_name_to_sets_of_related_ops,
     get_base_name_for_op,
     add_op_to_sets_of_related_ops,
 )
-from torch.quantization.ns.weight_utils import (
+from torch.ao.ns.fx.weight_utils import (
     get_op_to_type_to_weight_extraction_fn,
 )
-from torch.quantization._numeric_suite_fx import (
+from torch.ao.ns._numeric_suite_fx import (
     extract_weights,
     _extract_weights_impl,
     add_loggers,
@@ -1634,7 +1634,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
         op_to_type_to_weight_extraction_fn = \
             get_op_to_type_to_weight_extraction_fn()
         op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
-            torch.quantization.ns.weight_utils.get_linear_fun_weight
+            torch.ao.ns.fx.weight_utils.get_linear_fun_weight
 
         # test compare weights
         results = extract_weights(
diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py
new file mode 100644 (file)
index 0000000..aa648e7
--- /dev/null
@@ -0,0 +1,486 @@
+import torch
+import torch.nn as nn
+import torch.nn.quantized as nnq
+import torch.nn.quantized.dynamic as nnqd
+from torch.quantization import prepare
+from typing import Dict, List, Optional, Any, Union, Callable, Set
+
+from torch.quantization.quantization_mappings import (
+    get_default_compare_output_module_list,
+)
+
+NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
+    nnqd.Linear,
+    nnq.Linear,
+    nnqd.LSTM,
+    nn.LSTM,
+}
+
+
+def _find_match(
+    str_list: Union[Dict[str, Any], List[str]], key_str: str,
+    postfix: str,
+) -> Optional[str]:
+    split_str = key_str.split(".")
+    if split_str[-1] == postfix:
+        match_string = "".join(key_str.split(".")[0:-1])
+        for s2 in str_list:
+            pattern1 = "".join(s2.split(".")[0:-1])
+            pattern2 = "".join(s2.split(".")[0:-2])
+            if match_string == pattern1:
+                return s2
+            if match_string == pattern2:
+                return s2
+
+        # For matching "fc.weight" and "fc._packed_params._packed_params"
+        if postfix == "_packed_params":
+            match_string = "".join(key_str.split(".")[0:-2])
+            if len(match_string) == 0:
+                return None
+            for s2 in str_list:
+                pattern1 = "".join(s2.split(".")[0:-1])
+                pattern2 = "".join(s2.split(".")[0:-2])
+                if match_string == pattern1:
+                    return s2
+                if match_string == pattern2:
+                    return s2
+        return None
+    else:
+        return None
+
+
+def compare_weights(
+    float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
+) -> Dict[str, Dict[str, torch.Tensor]]:
+    r"""Compare the weights of the float module with its corresponding quantized
+    module. Return a dict with key corresponding to module names and each entry being
+    a dictionary with two keys 'float' and 'quantized', containing the float and
+    quantized weights. This dict can be used to compare and compute the quantization
+    error of the weights of float and quantized models.
+
+    Example usage:
+        wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
+        for key in wt_compare_dict:
+            print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
+
+    Args:
+        float_dict: state dict of the float model
+        quantized_dict: state dict of the quantized model
+
+    Return:
+        weight_dict: dict with key corresponding to module names and each entry being
+        a dictionary with two keys 'float' and 'quantized', containing the float and
+        quantized weights
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
+    weight_dict: Dict[str, Dict] = {}
+    for key in quantized_dict:
+        match_key = _find_match(float_dict, key, "weight")
+        if match_key is not None:
+            weight_dict[key] = {}
+            weight_dict[key]["float"] = float_dict[match_key]
+            weight_dict[key]["quantized"] = quantized_dict[key]
+            continue
+
+        # For matching "fc.weight" and "fc._packed_params._packed_params"
+        match_key = _find_match(float_dict, key, "_packed_params")
+        if match_key is not None:
+            weight_dict[key] = {}
+            weight_dict[key]["float"] = float_dict[match_key]
+            weight_dict[key]["quantized"] = quantized_dict[key][0]
+
+        # For LSTM
+        split_str = key.split(".")
+        if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
+            layer = split_str[-2]
+            module_name = ".".join(split_str[:-3])
+            float_weight_ih_key = module_name + ".weight_ih_l" + layer
+            float_weight_hh_key = module_name + ".weight_hh_l" + layer
+            if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
+                weight_dict[key] = {}
+                weight_dict[key]["float"] = float_dict[float_weight_ih_key]
+                weight_dict[key]["quantized"] = (
+                    quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
+                )
+                weight_dict[key]["float"] = float_dict[float_weight_hh_key]
+                weight_dict[key]["quantized"] = (
+                    quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
+                )
+
+    return weight_dict
+
+
+def _get_logger_dict_helper(
+    mod: nn.Module, target_dict: Dict[str, Any],
+    prefix: str = "",
+) -> None:
+    r"""This is the helper function for get_logger_dict
+
+    Args:
+        mod: module we want to save all logger stats
+        prefix: prefix for the current module
+        target_dict: the dictionary used to save all logger stats
+    """
+
+    def get_prefix(prefix):
+        return prefix if prefix == "" else prefix + "."
+
+    for name, child in mod.named_children():
+        if isinstance(child, Logger):
+            target_dict[get_prefix(prefix) + "stats"] = child.stats
+            break
+
+    for name, child in mod.named_children():
+        module_prefix = get_prefix(prefix) + name if prefix else name
+        _get_logger_dict_helper(child, target_dict, module_prefix)
+
+
+def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
+    r"""Traverse the modules and save all logger stats into target dict.
+    This is mainly used for quantization accuracy debug.
+
+    Type of loggers supported:
+        ShadowLogger: used to log the outputs of the quantized module and its
+            matching float shadow module,
+        OutputLogger: used to log the outputs of the modules
+
+    Args:
+        mod: module we want to save all logger stats
+        prefix: prefix for the current module
+
+    Return:
+        target_dict: the dictionary used to save all logger stats
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
+
+    target_dict: Dict[str, Dict] = {}
+    _get_logger_dict_helper(mod, target_dict, prefix)
+    return target_dict
+
+
+class Logger(nn.Module):
+    r"""Base class for stats logging
+    """
+
+    def __init__(self):
+        super(Logger, self).__init__()
+        self.stats = {}
+        # We only insert observer if the op is quantized with static quantization,
+        # which is identified by activation_observer.dtype == quint8.  This is needed
+        # when attaching Logger as observer for FX mode
+        self.dtype = torch.quint8
+
+    def forward(self, x):
+        pass
+
+
+class ShadowLogger(Logger):
+    r"""Class used in Shadow module to record the outputs of the original and
+    shadow modules.
+    """
+
+    def __init__(self):
+        super(ShadowLogger, self).__init__()
+        self.stats["float"] = []
+        self.stats["quantized"] = []
+
+    def forward(self, x, y):
+        if len(x) > 1:
+            x = x[0]
+        if len(y) > 1:
+            y = y[0]
+        self.stats["quantized"].append(x.detach())
+        self.stats["float"].append(y.detach())
+
+
+class OutputLogger(Logger):
+    r"""Class used to log the outputs of the module
+    """
+
+    def __init__(self):
+        super(OutputLogger, self).__init__()
+        self.stats["tensor_val"] = []
+
+
+    def forward(self, x):
+        self.stats["tensor_val"].append(x)
+        return x
+
+
+def _convert_tuple_to_list(t: Any) -> Any:
+    return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
+
+
+def _dequantize_tensor_list(t: Any) -> Any:
+    return (
+        list(_dequantize_tensor_list(x) for x in t)
+        if type(t) is list
+        else t.dequantize()
+        if t.is_quantized
+        else t
+    )
+
+
+class Shadow(nn.Module):
+    r"""Shadow module attaches the float module to its matching quantized module
+    as the shadow. Then it uses Logger module to process the outputs of both
+    modules.
+
+    Args:
+        q_module: module quantized from float_module that we want to shadow
+        float_module: float module used to shadow q_module
+        logger_cls: type of logger used to process the outputs of q_module and
+            float_module. ShadowLogger or custom loggers can be used.
+    """
+
+    def __init__(self, q_module, float_module, logger_cls):
+        super(Shadow, self).__init__()
+        self.orig_module = q_module
+        self.shadow_module = float_module
+        self.dequant = nnq.DeQuantize()
+        self.logger = logger_cls()
+
+    def forward(self, *x) -> torch.Tensor:
+        xl = _convert_tuple_to_list(x)
+        output = self.orig_module(*xl)
+        xl_float = _dequantize_tensor_list(xl)
+        shadow_output = self.shadow_module(*xl_float)
+        self.logger(output, shadow_output)
+        return output
+
+    def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        output = self.orig_module.add(x, y)
+        x = x.dequantize()
+        y = y.dequantize()
+        shadow_output = self.shadow_module.add(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+    def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
+        output = self.orig_module.add_scalar(x, y)
+        x = x.dequantize()
+        shadow_output = self.shadow_module.add_scalar(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+    def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        output = self.orig_module.mul(x, y)
+        x = x.dequantize()
+        y = y.dequantize()
+        shadow_output = self.shadow_module.mul(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+    def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
+        output = self.orig_module.mul_scalar(x, y)
+        x = x.dequantize()
+        shadow_output = self.shadow_module.mul_scalar(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+    def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
+        output = self.orig_module.cat(x, dim)
+        x = [y.dequantize() for y in x]
+        shadow_output = self.shadow_module.cat(x, dim)
+        self.logger(output, shadow_output)
+        return output
+
+    def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        output = self.orig_module.add_relu(x, y)
+        x = x.dequantize()
+        y = y.dequantize()
+        shadow_output = self.shadow_module.add_relu(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+
+def prepare_model_with_stubs(
+    float_module: nn.Module, q_module: nn.Module,
+    module_swap_list: Set[type], logger_cls: Callable,
+) -> None:
+    r"""Prepare the model by attaching the float module to its matching quantized
+    module as the shadow if the float module type is in module_swap_list.
+
+    Example usage:
+        prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
+        q_model(data)
+        ob_dict = get_logger_dict(q_model)
+
+    Args:
+        float_module: float module used to generate the q_module
+        q_module: module quantized from float_module
+        module_swap_list: list of float module types to attach the shadow
+        logger_cls: type of logger to be used in shadow module to process the outputs of
+            quantized module and its float shadow module
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
+
+    float_module_children = {}
+    for name, mod in float_module.named_children():
+        float_module_children[name] = mod
+
+    reassign = {}
+    for name, mod in q_module.named_children():
+
+        if name not in float_module_children:
+            continue
+
+        float_mod = float_module_children[name]
+
+        if type(float_mod) not in module_swap_list:
+            prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
+
+        # Insert shadow module only if the module is not of the same type as
+        # the floating point module
+        if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
+            reassign[name] = Shadow(mod, float_mod, logger_cls)
+
+    for key, value in reassign.items():
+        q_module._modules[key] = value
+
+def _is_identical_module_type(mod1, mod2):
+    # Compare if two modules have the same dtype
+    mod1_module_types = [type(mod) for mod in mod1.modules()]
+    mod2_module_types = [type(mod) for mod in mod2.modules()]
+    return mod1_module_types == mod2_module_types
+
+
+
+def compare_model_stub(
+    float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
+    *data, logger_cls=ShadowLogger
+) -> Dict[str, Dict]:
+    r"""Compare quantized module in a model with its floating point counterpart,
+    feeding both of them the same input. Return a dict with key corresponding to
+    module names and each entry being a dictionary with two keys 'float' and
+    'quantized', containing the output tensors of quantized and its matching
+    float shadow module. This dict can be used to compare and compute the module
+    level quantization error.
+
+    This function first call prepare_model_with_stubs() to swap the quantized
+    module that we want to compare with the Shadow module, which takes quantized
+    module, corresponding float module and logger as input, and creates a forward
+    path inside to make the float module to shadow quantized module sharing the
+    same input. The logger can be customizable, default logger is ShadowLogger
+    and it will save the outputs of the quantized module and float module that
+    can be used to compute the module level quantization error.
+
+    Example usage:
+        module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
+        ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
+        for key in ob_dict:
+            print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
+
+    Args:
+        float_model: float model used to generate the q_model
+        q_model: model quantized from float_model
+        module_swap_list: list of float module types at which shadow modules will
+            be attached.
+        data: input data used to run the prepared q_model
+        logger_cls: type of logger to be used in shadow module to process the outputs of
+            quantized module and its float shadow module
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
+    prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
+    q_model(*data)
+    ob_dict = get_logger_dict(q_model)
+    return ob_dict
+
+
+def get_matching_activations(
+    float_module: nn.Module, q_module: nn.Module,
+) -> Dict[str, Dict[str, torch.Tensor]]:
+    r"""Find the matching activation between float and quantized modules.
+
+    Args:
+        float_module: float module used to generate the q_module
+        q_module: module quantized from float_module
+
+    Return:
+        act_dict: dict with key corresponding to quantized module names and each
+        entry being a dictionary with two keys 'float' and 'quantized', containing
+        the matching float and quantized activations
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
+    float_dict = get_logger_dict(float_module)
+    quantized_dict = get_logger_dict(q_module)
+    act_dict: Dict[str, Dict] = {}
+    for key in quantized_dict:
+        match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
+        if match_key is not None:
+            act_dict[key] = {}
+            act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
+            act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
+    return act_dict
+
+
+def prepare_model_outputs(
+    float_module: nn.Module,
+    q_module: nn.Module,
+    logger_cls=OutputLogger,
+    allow_list=None
+) -> None:
+    r"""Prepare the model by attaching the logger to both float module
+    and quantized module if they are in the allow_list.
+
+    Args:
+        float_module: float module used to generate the q_module
+        q_module: module quantized from float_module
+        logger_cls: type of logger to be attached to float_module and q_module
+        allow_list: list of module types to attach logger
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
+    if allow_list is None:
+        allow_list = get_default_compare_output_module_list()
+
+    qconfig_debug = torch.quantization.QConfig(activation=logger_cls, weight=None)
+    float_module.qconfig = qconfig_debug  # type: ignore[assignment]
+    prepare(float_module, inplace=True, allow_list=allow_list)
+    q_module.qconfig = qconfig_debug  # type: ignore[assignment]
+    prepare(
+        q_module,
+        inplace=True,
+        allow_list=allow_list,
+        observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
+    )
+
+
+def compare_model_outputs(
+    float_model: nn.Module,
+    q_model: nn.Module,
+    *data,
+    logger_cls=OutputLogger,
+    allow_list=None
+) -> Dict[str, Dict[str, torch.Tensor]]:
+    r"""Compare output activations between float and quantized models at
+    corresponding locations for the same input. Return a dict with key corresponding
+    to quantized module names and each entry being a dictionary with two keys
+    'float' and 'quantized', containing the activations of quantized model and
+    float model at matching locations. This dict can be used to compare and
+    compute the propagation quantization error.
+
+    Example usage:
+        act_compare_dict = compare_model_outputs(float_model, qmodel, data)
+        for key in act_compare_dict:
+            print(key, compute_error(act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize()))
+
+    Args:
+        float_model: float model used to generate the q_model
+        q_model: model quantized from float_model
+        data: input data used to run the prepared float_model and q_model
+        logger_cls: type of logger to be attached to float_module and q_module
+        allow_list: list of module types to attach logger
+
+    Return:
+        act_compare_dict: dict with key corresponding to quantized module names
+        and each entry being a dictionary with two keys 'float' and 'quantized',
+        containing the matching float and quantized activations
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
+    if allow_list is None:
+        allow_list = get_default_compare_output_module_list()
+    prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
+    float_model(*data)
+    q_model(*data)
+    act_compare_dict = get_matching_activations(float_model, q_model)
+    return act_compare_dict
diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py
new file mode 100644 (file)
index 0000000..812a389
--- /dev/null
@@ -0,0 +1,513 @@
+import collections
+
+import torch
+import torch.nn as nn
+import torch.quantization.quantize_fx as quantize_fx
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+from torch.ao.ns.fx.mappings import (
+    get_base_name_to_sets_of_related_ops,
+)
+from torch.ao.ns.fx.graph_matcher import (
+    get_matching_subgraph_pairs,
+    get_type_a_related_to_b,
+)
+
+from .fx.weight_utils import (
+    extract_weight_from_node,
+)
+
+from .fx.graph_passes import (
+    add_loggers_to_model,
+    create_a_shadows_b,
+)
+
+from .fx.utils import (
+    rekey_logger_info_on_node_name_of_model,
+    maybe_add_missing_fqns,
+    get_target_type_str,
+)
+
+from .fx.ns_types import (
+    NSSingleResultValuesType,
+    NSResultsType,
+    NSNodeTargetType,
+)
+
+from typing import Dict, Tuple, Callable, List, Optional, Set
+
+RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+
+class OutputLogger(nn.Module):
+    stats: List[torch.Tensor]
+    stats_rnn: List[RNNReturnType]
+
+    def __init__(
+        self,
+        ref_node_name: str,
+        prev_node_name: str,
+        model_name: str,
+        ref_name: str,
+        prev_node_target_type: str,
+        ref_node_target_type: str,
+        results_type: str,
+        index_within_arg: int,
+        index_of_arg: int,
+        fqn: Optional[str],
+    ):
+        super().__init__()
+        self.stats: List[torch.Tensor] = []
+        self.stats_rnn: List[RNNReturnType] = []
+
+        # name of the node which was responsible for adding this logger
+        # Note:
+        # - if we are logging node outputs, this is the same as prev_node_name
+        # - if we are logging node inputs, this is the name of the node
+        #   whose input this logger is logging.
+        #
+        # example, where logger1 is logging input of op1 and logger2 is logging
+        #    the output of op1:
+        #
+        #  x1 -> logger1 -> op1 -> logger2 -> x2
+        #
+        # in this example,
+        #   - logger1's prev_node_name is x1 and ref_node_name is op1
+        #   - logger2's prev_node_name is op1 and ref_node_name is op1
+        self.ref_node_name = ref_node_name
+        # name of the node whose output this Logger is capturing
+        self.prev_node_name = prev_node_name
+
+        # name of the model from which the node originated from
+        self.model_name = model_name
+        # reference name, used to match loggers from separate models
+        # to each other
+        self.ref_name = ref_name
+        # type of the target of the node whose output this logger is logging
+        self.prev_node_target_type = prev_node_target_type
+        # type of the target of the node which was respondible for adding this
+        # logger
+        self.ref_node_target_type = ref_node_target_type
+        # what kind of values are inside of stats
+        self.results_type = results_type
+        # index of this node within the arg of the input/output node
+        # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
+        self.index_within_arg = index_within_arg
+        # index of this node within the args of the input/output node
+        # for example, in add(x1, x2), x2 would have index_of_arg == 1
+        self.index_of_arg = index_of_arg
+        # fully qualified name
+        self.fqn = fqn
+
+    # Note: cannot annotate the type of x because TorchScript does not support
+    #   the Union type.
+    def forward(self, x):
+        if isinstance(x, torch.Tensor):
+            self.stats.append(x.detach())
+        elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
+            new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
+            self.stats_rnn.append(new_res)
+        return x
+
+    def __repr__(self):
+        return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
+prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
+ref_node_target_type={self.ref_node_target_type}
+results_type={self.results_type}, index_within_arg={self.index_within_arg},
+index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
+
+
+class NSTracer(quantize_fx.QuantizationTracer):
+    """
+    Just like a regular tracer, but treats observers and fake_quantize
+    modules as leaf modules.
+    """
+    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
+        if isinstance(m, torch.quantization.ObserverBase):
+            return True
+        elif isinstance(m, torch.quantization.FakeQuantizeBase):
+            return True
+        return super().is_leaf_module(m, module_qualified_name)
+
+
+def _extract_weights_one_model(
+    model_name: str,
+    model: GraphModule,
+    nodes_and_names_to_instrument: List[Tuple[Node, str]],
+    results: NSResultsType,
+    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> None:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
+    for node, ref_name in nodes_and_names_to_instrument:
+        res_type = NSSingleResultValuesType.WEIGHT.value
+        extracted_weight = extract_weight_from_node(
+            node, model, op_to_type_to_weight_extraction_fn)
+        if extracted_weight:
+            if ref_name not in results:
+                results[ref_name] = {res_type: {}}
+            results[ref_name][res_type][model_name] = [extracted_weight]
+
+
+def _extract_weights_impl(
+    model_name_a: str,
+    gm_a: GraphModule,
+    model_name_b: str,
+    gm_b: GraphModule,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> NSResultsType:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
+    matched_subgraph_pairs = get_matching_subgraph_pairs(
+        gm_a, gm_b, base_name_to_sets_of_related_ops,
+        unmatchable_types_map)
+
+    # split the subgraph pairs into one data structure for each model
+    nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
+    nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
+    for match_name, match in matched_subgraph_pairs.items():
+        subgraph_a, subgraph_b = match
+        nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
+        nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
+
+    # populate the results, one model at a time
+    results: NSResultsType = {}
+    _extract_weights_one_model(
+        model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
+        op_to_type_to_weight_extraction_fn)
+    _extract_weights_one_model(
+        model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
+        op_to_type_to_weight_extraction_fn)
+
+    # fill in missing fqn entries
+    maybe_add_missing_fqns(results)
+
+    # rekey on names of nodes in gm_b
+    results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
+
+    return results
+
+
+def extract_weights(
+    model_name_a: str,
+    model_a: nn.Module,
+    model_name_b: str,
+    model_b: nn.Module,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> NSResultsType:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
+    if base_name_to_sets_of_related_ops is None:
+        base_name_to_sets_of_related_ops = \
+            get_base_name_to_sets_of_related_ops()
+    type_a_related_to_b = \
+        get_type_a_related_to_b(base_name_to_sets_of_related_ops)
+
+    # TODO(future PR): expose these
+    skipped_module_names: List[str] = []
+    skipped_module_classes: List[Callable] = []
+    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+    if hasattr(model_a, '_node_name_to_scope'):
+        gm_a._node_name_to_scope = model_a._node_name_to_scope
+    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+    if hasattr(model_b, '_node_name_to_scope'):
+        gm_b._node_name_to_scope = model_b._node_name_to_scope
+    return _extract_weights_impl(
+        model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
+        unmatchable_types_map, op_to_type_to_weight_extraction_fn)
+
+
+def _add_loggers_one_model(
+    model_name: str,
+    model: GraphModule,
+    nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
+    nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
+    logger_cls: Callable,
+) -> nn.Module:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
+
+    # TODO(future PR): do not observe nodes we do not care
+    #   about (both fp32, denylist, etc)
+    node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
+    node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
+    for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
+        node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
+    for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
+        node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
+
+    model = add_loggers_to_model(
+        model, node_to_instrument_inputs_to_ref_name,
+        node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
+    return model
+
+
+def _add_loggers_impl(
+    name_a: str,
+    gm_a: GraphModule,
+    name_b: str,
+    gm_b: GraphModule,
+    logger_cls: Callable,
+    should_log_inputs: bool,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> Tuple[nn.Module, nn.Module]:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
+    matched_subgraph_pairs = get_matching_subgraph_pairs(
+        gm_a, gm_b,
+        base_name_to_sets_of_related_ops, unmatchable_types_map)
+    nodes_and_names_to_instrument_inputs_a = []
+    nodes_and_names_to_instrument_inputs_b = []
+    nodes_and_names_to_instrument_outputs_a = []
+    nodes_and_names_to_instrument_outputs_b = []
+    for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
+        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
+        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
+        # Note: for matching inputs we use start_node, such as observing
+        # the input of linear in linear-relu
+        if should_log_inputs:
+            nodes_and_names_to_instrument_inputs_a.append(
+                (subgraph_a.start_node, match_name, ref_node_type_a))
+            nodes_and_names_to_instrument_inputs_b.append(
+                (subgraph_b.start_node, match_name, ref_node_type_b))
+        # Note: for matching activations we always use end_node,
+        # such as observing the output of relu in linear-relu
+        nodes_and_names_to_instrument_outputs_a.append(
+            (subgraph_a.end_node, match_name, ref_node_type_a))
+        nodes_and_names_to_instrument_outputs_b.append(
+            (subgraph_b.end_node, match_name, ref_node_type_b))
+
+    new_model_a = _add_loggers_one_model(
+        name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
+        nodes_and_names_to_instrument_outputs_a, logger_cls)
+    new_model_b = _add_loggers_one_model(
+        name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
+        nodes_and_names_to_instrument_outputs_b, logger_cls)
+    return (new_model_a, new_model_b)
+
+
+def add_loggers(
+    name_a: str,
+    model_a: nn.Module,
+    name_b: str,
+    model_b: nn.Module,
+    logger_cls: Callable,
+    should_log_inputs : bool = False,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> Tuple[nn.Module, nn.Module]:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
+    # TODO(future PR): expose these
+    skipped_module_names: List[str] = []
+    skipped_module_classes: List[Callable] = []
+    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+    if hasattr(model_a, '_node_name_to_scope'):
+        gm_a._node_name_to_scope = model_a._node_name_to_scope
+    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+    if hasattr(model_b, '_node_name_to_scope'):
+        gm_b._node_name_to_scope = model_b._node_name_to_scope
+    return _add_loggers_impl(
+        name_a, gm_a, name_b, gm_b, logger_cls,
+        should_log_inputs=should_log_inputs,
+        base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
+        unmatchable_types_map=unmatchable_types_map)
+
+
+def _extract_logger_info_one_model(
+    model: nn.Module,
+    results: NSResultsType,
+    logger_cls: Callable,
+) -> None:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
+    for gm_name, mod in model.named_modules():
+        # TODO(future PR): better check when scripted
+        is_logger = (
+            isinstance(mod, logger_cls)  # type: ignore[arg-type]
+            or (
+                isinstance(mod, torch.jit.RecursiveScriptModule)
+                and mod.original_name == 'OutputLogger'
+            )
+        )
+        if is_logger:
+            key = mod.ref_name
+            if key not in results:
+                results[key] = {}
+            assert mod.model_name not in results[key], \
+                f"{mod.model_name} is already present in results"
+            if mod.results_type not in results[key]:
+                results[key][mod.results_type] = {}
+            if mod.model_name not in results[key][mod.results_type]:
+                results[key][mod.results_type][mod.model_name] = []
+            stats_to_use = mod.stats
+            if len(mod.stats_rnn) > 0:
+                stats_to_use = mod.stats_rnn
+            results[key][mod.results_type][mod.model_name].append({
+                'type': mod.results_type,
+                'values': stats_to_use,
+                'ref_node_name': mod.ref_node_name,
+                'ref_node_target_type': mod.ref_node_target_type,
+                'prev_node_name': mod.prev_node_name,
+                'prev_node_target_type': mod.prev_node_target_type,
+                'index_within_arg': mod.index_within_arg,
+                'index_of_arg': mod.index_of_arg,
+                'fqn': mod.fqn,
+            })
+            # ensure the list stays sorted
+            results[key][mod.results_type][mod.model_name].sort(
+                key=lambda res:
+                f"{res['index_of_arg']}:{res['index_within_arg']}"
+            )
+
+
+# TODO(future PR): align on naming
+# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
+def extract_logger_info(
+    model_a: nn.Module,
+    model_b: nn.Module,
+    logger_cls: Callable,
+    model_name_to_use_for_layer_names: str,
+) -> NSResultsType:
+    """
+    Same thing as ns.extract_logger_info, but for models prepared with
+    this module.
+
+    TODO(future PR): real docblock
+
+    Output format: NSResultsType
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
+    results: NSResultsType = {}
+    for model in (model_a, model_b):
+        _extract_logger_info_one_model(model, results, logger_cls)
+    # fill in missing fqn entries
+    maybe_add_missing_fqns(results)
+    # rekey on the name of model b
+    results = rekey_logger_info_on_node_name_of_model(
+        results, model_name_to_use_for_layer_names)
+    return results
+
+
+def _add_shadow_loggers_impl(
+    name_a: str,
+    gm_a: GraphModule,
+    name_b: str,
+    gm_b: GraphModule,
+    logger_cls: Callable,
+    should_log_inputs: bool,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> nn.Module:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
+    matched_subgraph_pairs = get_matching_subgraph_pairs(
+        gm_a, gm_b, base_name_to_sets_of_related_ops,
+        unmatchable_types_map)
+    gm_a_shadows_b = create_a_shadows_b(
+        name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
+        should_log_inputs=should_log_inputs,
+        node_type_to_io_type_map=node_type_to_io_type_map)
+    return gm_a_shadows_b
+
+
+def add_shadow_loggers(
+    name_a: str,
+    model_a: nn.Module,
+    name_b: str,
+    model_b: nn.Module,
+    logger_cls: Callable,
+    should_log_inputs: bool = False,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> nn.Module:
+    """
+    Same thing as add_loggers, but for an `a_shadows_b` model.
+    TODO(future PR): real docblock
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
+    # TODO(future PR): expose these
+    skipped_module_names: List[str] = []
+    skipped_module_classes: List[Callable] = []
+    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+    if hasattr(model_a, '_node_name_to_scope'):
+        gm_a._node_name_to_scope = model_a._node_name_to_scope
+    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+    if hasattr(model_b, '_node_name_to_scope'):
+        gm_b._node_name_to_scope = model_b._node_name_to_scope
+    return _add_shadow_loggers_impl(
+        name_a, gm_a, name_b, gm_b, logger_cls,
+        should_log_inputs=should_log_inputs,
+        base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
+        node_type_to_io_type_map=node_type_to_io_type_map,
+        unmatchable_types_map=unmatchable_types_map)
+
+
+def extract_shadow_logger_info(
+    model_a_shadows_b: nn.Module,
+    logger_cls: Callable,
+    model_name_to_use_for_layer_names: str,
+) -> NSResultsType:
+    """
+    Same thing as extract_logger_info, but for an `a_shadows_b` model.
+    TODO(future PR): real docblock
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
+    results: NSResultsType = collections.defaultdict(dict)
+    _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
+    # fill in missing fqn entries
+    maybe_add_missing_fqns(results)
+    # rekey on the name of model b
+    results = rekey_logger_info_on_node_name_of_model(
+        results, model_name_to_use_for_layer_names)
+    return dict(results)
+
+
+def extend_logger_results_with_comparison(
+    results: NSResultsType,
+    model_name_1: str,
+    model_name_2: str,
+    comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+    comparison_name: str,
+) -> None:
+    """
+    Compares the logged values from `model_name_2` against the corresponding
+    values in `model_name_1`, using `comparison_fn`. Records the result
+    in `model_name_2`'s results under `comparison_name`.
+    """
+    for _, results_type_to_results in results.items():
+        for _, model_name_to_results in results_type_to_results.items():
+            assert model_name_1 in model_name_to_results, \
+                f"{model_name_1} not found in results"
+            assert model_name_2 in model_name_to_results, \
+                f"{model_name_2} not found in results"
+
+            results_1 = model_name_to_results[model_name_1]
+            results_2 = model_name_to_results[model_name_2]
+
+            for result_2 in results_2:
+                index_within_arg_2 = result_2['index_within_arg']
+                index_of_arg_2 = result_2['index_of_arg']
+                # find corresponding result_1
+                result_1 = None
+                for cur_result_1 in results_1:
+                    index_within_arg_1 = cur_result_1['index_within_arg']
+                    index_of_arg_1 = cur_result_1['index_of_arg']
+                    if (
+                        (index_within_arg_1 == index_within_arg_2) and
+                        (index_of_arg_1 == index_of_arg_2)
+                    ):
+                        result_1 = cur_result_1
+                        break
+                assert result_1 is not None
+
+                values_1 = result_1['values']
+                values_2 = result_2['values']
+                result_2[comparison_name] = []
+                for value_1, value_2 in zip(values_1, values_2):
+                    comparison_result = comparison_fn(value_1, value_2)
+                    result_2[comparison_name].append(comparison_result)
diff --git a/torch/ao/ns/fx/__init__.py b/torch/ao/ns/fx/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
similarity index 99%
rename from torch/quantization/ns/graph_passes.py
rename to torch/ao/ns/fx/graph_passes.py
index 51eb6c2..05dc0a9 100644 (file)
@@ -19,7 +19,7 @@ from .ns_types import (
     NSSubgraph,
     NSNodeTargetType,
 )
-from torch.quantization.ns.mappings import (
+from torch.ao.ns.fx.mappings import (
     get_node_type_to_io_type_map,
 )
 from torch.quantization.quantize import is_activation_post_process
index d3ffd06..ed373e2 100644 (file)
@@ -3,7 +3,7 @@ import torch.nn as nn
 import torch.nn.quantized as nnq
 
 import torch.quantization
-import torch.quantization._numeric_suite as ns
+import torch.ao.ns._numeric_suite as ns
 
 _supported_modules = {nn.Linear, nn.Conv2d}
 _supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
index c6337c7..c5a7848 100644 (file)
-import torch
-import torch.nn as nn
-import torch.nn.quantized as nnq
-import torch.nn.quantized.dynamic as nnqd
-from torch.quantization import prepare
-from typing import Dict, List, Optional, Any, Union, Callable, Set
-
-from .quantization_mappings import (
-    get_default_compare_output_module_list,
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/ns/_numeric_suite.py`, while adding an import statement
+here.
+"""
+
+from torch.ao.ns._numeric_suite import (
+    NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
+    _find_match,
+    compare_weights,
+    _get_logger_dict_helper,
+    get_logger_dict,
+    Logger,
+    ShadowLogger,
+    OutputLogger,
+    _convert_tuple_to_list,
+    _dequantize_tensor_list,
+    Shadow,
+    prepare_model_with_stubs,
+    _is_identical_module_type,
+    compare_model_stub,
+    get_matching_activations,
+    prepare_model_outputs,
+    compare_model_outputs,
 )
-
-NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
-    nnqd.Linear,
-    nnq.Linear,
-    nnqd.LSTM,
-    nn.LSTM,
-}
-
-
-def _find_match(
-    str_list: Union[Dict[str, Any], List[str]], key_str: str,
-    postfix: str,
-) -> Optional[str]:
-    split_str = key_str.split(".")
-    if split_str[-1] == postfix:
-        match_string = "".join(key_str.split(".")[0:-1])
-        for s2 in str_list:
-            pattern1 = "".join(s2.split(".")[0:-1])
-            pattern2 = "".join(s2.split(".")[0:-2])
-            if match_string == pattern1:
-                return s2
-            if match_string == pattern2:
-                return s2
-
-        # For matching "fc.weight" and "fc._packed_params._packed_params"
-        if postfix == "_packed_params":
-            match_string = "".join(key_str.split(".")[0:-2])
-            if len(match_string) == 0:
-                return None
-            for s2 in str_list:
-                pattern1 = "".join(s2.split(".")[0:-1])
-                pattern2 = "".join(s2.split(".")[0:-2])
-                if match_string == pattern1:
-                    return s2
-                if match_string == pattern2:
-                    return s2
-        return None
-    else:
-        return None
-
-
-def compare_weights(
-    float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
-) -> Dict[str, Dict[str, torch.Tensor]]:
-    r"""Compare the weights of the float module with its corresponding quantized
-    module. Return a dict with key corresponding to module names and each entry being
-    a dictionary with two keys 'float' and 'quantized', containing the float and
-    quantized weights. This dict can be used to compare and compute the quantization
-    error of the weights of float and quantized models.
-
-    Example usage:
-        wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
-        for key in wt_compare_dict:
-            print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
-
-    Args:
-        float_dict: state dict of the float model
-        quantized_dict: state dict of the quantized model
-
-    Return:
-        weight_dict: dict with key corresponding to module names and each entry being
-        a dictionary with two keys 'float' and 'quantized', containing the float and
-        quantized weights
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
-    weight_dict: Dict[str, Dict] = {}
-    for key in quantized_dict:
-        match_key = _find_match(float_dict, key, "weight")
-        if match_key is not None:
-            weight_dict[key] = {}
-            weight_dict[key]["float"] = float_dict[match_key]
-            weight_dict[key]["quantized"] = quantized_dict[key]
-            continue
-
-        # For matching "fc.weight" and "fc._packed_params._packed_params"
-        match_key = _find_match(float_dict, key, "_packed_params")
-        if match_key is not None:
-            weight_dict[key] = {}
-            weight_dict[key]["float"] = float_dict[match_key]
-            weight_dict[key]["quantized"] = quantized_dict[key][0]
-
-        # For LSTM
-        split_str = key.split(".")
-        if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
-            layer = split_str[-2]
-            module_name = ".".join(split_str[:-3])
-            float_weight_ih_key = module_name + ".weight_ih_l" + layer
-            float_weight_hh_key = module_name + ".weight_hh_l" + layer
-            if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
-                weight_dict[key] = {}
-                weight_dict[key]["float"] = float_dict[float_weight_ih_key]
-                weight_dict[key]["quantized"] = (
-                    quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
-                )
-                weight_dict[key]["float"] = float_dict[float_weight_hh_key]
-                weight_dict[key]["quantized"] = (
-                    quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
-                )
-
-    return weight_dict
-
-
-def _get_logger_dict_helper(
-    mod: nn.Module, target_dict: Dict[str, Any],
-    prefix: str = "",
-) -> None:
-    r"""This is the helper function for get_logger_dict
-
-    Args:
-        mod: module we want to save all logger stats
-        prefix: prefix for the current module
-        target_dict: the dictionary used to save all logger stats
-    """
-
-    def get_prefix(prefix):
-        return prefix if prefix == "" else prefix + "."
-
-    for name, child in mod.named_children():
-        if isinstance(child, Logger):
-            target_dict[get_prefix(prefix) + "stats"] = child.stats
-            break
-
-    for name, child in mod.named_children():
-        module_prefix = get_prefix(prefix) + name if prefix else name
-        _get_logger_dict_helper(child, target_dict, module_prefix)
-
-
-def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
-    r"""Traverse the modules and save all logger stats into target dict.
-    This is mainly used for quantization accuracy debug.
-
-    Type of loggers supported:
-        ShadowLogger: used to log the outputs of the quantized module and its
-            matching float shadow module,
-        OutputLogger: used to log the outputs of the modules
-
-    Args:
-        mod: module we want to save all logger stats
-        prefix: prefix for the current module
-
-    Return:
-        target_dict: the dictionary used to save all logger stats
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
-
-    target_dict: Dict[str, Dict] = {}
-    _get_logger_dict_helper(mod, target_dict, prefix)
-    return target_dict
-
-
-class Logger(nn.Module):
-    r"""Base class for stats logging
-    """
-
-    def __init__(self):
-        super(Logger, self).__init__()
-        self.stats = {}
-        # We only insert observer if the op is quantized with static quantization,
-        # which is identified by activation_observer.dtype == quint8.  This is needed
-        # when attaching Logger as observer for FX mode
-        self.dtype = torch.quint8
-
-    def forward(self, x):
-        pass
-
-
-class ShadowLogger(Logger):
-    r"""Class used in Shadow module to record the outputs of the original and
-    shadow modules.
-    """
-
-    def __init__(self):
-        super(ShadowLogger, self).__init__()
-        self.stats["float"] = []
-        self.stats["quantized"] = []
-
-    def forward(self, x, y):
-        if len(x) > 1:
-            x = x[0]
-        if len(y) > 1:
-            y = y[0]
-        self.stats["quantized"].append(x.detach())
-        self.stats["float"].append(y.detach())
-
-
-class OutputLogger(Logger):
-    r"""Class used to log the outputs of the module
-    """
-
-    def __init__(self):
-        super(OutputLogger, self).__init__()
-        self.stats["tensor_val"] = []
-
-
-    def forward(self, x):
-        self.stats["tensor_val"].append(x)
-        return x
-
-
-def _convert_tuple_to_list(t: Any) -> Any:
-    return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
-
-
-def _dequantize_tensor_list(t: Any) -> Any:
-    return (
-        list(_dequantize_tensor_list(x) for x in t)
-        if type(t) is list
-        else t.dequantize()
-        if t.is_quantized
-        else t
-    )
-
-
-class Shadow(nn.Module):
-    r"""Shadow module attaches the float module to its matching quantized module
-    as the shadow. Then it uses Logger module to process the outputs of both
-    modules.
-
-    Args:
-        q_module: module quantized from float_module that we want to shadow
-        float_module: float module used to shadow q_module
-        logger_cls: type of logger used to process the outputs of q_module and
-            float_module. ShadowLogger or custom loggers can be used.
-    """
-
-    def __init__(self, q_module, float_module, logger_cls):
-        super(Shadow, self).__init__()
-        self.orig_module = q_module
-        self.shadow_module = float_module
-        self.dequant = nnq.DeQuantize()
-        self.logger = logger_cls()
-
-    def forward(self, *x) -> torch.Tensor:
-        xl = _convert_tuple_to_list(x)
-        output = self.orig_module(*xl)
-        xl_float = _dequantize_tensor_list(xl)
-        shadow_output = self.shadow_module(*xl_float)
-        self.logger(output, shadow_output)
-        return output
-
-    def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
-        output = self.orig_module.add(x, y)
-        x = x.dequantize()
-        y = y.dequantize()
-        shadow_output = self.shadow_module.add(x, y)
-        self.logger(output, shadow_output)
-        return output
-
-    def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
-        output = self.orig_module.add_scalar(x, y)
-        x = x.dequantize()
-        shadow_output = self.shadow_module.add_scalar(x, y)
-        self.logger(output, shadow_output)
-        return output
-
-    def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
-        output = self.orig_module.mul(x, y)
-        x = x.dequantize()
-        y = y.dequantize()
-        shadow_output = self.shadow_module.mul(x, y)
-        self.logger(output, shadow_output)
-        return output
-
-    def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
-        output = self.orig_module.mul_scalar(x, y)
-        x = x.dequantize()
-        shadow_output = self.shadow_module.mul_scalar(x, y)
-        self.logger(output, shadow_output)
-        return output
-
-    def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
-        output = self.orig_module.cat(x, dim)
-        x = [y.dequantize() for y in x]
-        shadow_output = self.shadow_module.cat(x, dim)
-        self.logger(output, shadow_output)
-        return output
-
-    def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
-        output = self.orig_module.add_relu(x, y)
-        x = x.dequantize()
-        y = y.dequantize()
-        shadow_output = self.shadow_module.add_relu(x, y)
-        self.logger(output, shadow_output)
-        return output
-
-
-def prepare_model_with_stubs(
-    float_module: nn.Module, q_module: nn.Module,
-    module_swap_list: Set[type], logger_cls: Callable,
-) -> None:
-    r"""Prepare the model by attaching the float module to its matching quantized
-    module as the shadow if the float module type is in module_swap_list.
-
-    Example usage:
-        prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
-        q_model(data)
-        ob_dict = get_logger_dict(q_model)
-
-    Args:
-        float_module: float module used to generate the q_module
-        q_module: module quantized from float_module
-        module_swap_list: list of float module types to attach the shadow
-        logger_cls: type of logger to be used in shadow module to process the outputs of
-            quantized module and its float shadow module
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
-
-    float_module_children = {}
-    for name, mod in float_module.named_children():
-        float_module_children[name] = mod
-
-    reassign = {}
-    for name, mod in q_module.named_children():
-
-        if name not in float_module_children:
-            continue
-
-        float_mod = float_module_children[name]
-
-        if type(float_mod) not in module_swap_list:
-            prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
-
-        # Insert shadow module only if the module is not of the same type as
-        # the floating point module
-        if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
-            reassign[name] = Shadow(mod, float_mod, logger_cls)
-
-    for key, value in reassign.items():
-        q_module._modules[key] = value
-
-def _is_identical_module_type(mod1, mod2):
-    # Compare if two modules have the same dtype
-    mod1_module_types = [type(mod) for mod in mod1.modules()]
-    mod2_module_types = [type(mod) for mod in mod2.modules()]
-    return mod1_module_types == mod2_module_types
-
-
-
-def compare_model_stub(
-    float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
-    *data, logger_cls=ShadowLogger
-) -> Dict[str, Dict]:
-    r"""Compare quantized module in a model with its floating point counterpart,
-    feeding both of them the same input. Return a dict with key corresponding to
-    module names and each entry being a dictionary with two keys 'float' and
-    'quantized', containing the output tensors of quantized and its matching
-    float shadow module. This dict can be used to compare and compute the module
-    level quantization error.
-
-    This function first call prepare_model_with_stubs() to swap the quantized
-    module that we want to compare with the Shadow module, which takes quantized
-    module, corresponding float module and logger as input, and creates a forward
-    path inside to make the float module to shadow quantized module sharing the
-    same input. The logger can be customizable, default logger is ShadowLogger
-    and it will save the outputs of the quantized module and float module that
-    can be used to compute the module level quantization error.
-
-    Example usage:
-        module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
-        ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
-        for key in ob_dict:
-            print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
-
-    Args:
-        float_model: float model used to generate the q_model
-        q_model: model quantized from float_model
-        module_swap_list: list of float module types at which shadow modules will
-            be attached.
-        data: input data used to run the prepared q_model
-        logger_cls: type of logger to be used in shadow module to process the outputs of
-            quantized module and its float shadow module
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
-    prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
-    q_model(*data)
-    ob_dict = get_logger_dict(q_model)
-    return ob_dict
-
-
-def get_matching_activations(
-    float_module: nn.Module, q_module: nn.Module,
-) -> Dict[str, Dict[str, torch.Tensor]]:
-    r"""Find the matching activation between float and quantized modules.
-
-    Args:
-        float_module: float module used to generate the q_module
-        q_module: module quantized from float_module
-
-    Return:
-        act_dict: dict with key corresponding to quantized module names and each
-        entry being a dictionary with two keys 'float' and 'quantized', containing
-        the matching float and quantized activations
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
-    float_dict = get_logger_dict(float_module)
-    quantized_dict = get_logger_dict(q_module)
-    act_dict: Dict[str, Dict] = {}
-    for key in quantized_dict:
-        match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
-        if match_key is not None:
-            act_dict[key] = {}
-            act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
-            act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
-    return act_dict
-
-
-def prepare_model_outputs(
-    float_module: nn.Module,
-    q_module: nn.Module,
-    logger_cls=OutputLogger,
-    allow_list=None
-) -> None:
-    r"""Prepare the model by attaching the logger to both float module
-    and quantized module if they are in the allow_list.
-
-    Args:
-        float_module: float module used to generate the q_module
-        q_module: module quantized from float_module
-        logger_cls: type of logger to be attached to float_module and q_module
-        allow_list: list of module types to attach logger
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
-    if allow_list is None:
-        allow_list = get_default_compare_output_module_list()
-
-    qconfig_debug = torch.quantization.QConfig(activation=logger_cls, weight=None)
-    float_module.qconfig = qconfig_debug  # type: ignore[assignment]
-    prepare(float_module, inplace=True, allow_list=allow_list)
-    q_module.qconfig = qconfig_debug  # type: ignore[assignment]
-    prepare(
-        q_module,
-        inplace=True,
-        allow_list=allow_list,
-        observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
-    )
-
-
-def compare_model_outputs(
-    float_model: nn.Module,
-    q_model: nn.Module,
-    *data,
-    logger_cls=OutputLogger,
-    allow_list=None
-) -> Dict[str, Dict[str, torch.Tensor]]:
-    r"""Compare output activations between float and quantized models at
-    corresponding locations for the same input. Return a dict with key corresponding
-    to quantized module names and each entry being a dictionary with two keys
-    'float' and 'quantized', containing the activations of quantized model and
-    float model at matching locations. This dict can be used to compare and
-    compute the propagation quantization error.
-
-    Example usage:
-        act_compare_dict = compare_model_outputs(float_model, qmodel, data)
-        for key in act_compare_dict:
-            print(key, compute_error(act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize()))
-
-    Args:
-        float_model: float model used to generate the q_model
-        q_model: model quantized from float_model
-        data: input data used to run the prepared float_model and q_model
-        logger_cls: type of logger to be attached to float_module and q_module
-        allow_list: list of module types to attach logger
-
-    Return:
-        act_compare_dict: dict with key corresponding to quantized module names
-        and each entry being a dictionary with two keys 'float' and 'quantized',
-        containing the matching float and quantized activations
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
-    if allow_list is None:
-        allow_list = get_default_compare_output_module_list()
-    prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
-    float_model(*data)
-    q_model(*data)
-    act_compare_dict = get_matching_activations(float_model, q_model)
-    return act_compare_dict
index 0e3e313..991b847 100644 (file)
-import collections
-
-import torch
-import torch.nn as nn
-import torch.quantization.quantize_fx as quantize_fx
-from torch.fx import GraphModule
-from torch.fx.graph import Node
-from torch.quantization.ns.mappings import (
-    get_base_name_to_sets_of_related_ops,
-)
-from torch.quantization.ns.graph_matcher import (
-    get_matching_subgraph_pairs,
-    get_type_a_related_to_b,
-)
-
-from .ns.weight_utils import (
-    extract_weight_from_node,
-)
-
-from .ns.graph_passes import (
-    add_loggers_to_model,
-    create_a_shadows_b,
-)
-
-from .ns.utils import (
-    rekey_logger_info_on_node_name_of_model,
-    maybe_add_missing_fqns,
-    get_target_type_str,
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/ns/_numeric_suite_fx.py`, while adding an import statement
+here.
+"""
+
+from torch.ao.ns._numeric_suite_fx import (
+    RNNReturnType,
+    OutputLogger,
+    NSTracer,
+    _extract_weights_one_model,
+    _extract_weights_impl,
+    extract_weights,
+    _add_loggers_one_model,
+    _add_loggers_impl,
+    add_loggers,
+    _extract_logger_info_one_model,
+    extract_logger_info,
+    _add_shadow_loggers_impl,
+    add_shadow_loggers,
+    extract_shadow_logger_info,
+    extend_logger_results_with_comparison,
 )
-
-from .ns.ns_types import (
-    NSSingleResultValuesType,
-    NSResultsType,
-    NSNodeTargetType,
-)
-
-from typing import Dict, Tuple, Callable, List, Optional, Set
-
-RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-
-class OutputLogger(nn.Module):
-    stats: List[torch.Tensor]
-    stats_rnn: List[RNNReturnType]
-
-    def __init__(
-        self,
-        ref_node_name: str,
-        prev_node_name: str,
-        model_name: str,
-        ref_name: str,
-        prev_node_target_type: str,
-        ref_node_target_type: str,
-        results_type: str,
-        index_within_arg: int,
-        index_of_arg: int,
-        fqn: Optional[str],
-    ):
-        super().__init__()
-        self.stats: List[torch.Tensor] = []
-        self.stats_rnn: List[RNNReturnType] = []
-
-        # name of the node which was responsible for adding this logger
-        # Note:
-        # - if we are logging node outputs, this is the same as prev_node_name
-        # - if we are logging node inputs, this is the name of the node
-        #   whose input this logger is logging.
-        #
-        # example, where logger1 is logging input of op1 and logger2 is logging
-        #    the output of op1:
-        #
-        #  x1 -> logger1 -> op1 -> logger2 -> x2
-        #
-        # in this example,
-        #   - logger1's prev_node_name is x1 and ref_node_name is op1
-        #   - logger2's prev_node_name is op1 and ref_node_name is op1
-        self.ref_node_name = ref_node_name
-        # name of the node whose output this Logger is capturing
-        self.prev_node_name = prev_node_name
-
-        # name of the model from which the node originated from
-        self.model_name = model_name
-        # reference name, used to match loggers from separate models
-        # to each other
-        self.ref_name = ref_name
-        # type of the target of the node whose output this logger is logging
-        self.prev_node_target_type = prev_node_target_type
-        # type of the target of the node which was respondible for adding this
-        # logger
-        self.ref_node_target_type = ref_node_target_type
-        # what kind of values are inside of stats
-        self.results_type = results_type
-        # index of this node within the arg of the input/output node
-        # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
-        self.index_within_arg = index_within_arg
-        # index of this node within the args of the input/output node
-        # for example, in add(x1, x2), x2 would have index_of_arg == 1
-        self.index_of_arg = index_of_arg
-        # fully qualified name
-        self.fqn = fqn
-
-    # Note: cannot annotate the type of x because TorchScript does not support
-    #   the Union type.
-    def forward(self, x):
-        if isinstance(x, torch.Tensor):
-            self.stats.append(x.detach())
-        elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
-            new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
-            self.stats_rnn.append(new_res)
-        return x
-
-    def __repr__(self):
-        return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
-prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
-ref_node_target_type={self.ref_node_target_type}
-results_type={self.results_type}, index_within_arg={self.index_within_arg},
-index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
-
-
-class NSTracer(quantize_fx.QuantizationTracer):
-    """
-    Just like a regular tracer, but treats observers and fake_quantize
-    modules as leaf modules.
-    """
-    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
-        if isinstance(m, torch.quantization.ObserverBase):
-            return True
-        elif isinstance(m, torch.quantization.FakeQuantizeBase):
-            return True
-        return super().is_leaf_module(m, module_qualified_name)
-
-
-def _extract_weights_one_model(
-    model_name: str,
-    model: GraphModule,
-    nodes_and_names_to_instrument: List[Tuple[Node, str]],
-    results: NSResultsType,
-    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
-) -> None:
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
-    for node, ref_name in nodes_and_names_to_instrument:
-        res_type = NSSingleResultValuesType.WEIGHT.value
-        extracted_weight = extract_weight_from_node(
-            node, model, op_to_type_to_weight_extraction_fn)
-        if extracted_weight:
-            if ref_name not in results:
-                results[ref_name] = {res_type: {}}
-            results[ref_name][res_type][model_name] = [extracted_weight]
-
-
-def _extract_weights_impl(
-    model_name_a: str,
-    gm_a: GraphModule,
-    model_name_b: str,
-    gm_b: GraphModule,
-    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
-) -> NSResultsType:
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
-    matched_subgraph_pairs = get_matching_subgraph_pairs(
-        gm_a, gm_b, base_name_to_sets_of_related_ops,
-        unmatchable_types_map)
-
-    # split the subgraph pairs into one data structure for each model
-    nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
-    nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
-    for match_name, match in matched_subgraph_pairs.items():
-        subgraph_a, subgraph_b = match
-        nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
-        nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
-
-    # populate the results, one model at a time
-    results: NSResultsType = {}
-    _extract_weights_one_model(
-        model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
-        op_to_type_to_weight_extraction_fn)
-    _extract_weights_one_model(
-        model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
-        op_to_type_to_weight_extraction_fn)
-
-    # fill in missing fqn entries
-    maybe_add_missing_fqns(results)
-
-    # rekey on names of nodes in gm_b
-    results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
-
-    return results
-
-
-def extract_weights(
-    model_name_a: str,
-    model_a: nn.Module,
-    model_name_b: str,
-    model_b: nn.Module,
-    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
-) -> NSResultsType:
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
-    if base_name_to_sets_of_related_ops is None:
-        base_name_to_sets_of_related_ops = \
-            get_base_name_to_sets_of_related_ops()
-    type_a_related_to_b = \
-        get_type_a_related_to_b(base_name_to_sets_of_related_ops)
-
-    # TODO(future PR): expose these
-    skipped_module_names: List[str] = []
-    skipped_module_classes: List[Callable] = []
-    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
-    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
-    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
-    if hasattr(model_a, '_node_name_to_scope'):
-        gm_a._node_name_to_scope = model_a._node_name_to_scope
-    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
-    if hasattr(model_b, '_node_name_to_scope'):
-        gm_b._node_name_to_scope = model_b._node_name_to_scope
-    return _extract_weights_impl(
-        model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
-        unmatchable_types_map, op_to_type_to_weight_extraction_fn)
-
-
-def _add_loggers_one_model(
-    model_name: str,
-    model: GraphModule,
-    nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
-    nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
-    logger_cls: Callable,
-) -> nn.Module:
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
-
-    # TODO(future PR): do not observe nodes we do not care
-    #   about (both fp32, denylist, etc)
-    node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
-    node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
-    for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
-        node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
-    for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
-        node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
-
-    model = add_loggers_to_model(
-        model, node_to_instrument_inputs_to_ref_name,
-        node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
-    return model
-
-
-def _add_loggers_impl(
-    name_a: str,
-    gm_a: GraphModule,
-    name_b: str,
-    gm_b: GraphModule,
-    logger_cls: Callable,
-    should_log_inputs: bool,
-    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-) -> Tuple[nn.Module, nn.Module]:
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
-    matched_subgraph_pairs = get_matching_subgraph_pairs(
-        gm_a, gm_b,
-        base_name_to_sets_of_related_ops, unmatchable_types_map)
-    nodes_and_names_to_instrument_inputs_a = []
-    nodes_and_names_to_instrument_inputs_b = []
-    nodes_and_names_to_instrument_outputs_a = []
-    nodes_and_names_to_instrument_outputs_b = []
-    for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
-        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
-        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
-        # Note: for matching inputs we use start_node, such as observing
-        # the input of linear in linear-relu
-        if should_log_inputs:
-            nodes_and_names_to_instrument_inputs_a.append(
-                (subgraph_a.start_node, match_name, ref_node_type_a))
-            nodes_and_names_to_instrument_inputs_b.append(
-                (subgraph_b.start_node, match_name, ref_node_type_b))
-        # Note: for matching activations we always use end_node,
-        # such as observing the output of relu in linear-relu
-        nodes_and_names_to_instrument_outputs_a.append(
-            (subgraph_a.end_node, match_name, ref_node_type_a))
-        nodes_and_names_to_instrument_outputs_b.append(
-            (subgraph_b.end_node, match_name, ref_node_type_b))
-
-    new_model_a = _add_loggers_one_model(
-        name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
-        nodes_and_names_to_instrument_outputs_a, logger_cls)
-    new_model_b = _add_loggers_one_model(
-        name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
-        nodes_and_names_to_instrument_outputs_b, logger_cls)
-    return (new_model_a, new_model_b)
-
-
-def add_loggers(
-    name_a: str,
-    model_a: nn.Module,
-    name_b: str,
-    model_b: nn.Module,
-    logger_cls: Callable,
-    should_log_inputs : bool = False,
-    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-) -> Tuple[nn.Module, nn.Module]:
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
-    # TODO(future PR): expose these
-    skipped_module_names: List[str] = []
-    skipped_module_classes: List[Callable] = []
-    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
-    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
-    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
-    if hasattr(model_a, '_node_name_to_scope'):
-        gm_a._node_name_to_scope = model_a._node_name_to_scope
-    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
-    if hasattr(model_b, '_node_name_to_scope'):
-        gm_b._node_name_to_scope = model_b._node_name_to_scope
-    return _add_loggers_impl(
-        name_a, gm_a, name_b, gm_b, logger_cls,
-        should_log_inputs=should_log_inputs,
-        base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
-        unmatchable_types_map=unmatchable_types_map)
-
-
-def _extract_logger_info_one_model(
-    model: nn.Module,
-    results: NSResultsType,
-    logger_cls: Callable,
-) -> None:
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
-    for gm_name, mod in model.named_modules():
-        # TODO(future PR): better check when scripted
-        is_logger = (
-            isinstance(mod, logger_cls)  # type: ignore[arg-type]
-            or (
-                isinstance(mod, torch.jit.RecursiveScriptModule)
-                and mod.original_name == 'OutputLogger'
-            )
-        )
-        if is_logger:
-            key = mod.ref_name
-            if key not in results:
-                results[key] = {}
-            assert mod.model_name not in results[key], \
-                f"{mod.model_name} is already present in results"
-            if mod.results_type not in results[key]:
-                results[key][mod.results_type] = {}
-            if mod.model_name not in results[key][mod.results_type]:
-                results[key][mod.results_type][mod.model_name] = []
-            stats_to_use = mod.stats
-            if len(mod.stats_rnn) > 0:
-                stats_to_use = mod.stats_rnn
-            results[key][mod.results_type][mod.model_name].append({
-                'type': mod.results_type,
-                'values': stats_to_use,
-                'ref_node_name': mod.ref_node_name,
-                'ref_node_target_type': mod.ref_node_target_type,
-                'prev_node_name': mod.prev_node_name,
-                'prev_node_target_type': mod.prev_node_target_type,
-                'index_within_arg': mod.index_within_arg,
-                'index_of_arg': mod.index_of_arg,
-                'fqn': mod.fqn,
-            })
-            # ensure the list stays sorted
-            results[key][mod.results_type][mod.model_name].sort(
-                key=lambda res:
-                f"{res['index_of_arg']}:{res['index_within_arg']}"
-            )
-
-
-# TODO(future PR): align on naming
-# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
-def extract_logger_info(
-    model_a: nn.Module,
-    model_b: nn.Module,
-    logger_cls: Callable,
-    model_name_to_use_for_layer_names: str,
-) -> NSResultsType:
-    """
-    Same thing as ns.extract_logger_info, but for models prepared with
-    this module.
-
-    TODO(future PR): real docblock
-
-    Output format: NSResultsType
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
-    results: NSResultsType = {}
-    for model in (model_a, model_b):
-        _extract_logger_info_one_model(model, results, logger_cls)
-    # fill in missing fqn entries
-    maybe_add_missing_fqns(results)
-    # rekey on the name of model b
-    results = rekey_logger_info_on_node_name_of_model(
-        results, model_name_to_use_for_layer_names)
-    return results
-
-
-def _add_shadow_loggers_impl(
-    name_a: str,
-    gm_a: GraphModule,
-    name_b: str,
-    gm_b: GraphModule,
-    logger_cls: Callable,
-    should_log_inputs: bool,
-    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-) -> nn.Module:
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
-    matched_subgraph_pairs = get_matching_subgraph_pairs(
-        gm_a, gm_b, base_name_to_sets_of_related_ops,
-        unmatchable_types_map)
-    gm_a_shadows_b = create_a_shadows_b(
-        name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
-        should_log_inputs=should_log_inputs,
-        node_type_to_io_type_map=node_type_to_io_type_map)
-    return gm_a_shadows_b
-
-
-def add_shadow_loggers(
-    name_a: str,
-    model_a: nn.Module,
-    name_b: str,
-    model_b: nn.Module,
-    logger_cls: Callable,
-    should_log_inputs: bool = False,
-    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
-) -> nn.Module:
-    """
-    Same thing as add_loggers, but for an `a_shadows_b` model.
-    TODO(future PR): real docblock
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
-    # TODO(future PR): expose these
-    skipped_module_names: List[str] = []
-    skipped_module_classes: List[Callable] = []
-    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
-    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
-    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
-    if hasattr(model_a, '_node_name_to_scope'):
-        gm_a._node_name_to_scope = model_a._node_name_to_scope
-    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
-    if hasattr(model_b, '_node_name_to_scope'):
-        gm_b._node_name_to_scope = model_b._node_name_to_scope
-    return _add_shadow_loggers_impl(
-        name_a, gm_a, name_b, gm_b, logger_cls,
-        should_log_inputs=should_log_inputs,
-        base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
-        node_type_to_io_type_map=node_type_to_io_type_map,
-        unmatchable_types_map=unmatchable_types_map)
-
-
-def extract_shadow_logger_info(
-    model_a_shadows_b: nn.Module,
-    logger_cls: Callable,
-    model_name_to_use_for_layer_names: str,
-) -> NSResultsType:
-    """
-    Same thing as extract_logger_info, but for an `a_shadows_b` model.
-    TODO(future PR): real docblock
-    """
-    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
-    results: NSResultsType = collections.defaultdict(dict)
-    _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
-    # fill in missing fqn entries
-    maybe_add_missing_fqns(results)
-    # rekey on the name of model b
-    results = rekey_logger_info_on_node_name_of_model(
-        results, model_name_to_use_for_layer_names)
-    return dict(results)
-
-
-def extend_logger_results_with_comparison(
-    results: NSResultsType,
-    model_name_1: str,
-    model_name_2: str,
-    comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
-    comparison_name: str,
-) -> None:
-    """
-    Compares the logged values from `model_name_2` against the corresponding
-    values in `model_name_1`, using `comparison_fn`. Records the result
-    in `model_name_2`'s results under `comparison_name`.
-    """
-    for _, results_type_to_results in results.items():
-        for _, model_name_to_results in results_type_to_results.items():
-            assert model_name_1 in model_name_to_results, \
-                f"{model_name_1} not found in results"
-            assert model_name_2 in model_name_to_results, \
-                f"{model_name_2} not found in results"
-
-            results_1 = model_name_to_results[model_name_1]
-            results_2 = model_name_to_results[model_name_2]
-
-            for result_2 in results_2:
-                index_within_arg_2 = result_2['index_within_arg']
-                index_of_arg_2 = result_2['index_of_arg']
-                # find corresponding result_1
-                result_1 = None
-                for cur_result_1 in results_1:
-                    index_within_arg_1 = cur_result_1['index_within_arg']
-                    index_of_arg_1 = cur_result_1['index_of_arg']
-                    if (
-                        (index_within_arg_1 == index_within_arg_2) and
-                        (index_of_arg_1 == index_of_arg_2)
-                    ):
-                        result_1 = cur_result_1
-                        break
-                assert result_1 is not None
-
-                values_1 = result_1['values']
-                values_2 = result_2['values']
-                result_2[comparison_name] = []
-                for value_1, value_2 in zip(values_1, values_2):
-                    comparison_result = comparison_fn(value_1, value_2)
-                    result_2[comparison_name].append(comparison_result)
index 231f8df..71fbf29 100644 (file)
@@ -742,8 +742,8 @@ def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor)
         model_b: A quantized model
         x: Inputs to use during calibration
     """
-    import torch.quantization._numeric_suite_fx as ns
-    from torch.quantization.ns.mappings import get_unmatchable_types_map
+    import torch.ao.ns._numeric_suite_fx as ns
+    from torch.ao.ns.fx.mappings import get_unmatchable_types_map
 
     unmatchable_types_map = get_unmatchable_types_map()
     unmatchable_types_map["funs_unmatchable"].add(torch.mul)
@@ -766,7 +766,7 @@ def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor)
     ns.extend_logger_results_with_comparison(
         activation_comparison_dict,
         'fp32', 'int8',
-        torch.quantization.ns.utils.compute_sqnr, 'sqnr'
+        torch.ao.ns.fx.utils.compute_sqnr, 'sqnr'
     )
 
     # Construct a dictionary mapping layer names to the SQNR values
index 77512f7..33e758c 100644 (file)
@@ -33,7 +33,7 @@ try:
         prepare_qat_fx,
         convert_fx,
     )
-    from torch.quantization.ns.ns_types import NSSingleResultValuesType, NSSubgraph
+    from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
     from torch.fx.graph import Node
     from torch.fx import GraphModule
     HAS_FX = True