Wrap workaround for cpp custom types a bit prettier and add an example (#18791)
authorDmytro Dzhulgakov <dzhulgakov@fb.com>
Fri, 5 Apr 2019 18:14:11 +0000 (11:14 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 18:20:13 +0000 (11:20 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18791

As a temporary demonstration on how to extend this hack further until custom C types are ready.

Reviewed By: jamesr66a

Differential Revision: D14742020

fbshipit-source-id: 0f2fd83ae56ab2abe16977a1829ed421e6abe74b

aten/src/ATen/cpp_custom_type_hack.h [new file with mode: 0644]
aten/src/ATen/native/QuantizedLinear.cpp

diff --git a/aten/src/ATen/cpp_custom_type_hack.h b/aten/src/ATen/cpp_custom_type_hack.h
new file mode 100644 (file)
index 0000000..211f1a0
--- /dev/null
@@ -0,0 +1,48 @@
+// WARNING! WARNING! WARNING!
+// This file is a temporary hack to enable development of pytorch quantization
+//
+// It's a stub for wrapping arbitrary cpp types in TorchScript. Proper
+// implementation (under development) is to use TorchScript custom types.
+// In the meantime, we abuse ByteTensor with custom deleter for this purpose.
+//
+// Template argument <T> has to be registered with CAFFE_KNOWN_TYPE mechanism.
+
+#include "ATen/ATen.h"
+
+namespace at {
+namespace cpp_custom_type_hack {
+
+template<typename T>
+T& cast(const Tensor& packed) {
+  AT_CHECK(
+      packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
+  AT_CHECK(
+      packed.storage().data_ptr().get_deleter() ==
+          caffe2::TypeMeta::Make<T>().deleteFn(),
+      "Expected temporary cpp type wrapper of type ",
+      caffe2::TypeMeta::TypeName<T>());
+  return *reinterpret_cast<T*>(packed.storage().data_ptr().get());
+}
+
+template<typename T>
+Tensor create(std::unique_ptr<T> ptr) {
+  // We store this instance away in a Tensor and register a deleter function
+  // so that we do not leak memory. On the other side, we pull out the storage's
+  // data_ptr and get the right typed pointer.
+  void* raw_ptr = ptr.release();
+  at::DataPtr at_ptr(
+      raw_ptr,
+      raw_ptr,
+      caffe2::TypeMeta::Make<T>().deleteFn(),
+      at::kCPU);
+
+  // size doesn't really matter, but we can align it to the actual size
+  // returning variables because one likely want to use this hack from python
+  auto retval = at::empty(
+      {sizeof(T)},
+      at::device(kCPU).dtype(at::kByte).is_variable(true).requires_grad(false));
+  retval.storage().set_data_ptr(std::move(at_ptr));
+  return retval;
+}
+}
+}
index a9c1b3d..58f64de 100644 (file)
@@ -1,6 +1,7 @@
 #include "ATen/ATen.h"
 #include "ATen/NativeFunctions.h"
 #include "ATen/WrapDimUtilsMulti.h"
+#include "ATen/cpp_custom_type_hack.h"
 
 #ifdef USE_FBGEMM
 #include "fbgemm/Fbgemm.h"
 #include <vector>
 
 #include <chrono>
+
+namespace caffe2 {
+#ifdef USE_FBGEMM
+// Required for cpp_custom_type_hack to work
+CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
+#endif // USE_FBGEMM
+}
+
 namespace at {
 namespace native {
 
@@ -127,13 +136,12 @@ Tensor fbgemm_linear_int8_weight(
   auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
 
   // Pull out the PackBMatrix instance from the owning tensor
-  auto* packB = reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(
-      packed.storage().data_ptr().get());
+  auto& packB = cpp_custom_type_hack::cast<fbgemm::PackBMatrix<int8_t>>(packed);
 
   // Do the GEMM
   fbgemm::fbgemmPacked(
       /*packA=*/packA,
-      /*packB=*/*packB,
+      /*packB=*/packB,
       /*C=*/output.data<float>(),
       /*C_buffer=*/buffer.data<int32_t>(),
       /*ldc=*/N,
@@ -233,7 +241,7 @@ Tensor fbgemm_pack_quantized_matrix(
   AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
   auto weight_contig = weight.contiguous();
   auto contiguous_ptr = weight_contig.data<int8_t>();
-  auto* ptr = new fbgemm::PackBMatrix<int8_t>(
+  auto ptr = std::make_unique<fbgemm::PackBMatrix<int8_t>>(
       /*trans=*/fbgemm::matrix_op_t::Transpose,
       /*nRow=*/K,
       /*nCol=*/N,
@@ -241,26 +249,7 @@ Tensor fbgemm_pack_quantized_matrix(
       /*ld=*/K,
       /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
       /*groups=*/1);
-
-  // We store this instance away in a Tensor and register a deleter function
-  // so that we do not leak memory. On the other side, we pull out the storage's
-  // data_ptr and get the PackBMatrix's pointer.
-  at::DataPtr at_ptr(
-      ptr,
-      ptr,
-      [](void* ptr) {
-        fbgemm::PackBMatrix<int8_t>* typed_ptr =
-            reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(ptr);
-        delete typed_ptr;
-      },
-      at::kCPU);
-
-  auto retval = at::empty(
-      {sizeof(fbgemm::PackBMatrix<int8_t>)}, weight.options().dtype(at::kByte));
-
-  retval.storage().set_data_ptr(std::move(at_ptr));
-
-  return retval;
+  return cpp_custom_type_hack::create(std::move(ptr));
 }
 
 #else // USE_FBGEMM