Extend 2Dim embedding bag benchmarking to include 3Dim benchmarks (#64647)
authorEddie Ren <edwardren@fb.com>
Fri, 10 Sep 2021 19:31:27 +0000 (12:31 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 23:49:02 +0000 (16:49 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64647

Add support for benchmarking of 8 bit quantizations of N-D batched embeddings. Currently only works for 3Dim embeddings and still requires thought on ramping up from 3Dim to NDim.

Test Plan: ```buck run //caffe2/benchmarks/operator_benchmark/pt:qembedding_pack_test```

Reviewed By: jingsh

Differential Revision: D30770085

fbshipit-source-id: 26659020f3458991592065a05366bde0f060494e

benchmarks/operator_benchmark/pt/qembedding_pack_test.py

index f9a3aaf..d55c5fc 100644 (file)
@@ -15,6 +15,13 @@ embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
     tags=('long',)
 )
 
+embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs(
+    num_embeddings=(80,),
+    embedding_dim=(128, 256, 512),
+    batch_size=(10,),
+    tags=('short',)
+)
+
 conversion_ops = op_bench.op_list(
     attrs=(
         ('qembeddingbag_byte_prepack', torch.ops.quantized.embedding_bag_byte_prepack),
@@ -44,6 +51,16 @@ class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
     def forward(self, weight):
         return self.op_func(weight)
 
+class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase):
+    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
+        self.inputs = {
+            "weight": torch.rand(batch_size, num_embeddings, embedding_dim, dtype=torch.float) + 1
+        }
+        self.op_func = op_func
+
+    def forward(self, weight):
+        return self.op_func(weight)
+
 class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
     def init(self, num_embeddings, embedding_dim, op_func):
         weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
@@ -55,6 +72,16 @@ class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
     def forward(self, packed_weight):
         return self.op_func(packed_weight)
 
+class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase):
+    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
+        weight = torch.randn(batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float)
+        self.inputs = {
+            "packed_weight": weight.to(torch.uint8)
+        }
+        self.op_func = op_func
+
+    def forward(self, packed_weight):
+        return self.op_func(packed_weight)
 
 op_bench.generate_pt_tests_from_op_list(conversion_ops,
                                         embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
@@ -62,6 +89,12 @@ op_bench.generate_pt_tests_from_op_list(conversion_ops,
 op_bench.generate_pt_tests_from_op_list(unpack_ops,
                                         embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
                                         EmbeddingBagFusedToFloatBase)
+op_bench.generate_pt_tests_from_op_list(conversion_ops,
+                                        embeddingbag_conversion_three_dim_configs,
+                                        EmbeddingBagThreeDimFloatToFusedBase)
+op_bench.generate_pt_tests_from_op_list(unpack_ops,
+                                        embeddingbag_conversion_three_dim_configs,
+                                        EmbeddingBagThreeDimFusedToFloatBase)
 
 if __name__ == "__main__":
     op_bench.benchmark_runner.main()