Add support for lowering info during serialize_module, and add padding/partial to...
authorJordan Fix <jfix@fb.com>
Thu, 9 Sep 2021 19:59:54 +0000 (12:59 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 20:01:28 +0000 (13:01 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/glow/pull/5810

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64725

- Any info added to the dict in node.meta["lowering_info"] will be added to the node_rep during serialization.
- Use this to add annotations on placeholders that allow partial inputs and require padding.
- Check for these annotations and set them in the NNPICompiledFunction as expected

Test Plan: Validated working on inline_cvr in stack. Additionally existing fx_glow end to end tests should still pass.

Reviewed By: 842974287

Differential Revision: D30824192

fbshipit-source-id: def64ef097aa35c337abb494415f7d437c6c7fa9

torch/fx/experimental/graph_manipulation.py

index 86fb128..86f7bc2 100644 (file)
@@ -319,6 +319,15 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
                 node_rep["q_scale"] = tensor_meta.q_scale
                 node_rep["q_zero_point"] = tensor_meta.q_zero_point
 
+        # Add all extra lowering_info that was provided in node.meta.
+        lowering_info = node.meta.get("lowering_info")
+        if lowering_info is not None:
+            overlapping_keys = node_rep.keys() & lowering_info.keys()
+            assert (
+                len(overlapping_keys) == 0
+            ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}"
+            node_rep.update(lowering_info)
+
         return node_rep
 
     # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
@@ -370,9 +379,9 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
                 # For quantized embedding tables we need to update the shape/type,
                 # so we check if the users of this get_attr is a quantized EB and this is the weight for the EB.
                 user_targets = {
-                    _get_qualified_name(
-                        n.target
-                    ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n
+                    _get_qualified_name(n.target)
+                    .replace("torch.fx.experimental.fx_acc.", "")
+                    .replace("glow.fb.fx.", ""): n
                     for n in node.users.keys()
                 }
                 if (