[pytorch] add per_sample_weights support for embedding_bag_4bit_rowwise_offsets ...
authorPaul Johnson <johnsonpaul@fb.com>
Fri, 27 Aug 2021 00:28:35 +0000 (17:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 00:31:45 +0000 (17:31 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63605

Reviewed By: houseroad

Differential Revision: D30434664

fbshipit-source-id: eb4cbae3c705f9dec5c073a56f0f23daee353bc1

aten/src/ATen/native/quantized/cuda/embedding_bag.cu

index 6d44ce0..55b0b0d 100644 (file)
@@ -56,15 +56,15 @@ dequantize_intx(uint32_t packedVals, float2 scale_bias, uint8_t offset_bits) {
 
 template <uint8_t bits_per_dim>
 __forceinline__ __device__ void
-accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias) {
+accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias, float sample_weight) {
   constexpr uint8_t dims_per_byte = 8 / bits_per_dim;
   for (uint8_t i = 0; i < dims_per_byte; i++) {
     float4 res = dequantize_intx<bits_per_dim>(packedVals, scale_bias, 4 * bits_per_dim * i /* offset_bits */);
     // Accumulate in float32.
-    acc[i].x += res.x;
-    acc[i].y += res.y;
-    acc[i].z += res.z;
-    acc[i].w += res.w;
+    acc[i].x += (res.x * sample_weight);
+    acc[i].y += (res.y * sample_weight);
+    acc[i].z += (res.z * sample_weight);
+    acc[i].w += (res.w * sample_weight);
   }
 }
 
@@ -77,7 +77,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel(
     const PackedTensorAccessor32<index_t, 1, RestrictPtrTraits> indices,
     const PackedTensorAccessor32<index_t, 1, RestrictPtrTraits> offsets,
     const bool /* pruned_weights */,
-    const c10::optional<Tensor>& per_sample_weights_,
+    const PackedTensorAccessor32<float, 1, RestrictPtrTraits> per_sample_weights_,
     const c10::optional<Tensor>& compressed_indices_mapping,
     const bool include_last_offset,
     PackedTensorAccessor32<float, 2, RestrictPtrTraits> output) {
@@ -96,6 +96,8 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel(
 
   const int32_t D_bytes = weight.size(1);
 
+  bool use_per_sample = per_sample_weights_.size(0) > 0;
+
   int64_t indices_start = offsets[t * B + b];
   int64_t indices_end;
   if (include_last_offset) {
@@ -124,6 +126,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel(
     }
     for (int32_t l = indices_start; l < indices_end; ++l) {
       int64_t idx = indices[l];
+      float sample_weight = use_per_sample ? per_sample_weights_[l] : 1.0f;
       const uint8_t* __restrict__ row = &weights[idx * D_bytes];
       float2 scale_bias;
       if (fp32_scale_bias) {
@@ -138,7 +141,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel(
 
       uint32_t v0 = reinterpret_cast<const uint32_t*>(&row[byte_offset])[0];
 
-      accumulate_packed_intx<bits_per_dim>(accumulator, v0, scale_bias);
+      accumulate_packed_intx<bits_per_dim>(accumulator, v0, scale_bias, sample_weight);
     }
 
 
@@ -204,9 +207,11 @@ at::Tensor& embedding_bag_byte_impl(
   const int D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias
   const int64_t M = offsets.sizes()[0];
   TORCH_CHECK(D % 4 == 0);
-  TORCH_CHECK(
-      !per_sample_weights_.has_value(),
-      "Per sample weights not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
+  if(per_sample_weights_.has_value()) {
+      TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat,
+              "Per sample weights expected scalar type ", at::kFloat, " but got ",
+              per_sample_weights_.value().scalar_type());
+  }
   TORCH_CHECK(
       !compressed_indices_mapping.has_value(),
       "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
@@ -215,6 +220,13 @@ at::Tensor& embedding_bag_byte_impl(
 
   int64_t output_size = include_last_offset ? M - 1 : M;
 
+  at::Tensor sample_weights;
+  if (per_sample_weights_.has_value()) {
+      sample_weights = per_sample_weights_.value();
+  } else {
+      sample_weights = create_empty_from(output, kFloat);
+  }
+
   const std::vector<int64_t> shape = {output_size, D};
   at::native::resize_(output, shape, c10::nullopt);
   AT_DISPATCH_INDEX_TYPES(
@@ -228,7 +240,7 @@ at::Tensor& embedding_bag_byte_impl(
             indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
             offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
             false /* pruned_weights */,
-            per_sample_weights_,
+            sample_weights.packed_accessor32<float, 1, RestrictPtrTraits>(),
             compressed_indices_mapping,
             include_last_offset,
             output.packed_accessor32<float, 2, RestrictPtrTraits>());
@@ -377,9 +389,11 @@ at::Tensor& embedding_bag_4bit_impl(
   const int D = 2*(weight_sizes[1] - 4); // NB: -4 to account for scale and bias @fp16
   const int64_t M = offsets.sizes()[0];
   TORCH_CHECK(D % 8 == 0);
-  TORCH_CHECK(
-      !per_sample_weights_.has_value(),
-      "Per sample weights not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
+  if(per_sample_weights_.has_value()) {
+      TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat,
+              "Per sample weights expected scalar type ", at::kFloat, " but got ",
+              per_sample_weights_.value().scalar_type());
+  }
   TORCH_CHECK(
       !compressed_indices_mapping.has_value(),
       "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
@@ -388,6 +402,13 @@ at::Tensor& embedding_bag_4bit_impl(
 
   int64_t output_size = include_last_offset ? M - 1 : M;
 
+  at::Tensor sample_weights;
+  if (per_sample_weights_.has_value()) {
+      sample_weights = per_sample_weights_.value();
+  } else {
+      sample_weights = create_empty_from(output, kFloat);
+  }
+
   const std::vector<int64_t> shape = {output_size, D};
   at::native::resize_(output, shape, c10::nullopt);
   AT_DISPATCH_INDEX_TYPES(
@@ -401,7 +422,7 @@ at::Tensor& embedding_bag_4bit_impl(
             indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
             offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
             false /* pruned_weights */,
-            per_sample_weights_,
+            sample_weights.packed_accessor32<float, 1, RestrictPtrTraits>(),
             compressed_indices_mapping,
             include_last_offset,
             output.packed_accessor32<float, 2, RestrictPtrTraits>());