[fx_acc] Add mapper for torch.log1p (#63792)
authorShiyan Deng <dsy842974287@fb.com>
Tue, 24 Aug 2021 00:41:38 +0000 (17:41 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 00:48:59 +0000 (17:48 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63792

Map `torch.log1p` to `acc_ops.add` + `acc_ops.log`.

Test Plan: buck test mode/opt glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_log1p

Reviewed By: wushirong

Differential Revision: D30491706

fbshipit-source-id: bcbeddf06131113185d2019cfd7cf5e9193a8a78

torch/fx/experimental/fx_acc/acc_ops.py

index 7c95206..0c0965a 100644 (file)
@@ -509,6 +509,21 @@ def div(*, input, other):
 def relu(*, input, inplace=False):
     return nn.functional.relu(**locals())
 
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.log1p),
+    arg_replacement_tuples=[
+        ("input", "input"),
+    ],
+)
+def torch_log1p_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node:
+    with node.graph.inserting_before(node):
+        add_kwargs = {"input": node.kwargs["input"], "other": 1}
+        add_node = node.graph.call_function(add, kwargs=add_kwargs)
+        add_node.meta = node.meta.copy()
+        log_kwargs = {"input": add_node}
+        log_node = node.graph.call_function(log, kwargs=log_kwargs)
+        log_node.meta = node.meta.copy()
+        return log_node
 
 @register_custom_acc_mapper_fn(
     op_and_target=("call_method", "sum"),