From b616132403697a043fd9691693e40f407e77063a Mon Sep 17 00:00:00 2001 From: Protonu Basu Date: Wed, 8 Sep 2021 07:11:38 -0700 Subject: [PATCH] Adding support for lowering 4Bit EmbeddingBag Operator (#5806) Summary: Pull Request resolved: https://github.com/pytorch/glow/pull/5806 Pull Request resolved: https://github.com/pytorch/pytorch/pull/64001 Add 4 bit embeddingbag operator in acc_ops. Test Plan: Let CI run. Reviewed By: jfix71 Differential Revision: D30532824 fbshipit-source-id: bf476c9710477792aae202dacf64e23539c33bd9 --- torch/fx/experimental/fx_acc/acc_ops.py | 23 ++++++++++++++++++++++- torch/fx/experimental/graph_manipulation.py | 11 +++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index b10d35e..0f1c92a 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -891,7 +891,7 @@ def embedding_bag( def embedding_bag_byte_rowwise_offsets( *, weight, - input, + indices, offsets, scale_grad_by_freq, mode, @@ -902,6 +902,27 @@ def embedding_bag_byte_rowwise_offsets( ): return torch.ops.quantized.embedding_bag_byte_rowwise_offsets(**locals()) +@register_acc_op_mapping( + op_and_target=( + "call_function", + torch.ops.quantized.embedding_bag_4bit_rowwise_offsets, + ) +) +@register_acc_op +def embedding_bag_4bit_rowwise_offsets( + *, + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + pruned_weights, + per_sample_weights, + compressed_indices_mapping, + include_last_offset, +): + return torch.ops.quantized.embedding_bag_4bit_rowwise_offsets(**locals()) + @register_acc_op_mapping(op_and_target=("call_function", torch.sin)) @register_acc_op diff --git a/torch/fx/experimental/graph_manipulation.py b/torch/fx/experimental/graph_manipulation.py index 6daa000..86fb128 100644 --- a/torch/fx/experimental/graph_manipulation.py +++ b/torch/fx/experimental/graph_manipulation.py @@ -385,6 +385,17 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D == stripped_name ): weight[stripped_name]["dtype"] = "acc.uint8fused" + # Same as above, but for the 4 bit version. + if ( + "acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets + and str( + user_targets[ + "acc_ops.embedding_bag_4bit_rowwise_offsets" + ].kwargs["weight"] + ) + == stripped_name + ): + weight[stripped_name]["dtype"] = "acc.uint4fused" serialized_dict["weights"].update(weight) else: -- 2.7.4