}
template <typename IndexType, typename OffsetType>
-at::Tensor& embedding_bag_4bit_impl(
+at::Tensor& embedding_bag_nbit_impl(
at::Tensor& output,
const at::Tensor& weight,
+ const int bit_width,
const at::Tensor& indices,
const at::Tensor& offsets,
bool pruned_weights,
const auto weight_sizes = weight.sizes();
const int64_t weight_size = weight_sizes[1];
+ int NUM_ELEM_PER_BYTE = 8 / bit_width;
const int64_t D =
- (weight_size - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset
+ (weight_size - 2 * sizeof(at::Half)) * NUM_ELEM_PER_BYTE; // NB: 2-byte fp16 scale and 2-byte zero_offset
const int64_t M = offsets.sizes()[0];
int64_t output_size = M - 1;
if (!pruned_weights || fallback_to_no_sparse) {
// Generate the fbgemm kernel
auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit<IndexType, OffsetType>(
- /*bit rate=*/4,
+ /*bit rate=*/bit_width,
/*block size=*/block_size,
/*has weights=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,
TORCH_CHECK(
success,
- "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for 4-bit input");
+ "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for ",
+ bit_width,
+ "-bit input");
} else {
auto kernel =
fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<IndexType, OffsetType>(
- /*bit rate=*/4,
+ /*bit rate=*/bit_width,
/*block_size=*/block_size,
/*has weights=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,
/*compressed_indices_table=*/compressed_indices_mapping_data);
TORCH_CHECK(
success,
- "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for 4-bit input");
+ "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for ",
+ bit_width,
+ "-bit input");
}
return output;
#else
- return embedding_lookup_fallback_impl<IndexType, OffsetType, 4, 2>(
+ if (bit_width == 4) {
+ return embedding_lookup_fallback_impl<IndexType, OffsetType, 4, 2>(
weight,
indices,
offsets,
output_size,
include_last_offset,
(pruned_weights && !fallback_to_no_sparse));
+ }
+ // bit_width == 2
+ return embedding_lookup_fallback_impl<IndexType, OffsetType, 2, 4>(
+ weight,
+ indices,
+ offsets,
+ per_sample_weights_,
+ compressed_indices_mapping,
+ output,
+ D,
+ output_size,
+ include_last_offset,
+ (pruned_weights && !fallback_to_no_sparse));
#endif
}
is_embedding_op);
}
-at::Tensor& embedding_bag_4bit_helper(
+at::Tensor& _embedding_bag_nbit_helper(
at::Tensor& output,
const at::Tensor& weight,
+ const int bit_width,
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool pruned_weights,
bool include_last_offset) {
c10::MaybeOwned<at::Tensor> offsets;
TORCH_CHECK(
+ bit_width == 4 || bit_width == 2,
+ "qembedding/qembedding_bag operator supports bit_width 2 or 4, got ",
+ bit_width);
+ TORCH_CHECK(
indices.dim() == 1 || indices.dim() == 2,
"qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
indices.dim());
if (indices.dim() == 2) {
TORCH_CHECK(
!offsets_in.has_value(),
- "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
+ "embedding_bag_4bit/embedding_bag_2bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(
0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
} else {
TORCH_CHECK(
offsets_in.has_value(),
- "embedding_bag_4bit operator expects offsets to be set for 1D indices.");
+ "embedding_bag_4bit/embedding_bag_2bit operator expects offsets to be set for 1D indices.");
offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
}
// Using helper function to support different type combination without the
// need to cast, which can be additional performance overhead
if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
- return embedding_bag_4bit_impl<int, int>(
+ return embedding_bag_nbit_impl<int, int>(
output,
weight,
+ bit_width,
indices,
*offsets,
pruned_weights,
include_last_offset);
} else if (
indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) {
- return embedding_bag_4bit_impl<int, int64_t>(
+ return embedding_bag_nbit_impl<int, int64_t>(
output,
weight,
+ bit_width,
indices,
*offsets,
pruned_weights,
include_last_offset);
} else if (
indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) {
- return embedding_bag_4bit_impl<int64_t, int>(
+ return embedding_bag_nbit_impl<int64_t, int>(
output,
weight,
+ bit_width,
indices,
*offsets,
pruned_weights,
compressed_indices_mapping,
include_last_offset);
}
- return embedding_bag_4bit_impl<int64_t, int64_t>(
+ return embedding_bag_nbit_impl<int64_t, int64_t>(
output,
weight,
+ bit_width,
indices,
*offsets,
pruned_weights,
}
auto output = at::empty({0}, packed_w.options().dtype(at::kFloat));
- return embedding_bag_4bit_helper(
+ return _embedding_bag_nbit_helper(
output,
packed_w,
+ 4,
indices,
offsets_in,
pruned_weights,
per_sample_weights_.value().scalar_type(),
" instead")
}
- return embedding_bag_4bit_helper(
+ return _embedding_bag_nbit_helper(
+ output,
+ weight,
+ 4,
+ indices,
+ offsets_in,
+ pruned_weights,
+ per_sample_weights_.has_value()
+ ? per_sample_weights_.value().to(at::kFloat)
+ : per_sample_weights_,
+ compressed_indices_mapping,
+ include_last_offset);
+}
+
+Tensor& embedding_bag_2bit_rowwise_offsets_out(
+ Tensor& output,
+ const Tensor& weight,
+ const Tensor& indices,
+ const c10::optional<Tensor>& offsets_in,
+ const bool /* scale_grad_by_freq */,
+ const int64_t /* mode */,
+ bool pruned_weights,
+ const c10::optional<Tensor>& per_sample_weights_,
+ const c10::optional<Tensor>& compressed_indices_mapping,
+ bool include_last_offset) {
+
+ if (per_sample_weights_.has_value()) {
+ TORCH_CHECK(
+ (per_sample_weights_.value().scalar_type() == at::kFloat ||
+ per_sample_weights_.value().scalar_type() == at::kHalf),
+ "Expect fp32 or fp16 weights, but found",
+ per_sample_weights_.value().scalar_type(),
+ " instead")
+ }
+ return _embedding_bag_nbit_helper(
output,
weight,
+ 2,
indices,
offsets_in,
pruned_weights,
return output;
}
+Tensor embedding_bag_2bit_rowwise_offsets(
+ const Tensor& weight,
+ const Tensor& indices,
+ const c10::optional<Tensor>& offsets_in,
+ const bool /* scale_grad_by_freq */,
+ const int64_t /* mode */,
+ bool pruned_weights,
+ const c10::optional<Tensor>& per_sample_weights_,
+ const c10::optional<Tensor>& compressed_indices_mapping,
+ bool include_last_offset) {
+
+ auto output = create_empty_from(weight, at::kFloat);
+ embedding_bag_2bit_rowwise_offsets_out(
+ output,
+ weight,
+ indices,
+ offsets_in,
+ false, // unused scale_grad_by_freq
+ 0, // unused mode
+ pruned_weights,
+ per_sample_weights_,
+ compressed_indices_mapping,
+ include_last_offset
+ );
+ return output;
+}
+
template <int bit_rate>
class QEmbeddingBag final {
public:
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"),
embedding_bag_4bit_rowwise_offsets);
+ m.impl(
+ TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_rowwise_offsets"),
+ embedding_bag_2bit_rowwise_offsets);
}
} // namespace
} // namespace native
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
+ m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor"));
if bit_rate == 4:
pt_op = torch.ops.quantized.embedding_bag_4bit_rowwise_offsets
pt_prepack_op = torch.ops.quantized.embedding_bag_4bit_prepack
+ elif bit_rate == 2:
+ pt_op = torch.ops.quantized.embedding_bag_2bit_rowwise_offsets
+ pt_prepack_op = torch.ops.quantized.embedding_bag_2bit_prepack
weights = torch.from_numpy((np.random.random_sample((
num_embeddings, embedding_dim)) + 1).astype(np.float32))
sparsity=sparsity,
atol=0.1, rtol=1e-2)
+ """ Tests the correctness of the embedding_bag_2bit quantized operator """
+ @given(num_embeddings=st.integers(10, 100),
+ embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0),
+ num_offsets=st.integers(1, 20),
+ use_32bit_indices=st.booleans(),
+ use_32bit_offsets=st.booleans(),
+ enable_per_sample_weights=st.booleans(),
+ include_last_offset=st.booleans(),
+ fallback_to_no_sparse=st.booleans(),
+ sparsity=st.sampled_from([0.0, 0.5, 0.7]))
+ def test_embedding_bag_2bit(self, num_embeddings,
+ embedding_dim, num_offsets,
+ use_32bit_indices,
+ use_32bit_offsets,
+ enable_per_sample_weights,
+ include_last_offset,
+ fallback_to_no_sparse,
+ sparsity):
+ self.embedding_bag_rowwise_offsets_run(2, num_embeddings,
+ embedding_dim, num_offsets,
+ use_32bit_indices, use_32bit_offsets,
+ enable_per_sample_weights,
+ include_last_offset,
+ fallback_to_no_sparse,
+ sparsity=sparsity,
+ atol=1.0, rtol=1e-1)
+
""" Tests the correctness of the quantized embedding lookup operator """
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0))