[pytorch][quant][oss] Support 2-bit embedding_bag op "embedding_bag_2bit_rowwise_offs...
authorShijun Kong <shijunk@fb.com>
Thu, 26 Aug 2021 23:06:17 +0000 (16:06 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 23:09:35 +0000 (16:09 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63658

Support 2-bit embedding_bag op "embedding_bag_2bit_rowwise_offsets"

Reviewed By: jingsh, supriyar

Differential Revision: D30454994

fbshipit-source-id: 7aa7bfe405c2ffff639d5658a35181036e162dc9

aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
aten/src/ATen/native/quantized/library.cpp
test/quantization/core/test_quantized_op.py

index 7adf05a..6aae3ba 100644 (file)
@@ -141,9 +141,10 @@ at::Tensor& embedding_lookup_fallback_impl(
 }
 
 template <typename IndexType, typename OffsetType>
-at::Tensor& embedding_bag_4bit_impl(
+at::Tensor& embedding_bag_nbit_impl(
     at::Tensor& output,
     const at::Tensor& weight,
+    const int bit_width,
     const at::Tensor& indices,
     const at::Tensor& offsets,
     bool pruned_weights,
@@ -174,8 +175,9 @@ at::Tensor& embedding_bag_4bit_impl(
 
   const auto weight_sizes = weight.sizes();
   const int64_t weight_size = weight_sizes[1];
+  int NUM_ELEM_PER_BYTE = 8 / bit_width;
   const int64_t D =
-      (weight_size - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset
+      (weight_size - 2 * sizeof(at::Half)) * NUM_ELEM_PER_BYTE; // NB: 2-byte fp16 scale and 2-byte zero_offset
   const int64_t M = offsets.sizes()[0];
 
   int64_t output_size = M - 1;
@@ -211,7 +213,7 @@ at::Tensor& embedding_bag_4bit_impl(
   if (!pruned_weights || fallback_to_no_sparse) {
     // Generate the fbgemm kernel
     auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit<IndexType, OffsetType>(
-        /*bit rate=*/4,
+        /*bit rate=*/bit_width,
         /*block size=*/block_size,
         /*has weights=*/per_sample_weights_.has_value(),
         /*normalize_by_lengths=*/false,
@@ -234,11 +236,13 @@ at::Tensor& embedding_bag_4bit_impl(
 
     TORCH_CHECK(
         success,
-        "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for 4-bit input");
+        "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for ",
+        bit_width,
+        "-bit input");
   } else {
     auto kernel =
         fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<IndexType, OffsetType>(
-            /*bit rate=*/4,
+            /*bit rate=*/bit_width,
             /*block_size=*/block_size,
             /*has weights=*/per_sample_weights_.has_value(),
             /*normalize_by_lengths=*/false,
@@ -260,11 +264,14 @@ at::Tensor& embedding_bag_4bit_impl(
         /*compressed_indices_table=*/compressed_indices_mapping_data);
     TORCH_CHECK(
         success,
-        "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for 4-bit input");
+        "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for ",
+        bit_width,
+        "-bit input");
   }
   return output;
 #else
-  return embedding_lookup_fallback_impl<IndexType, OffsetType, 4, 2>(
+  if (bit_width == 4) {
+    return embedding_lookup_fallback_impl<IndexType, OffsetType, 4, 2>(
       weight,
       indices,
       offsets,
@@ -275,6 +282,19 @@ at::Tensor& embedding_bag_4bit_impl(
       output_size,
       include_last_offset,
       (pruned_weights && !fallback_to_no_sparse));
+  }
+  // bit_width == 2
+  return embedding_lookup_fallback_impl<IndexType, OffsetType, 2, 4>(
+    weight,
+    indices,
+    offsets,
+    per_sample_weights_,
+    compressed_indices_mapping,
+    output,
+    D,
+    output_size,
+    include_last_offset,
+    (pruned_weights && !fallback_to_no_sparse));
 #endif
 }
 
@@ -519,9 +539,10 @@ at::Tensor& embedding_bag_byte_helper(
       is_embedding_op);
 }
 
-at::Tensor& embedding_bag_4bit_helper(
+at::Tensor& _embedding_bag_nbit_helper(
     at::Tensor& output,
     const at::Tensor& weight,
+    const int bit_width,
     const at::Tensor& indices,
     const c10::optional<at::Tensor>& offsets_in,
     bool pruned_weights,
@@ -530,6 +551,10 @@ at::Tensor& embedding_bag_4bit_helper(
     bool include_last_offset) {
   c10::MaybeOwned<at::Tensor> offsets;
   TORCH_CHECK(
+      bit_width == 4 || bit_width == 2,
+      "qembedding/qembedding_bag operator supports bit_width 2 or 4, got ",
+      bit_width);
+  TORCH_CHECK(
       indices.dim() == 1 || indices.dim() == 2,
       "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
       indices.dim());
@@ -539,14 +564,14 @@ at::Tensor& embedding_bag_4bit_helper(
   if (indices.dim() == 2) {
     TORCH_CHECK(
         !offsets_in.has_value(),
-        "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
+        "embedding_bag_4bit/embedding_bag_2bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
 
     offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(
         0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
   } else {
     TORCH_CHECK(
         offsets_in.has_value(),
-        "embedding_bag_4bit operator expects offsets to be set for 1D indices.");
+        "embedding_bag_4bit/embedding_bag_2bit operator expects offsets to be set for 1D indices.");
     offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
   }
 
@@ -568,9 +593,10 @@ at::Tensor& embedding_bag_4bit_helper(
   // Using helper function to support different type combination without the
   // need to cast, which can be additional performance overhead
   if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
-    return embedding_bag_4bit_impl<int, int>(
+    return embedding_bag_nbit_impl<int, int>(
         output,
         weight,
+        bit_width,
         indices,
         *offsets,
         pruned_weights,
@@ -579,9 +605,10 @@ at::Tensor& embedding_bag_4bit_helper(
         include_last_offset);
   } else if (
       indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) {
-    return embedding_bag_4bit_impl<int, int64_t>(
+    return embedding_bag_nbit_impl<int, int64_t>(
         output,
         weight,
+        bit_width,
         indices,
         *offsets,
         pruned_weights,
@@ -590,9 +617,10 @@ at::Tensor& embedding_bag_4bit_helper(
         include_last_offset);
   } else if (
       indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) {
-    return embedding_bag_4bit_impl<int64_t, int>(
+    return embedding_bag_nbit_impl<int64_t, int>(
         output,
         weight,
+        bit_width,
         indices,
         *offsets,
         pruned_weights,
@@ -600,9 +628,10 @@ at::Tensor& embedding_bag_4bit_helper(
         compressed_indices_mapping,
         include_last_offset);
   }
-  return embedding_bag_4bit_impl<int64_t, int64_t>(
+  return embedding_bag_nbit_impl<int64_t, int64_t>(
       output,
       weight,
+      bit_width,
       indices,
       *offsets,
       pruned_weights,
@@ -650,9 +679,10 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
   }
 
   auto output = at::empty({0}, packed_w.options().dtype(at::kFloat));
-  return embedding_bag_4bit_helper(
+  return _embedding_bag_nbit_helper(
     output,
     packed_w,
+    4,
     indices,
     offsets_in,
     pruned_weights,
@@ -709,9 +739,44 @@ Tensor& embedding_bag_4bit_rowwise_offsets_out(
         per_sample_weights_.value().scalar_type(),
         " instead")
   }
-  return embedding_bag_4bit_helper(
+  return _embedding_bag_nbit_helper(
+      output,
+      weight,
+      4,
+      indices,
+      offsets_in,
+      pruned_weights,
+      per_sample_weights_.has_value()
+          ? per_sample_weights_.value().to(at::kFloat)
+          : per_sample_weights_,
+      compressed_indices_mapping,
+      include_last_offset);
+}
+
+Tensor& embedding_bag_2bit_rowwise_offsets_out(
+    Tensor& output,
+    const Tensor& weight,
+    const Tensor& indices,
+    const c10::optional<Tensor>& offsets_in,
+    const bool /* scale_grad_by_freq */,
+    const int64_t /* mode */,
+    bool pruned_weights,
+    const c10::optional<Tensor>& per_sample_weights_,
+    const c10::optional<Tensor>& compressed_indices_mapping,
+    bool include_last_offset) {
+
+  if (per_sample_weights_.has_value()) {
+    TORCH_CHECK(
+        (per_sample_weights_.value().scalar_type() == at::kFloat ||
+         per_sample_weights_.value().scalar_type() == at::kHalf),
+        "Expect fp32 or fp16 weights, but found",
+        per_sample_weights_.value().scalar_type(),
+        " instead")
+  }
+  return _embedding_bag_nbit_helper(
       output,
       weight,
+      2,
       indices,
       offsets_in,
       pruned_weights,
@@ -784,6 +849,33 @@ Tensor embedding_bag_4bit_rowwise_offsets(
   return output;
 }
 
+Tensor embedding_bag_2bit_rowwise_offsets(
+    const Tensor& weight,
+    const Tensor& indices,
+    const c10::optional<Tensor>& offsets_in,
+    const bool /* scale_grad_by_freq */,
+    const int64_t /* mode */,
+    bool pruned_weights,
+    const c10::optional<Tensor>& per_sample_weights_,
+    const c10::optional<Tensor>& compressed_indices_mapping,
+    bool include_last_offset) {
+
+  auto output = create_empty_from(weight, at::kFloat);
+  embedding_bag_2bit_rowwise_offsets_out(
+    output,
+    weight,
+    indices,
+    offsets_in,
+    false, // unused scale_grad_by_freq
+    0, // unused mode
+    pruned_weights,
+    per_sample_weights_,
+    compressed_indices_mapping,
+    include_last_offset
+  );
+  return output;
+}
+
 template <int bit_rate>
 class QEmbeddingBag final {
  public:
@@ -869,6 +961,9 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
   m.impl(
       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"),
       embedding_bag_4bit_rowwise_offsets);
+  m.impl(
+      TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_rowwise_offsets"),
+      embedding_bag_2bit_rowwise_offsets);
 }
 } // namespace
 } // namespace native
index 7cdb5cb..8ead74f 100644 (file)
@@ -128,6 +128,7 @@ TORCH_LIBRARY(quantized, m) {
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
+  m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor"));
index 1821267..9243fe2 100644 (file)
@@ -3318,6 +3318,9 @@ class TestQuantizedEmbeddingOps(TestCase):
         if bit_rate == 4:
             pt_op = torch.ops.quantized.embedding_bag_4bit_rowwise_offsets
             pt_prepack_op = torch.ops.quantized.embedding_bag_4bit_prepack
+        elif bit_rate == 2:
+            pt_op = torch.ops.quantized.embedding_bag_2bit_rowwise_offsets
+            pt_prepack_op = torch.ops.quantized.embedding_bag_2bit_prepack
 
         weights = torch.from_numpy((np.random.random_sample((
             num_embeddings, embedding_dim)) + 1).astype(np.float32))
@@ -3483,6 +3486,33 @@ class TestQuantizedEmbeddingOps(TestCase):
                                                sparsity=sparsity,
                                                atol=0.1, rtol=1e-2)
 
+    """ Tests the correctness of the embedding_bag_2bit quantized operator """
+    @given(num_embeddings=st.integers(10, 100),
+           embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0),
+           num_offsets=st.integers(1, 20),
+           use_32bit_indices=st.booleans(),
+           use_32bit_offsets=st.booleans(),
+           enable_per_sample_weights=st.booleans(),
+           include_last_offset=st.booleans(),
+           fallback_to_no_sparse=st.booleans(),
+           sparsity=st.sampled_from([0.0, 0.5, 0.7]))
+    def test_embedding_bag_2bit(self, num_embeddings,
+                                embedding_dim, num_offsets,
+                                use_32bit_indices,
+                                use_32bit_offsets,
+                                enable_per_sample_weights,
+                                include_last_offset,
+                                fallback_to_no_sparse,
+                                sparsity):
+        self.embedding_bag_rowwise_offsets_run(2, num_embeddings,
+                                               embedding_dim, num_offsets,
+                                               use_32bit_indices, use_32bit_offsets,
+                                               enable_per_sample_weights,
+                                               include_last_offset,
+                                               fallback_to_no_sparse,
+                                               sparsity=sparsity,
+                                               atol=1.0, rtol=1e-1)
+
     """ Tests the correctness of the quantized embedding lookup operator """
     @given(num_embeddings=st.integers(10, 100),
            embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0))