From: Shijun Kong Date: Thu, 26 Aug 2021 23:06:17 +0000 (-0700) Subject: [pytorch][quant][oss] Support 2-bit embedding_bag op "embedding_bag_2bit_rowwise_offs... X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~673 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dfa35ab3e710848353aa1d313c5d9127ed2ef745;p=platform%2Fupstream%2Fpytorch.git [pytorch][quant][oss] Support 2-bit embedding_bag op "embedding_bag_2bit_rowwise_offsets" (#63658) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63658 Support 2-bit embedding_bag op "embedding_bag_2bit_rowwise_offsets" Reviewed By: jingsh, supriyar Differential Revision: D30454994 fbshipit-source-id: 7aa7bfe405c2ffff639d5658a35181036e162dc9 --- diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index 7adf05a..6aae3ba 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -141,9 +141,10 @@ at::Tensor& embedding_lookup_fallback_impl( } template -at::Tensor& embedding_bag_4bit_impl( +at::Tensor& embedding_bag_nbit_impl( at::Tensor& output, const at::Tensor& weight, + const int bit_width, const at::Tensor& indices, const at::Tensor& offsets, bool pruned_weights, @@ -174,8 +175,9 @@ at::Tensor& embedding_bag_4bit_impl( const auto weight_sizes = weight.sizes(); const int64_t weight_size = weight_sizes[1]; + int NUM_ELEM_PER_BYTE = 8 / bit_width; const int64_t D = - (weight_size - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset + (weight_size - 2 * sizeof(at::Half)) * NUM_ELEM_PER_BYTE; // NB: 2-byte fp16 scale and 2-byte zero_offset const int64_t M = offsets.sizes()[0]; int64_t output_size = M - 1; @@ -211,7 +213,7 @@ at::Tensor& embedding_bag_4bit_impl( if (!pruned_weights || fallback_to_no_sparse) { // Generate the fbgemm kernel auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit( - /*bit rate=*/4, + /*bit rate=*/bit_width, /*block size=*/block_size, /*has weights=*/per_sample_weights_.has_value(), /*normalize_by_lengths=*/false, @@ -234,11 +236,13 @@ at::Tensor& embedding_bag_4bit_impl( TORCH_CHECK( success, - "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for 4-bit input"); + "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for ", + bit_width, + "-bit input"); } else { auto kernel = fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse( - /*bit rate=*/4, + /*bit rate=*/bit_width, /*block_size=*/block_size, /*has weights=*/per_sample_weights_.has_value(), /*normalize_by_lengths=*/false, @@ -260,11 +264,14 @@ at::Tensor& embedding_bag_4bit_impl( /*compressed_indices_table=*/compressed_indices_mapping_data); TORCH_CHECK( success, - "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for 4-bit input"); + "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for ", + bit_width, + "-bit input"); } return output; #else - return embedding_lookup_fallback_impl( + if (bit_width == 4) { + return embedding_lookup_fallback_impl( weight, indices, offsets, @@ -275,6 +282,19 @@ at::Tensor& embedding_bag_4bit_impl( output_size, include_last_offset, (pruned_weights && !fallback_to_no_sparse)); + } + // bit_width == 2 + return embedding_lookup_fallback_impl( + weight, + indices, + offsets, + per_sample_weights_, + compressed_indices_mapping, + output, + D, + output_size, + include_last_offset, + (pruned_weights && !fallback_to_no_sparse)); #endif } @@ -519,9 +539,10 @@ at::Tensor& embedding_bag_byte_helper( is_embedding_op); } -at::Tensor& embedding_bag_4bit_helper( +at::Tensor& _embedding_bag_nbit_helper( at::Tensor& output, const at::Tensor& weight, + const int bit_width, const at::Tensor& indices, const c10::optional& offsets_in, bool pruned_weights, @@ -530,6 +551,10 @@ at::Tensor& embedding_bag_4bit_helper( bool include_last_offset) { c10::MaybeOwned offsets; TORCH_CHECK( + bit_width == 4 || bit_width == 2, + "qembedding/qembedding_bag operator supports bit_width 2 or 4, got ", + bit_width); + TORCH_CHECK( indices.dim() == 1 || indices.dim() == 2, "qembedding/qembedding_bag operator supports 1 or 2d indices, got ", indices.dim()); @@ -539,14 +564,14 @@ at::Tensor& embedding_bag_4bit_helper( if (indices.dim() == 2) { TORCH_CHECK( !offsets_in.has_value(), - "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); + "embedding_bag_4bit/embedding_bag_2bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); offsets = c10::MaybeOwned::owned(at::arange( 0, indices.numel(), indices.sizes()[1], indices.scalar_type())); } else { TORCH_CHECK( offsets_in.has_value(), - "embedding_bag_4bit operator expects offsets to be set for 1D indices."); + "embedding_bag_4bit/embedding_bag_2bit operator expects offsets to be set for 1D indices."); offsets = c10::MaybeOwned::borrowed(offsets_in.value()); } @@ -568,9 +593,10 @@ at::Tensor& embedding_bag_4bit_helper( // Using helper function to support different type combination without the // need to cast, which can be additional performance overhead if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) { - return embedding_bag_4bit_impl( + return embedding_bag_nbit_impl( output, weight, + bit_width, indices, *offsets, pruned_weights, @@ -579,9 +605,10 @@ at::Tensor& embedding_bag_4bit_helper( include_last_offset); } else if ( indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) { - return embedding_bag_4bit_impl( + return embedding_bag_nbit_impl( output, weight, + bit_width, indices, *offsets, pruned_weights, @@ -590,9 +617,10 @@ at::Tensor& embedding_bag_4bit_helper( include_last_offset); } else if ( indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) { - return embedding_bag_4bit_impl( + return embedding_bag_nbit_impl( output, weight, + bit_width, indices, *offsets, pruned_weights, @@ -600,9 +628,10 @@ at::Tensor& embedding_bag_4bit_helper( compressed_indices_mapping, include_last_offset); } - return embedding_bag_4bit_impl( + return embedding_bag_nbit_impl( output, weight, + bit_width, indices, *offsets, pruned_weights, @@ -650,9 +679,10 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( } auto output = at::empty({0}, packed_w.options().dtype(at::kFloat)); - return embedding_bag_4bit_helper( + return _embedding_bag_nbit_helper( output, packed_w, + 4, indices, offsets_in, pruned_weights, @@ -709,9 +739,44 @@ Tensor& embedding_bag_4bit_rowwise_offsets_out( per_sample_weights_.value().scalar_type(), " instead") } - return embedding_bag_4bit_helper( + return _embedding_bag_nbit_helper( + output, + weight, + 4, + indices, + offsets_in, + pruned_weights, + per_sample_weights_.has_value() + ? per_sample_weights_.value().to(at::kFloat) + : per_sample_weights_, + compressed_indices_mapping, + include_last_offset); +} + +Tensor& embedding_bag_2bit_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + + if (per_sample_weights_.has_value()) { + TORCH_CHECK( + (per_sample_weights_.value().scalar_type() == at::kFloat || + per_sample_weights_.value().scalar_type() == at::kHalf), + "Expect fp32 or fp16 weights, but found", + per_sample_weights_.value().scalar_type(), + " instead") + } + return _embedding_bag_nbit_helper( output, weight, + 2, indices, offsets_in, pruned_weights, @@ -784,6 +849,33 @@ Tensor embedding_bag_4bit_rowwise_offsets( return output; } +Tensor embedding_bag_2bit_rowwise_offsets( + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + + auto output = create_empty_from(weight, at::kFloat); + embedding_bag_2bit_rowwise_offsets_out( + output, + weight, + indices, + offsets_in, + false, // unused scale_grad_by_freq + 0, // unused mode + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset + ); + return output; +} + template class QEmbeddingBag final { public: @@ -869,6 +961,9 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"), embedding_bag_4bit_rowwise_offsets); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_rowwise_offsets"), + embedding_bag_2bit_rowwise_offsets); } } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 7cdb5cb..8ead74f 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -128,6 +128,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor")); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 1821267..9243fe2 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -3318,6 +3318,9 @@ class TestQuantizedEmbeddingOps(TestCase): if bit_rate == 4: pt_op = torch.ops.quantized.embedding_bag_4bit_rowwise_offsets pt_prepack_op = torch.ops.quantized.embedding_bag_4bit_prepack + elif bit_rate == 2: + pt_op = torch.ops.quantized.embedding_bag_2bit_rowwise_offsets + pt_prepack_op = torch.ops.quantized.embedding_bag_2bit_prepack weights = torch.from_numpy((np.random.random_sample(( num_embeddings, embedding_dim)) + 1).astype(np.float32)) @@ -3483,6 +3486,33 @@ class TestQuantizedEmbeddingOps(TestCase): sparsity=sparsity, atol=0.1, rtol=1e-2) + """ Tests the correctness of the embedding_bag_2bit quantized operator """ + @given(num_embeddings=st.integers(10, 100), + embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0), + num_offsets=st.integers(1, 20), + use_32bit_indices=st.booleans(), + use_32bit_offsets=st.booleans(), + enable_per_sample_weights=st.booleans(), + include_last_offset=st.booleans(), + fallback_to_no_sparse=st.booleans(), + sparsity=st.sampled_from([0.0, 0.5, 0.7])) + def test_embedding_bag_2bit(self, num_embeddings, + embedding_dim, num_offsets, + use_32bit_indices, + use_32bit_offsets, + enable_per_sample_weights, + include_last_offset, + fallback_to_no_sparse, + sparsity): + self.embedding_bag_rowwise_offsets_run(2, num_embeddings, + embedding_dim, num_offsets, + use_32bit_indices, use_32bit_offsets, + enable_per_sample_weights, + include_last_offset, + fallback_to_no_sparse, + sparsity=sparsity, + atol=1.0, rtol=1e-1) + """ Tests the correctness of the quantized embedding lookup operator """ @given(num_embeddings=st.integers(10, 100), embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0))