Moving getattr_from_fqn to torch.quantization.utils (#63107)
authorCharles David Hernandez <cdhernandez@fb.com>
Fri, 13 Aug 2021 03:57:54 +0000 (20:57 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 03:59:01 +0000 (20:59 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63107

moving this function because the functionality would be useful outside of ns
ghstack-source-id: 135727260

Test Plan: buck test //caffe2/test:quantization_fx mode/dev-nosan --keep-going --config client.id=nuclide --show-full-output -- suite

Reviewed By: supriyar

Differential Revision: D30260735

fbshipit-source-id: 58deabdd0f3b03b0ee7ee92be0548a0945084d65

torch/quantization/ns/graph_matcher.py
torch/quantization/ns/pattern_utils.py
torch/quantization/ns/utils.py
torch/quantization/utils.py

index 5d94cdc..51ae2ef 100644 (file)
@@ -7,7 +7,7 @@ toq = torch.ops.quantized
 from torch.fx import GraphModule
 from torch.fx.graph import Graph, Node
 
-from .utils import getattr_from_fqn
+from torch.quantization.utils import getattr_from_fqn
 from .ns_types import NSSubgraph, NSNodeTargetType
 from .mappings import (
     get_base_name_to_sets_of_related_ops,
index 9217f44..7a80786 100644 (file)
@@ -6,7 +6,7 @@ toq = torch.ops.quantized
 from torch.fx import GraphModule
 from torch.fx.graph import Node
 
-from .utils import getattr_from_fqn
+from torch.quantization.utils import getattr_from_fqn
 from .ns_types import NSNodeTargetType
 from torch.quantization.fx.pattern_utils import get_default_quant_patterns
 from torch.quantization import (
index 397c783..678f60a 100644 (file)
@@ -3,31 +3,23 @@ import operator
 
 import torch
 import torch.nn as nn
-import torch.nn.quantized as nnq
 import torch.nn.intrinsic.quantized as nniq
+import torch.nn.quantized as nnq
+
 toq = torch.ops.quantized
+from typing import Tuple, Callable, Dict, Set, List, Optional, Union
+
 from torch.fx import GraphModule
 from torch.fx.graph import Node
-from torch.quantization.quantize import is_activation_post_process
 from torch.quantization import (
     ObserverBase,
     FakeQuantizeBase,
 )
+from torch.quantization.utils import getattr_from_fqn
+from torch.quantization.quantize import is_activation_post_process
 
 from .ns_types import NSNodeTargetType, NSResultsType
 
-from typing import Any, Tuple, Callable, Dict, Set, List, Optional, Union
-
-def getattr_from_fqn(gm: GraphModule, fqn: str) -> Any:
-    """
-    Given a gm and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
-    """
-    fqn_parts = fqn.split(".")
-    cur_val = gm
-    for part in fqn_parts:
-        cur_val = getattr(cur_val, part)
-    return cur_val
-
 # TODO(future PR): consider deleting this enum and using the torch types
 # directly.  This might be tricky because it is not a one to one mapping.
 class NodeInputOrOutputType(enum.Enum):
@@ -51,16 +43,16 @@ def get_node_first_input_and_output_type(
 ) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
 
     # TODO(future PR): clean this up
-    FUNS_IO_TYPE_FP32 = node_type_to_io_type_map['funs_io_type_fp32']
-    FUNS_IO_TYPE_FP16 = node_type_to_io_type_map['funs_io_type_fp16']
-    FUNS_IO_TYPE_INT8 = node_type_to_io_type_map['funs_io_type_int8']
-    FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['funs_io_type_fp32_or_int8']
-    MODS_IO_TYPE_FP32 = node_type_to_io_type_map['mods_io_type_fp32']
-    MODS_IO_TYPE_INT8 = node_type_to_io_type_map['mods_io_type_int8']
-    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['mods_io_type_fp32_or_int8']
-    METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['meths_io_type_fp32_or_int8']
-
-    if node.op == 'call_function':
+    FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
+    FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
+    FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
+    FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
+    MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
+    MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
+    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
+    METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
+
+    if node.op == "call_function":
         if node.target in FUNS_IO_TYPE_FP32:
             return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
         if node.target in FUNS_IO_TYPE_FP16:
@@ -68,12 +60,15 @@ def get_node_first_input_and_output_type(
         elif node.target in FUNS_IO_TYPE_INT8:
             return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
         elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
-            return (NodeInputOrOutputType.FP32_OR_INT8, NodeInputOrOutputType.FP32_OR_INT8)
+            return (
+                NodeInputOrOutputType.FP32_OR_INT8,
+                NodeInputOrOutputType.FP32_OR_INT8,
+            )
         else:
             return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
 
-    elif node.op == 'call_module':
-        assert node.op == 'call_module'
+    elif node.op == "call_module":
+        assert node.op == "call_module"
         assert isinstance(node.target, str)
         mod = getattr_from_fqn(gm, node.target)
         if isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)):  # type: ignore[arg-type]
@@ -81,9 +76,12 @@ def get_node_first_input_and_output_type(
             # type of the preceding node.
             first_arg = node.args[0]
             assert isinstance(first_arg, Node)
-            _prev_node_input_type, prev_node_output_type = \
-                get_node_first_input_and_output_type(
-                    first_arg, gm, logger_cls, node_type_to_io_type_map)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                first_arg, gm, logger_cls, node_type_to_io_type_map
+            )
             return (prev_node_output_type, prev_node_output_type)
         is_known_fp32_input_module = any(
             isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32  # type: ignore[arg-type]
@@ -99,46 +97,60 @@ def get_node_first_input_and_output_type(
         elif is_known_int8_input_module:
             return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
         elif is_known_fp32_or_int8_input_module:
-            return (NodeInputOrOutputType.FP32_OR_INT8, NodeInputOrOutputType.FP32_OR_INT8)
+            return (
+                NodeInputOrOutputType.FP32_OR_INT8,
+                NodeInputOrOutputType.FP32_OR_INT8,
+            )
         else:
             return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
 
-    elif node.op == 'call_method':
-        if node.target == 'dequantize':
+    elif node.op == "call_method":
+        if node.target == "dequantize":
             # Dequantize is a special node because it allows multiple input types.
             # So, we look up the output type of the previous node and return that
             # as the input type of this node instance.
             prev_node = node.args[0]
             assert isinstance(prev_node, Node)
-            _prev_node_input_type, prev_node_output_type = \
-                get_node_first_input_and_output_type(
-                    prev_node, gm, logger_cls, node_type_to_io_type_map)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                prev_node, gm, logger_cls, node_type_to_io_type_map
+            )
             return (prev_node_output_type, NodeInputOrOutputType.FP32)
 
-        elif node.target == 'to':
+        elif node.target == "to":
             # to is a special node because it allows multiple input types.
             # So, we look up the output type of the previous node and return that
             # as the input type of this node instance. We also look up the target
             # of to and return the correct output type.
             prev_node = node.args[0]
             assert isinstance(prev_node, Node)
-            _prev_node_input_type, prev_node_output_type = \
-                get_node_first_input_and_output_type(
-                    prev_node, gm, logger_cls, node_type_to_io_type_map)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                prev_node, gm, logger_cls, node_type_to_io_type_map
+            )
 
             cur_node_dtype_target = node.args[1]
-            assert cur_node_dtype_target is torch.float16, \
-                f"{cur_node_dtype_target} handling needs to be added"
+            assert (
+                cur_node_dtype_target is torch.float16
+            ), f"{cur_node_dtype_target} handling needs to be added"
 
             return (prev_node_output_type, NodeInputOrOutputType.FP16)
 
         elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
-            return (NodeInputOrOutputType.FP32_OR_INT8, NodeInputOrOutputType.FP32_OR_INT8)
+            return (
+                NodeInputOrOutputType.FP32_OR_INT8,
+                NodeInputOrOutputType.FP32_OR_INT8,
+            )
 
         return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
     else:
         return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
 
+
 def get_node_input_qparams(
     node: Node,
     gm: GraphModule,
@@ -153,7 +165,7 @@ def get_node_input_qparams(
     if not isinstance(prev_node, Node):
         return None
 
-    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['mods_io_type_fp32_or_int8']
+    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
 
     def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
         scale_node, zp_node = node.args[scale_arg_idx], node.args[zp_arg_idx]
@@ -163,7 +175,7 @@ def get_node_input_qparams(
         zp_obj = getattr_from_fqn(gm, zp_node.target)
         return (scale_obj, zp_obj)
 
-    if prev_node.op == 'call_function':
+    if prev_node.op == "call_function":
 
         # quantize - read the args directly
         if prev_node.target == torch.quantize_per_tensor:
@@ -175,7 +187,7 @@ def get_node_input_qparams(
         # TODO(future PR): handle more functionals
         # TODO(future PR): handle functional ops which inherit qparams from input
 
-    elif prev_node.op == 'call_module':
+    elif prev_node.op == "call_module":
 
         # get type of the module
         assert isinstance(prev_node.target, str)
@@ -207,7 +219,7 @@ def get_node_input_qparams(
                 nniq.ConvReLU2d,
                 nniq.ConvReLU3d,
                 nniq.LinearReLU,
-            )
+            ),
         ):
             return (module_obj.scale, module_obj.zero_point)  # type: ignore[return-value]
 
@@ -215,11 +227,11 @@ def get_node_input_qparams(
             isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8  # type: ignore[arg-type]
         )
         if is_known_fp32_or_int8_input_module:
-            return get_node_input_qparams(
-                prev_node, gm, node_type_to_io_type_map)
+            return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
 
     return None
 
+
 def return_first_non_observer_node(
     node: Node,
     gm: GraphModule,
@@ -233,7 +245,7 @@ def return_first_non_observer_node(
     graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
     graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
     """
-    if node.op == 'call_module':
+    if node.op == "call_module":
         node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
         if is_activation_post_process(node_obj):
             assert len(node.args) == 1
@@ -248,6 +260,7 @@ def return_first_non_observer_node(
                 node = node.args[0]
     return node
 
+
 def get_number_of_non_param_args(
     node: Node,
     gm: GraphModule,
@@ -265,7 +278,7 @@ def get_number_of_non_param_args(
 
     Returns 2, because both x and hid are non-param args.
     """
-    if node.op == 'call_module':
+    if node.op == "call_module":
         node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
         if isinstance(node_obj, nn.LSTM):
             return 2
@@ -273,6 +286,7 @@ def get_number_of_non_param_args(
     # default is 1
     return 1
 
+
 def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
     """
     Returns the indices of args of the node which we should attach
@@ -287,12 +301,10 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
     """
     if len(node.args) == 0:
         return []
-    if (
-        node.op == 'call_function' and (
-            # TODO(future PR): use relationship map instead of hardcoding
-            node.target in (torch.add, torch.ops.quantized.add, operator.add) or
-            node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
-        )
+    if node.op == "call_function" and (
+        # TODO(future PR): use relationship map instead of hardcoding
+        node.target in (torch.add, torch.ops.quantized.add, operator.add)
+        or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
     ):
         result = []
         for i in range(2):
@@ -301,20 +313,22 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
         return result
     return [0]
 
+
 def get_target_type_str(node: Node, gm: GraphModule) -> str:
     """
     Returns a string representation of the type of the function or module
     pointed to by this node, or '' for other op types.
     """
-    target_type = ''
-    if node.op in ('call_function', 'call_method'):
+    target_type = ""
+    if node.op in ("call_function", "call_method"):
         target_type = str(node.target)
-    elif node.op == 'call_module':
+    elif node.op == "call_module":
         assert isinstance(node.target, str)
         target_mod = getattr_from_fqn(gm, node.target)
         target_type = str(type(target_mod))
     return target_type
 
+
 def rekey_logger_info_on_node_name_of_model(
     results: NSResultsType,
     model_name: str,
@@ -344,7 +358,7 @@ def rekey_logger_info_on_node_name_of_model(
             for cur_model_name, list_of_results in model_name_to_results.items():
                 if cur_model_name == model_name:
                     assert len(list_of_results)
-                    new_layer_name = list_of_results[0]['ref_node_name']
+                    new_layer_name = list_of_results[0]["ref_node_name"]
                 else:
                     continue
         if new_layer_name is not None:
@@ -367,11 +381,10 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None:
     # Check in the first result to find any model with fqn entries defined.
     model_name_with_fqns = None
     for layer_name, result_type_to_results in results.items():
-        for result_type, model_name_to_results in \
-                result_type_to_results.items():
+        for result_type, model_name_to_results in result_type_to_results.items():
             for model_name, model_results in model_name_to_results.items():
                 if len(model_results) > 0:
-                    if model_results[0]['fqn'] is not None:
+                    if model_results[0]["fqn"] is not None:
                         model_name_with_fqns = model_name
                         break
             break
@@ -379,22 +392,23 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None:
 
     if model_name_with_fqns:
         for layer_name, result_type_to_results in results.items():
-            for result_type, model_name_to_results in \
-                    result_type_to_results.items():
+            for result_type, model_name_to_results in result_type_to_results.items():
                 ref_model_results = model_name_to_results[model_name_with_fqns]
                 for model_name, model_results in model_name_to_results.items():
                     if model_name == model_name_with_fqns:
                         continue
                     for i in range(len(model_results)):
-                        fqn = ref_model_results[i]['fqn']
-                        model_results[i]['fqn'] = fqn
+                        fqn = ref_model_results[i]["fqn"]
+                        model_results[i]["fqn"] = fqn
+
 
 def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
     def inner(*args, **kwargs):
         a0, a1, *a_other = args
 
-        if (isinstance(a0, tuple) and isinstance(a1, tuple)) or \
-                (isinstance(a0, list) and isinstance(a1, list)):
+        if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
+            isinstance(a0, list) and isinstance(a1, list)
+        ):
             results = []
             for el0, el1 in zip(a0, a1):
                 new_args = (el0, el1, *a_other)
@@ -413,18 +427,22 @@ def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
 
         new_args = (a0, a1, *a_other)
         return f(*new_args, **kwargs)
+
     return inner
 
+
 @maybe_dequantize_first_two_tensor_args_and_handle_tuples
 def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
     Ps = torch.norm(x)
     Pn = torch.norm(x - y)
     return 20 * torch.log10(Ps / Pn)
 
+
 @maybe_dequantize_first_two_tensor_args_and_handle_tuples
 def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
     return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
 
+
 @maybe_dequantize_first_two_tensor_args_and_handle_tuples
 def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
     # For convolutions, the shape of the quantized weight has one additional
index 5b24c1d..9c5198b 100644 (file)
@@ -2,10 +2,10 @@
 Utils shared by different modes of quantization (eager/graph)
 """
 import warnings
-
+import functools
 import torch
 from .quant_type import QuantType, quant_type_to_str
-from typing import Tuple
+from typing import Tuple, Any
 
 def get_combined_dict(default_dict, additional_dict):
     d = default_dict.copy()
@@ -21,6 +21,12 @@ def is_per_channel(qscheme):
                        torch.per_channel_affine_float_qparams,
                        torch.per_channel_symmetric]
 
+def getattr_from_fqn(obj: Any, fqn: str) -> Any:
+    """
+    Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
+    """
+    return functools.reduce(getattr, fqn.split("."), obj)
+
 def get_qparam_dict(observer_or_fake_quant):
     qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
     dtype = observer_or_fake_quant.dtype