+#include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
+
#include <c10/core/ScalarType.h>
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
namespace at {
namespace native {
-namespace {
// Note - This is a temporary pack function for embedding bag which quantizes
// and packs the float weight tensor. In the next step it will be replaced by a
//
// [[50. , 60.00000035],
// [70. , 80.00000035]]])
-Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
+Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
// The "last" dimension of an N-Dimensioned batch of embedding bags is
// quantization channel. E.g. for a 2D embedding bag, this has
// [ row, col ] dimensions, for batched of embedding bags, dimensions might be
const int32_t embedding_cols = weight_sizes[cols_dim];
// Add 8 bytes per column to store FP32 scale and zero_point per row.
const int32_t output_columns = embedding_cols + 2 * sizeof(float);
- Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
+ const auto weight_contig = weight.expect_contiguous(weight.suggest_memory_format());
// Adjust output dimensions to account for FP32 scale and zero_points.
std::vector<int64_t> output_shape = weight_sizes.vec();
output_shape[cols_dim] = output_columns;
-
- // Allocate output packed weights
- auto output = at::empty(
- output_shape,
- weight_contig.options().dtype(at::kByte),
- weight_contig.suggest_memory_format());
+ at::native::resize_(output, output_shape, c10::nullopt);
auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM
}
#else
- const auto float_weight = weight_contig.scalar_type() == at::ScalarType::Half
- ? weight_contig.to(at::ScalarType::Float)
- : weight_contig;
- const auto weight_data = float_weight.data_ptr<float>();
+ const auto weight_data = weight_contig->scalar_type() == at::ScalarType::Half
+ ? weight_contig->to(at::ScalarType::Float).data_ptr<float>()
+ : weight_contig->data_ptr<float>();
constexpr float kEpsilon = 1e-8f;
for (auto row: c10::irange(embedding_rows)) {
const float* input_row = weight_data + row * embedding_cols;
return output;
}
+Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
+ const auto weight_contig = weight.expect_contiguous(weight.suggest_memory_format());
+ auto output = at::detail::empty_cpu(
+ {0},
+ at::kByte,
+ weight_contig->layout(),
+ weight_contig->device(),
+ c10::nullopt,
+ c10::nullopt);
+ qembeddingbag_byte_prepack_out(output, weight);
+ return output;
+}
+
+namespace {
+
// TODO: Extend support to N-D batched embeddings, similar to qembeddingbag_byte_prepack
Tensor _qembeddingbag_nbit_prepack_helper(
const Tensor& weight,
std::vector<IValue> args3{c, 4};
testStaticRuntime(fmod_scalar, args2, args3);
}
+
+TEST(StaticRuntime, QEmbeddingBagByteUnpack) {
+ auto a = torch::randn({8, 16}, at::ScalarType::Float);
+ auto b = torch::randn({8*2, 16*2}, at::ScalarType::Float);
+
+ testStaticRuntime(embedding_bag_byte_prepack_script, {a});
+ testStaticRuntime(embedding_bag_byte_prepack_script, {a},{b});
+}
#include <ATen/native/layer_norm.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qembeddingbag.h>
+#include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/static/impl.h>
include_last_offset);
};
});
+
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_4bit_rowwise_offsets,
embedding_bag_4bit_rowwise_offsets,
};
});
+REGISTER_OPERATOR_FUNCTOR(
+ quantized::embedding_bag_byte_prepack,
+ embedding_bag_byte_prepack,
+ [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& weight = p_node->Input(0).toTensor();
+ if (p_node->Output(0).isNone()) {
+ p_node->Output(0) = at::native::qembeddingbag_byte_prepack(weight);
+ return;
+ }
+ auto& out_t = p_node->Output(0).toTensor();
+ fastResizeToZero(out_t);
+ at::native::qembeddingbag_byte_prepack_out(out_t, weight);
+ };
+ });
+
// The out variant takes precedence over native
REGISTER_OPERATOR_FUNCTOR(aten::narrow_copy, aten_narrow_copy, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(