From: Paul Johnson Date: Fri, 27 Aug 2021 00:28:35 +0000 (-0700) Subject: [pytorch] add per_sample_weights support for embedding_bag_4bit_rowwise_offsets ... X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~669 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0c9dce90ed6a12d81b0e769b76d6b0c282326823;p=platform%2Fupstream%2Fpytorch.git [pytorch] add per_sample_weights support for embedding_bag_4bit_rowwise_offsets (#63605) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63605 Reviewed By: houseroad Differential Revision: D30434664 fbshipit-source-id: eb4cbae3c705f9dec5c073a56f0f23daee353bc1 --- diff --git a/aten/src/ATen/native/quantized/cuda/embedding_bag.cu b/aten/src/ATen/native/quantized/cuda/embedding_bag.cu index 6d44ce0..55b0b0d 100644 --- a/aten/src/ATen/native/quantized/cuda/embedding_bag.cu +++ b/aten/src/ATen/native/quantized/cuda/embedding_bag.cu @@ -56,15 +56,15 @@ dequantize_intx(uint32_t packedVals, float2 scale_bias, uint8_t offset_bits) { template __forceinline__ __device__ void -accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias) { +accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias, float sample_weight) { constexpr uint8_t dims_per_byte = 8 / bits_per_dim; for (uint8_t i = 0; i < dims_per_byte; i++) { float4 res = dequantize_intx(packedVals, scale_bias, 4 * bits_per_dim * i /* offset_bits */); // Accumulate in float32. - acc[i].x += res.x; - acc[i].y += res.y; - acc[i].z += res.z; - acc[i].w += res.w; + acc[i].x += (res.x * sample_weight); + acc[i].y += (res.y * sample_weight); + acc[i].z += (res.z * sample_weight); + acc[i].w += (res.w * sample_weight); } } @@ -77,7 +77,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel( const PackedTensorAccessor32 indices, const PackedTensorAccessor32 offsets, const bool /* pruned_weights */, - const c10::optional& per_sample_weights_, + const PackedTensorAccessor32 per_sample_weights_, const c10::optional& compressed_indices_mapping, const bool include_last_offset, PackedTensorAccessor32 output) { @@ -96,6 +96,8 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel( const int32_t D_bytes = weight.size(1); + bool use_per_sample = per_sample_weights_.size(0) > 0; + int64_t indices_start = offsets[t * B + b]; int64_t indices_end; if (include_last_offset) { @@ -124,6 +126,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel( } for (int32_t l = indices_start; l < indices_end; ++l) { int64_t idx = indices[l]; + float sample_weight = use_per_sample ? per_sample_weights_[l] : 1.0f; const uint8_t* __restrict__ row = &weights[idx * D_bytes]; float2 scale_bias; if (fp32_scale_bias) { @@ -138,7 +141,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel( uint32_t v0 = reinterpret_cast(&row[byte_offset])[0]; - accumulate_packed_intx(accumulator, v0, scale_bias); + accumulate_packed_intx(accumulator, v0, scale_bias, sample_weight); } @@ -204,9 +207,11 @@ at::Tensor& embedding_bag_byte_impl( const int D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias const int64_t M = offsets.sizes()[0]; TORCH_CHECK(D % 4 == 0); - TORCH_CHECK( - !per_sample_weights_.has_value(), - "Per sample weights not yet implemented for embedding_bag_byte_rowwise_offsets_cuda"); + if(per_sample_weights_.has_value()) { + TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat, + "Per sample weights expected scalar type ", at::kFloat, " but got ", + per_sample_weights_.value().scalar_type()); + } TORCH_CHECK( !compressed_indices_mapping.has_value(), "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda"); @@ -215,6 +220,13 @@ at::Tensor& embedding_bag_byte_impl( int64_t output_size = include_last_offset ? M - 1 : M; + at::Tensor sample_weights; + if (per_sample_weights_.has_value()) { + sample_weights = per_sample_weights_.value(); + } else { + sample_weights = create_empty_from(output, kFloat); + } + const std::vector shape = {output_size, D}; at::native::resize_(output, shape, c10::nullopt); AT_DISPATCH_INDEX_TYPES( @@ -228,7 +240,7 @@ at::Tensor& embedding_bag_byte_impl( indices.packed_accessor32(), offsets.packed_accessor32(), false /* pruned_weights */, - per_sample_weights_, + sample_weights.packed_accessor32(), compressed_indices_mapping, include_last_offset, output.packed_accessor32()); @@ -377,9 +389,11 @@ at::Tensor& embedding_bag_4bit_impl( const int D = 2*(weight_sizes[1] - 4); // NB: -4 to account for scale and bias @fp16 const int64_t M = offsets.sizes()[0]; TORCH_CHECK(D % 8 == 0); - TORCH_CHECK( - !per_sample_weights_.has_value(), - "Per sample weights not yet implemented for embedding_bag_byte_rowwise_offsets_cuda"); + if(per_sample_weights_.has_value()) { + TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat, + "Per sample weights expected scalar type ", at::kFloat, " but got ", + per_sample_weights_.value().scalar_type()); + } TORCH_CHECK( !compressed_indices_mapping.has_value(), "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda"); @@ -388,6 +402,13 @@ at::Tensor& embedding_bag_4bit_impl( int64_t output_size = include_last_offset ? M - 1 : M; + at::Tensor sample_weights; + if (per_sample_weights_.has_value()) { + sample_weights = per_sample_weights_.value(); + } else { + sample_weights = create_empty_from(output, kFloat); + } + const std::vector shape = {output_size, D}; at::native::resize_(output, shape, c10::nullopt); AT_DISPATCH_INDEX_TYPES( @@ -401,7 +422,7 @@ at::Tensor& embedding_bag_4bit_impl( indices.packed_accessor32(), offsets.packed_accessor32(), false /* pruned_weights */, - per_sample_weights_, + sample_weights.packed_accessor32(), compressed_indices_mapping, include_last_offset, output.packed_accessor32());