Direct FBGEMM integraton into ATen (#13777)
authorJames Reed <jamesreed@fb.com>
Fri, 21 Dec 2018 18:32:57 +0000 (10:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 21 Dec 2018 18:35:51 +0000 (10:35 -0800)
Summary:
This PR implements infrastructure for post-processing a model to apply int8 quantization to its `nn.Linear` modules. Highlights of the implementation:

1) Inputs and outputs are `float` (quantized and packed internally), but the weight is quantized and packed ahead of time for efficiency. This implementation performs well in small-batch size GEMM calls. It should not be considered a general-purpose quantized GEMM kernel.
2) Weight packing is dependent on machine architecture (e.g. vector register width), so it is done just-in-time. Concretely, it is done on model load for the weights and it is done during operator execution for the input value.
3) Biases are unquantized
4) We fail loudly if we are attempting to run this on a machine that does not support FBGEMM. This is because we do not want a model's numerics to differ based on which machine it is run on. A model containing these FBGEMM ops *must* be run with FBGEMM

The API can be seen in the added test case. Highlights are:
1) `torch.jit.quantized.quantize_linear_modules` walks the module hierarchy of the passed-in Module and replaces all `nn.Linear` modules with a new `QuantizedLinear` module, which encapsulates the behavior described above.
2) `_pack()` and `_unpack()` script methods are present on `QuantizedLinear` modules. These methods should be called before serialization and after deserialization, respectively. This ensures that the weight matrix is properly packed for the running machine's architecture. Note that in the long term, we would like to move toward a more Pickle-style serialization technique, rather than having these explicit methods that mutate member values. This is blocked on being able to assign attributes in a ScriptMethod, among other things.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13777

Differential Revision: D13383276

Pulled By: jamesr66a

fbshipit-source-id: 00f29c9f34544add2b90107e3cf55a287802c344

CMakeLists.txt
aten/src/ATen/native/QuantizedLinear.cpp [new file with mode: 0644]
aten/src/ATen/native/native_functions.yaml
cmake/Dependencies.cmake
test/test_jit.py
tools/autograd/gen_python_functions.py
torch/csrc/autograd/utils/wrap_outputs.h
torch/jit/quantized.py [new file with mode: 0644]
ubsan.supp

index a29cd82..92201b6 100644 (file)
@@ -198,6 +198,10 @@ include(ExternalProject)
 # ---[ Dependencies
 include(cmake/Dependencies.cmake)
 
+if(USE_FBGEMM)
+  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_FBGEMM")
+endif()
+
 # ---[ Whitelist file if whitelist is specified
 include(cmake/Whitelist.cmake)
 
diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp
new file mode 100644 (file)
index 0000000..64cd4b0
--- /dev/null
@@ -0,0 +1,308 @@
+#include "ATen/ATen.h"
+#include "ATen/NativeFunctions.h"
+#include "ATen/WrapDimUtilsMulti.h"
+
+#ifdef USE_FBGEMM
+#include "fbgemm/Fbgemm.h"
+#include "fbgemm/QuantUtils.h"
+#endif // USE_FBGEMM
+
+#include <array>
+#include <cctype>
+#include <cmath>
+#include <cstddef>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <chrono>
+namespace at {
+namespace native {
+
+#ifdef USE_FBGEMM
+
+Tensor fbgemm_linear_int8_weight(
+    const Tensor& input,
+    const Tensor& weight,
+    const Tensor& packed,
+    const Tensor& col_offsets,
+    Scalar weight_scale,
+    Scalar weight_zero_point,
+    const Tensor& bias) {
+  // We make a strong guarantee that models using these operators will have the
+  // same numerics across different machines. Therefore, we do not provide a
+  // fallback path and rather fail loudly if we cannot run FBGEMM.
+  AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
+
+  // We call contiguous on `input` and `weight` here because these APIs all
+  // expect row-major tensor buffers.
+  auto* input_ptr = input.contiguous().data<float>();
+  auto* weight_ptr = weight.contiguous().data<int8_t>();
+
+  AT_ASSERT(input.dim() >= 2);
+  int64_t M = 1;
+  for (size_t i = 0; i < input.dim() - 1; ++i) {
+    M *= input.size(i);
+  }
+  int64_t K = input.size(input.dim() - 1);
+  AT_ASSERT(weight.dim() == 2);
+  AT_ASSERT(K == weight.size(1));
+  auto N = weight.size(0);
+  AT_ASSERT(bias.dim() == 1);
+  AT_ASSERT(bias.size(0) == N);
+  AT_ASSERT(weight_scale.isFloatingPoint());
+  AT_ASSERT(weight_zero_point.isIntegral());
+
+  // Calculate statistics for quantization of the input Tensor
+  float x_min, x_max;
+  fbgemm::FindMinMax(
+      /*m=*/input_ptr,
+      /*min=*/&x_min,
+      /*max=*/&x_max,
+      /*len=*/input.numel());
+
+  // Input tensor is quantized as 8-bit unsigned values
+  static constexpr int precision = 8;
+  static constexpr bool is_signed = false;
+
+  // Calculate scale and zero point for quantization of input tensor
+  auto q_params = fbgemm::ChooseQuantizationParams(
+      /*min=*/x_min,
+      /*max=*/x_max,
+      /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
+      /*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
+      /*preserve_sparsity=*/false);
+
+  q_params.precision = precision;
+
+  // This operation does the following:
+  // 1) Quantizes the input matrix given the statistics we've calculated above
+  // 2) Creates a "row buffer" vector with offset values that must be added
+  //    to the integer matrix multiplication operation to ensure correctness
+  // 3) Packs the resulting quantized matrix into vector-register and cache
+  //    friendly tiles.
+  //
+  //  Note this is not executed eagerly, but rather within the fbgemmPacked call
+  //  below.
+  fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
+      /*trans=*/fbgemm::matrix_op_t::NoTranspose,
+      /*nRow=*/M,
+      /*nCol=*/K,
+      /*smat=*/input_ptr,
+      /*ld=*/K,
+      /*pmat=*/nullptr, // packA manages ownership of `pmat`
+      /*scale=*/q_params.scale,
+      /*zero_pt=*/q_params.zero_point);
+
+  // ReQuantizeForFloat requires pointers to the scale and zero point values,
+  // since in the case of rowwise quantization these will be arrays rather than
+  // scalars. But in this case, we're doing whole-tensor quantization so we just
+  // pass a pointer to the scale values (and internally ReQuantizeFor Float
+  // won't index past 0
+  float weight_scale_float = static_cast<float>(weight_scale.to<double>());
+  int32_t weight_zero_point_int32 =
+      static_cast<int32_t>(weight_zero_point.to<int64_t>());
+
+  // This is the end of the pipeline, pass the resulting matrix through
+  fbgemm::DoNothing<float, float> doNothingObj{};
+
+  // After the uint8 * int8 matrix multiplication is performed, this operation
+  // does:
+  //  1) Add in row and column offsets to the rows and columns, respectively
+  //  2) Dequantize the results into floating point
+  //  3) Add in the bias term
+  fbgemm::ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
+      /*nextop=*/doNothingObj,
+      /*Aq_scale=*/q_params.scale,
+      /*Bq_scale=*/&weight_scale_float,
+      /*Aq_zero_point=*/q_params.zero_point,
+      /*Bq_zero_point=*/&weight_zero_point_int32,
+      /*row_offsets=*/packA.getRowOffsetBuffer(),
+      /*col_offsets=*/col_offsets.data<int32_t>(),
+      /*bias=*/bias.contiguous().data<float>(),
+      /*ncol=*/N);
+
+  // Allocate output Tensor and a buffer for fbgemmPacked to use
+  auto output = at::zeros_like(bias).to(at::kFloat).expand({M, N}).contiguous();
+  auto buffer = at::zeros_like(output).to(at::kInt).contiguous();
+
+  // Pull out the PackBMatrix instance from the owning tensor
+  auto* packB = reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(
+      packed.storage().data_ptr().get());
+
+  // Do the GEMM
+  fbgemm::fbgemmPacked(
+      /*packA=*/packA,
+      /*packB=*/*packB,
+      /*C=*/output.data<float>(),
+      /*C_buffer=*/buffer.data<int32_t>(),
+      /*ldc=*/N,
+      /*outProcess=*/outputProcObj,
+      /*thread_id=*/0,
+      /*num_threads=*/1);
+
+  // The resulting matrix here is 2-D, let's view it with the original
+  // left hand dimensions of the input.
+  std::vector<int64_t> out_sizes = input.sizes().vec();
+  out_sizes.back() = N;
+  return output.view(out_sizes);
+}
+
+namespace {
+// Calculate the column offsets
+// Note this includes the sum of the columns as well as the scalar term
+// B_zero_point * K, whereas the row_offsets created by PackAWithQuantRowOffset
+// is only the sum of the A rows.
+void calc_col_offsets_transpose(
+    int K,
+    int N,
+    const int8_t* Bint8,
+    int32_t B_zero_point,
+    int32_t* col_offsets) {
+  for (size_t i = 0; i < N; ++i) {
+    int32_t sum = 0;
+    for (size_t j = 0; j < K; ++j) {
+      sum += Bint8[i * K + j];
+    }
+    col_offsets[i] = sum - B_zero_point * K;
+  }
+}
+} // namespace
+
+std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
+    const Tensor& weight) {
+  // We make a strong guarantee that models using these operators will have the
+  // same numerics across different machines. Therefore, we do not provide a
+  // fallback path and rather fail loudly if we cannot run FBGEMM.
+  AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
+  auto weight_contig = weight.contiguous();
+
+  // Calculate weight statistics
+  float w_min, w_max;
+  fbgemm::FindMinMax(
+      /*m=*/weight_contig.data<float>(),
+      /*min=*/&w_min,
+      /*max=*/&w_max,
+      /*len=*/weight_contig.numel());
+
+  // Choose parameters for quantizing the weight as 8-bit signed integer
+  static constexpr bool is_signed = true;
+  static constexpr int precision = 8;
+  auto q_params = fbgemm::ChooseQuantizationParams(
+      /*min=*/w_min,
+      /*max=*/w_max,
+      /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
+      /*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
+      /*preserve_sparsity=*/false);
+
+  q_params.precision = precision;
+
+  auto quantized = at::zeros_like(weight_contig).to(at::kChar).contiguous();
+  fbgemm::Quantize<int8_t>(
+      /*src=*/weight_contig.data<float>(),
+      /*dst=*/quantized.data<int8_t>(),
+      /*len=*/weight_contig.numel(),
+      /*qparams=*/q_params);
+
+  // Calculate column offsets of the weight and store them away in a tensor.
+  // Similarly to quantization, this can be done once and cached.
+  auto col_offsets =
+      at::zeros_like(quantized).sum({1}).to(at::kInt).contiguous();
+  calc_col_offsets_transpose(
+      /*K=*/quantized.size(1),
+      /*N=*/quantized.size(0),
+      /*Bint8=*/quantized.data<int8_t>(),
+      /*B_zero_point=*/q_params.zero_point,
+      /*col_offsets=*/col_offsets.data<int32_t>());
+
+  return std::make_tuple(
+      quantized, col_offsets, q_params.scale, q_params.zero_point);
+}
+
+bool fbgemm_is_cpu_supported() {
+  return fbgemm::fbgemmSupportedCPU();
+}
+
+Tensor fbgemm_pack_quantized_matrix(
+    const Tensor& weight,
+    int64_t K,
+    int64_t N) {
+  // We make a strong guarantee that models using these operators will have the
+  // same numerics across different machines. Therefore, we do not provide a
+  // fallback path and rather fail loudly if we cannot run FBGEMM.
+  AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
+  auto contiguous_ptr = weight.contiguous().data<int8_t>();
+  auto* ptr = new fbgemm::PackBMatrix<int8_t>(
+      /*trans=*/fbgemm::matrix_op_t::Transpose,
+      /*nRow=*/K,
+      /*nCol=*/N,
+      /*smat=*/contiguous_ptr,
+      /*ld=*/K,
+      /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
+      /*groups=*/1);
+
+  // We store this instance away in a Tensor and register a deleter function
+  // so that we do not leak memory. On the other side, we pull out the storage's
+  // data_ptr and get the PackBMatrix's pointer.
+  at::DataPtr at_ptr(
+      ptr,
+      ptr,
+      [](void* ptr) {
+        fbgemm::PackBMatrix<int8_t>* typed_ptr =
+            reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(ptr);
+        delete typed_ptr;
+      },
+      at::kCPU);
+
+  auto retval = at::empty(
+      {sizeof(fbgemm::PackBMatrix<int8_t>)}, weight.options().dtype(at::kByte));
+
+  retval.storage().set_data_ptr(std::move(at_ptr));
+
+  return retval;
+}
+
+#else // USE_FBGEMM
+
+Tensor fbgemm_linear_int8_weight(
+    const Tensor& /*input*/,
+    const Tensor& /*weight*/,
+    const Tensor& /*packed*/,
+    const Tensor& /*col_offsets*/,
+    Scalar /*weight_scale*/,
+    Scalar /*weight_zero_point*/,
+    const Tensor& /*bias*/) {
+  // We make a strong guarantee that models using these operators will have the
+  // same numerics across different machines. Therefore, we do not provide a
+  // fallback path and rather fail loudly if we cannot run FBGEMM.
+  AT_ASSERTM(
+      false, "This PyTorch installation was not built with FBGEMM operators");
+}
+
+std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
+    const Tensor& /*weight*/) {
+  // We make a strong guarantee that models using these operators will have the
+  // same numerics across different machines. Therefore, we do not provide a
+  // fallback path and rather fail loudly if we cannot run FBGEMM.
+  AT_ASSERTM(
+      false, "This PyTorch installation was not built with FBGEMM operators");
+}
+
+Tensor fbgemm_pack_quantized_matrix(
+    const Tensor& /*input*/,
+    int64_t /*K*/,
+    int64_t /*N*/) {
+  // We make a strong guarantee that models using these operators will have the
+  // same numerics across different machines. Therefore, we do not provide a
+  // fallback path and rather fail loudly if we cannot run FBGEMM.
+  AT_ASSERTM(
+      false, "This PyTorch installation was not built with FBGEMM operators");
+}
+
+bool fbgemm_is_cpu_supported() {
+  return false;
+}
+
+#endif // USE_FBGEMM
+}
+} // namespace at
index 2028132..400497a 100644 (file)
 
 - func: linear(Tensor input, Tensor weight, Tensor? bias={}) -> Tensor
 
+- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
+
+- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, double, int64_t)
+
+- func: fbgemm_pack_quantized_matrix(Tensor input, int64_t K, int64_t N) -> Tensor
+
+- func: fbgemm_is_cpu_supported() -> bool
+
 - func: linspace(Scalar start, Scalar end, int64_t steps=100, TensorOptions options={}) -> Tensor
 
 - func: linspace_out(Tensor result, Scalar start, Scalar end, int64_t steps=100) -> Tensor
index 001835c..697bddf 100644 (file)
@@ -350,6 +350,9 @@ endif()
 if(USE_FBGEMM)
   set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
   include_directories(SYSTEM "${CAFFE2_THIRD_PARTY_ROOT}")
+  caffe2_update_option(USE_FBGEMM ON)
+else()
+  caffe2_update_option(USE_FBGEMM OFF)
 endif()
 
 
index 07f3b8f..7dab0b5 100644 (file)
@@ -3,6 +3,7 @@ import torch
 import torch.jit
 import torch.nn as nn
 import torch.nn.functional as F
+import torch.jit.quantized
 from contextlib import contextmanager
 from itertools import product, chain
 import torch.jit.frontend
@@ -8006,6 +8007,43 @@ a")
 
             traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
 
+    # These tests don't work because UBSAN has a false positive about accessing
+    # out of bounds on a dynamically sized struct internal to asmjit
+    if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
+        def test_int8_quantization_module(self):
+            K1, N1 = 2, 2
+
+            class FooBar(torch.nn.Module):
+                def __init__(self):
+                    super(FooBar, self).__init__()
+                    self.linear1 = torch.nn.Linear(K1, N1).float()
+
+                def forward(self, x):
+                    x = self.linear1(x)
+                    return x
+
+            fb = FooBar()
+            fb.linear1.weight = torch.nn.Parameter(
+                torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False)
+            fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False)
+            fb_ref = FooBar()
+            fb_ref.linear1.weight = torch.nn.Parameter(fb.linear1.weight.clone(), requires_grad=False)
+            fb_ref.linear1.bias = torch.nn.Parameter(fb.linear1.bias.clone(), requires_grad=False)
+            torch.jit.quantized.quantize_linear_modules(fb)
+
+            x = (torch.rand(1, K1).float() - 0.5) / 10.0
+            traced = torch.jit.trace(fb, (x,))
+            traced.apply(lambda s: s._pack() if s._has_method('_pack') else None)
+            fb = self.getExportImportCopy(traced)
+            traced.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+
+            fb.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+
+            x = torch.tensor([[100, -150]], dtype=torch.float)
+            y = fb(x)
+            y_ref = fb_ref(x)
+            torch.testing.assert_allclose(y, y_ref, rtol=0.0001, atol=1e-3)
+
     def checkTracerWarning(self, *args, **kwargs):
         with warnings.catch_warnings(record=True) as warns:
             torch.jit.trace(*args, **kwargs)
