Adding support for lowering 4Bit EmbeddingBag Operator (#5806)
authorProtonu Basu <protonu@fb.com>
Wed, 8 Sep 2021 14:11:38 +0000 (07:11 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 14:13:16 +0000 (07:13 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/glow/pull/5806

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64001

Add 4 bit embeddingbag operator in  acc_ops.

Test Plan: Let CI run.

Reviewed By: jfix71

Differential Revision: D30532824

fbshipit-source-id: bf476c9710477792aae202dacf64e23539c33bd9

torch/fx/experimental/fx_acc/acc_ops.py
torch/fx/experimental/graph_manipulation.py

index b10d35e..0f1c92a 100644 (file)
@@ -891,7 +891,7 @@ def embedding_bag(
 def embedding_bag_byte_rowwise_offsets(
     *,
     weight,
-    input,
+    indices,
     offsets,
     scale_grad_by_freq,
     mode,
@@ -902,6 +902,27 @@ def embedding_bag_byte_rowwise_offsets(
 ):
     return torch.ops.quantized.embedding_bag_byte_rowwise_offsets(**locals())
 
+@register_acc_op_mapping(
+    op_and_target=(
+        "call_function",
+        torch.ops.quantized.embedding_bag_4bit_rowwise_offsets,
+    )
+)
+@register_acc_op
+def embedding_bag_4bit_rowwise_offsets(
+    *,
+    weight,
+    indices,
+    offsets,
+    scale_grad_by_freq,
+    mode,
+    pruned_weights,
+    per_sample_weights,
+    compressed_indices_mapping,
+    include_last_offset,
+):
+    return torch.ops.quantized.embedding_bag_4bit_rowwise_offsets(**locals())
+
 
 @register_acc_op_mapping(op_and_target=("call_function", torch.sin))
 @register_acc_op
index 6daa000..86fb128 100644 (file)
@@ -385,6 +385,17 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
                     == stripped_name
                 ):
                     weight[stripped_name]["dtype"] = "acc.uint8fused"
+                # Same as above, but for the 4 bit version.
+                if (
+                    "acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
+                    and str(
+                        user_targets[
+                            "acc_ops.embedding_bag_4bit_rowwise_offsets"
+                        ].kwargs["weight"]
+                    )
+                    == stripped_name
+                ):
+                    weight[stripped_name]["dtype"] = "acc.uint4fused"
 
                 serialized_dict["weights"].update(weight)
             else: