[Static Runtime] Add out variant of quantized::embedding_bag_byte_prepack (#64081)
authorDon Jang <djang@fb.com>
Fri, 27 Aug 2021 17:42:50 +0000 (10:42 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 17:53:23 +0000 (10:53 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64081

This change add an out variant of `quantized::embedding_bag_byte_prepack`.

Test Plan:
- Added `ShapeInferenceTest.QEmbeddingBagByteUnpack`.

- Observed

```
V0824 13:38:49.723708 1322143 impl.cpp:1394] Switch to out variant for node: %2 : Tensor = quantized::embedding_bag_byte_prepack(%input)
```

Reviewed By: hlu1

Differential Revision: D30504216

fbshipit-source-id: 1d9d428e77a15bcc7da373d65e7ffabaf9c6caf2

aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h [new file with mode: 0644]
benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/ops.cpp

index 5d9abce..614e274 100644 (file)
@@ -1,3 +1,5 @@
+#include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
+
 #include <c10/core/ScalarType.h>
 #include <ATen/ATen.h>
 #include <ATen/Parallel.h>
@@ -122,7 +124,6 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
 
 namespace at {
 namespace native {
-namespace {
 
 // Note - This is a temporary pack function for embedding bag which quantizes
 // and packs the float weight tensor. In the next step it will be replaced by a
@@ -184,7 +185,7 @@ namespace {
 //
 //        [[50.        , 60.00000035],
 //         [70.        , 80.00000035]]])
-Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
+Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
   // The "last" dimension of an N-Dimensioned batch of embedding bags is
   // quantization channel. E.g. for a 2D embedding bag, this has
   // [ row, col ] dimensions, for batched of embedding bags, dimensions might be
@@ -208,17 +209,12 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
   const int32_t embedding_cols = weight_sizes[cols_dim];
   // Add 8 bytes per column to store FP32 scale and zero_point per row.
   const int32_t output_columns = embedding_cols + 2 * sizeof(float);
-  Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
+  const auto weight_contig = weight.expect_contiguous(weight.suggest_memory_format());
 
   // Adjust output dimensions to account for FP32 scale and zero_points.
   std::vector<int64_t> output_shape = weight_sizes.vec();
   output_shape[cols_dim] = output_columns;
-
-  // Allocate output packed weights
-  auto output = at::empty(
-      output_shape,
-      weight_contig.options().dtype(at::kByte),
-      weight_contig.suggest_memory_format());
+  at::native::resize_(output, output_shape, c10::nullopt);
   auto* output_data = output.data_ptr<uint8_t>();
 
 #ifdef USE_FBGEMM
@@ -246,10 +242,9 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
   }
 
 #else
-  const auto float_weight = weight_contig.scalar_type() == at::ScalarType::Half
-    ? weight_contig.to(at::ScalarType::Float)
-    : weight_contig;
-  const auto weight_data = float_weight.data_ptr<float>();
+  const auto weight_data = weight_contig->scalar_type() == at::ScalarType::Half
+    ? weight_contig->to(at::ScalarType::Float).data_ptr<float>()
+    : weight_contig->data_ptr<float>();
   constexpr float kEpsilon = 1e-8f;
   for (auto row: c10::irange(embedding_rows)) {
     const float* input_row = weight_data + row * embedding_cols;
@@ -276,6 +271,21 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
   return output;
 }
 
+Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
+  const auto weight_contig = weight.expect_contiguous(weight.suggest_memory_format());
+  auto output = at::detail::empty_cpu(
+      {0},
+      at::kByte,
+      weight_contig->layout(),
+      weight_contig->device(),
+      c10::nullopt,
+      c10::nullopt);
+  qembeddingbag_byte_prepack_out(output, weight);
+  return output;
+}
+
+namespace {
+
 // TODO: Extend support to N-D batched embeddings, similar to qembeddingbag_byte_prepack
 Tensor _qembeddingbag_nbit_prepack_helper(
     const Tensor& weight,
diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h
new file mode 100644 (file)
index 0000000..c52cbae
--- /dev/null
@@ -0,0 +1,11 @@
+#include <ATen/ATen.h>
+
+namespace at {
+namespace native {
+
+Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight);
+
+Tensor qembeddingbag_byte_prepack(const Tensor& weight);
+
+} // namespace native
+} // namespace at
index 477b191..bcc975b 100644 (file)
@@ -772,3 +772,11 @@ const auto fmod_scalar = R"JIT(
   def forward(self, a: Tensor, b: int):
       return torch.fmod(a, b).clone()
 )JIT";
+
+const std::string embedding_bag_byte_prepack_script = R"IR(
+  graph(%input: Tensor):
+      %none : None = prim::Constant()
+      %output: Tensor = quantized::embedding_bag_byte_prepack(%input)
+      %res: Tensor = aten::clone(%output, %none)
+      return (%res)
+)IR";
index bd213c7..1e987a9 100644 (file)
@@ -1257,3 +1257,11 @@ TEST(StaticRuntime, IndividualOps_FmodScalar) {
   std::vector<IValue> args3{c, 4};
   testStaticRuntime(fmod_scalar, args2, args3);
 }
+
+TEST(StaticRuntime, QEmbeddingBagByteUnpack) {
+  auto a = torch::randn({8, 16}, at::ScalarType::Float);
+  auto b = torch::randn({8*2, 16*2}, at::ScalarType::Float);
+
+  testStaticRuntime(embedding_bag_byte_prepack_script, {a});
+  testStaticRuntime(embedding_bag_byte_prepack_script, {a},{b});
+}
index 36f796f..f171d28 100644 (file)
@@ -14,6 +14,7 @@
 #include <ATen/native/layer_norm.h>
 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
 #include <ATen/native/quantized/cpu/qembeddingbag.h>
+#include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
 #include <c10/util/irange.h>
 #include <torch/csrc/jit/ir/ir.h>
 #include <torch/csrc/jit/runtime/static/impl.h>
@@ -761,6 +762,7 @@ REGISTER_OPERATOR_FUNCTOR(
             include_last_offset);
       };
     });
+
 REGISTER_OPERATOR_FUNCTOR(
     quantized::embedding_bag_4bit_rowwise_offsets,
     embedding_bag_4bit_rowwise_offsets,
@@ -799,6 +801,27 @@ REGISTER_OPERATOR_FUNCTOR(
       };
     });
 
+REGISTER_OPERATOR_FUNCTOR(
+    quantized::embedding_bag_byte_prepack,
+    embedding_bag_byte_prepack,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema(
+              "quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        const auto& weight = p_node->Input(0).toTensor();
+        if (p_node->Output(0).isNone()) {
+          p_node->Output(0) = at::native::qembeddingbag_byte_prepack(weight);
+          return;
+        }
+        auto& out_t = p_node->Output(0).toTensor();
+        fastResizeToZero(out_t);
+        at::native::qembeddingbag_byte_prepack_out(out_t, weight);
+      };
+    });
+
 // The out variant takes precedence over native
 REGISTER_OPERATOR_FUNCTOR(aten::narrow_copy, aten_narrow_copy, [](Node* n) -> SROperator {
   if (!n->matches(torch::schema(