@@ -9112,7 +9150,7 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
         self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
 
     @staticmethod
-    def _test_snli(self, device, check_export_import=True):
+    def _test_snli(self, device, check_export_import=True, quantized=False):
         class Bottle(nn.Module):
 
             def forward(self, input):
@@ -9199,13 +9237,26 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
         premise = torch.LongTensor(48, 128).random_(0, 100).to(device)
         hypothesis = torch.LongTensor(24, 128).random_(0, 100).to(device)
 
-        self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
-                        inputs_require_grads=False, export_import=check_export_import)
+        if quantized:
+            snli = SNLIClassifier(Config()).cpu()
+            torch.jit.quantized.quantize_linear_modules(snli)
+            # we don't do export/import checks because we would need to call
+            # _pack/_unpack
+            self.checkTrace(snli, (premise, hypothesis), inputs_require_grads=False,
+                            export_import=False)
+        else:
+            self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
+                            inputs_require_grads=False, export_import=check_export_import)
 
     @skipIfRocm
     def test_snli(self):
         self._test_snli(self, device='cpu')
 
+    if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
+        @skipIfRocm
+        def test_snli_quantized(self):
+            self._test_snli(self, device='cpu', quantized=True)
+
     @skipIfRocm
     @unittest.skipIf(not RUN_CUDA, "no CUDA")
     def test_snli_cuda(self):
