From: Shirong Wu Date: Wed, 8 Sep 2021 21:29:33 +0000 (-0700) Subject: Add plugin for linalg norm operation (#64611) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~365 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=cc0565326ce098236507b13c2bf2d5b48f64fba3;p=platform%2Fupstream%2Fpytorch.git Add plugin for linalg norm operation (#64611) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64611 Add plugin for torch.linalg.norm, this plugin correctly only support norm operation without batch_size change, so vector input or matrix input with dim including '0' is not supported with this plugin. Test Plan: Unit test Reviewed By: 842974287 Differential Revision: D30525958 fbshipit-source-id: 0d66b60a390bb6235166e5a80390090d0acf691a --- diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 0f1c92a..ccd8ec8 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -763,6 +763,12 @@ def torch_argmin_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Nod """ return argmin_max_mapper_impl(node, largest=False) +@register_acc_op_mapping(op_and_target=("call_function", torch.linalg.norm)) +@register_acc_op +def linalg_norm(*, input, ord, dim, keepdim): + return torch.linalg.norm(**locals()) + + @register_custom_acc_mapper_fn( op_and_target=("call_method", "split"), arg_replacement_tuples=[