--- /dev/null
+// WARNING! WARNING! WARNING!
+// This file is a temporary hack to enable development of pytorch quantization
+//
+// It's a stub for wrapping arbitrary cpp types in TorchScript. Proper
+// implementation (under development) is to use TorchScript custom types.
+// In the meantime, we abuse ByteTensor with custom deleter for this purpose.
+//
+// Template argument <T> has to be registered with CAFFE_KNOWN_TYPE mechanism.
+
+#include "ATen/ATen.h"
+
+namespace at {
+namespace cpp_custom_type_hack {
+
+template<typename T>
+T& cast(const Tensor& packed) {
+ AT_CHECK(
+ packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
+ AT_CHECK(
+ packed.storage().data_ptr().get_deleter() ==
+ caffe2::TypeMeta::Make<T>().deleteFn(),
+ "Expected temporary cpp type wrapper of type ",
+ caffe2::TypeMeta::TypeName<T>());
+ return *reinterpret_cast<T*>(packed.storage().data_ptr().get());
+}
+
+template<typename T>
+Tensor create(std::unique_ptr<T> ptr) {
+ // We store this instance away in a Tensor and register a deleter function
+ // so that we do not leak memory. On the other side, we pull out the storage's
+ // data_ptr and get the right typed pointer.
+ void* raw_ptr = ptr.release();
+ at::DataPtr at_ptr(
+ raw_ptr,
+ raw_ptr,
+ caffe2::TypeMeta::Make<T>().deleteFn(),
+ at::kCPU);
+
+ // size doesn't really matter, but we can align it to the actual size
+ // returning variables because one likely want to use this hack from python
+ auto retval = at::empty(
+ {sizeof(T)},
+ at::device(kCPU).dtype(at::kByte).is_variable(true).requires_grad(false));
+ retval.storage().set_data_ptr(std::move(at_ptr));
+ return retval;
+}
+}
+}
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtilsMulti.h"
+#include "ATen/cpp_custom_type_hack.h"
#ifdef USE_FBGEMM
#include "fbgemm/Fbgemm.h"
#include <vector>
#include <chrono>
+
+namespace caffe2 {
+#ifdef USE_FBGEMM
+// Required for cpp_custom_type_hack to work
+CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
+#endif // USE_FBGEMM
+}
+
namespace at {
namespace native {
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
// Pull out the PackBMatrix instance from the owning tensor
- auto* packB = reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(
- packed.storage().data_ptr().get());
+ auto& packB = cpp_custom_type_hack::cast<fbgemm::PackBMatrix<int8_t>>(packed);
// Do the GEMM
fbgemm::fbgemmPacked(
/*packA=*/packA,
- /*packB=*/*packB,
+ /*packB=*/packB,
/*C=*/output.data<float>(),
/*C_buffer=*/buffer.data<int32_t>(),
/*ldc=*/N,
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
auto weight_contig = weight.contiguous();
auto contiguous_ptr = weight_contig.data<int8_t>();
- auto* ptr = new fbgemm::PackBMatrix<int8_t>(
+ auto ptr = std::make_unique<fbgemm::PackBMatrix<int8_t>>(
/*trans=*/fbgemm::matrix_op_t::Transpose,
/*nRow=*/K,
/*nCol=*/N,
/*ld=*/K,
/*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
/*groups=*/1);
-
- // We store this instance away in a Tensor and register a deleter function
- // so that we do not leak memory. On the other side, we pull out the storage's
- // data_ptr and get the PackBMatrix's pointer.
- at::DataPtr at_ptr(
- ptr,
- ptr,
- [](void* ptr) {
- fbgemm::PackBMatrix<int8_t>* typed_ptr =
- reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(ptr);
- delete typed_ptr;
- },
- at::kCPU);
-
- auto retval = at::empty(
- {sizeof(fbgemm::PackBMatrix<int8_t>)}, weight.options().dtype(at::kByte));
-
- retval.storage().set_data_ptr(std::move(at_ptr));
-
- return retval;
+ return cpp_custom_type_hack::create(std::move(ptr));
}
#else // USE_FBGEMM