@@ -9308,7 +9359,7 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
                         export_import=False)
 
     @staticmethod
-    def _test_vae(self, device, check_export_import=True):
+    def _test_vae(self, device, check_export_import=True, quantized=False):
         class VAE(nn.Module):
             def __init__(self):
                 super(VAE, self).__init__()
@@ -9340,13 +9391,26 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
                 z = self.reparameterize(mu, logvar)
                 return self.decode(z), mu, logvar
 
-        # eval() is present because randn_like makes this nondeterministic
-        self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
-                        export_import=check_export_import)
+        if quantized:
+            vae = VAE().to(device).eval()
+            torch.jit.quantized.quantize_linear_modules(vae)
+            # We don't do export/import checks because we would need to call
+            # _unpack and _pack
+            self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device),),
+                            export_import=False, allow_unused=True,
+                            inputs_require_grads=False)
+        else:
+            # eval() is present because randn_like makes this nondeterministic
+            self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
+                            export_import=check_export_import)
 
     def test_vae(self):
         self._test_vae(self, device='cpu')
 
+    if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
+        def test_vae_quantized(self):
+            self._test_vae(self, device='cpu', quantized=True)
+
     @unittest.skipIf(not RUN_CUDA, "no CUDA")
     def test_vae_cuda(self):
         # XXX: export_import on CUDA modules doesn't work (#11480)
