[acc_normalizer] Improve error when kwarg normalization fails (#64408)
authorJordan Fix <jfix@fb.com>
Wed, 8 Sep 2021 22:30:28 +0000 (15:30 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 22:33:32 +0000 (15:33 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64408

att

Test Plan: NFC

Reviewed By: protonu

Differential Revision: D30716392

fbshipit-source-id: e1c3bb1afcd5363a9d502549d8a46b90226be40c

torch/fx/experimental/fx_acc/acc_normalizer.py

index a5d116a..66e83a5 100644 (file)
@@ -372,9 +372,16 @@ def normalize(mod: torch.fx.GraphModule, expect_nodes_have_shapes: bool = False)
             normalized_kwargs = node.kwargs
         else:
             normalized_args = ()
-            normalized_kwargs = get_normalized_kwargs(
-                node, normalization_info.arg_replacement_tuples
-            )
+            try:
+                normalized_kwargs = get_normalized_kwargs(
+                    node, normalization_info.arg_replacement_tuples
+                )
+            except Exception:
+                print(
+                    f"Error during kwarg normalization for: {node.format_node()}; "
+                    f"arg_replacement_tuples={normalization_info.arg_replacement_tuples}"
+                )
+                raise
 
         if (
             normalization_info.needs_shapes_for_normalization