ns for fx: make layer types more readable (#64270)
authorVasiliy Kuznetsov <vasiliy@fb.com>
Tue, 31 Aug 2021 19:09:59 +0000 (12:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 19:31:34 +0000 (12:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64270

Before this PR, layer types were populated by doing
`str(module_instance)` and `str(function)`. This resulted
in moderately readable strings for modules, and poorly readable
strings for functions.

This PR switches the logic to use `torch.typename` utility instead.
The results are significantly more readable.

Example function type:

```
# before
'<built-in method linear of PyCapsule object at 0x7fe9b20ce7b0>'

# after
'torch._ops.quantized.PyCapsule.linear'
```

Example module type:

```
# before
"<class 'torch.nn.quantized.modules.conv.Conv2d'>"

# after
'torch.nn.quantized.modules.conv.Conv2d'
```

Test Plan:
Manually inspect NS results for modules and functions, verify they are
more readable.

Manually inspect NS results for modules and functions, verify they are
more readable.

Imported from OSS

Differential Revision:
D30669545
D30669545

Reviewed By: jerryzh168

Pulled By: vkuzo

fbshipit-source-id: 60959e5cafa0a4992b083bf99f5d8260f9acdac0

torch/quantization/ns/utils.py
torch/quantization/ns/weight_utils.py

index 678f60a..62397d0 100644 (file)
@@ -317,15 +317,15 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
 def get_target_type_str(node: Node, gm: GraphModule) -> str:
     """
     Returns a string representation of the type of the function or module
-    pointed to by this node, or '' for other op types.
+    pointed to by this node, or '' for other node types.
     """
     target_type = ""
     if node.op in ("call_function", "call_method"):
-        target_type = str(node.target)
+        target_type = torch.typename(node.target)
     elif node.op == "call_module":
         assert isinstance(node.target, str)
         target_mod = getattr_from_fqn(gm, node.target)
-        target_type = str(type(target_mod))
+        target_type = torch.typename(target_mod)
     return target_type
 
 
index 724cdc7..36e183e 100644 (file)
@@ -231,6 +231,8 @@ def extract_weight_from_node(
         op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
 
     ref_node_type = get_target_type_str(node, gm)
+    # for extracting weights, these are always the same
+    prev_node_type = ref_node_type
 
     if node.op == 'call_function':
         function_mapping = op_to_type_to_weight_extraction_fn['call_function']
@@ -241,7 +243,7 @@ def extract_weight_from_node(
                     'type': res_type,
                     'values': [weight],
                     'prev_node_name': node.name,
-                    'prev_node_target_type': str(node.target),
+                    'prev_node_target_type': prev_node_type,
                     'ref_node_name': node.name,
                     'ref_node_target_type': ref_node_type,
                     'index_within_arg': 0,
@@ -261,7 +263,7 @@ def extract_weight_from_node(
                     'type': res_type,
                     'values': [weight],
                     'prev_node_name': node.name,
-                    'prev_node_target_type': str(type(mod)),
+                    'prev_node_target_type': prev_node_type,
                     'ref_node_name': node.name,
                     'ref_node_target_type': ref_node_type,
                     'index_within_arg': 0,