Summary:
Pull Request resolved: https://github.com/pytorch/glow/pull/5806
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64001
Add 4 bit embeddingbag operator in acc_ops.
Test Plan: Let CI run.
Reviewed By: jfix71
Differential Revision:
D30532824
fbshipit-source-id:
bf476c9710477792aae202dacf64e23539c33bd9
def embedding_bag_byte_rowwise_offsets(
*,
weight,
- input,
+ indices,
offsets,
scale_grad_by_freq,
mode,
):
return torch.ops.quantized.embedding_bag_byte_rowwise_offsets(**locals())
+@register_acc_op_mapping(
+ op_and_target=(
+ "call_function",
+ torch.ops.quantized.embedding_bag_4bit_rowwise_offsets,
+ )
+)
+@register_acc_op
+def embedding_bag_4bit_rowwise_offsets(
+ *,
+ weight,
+ indices,
+ offsets,
+ scale_grad_by_freq,
+ mode,
+ pruned_weights,
+ per_sample_weights,
+ compressed_indices_mapping,
+ include_last_offset,
+):
+ return torch.ops.quantized.embedding_bag_4bit_rowwise_offsets(**locals())
+
@register_acc_op_mapping(op_and_target=("call_function", torch.sin))
@register_acc_op
== stripped_name
):
weight[stripped_name]["dtype"] = "acc.uint8fused"
+ # Same as above, but for the 4 bit version.
+ if (
+ "acc_ops.embedding_bag_4bit_rowwise_offsets" in user_targets
+ and str(
+ user_targets[
+ "acc_ops.embedding_bag_4bit_rowwise_offsets"
+ ].kwargs["weight"]
+ )
+ == stripped_name
+ ):
+ weight[stripped_name]["dtype"] = "acc.uint4fused"
serialized_dict["weights"].update(weight)
else: