Add plugin for linalg norm operation (#64611)
authorShirong Wu <shirong@fb.com>
Wed, 8 Sep 2021 21:29:33 +0000 (14:29 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 21:33:20 +0000 (14:33 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64611

Add plugin for torch.linalg.norm, this plugin correctly only support norm operation without batch_size change, so vector input or matrix input with dim including '0' is not supported with this plugin.

Test Plan: Unit test

Reviewed By: 842974287

Differential Revision: D30525958

fbshipit-source-id: 0d66b60a390bb6235166e5a80390090d0acf691a

torch/fx/experimental/fx_acc/acc_ops.py

index 0f1c92a..ccd8ec8 100644 (file)
@@ -763,6 +763,12 @@ def torch_argmin_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Nod
     """
     return argmin_max_mapper_impl(node, largest=False)
 
+@register_acc_op_mapping(op_and_target=("call_function", torch.linalg.norm))
+@register_acc_op
+def linalg_norm(*, input, ord, dim, keepdim):
+    return torch.linalg.norm(**locals())
+
+
 @register_custom_acc_mapper_fn(
     op_and_target=("call_method", "split"),
     arg_replacement_tuples=[