template <uint8_t bits_per_dim>
__forceinline__ __device__ void
-accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias) {
+accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias, float sample_weight) {
constexpr uint8_t dims_per_byte = 8 / bits_per_dim;
for (uint8_t i = 0; i < dims_per_byte; i++) {
float4 res = dequantize_intx<bits_per_dim>(packedVals, scale_bias, 4 * bits_per_dim * i /* offset_bits */);
// Accumulate in float32.
- acc[i].x += res.x;
- acc[i].y += res.y;
- acc[i].z += res.z;
- acc[i].w += res.w;
+ acc[i].x += (res.x * sample_weight);
+ acc[i].y += (res.y * sample_weight);
+ acc[i].z += (res.z * sample_weight);
+ acc[i].w += (res.w * sample_weight);
}
}
const PackedTensorAccessor32<index_t, 1, RestrictPtrTraits> indices,
const PackedTensorAccessor32<index_t, 1, RestrictPtrTraits> offsets,
const bool /* pruned_weights */,
- const c10::optional<Tensor>& per_sample_weights_,
+ const PackedTensorAccessor32<float, 1, RestrictPtrTraits> per_sample_weights_,
const c10::optional<Tensor>& compressed_indices_mapping,
const bool include_last_offset,
PackedTensorAccessor32<float, 2, RestrictPtrTraits> output) {
const int32_t D_bytes = weight.size(1);
+ bool use_per_sample = per_sample_weights_.size(0) > 0;
+
int64_t indices_start = offsets[t * B + b];
int64_t indices_end;
if (include_last_offset) {
}
for (int32_t l = indices_start; l < indices_end; ++l) {
int64_t idx = indices[l];
+ float sample_weight = use_per_sample ? per_sample_weights_[l] : 1.0f;
const uint8_t* __restrict__ row = &weights[idx * D_bytes];
float2 scale_bias;
if (fp32_scale_bias) {
uint32_t v0 = reinterpret_cast<const uint32_t*>(&row[byte_offset])[0];
- accumulate_packed_intx<bits_per_dim>(accumulator, v0, scale_bias);
+ accumulate_packed_intx<bits_per_dim>(accumulator, v0, scale_bias, sample_weight);
}
const int D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias
const int64_t M = offsets.sizes()[0];
TORCH_CHECK(D % 4 == 0);
- TORCH_CHECK(
- !per_sample_weights_.has_value(),
- "Per sample weights not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
+ if(per_sample_weights_.has_value()) {
+ TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat,
+ "Per sample weights expected scalar type ", at::kFloat, " but got ",
+ per_sample_weights_.value().scalar_type());
+ }
TORCH_CHECK(
!compressed_indices_mapping.has_value(),
"Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
int64_t output_size = include_last_offset ? M - 1 : M;
+ at::Tensor sample_weights;
+ if (per_sample_weights_.has_value()) {
+ sample_weights = per_sample_weights_.value();
+ } else {
+ sample_weights = create_empty_from(output, kFloat);
+ }
+
const std::vector<int64_t> shape = {output_size, D};
at::native::resize_(output, shape, c10::nullopt);
AT_DISPATCH_INDEX_TYPES(
indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
false /* pruned_weights */,
- per_sample_weights_,
+ sample_weights.packed_accessor32<float, 1, RestrictPtrTraits>(),
compressed_indices_mapping,
include_last_offset,
output.packed_accessor32<float, 2, RestrictPtrTraits>());
const int D = 2*(weight_sizes[1] - 4); // NB: -4 to account for scale and bias @fp16
const int64_t M = offsets.sizes()[0];
TORCH_CHECK(D % 8 == 0);
- TORCH_CHECK(
- !per_sample_weights_.has_value(),
- "Per sample weights not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
+ if(per_sample_weights_.has_value()) {
+ TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat,
+ "Per sample weights expected scalar type ", at::kFloat, " but got ",
+ per_sample_weights_.value().scalar_type());
+ }
TORCH_CHECK(
!compressed_indices_mapping.has_value(),
"Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda");
int64_t output_size = include_last_offset ? M - 1 : M;
+ at::Tensor sample_weights;
+ if (per_sample_weights_.has_value()) {
+ sample_weights = per_sample_weights_.value();
+ } else {
+ sample_weights = create_empty_from(output, kFloat);
+ }
+
const std::vector<int64_t> shape = {output_size, D};
at::native::resize_(output, shape, c10::nullopt);
AT_DISPATCH_INDEX_TYPES(
indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
false /* pruned_weights */,
- per_sample_weights_,
+ sample_weights.packed_accessor32<float, 1, RestrictPtrTraits>(),
compressed_indices_mapping,
include_last_offset,
output.packed_accessor32<float, 2, RestrictPtrTraits>());