From 6cc8cc6e56431be75745fd6cff602ef587cf3e53 Mon Sep 17 00:00:00 2001 From: Jordan Fix Date: Thu, 9 Sep 2021 12:59:54 -0700 Subject: [PATCH] Add support for lowering info during serialize_module, and add padding/partial to it (#5810) Summary: Pull Request resolved: https://github.com/pytorch/glow/pull/5810 Pull Request resolved: https://github.com/pytorch/pytorch/pull/64725 - Any info added to the dict in node.meta["lowering_info"] will be added to the node_rep during serialization. - Use this to add annotations on placeholders that allow partial inputs and require padding. - Check for these annotations and set them in the NNPICompiledFunction as expected Test Plan: Validated working on inline_cvr in stack. Additionally existing fx_glow end to end tests should still pass. Reviewed By: 842974287 Differential Revision: D30824192 fbshipit-source-id: def64ef097aa35c337abb494415f7d437c6c7fa9 --- torch/fx/experimental/graph_manipulation.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/torch/fx/experimental/graph_manipulation.py b/torch/fx/experimental/graph_manipulation.py index 86fb128..86f7bc2 100644 --- a/torch/fx/experimental/graph_manipulation.py +++ b/torch/fx/experimental/graph_manipulation.py @@ -319,6 +319,15 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D node_rep["q_scale"] = tensor_meta.q_scale node_rep["q_zero_point"] = tensor_meta.q_zero_point + # Add all extra lowering_info that was provided in node.meta. + lowering_info = node.meta.get("lowering_info") + if lowering_info is not None: + overlapping_keys = node_rep.keys() & lowering_info.keys() + assert ( + len(overlapping_keys) == 0 + ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}" + node_rep.update(lowering_info) + return node_rep # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules @@ -370,9 +379,9 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D # For quantized embedding tables we need to update the shape/type, # so we check if the users of this get_attr is a quantized EB and this is the weight for the EB. user_targets = { - _get_qualified_name( - n.target - ).replace("torch.fx.experimental.fx_acc.", "").replace("glow.fb.fx.", ""): n + _get_qualified_name(n.target) + .replace("torch.fx.experimental.fx_acc.", "") + .replace("glow.fb.fx.", ""): n for n in node.users.keys() } if ( -- 2.7.4