index 2c9a9e7..88e2037 100644 (file)
@@ -123,6 +123,7 @@ ${name}(${py_formal_args})""")
 # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
 SUPPORTED_RETURN_TYPES = {
     'Tensor', 'std::tuple<Tensor,Tensor>',
+    'std::tuple<Tensor,Tensor,double,int64_t>',
     'std::tuple<Tensor,Tensor,Tensor>',
     'std::tuple<Tensor,Tensor,Tensor,Tensor>',
     'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
index 1c89d60..cc77775 100644 (file)
@@ -95,6 +95,16 @@ inline PyObject* wrap(at::Scalar scalar) {
   return wrap(scalar_to_tensor(scalar));
 }
 
+inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, float, int64_t> tensors) {
+  auto r = THPObjectPtr{PyTuple_New(4)};
+  if (!r) throw python_error();
+  PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors))));
+  PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors))));
+  PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors))));
+  PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors))));
+  return r.release();
+}
+
 inline PyObject* wrap(THPDtype *dtype) {
   Py_INCREF(dtype);
   return (PyObject*)dtype;
diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py
new file mode 100644 (file)
index 0000000..4eb3a91
--- /dev/null
@@ -0,0 +1,54 @@
+import torch
+import copy
+
+
+class QuantizedLinear(torch.jit.ScriptModule):
+    __constants__ = ['scale', 'zero_point']
+
+    def __init__(self, other):
+        super(QuantizedLinear, self).__init__()
+        self.in_features = other.in_features
+        self.out_features = other.out_features
+        # Quantize weight and discard the original
+        self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight(
+            other.weight.clone().float())
+        self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
+        self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
+        assert other.bias is not None, 'QuantizedLinear requires a bias'
+        self.bias = torch.nn.Parameter(other.bias.clone().float())
+
+        self.register_buffer(
+            'packed_tensor_ptr',
+            torch.fbgemm_pack_quantized_matrix(self.weight.clone(), self.weight.size(1), self.weight.size(0)))
+
+    @torch.jit.script_method
+    def _unpack(self):
+        self.packed_tensor_ptr.set_(
+            torch.fbgemm_pack_quantized_matrix(
+                self.weight, self.weight.size(1), self.weight.size(0)))
+
+    @torch.jit.script_method
+    def _pack(self):
+        self.packed_tensor_ptr.set_(
+            torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
+
+    @torch.jit.script_method
+    def forward(self, input):
+        out = torch.fbgemm_linear_int8_weight(
+            input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets,
+            self.scale, self.zero_point, self.bias)
+        return out.type_as(input)
+
+    def extra_repr(self):
+        repr = 'in_features={in_features}, out_features={out_features}, ' \
+               'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
+        return repr
+
+
+def quantize_linear_modules(module):
+    for name, mod in module.named_modules():
+        if mod is module:
+            continue
+        if isinstance(mod, torch.nn.Linear):
+            setattr(module, name, QuantizedLinear(mod))
+        quantize_linear_modules(mod)
index 68583a3..f1579d3 100644 (file)
@@ -1 +1,2 @@
 vptr:libtorch.so
+bounds:asmjit::Zone::_alloc