QTensor (#18230)
authorJerry Zhang <jerryzh@fb.com>
Wed, 3 Apr 2019 20:13:26 +0000 (13:13 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 20:17:11 +0000 (13:17 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18230

Implementing minimum qtensor API to unblock other workstreams in quantization

Changes:
- Added Quantizer which represents different quantization schemes
- Added qint8 as a data type for QTensor
- Added a new ScalarType QInt8
- Added QTensorImpl for QTensor
- Added following user facing APIs
  - quantize_linear(scale, zero_point)
  - dequantize()
  - q_scale()
  - q_zero_point()

Reviewed By: dzhulgakov

Differential Revision: D14524641

fbshipit-source-id: c1c0ae0978fb500d47cdb23fb15b747773429e6c

37 files changed:
aten/src/ATen/CMakeLists.txt
aten/src/ATen/DLConvertor.cpp
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/function_wrapper.py
aten/src/ATen/gen.py
aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/quantized/QTensor.cpp [new file with mode: 0644]
aten/src/ATen/native_parse.py
aten/src/ATen/preprocess_declarations.py
aten/src/ATen/quantized/CMakeLists.txt [new file with mode: 0644]
aten/src/ATen/quantized/QTensorImpl.cpp [new file with mode: 0644]
aten/src/ATen/quantized/QTensorImpl.h [new file with mode: 0644]
aten/src/ATen/quantized/Quantizer.cpp [new file with mode: 0644]
aten/src/ATen/quantized/Quantizer.h [new file with mode: 0644]
aten/src/ATen/templates/TensorMethods.h
aten/src/ATen/templates/Type.h
aten/src/ATen/test/CMakeLists.txt
aten/src/ATen/test/quantized_test.cpp [new file with mode: 0644]
aten/src/TH/THBlasUtils.h
c10/core/QScheme.h [new file with mode: 0644]
c10/core/Scalar.h
c10/core/ScalarType.h
c10/util/Half.cpp
c10/util/qint8.h [new file with mode: 0644]
c10/util/typeid.cpp
c10/util/typeid.h
caffe2/contrib/aten/aten_op_template.h
caffe2/operators/experimental/c10/cpu/cast_cpu.cc
docs/source/tensors.rst
test/test_torch.py
tools/autograd/gen_python_functions.py
tools/autograd/templates/variable_factories.h
torch/_tensor_docs.py
torch/csrc/utils/tensor_dtypes.cpp

index a1316e1..98af97d 100644 (file)
@@ -56,6 +56,7 @@ FILE(GLOB native_cpp "native/*.cpp")
 FILE(GLOB native_mkl_cpp "native/mkl/*.cpp")
 FILE(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
 FILE(GLOB native_sparse_cpp "native/sparse/*.cpp")
+FILE(GLOB native_quantized_cpp "native/quantized/*.cpp")
 
 FILE(GLOB native_cuda_cu "native/cuda/*.cu")
 FILE(GLOB native_cuda_cpp "native/cuda/*.cpp")
@@ -70,7 +71,8 @@ FILE(GLOB native_cudnn_hip_cpp "native/cudnn/hip/*.cpp")
 FILE(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip")
 FILE(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp")
 
-set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp})
+add_subdirectory(quantized)
+set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp})
 if(AT_MKL_ENABLED)
   set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
 endif()
@@ -435,10 +437,12 @@ set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
 set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
 set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} PARENT_SCOPE)
 set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE)
+set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE)
 set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE)
 set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE)
 set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)
 set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE)
+set(ATen_QUANTIZED_TEST_SRCS ${ATen_QUANTIZED_TEST_SRCS} PARENT_SCOPE)
 set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE)
 set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE)
 set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE)
index de0b1b0..e2983eb 100644 (file)
@@ -39,6 +39,9 @@ static DLDataType getDLDataType(const Tensor& t) {
     case ScalarType::Bool:
       dtype.code = DLDataTypeCode::kDLUInt;
       break;
+    case ScalarType::QInt8:
+      throw std::logic_error("QInt8 is not supported by dlpack");
+      break;
     case ScalarType::ComplexHalf:
       throw std::logic_error("ComplexHalf is not supported by dlpack");
     case ScalarType::ComplexFloat:
index add127c..8dbfb81 100644 (file)
@@ -566,6 +566,10 @@ class CAFFE2_API Tensor {
   std::vector<Tensor> unbind(int64_t dim=0) const;
   Tensor to_sparse(int64_t sparse_dim) const;
   Tensor to_sparse() const;
+  Tensor quantize_linear(double scale, int64_t zero_point) const;
+  Tensor dequantize() const;
+  Scalar q_scale() const;
+  Scalar q_zero_point() const;
   Tensor to(const TensorOptions & options, bool non_blocking=false, bool copy=false) const;
   Tensor to(Device device, ScalarType dtype, bool non_blocking=false, bool copy=false) const;
   Tensor to(ScalarType dtype, bool non_blocking=false, bool copy=false) const;
index 1605aff..9e50931 100644 (file)
@@ -769,6 +769,18 @@ inline Tensor Tensor::to_sparse(int64_t sparse_dim) const {
 inline Tensor Tensor::to_sparse() const {
     return type().to_sparse(*this);
 }
+inline Tensor Tensor::quantize_linear(double scale, int64_t zero_point) const {
+    return type().quantize_linear(*this, scale, zero_point);
+}
+inline Tensor Tensor::dequantize() const {
+    return type().dequantize(*this);
+}
+inline Scalar Tensor::q_scale() const {
+    return type().q_scale(*this);
+}
+inline Scalar Tensor::q_zero_point() const {
+    return type().q_zero_point(*this);
+}
 inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool copy) const {
     return type().to(*this, options, non_blocking, copy);
 }
@@ -1347,7 +1359,10 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST)
     return item().to##name();     \
   }
 
-AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ITEM)
+AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_ITEM)
 #undef DEFINE_ITEM
 
+// TODO: after is_quantized() is implemented,
+// implement item() (returnning a float) for quantized Tensor
+
 } //namespace at
index e50af83..9a1d6fc 100644 (file)
@@ -40,6 +40,7 @@ using TensorList = ArrayRef<Tensor>;
 
 class Context;
 struct Generator;
+struct Quantizer;
 
 static inline void noop_deleter(void*) {}
 
@@ -53,6 +54,7 @@ enum class TypeID {
   CPULong,
   CPUShort,
   CPUHalf,
+  CPUQInt8,
   SparseCPUBool,
   SparseCPUByte,
   SparseCPUChar,
@@ -61,6 +63,7 @@ enum class TypeID {
   SparseCPUInt,
   SparseCPULong,
   SparseCPUShort,
+  SparseCPUQInt8,
   CUDABool,
   CUDAByte,
   CUDAChar,
@@ -70,6 +73,7 @@ enum class TypeID {
   CUDALong,
   CUDAShort,
   CUDAHalf,
+  CUDAQInt8,
   SparseCUDABool,
   SparseCUDAByte,
   SparseCUDAChar,
@@ -78,6 +82,7 @@ enum class TypeID {
   SparseCUDAInt,
   SparseCUDALong,
   SparseCUDAShort,
+  SparseCUDAQInt8,
   MSNPUBool,
   MSNPUByte,
   MSNPUChar,
@@ -87,6 +92,7 @@ enum class TypeID {
   MSNPULong,
   MSNPUShort,
   MSNPUHalf,
+  MSNPUQInt8,
   XLABool,
   XLAByte,
   XLAChar,
@@ -96,6 +102,7 @@ enum class TypeID {
   XLALong,
   XLAShort,
   XLAHalf,
+  XLAQInt8,
   CPUComplexFloat,
   CPUComplexDouble,
   CUDAComplexFloat,
@@ -444,6 +451,10 @@ struct CAFFE2_API Type {
   virtual std::vector<Tensor> unbind(const Tensor & self, int64_t dim) const = 0;
   virtual Tensor to_sparse(const Tensor & self, int64_t sparse_dim) const = 0;
   virtual Tensor to_sparse(const Tensor & self) const = 0;
+  virtual Tensor quantize_linear(const Tensor & self, double scale, int64_t zero_point) const = 0;
+  virtual Tensor dequantize(const Tensor & self) const = 0;
+  virtual Scalar q_scale(const Tensor & self) const = 0;
+  virtual Scalar q_zero_point(const Tensor & self) const = 0;
   virtual Tensor to(const Tensor & self, const TensorOptions & options, bool non_blocking, bool copy) const = 0;
   virtual Tensor to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy) const = 0;
   virtual Tensor to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy) const = 0;
index 3694803..f50537b 100644 (file)
@@ -1564,6 +1564,9 @@ def create_derived(backend_type_env, declarations):
         # type: (FunctionOption) -> None
         pair = (backend_type_env['Backend'],
                 backend_type_env['ScalarName'])
+        # Skip generating TH code for QInt8
+        if pair[1] == 'QInt8':
+            return
         if pair in option['backend_type_pairs']:
             env = nested_dict(option, backend_type_env)
             body = emit_body(env, option)  # type: ignore
index 356e43f..9cb866b 100644 (file)
@@ -189,6 +189,7 @@ scalar_types = [
     ('Long', 'int64_t', 'Long', False),
     ('Short', 'int16_t', 'Long', False),
     ('Half', 'Half', 'Double', True),
+    ('QInt8', 'qint8', 'Long', False),
 ]
 
 # shared environment for non-derived base classes Type.h Tensor.h Storage.h
index 1cdcb42..3fc0ee2 100644 (file)
@@ -148,7 +148,7 @@ Tensor& empty_out(Tensor& result, IntArrayRef size) {
     return target_type.copy(self, non_blocking);                 \
   }
 
-AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CAST_OP)
+AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(DEFINE_CAST_OP)
 
 #undef DEFINE_CAST_OP
 
@@ -732,7 +732,7 @@ Tensor tensor_cuda(ArrayRef<T> values, const TensorOptions& options) {
       return tensor_cpu(values, options);                           \
     }                                                               \
   }
-AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(TENSOR)
+AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(TENSOR)
 #undef TENSOR
 } // namespace native
 } // namespace at
index e1cd053..a72d9c4 100644 (file)
     CPU: dense_to_sparse
     CUDA: dense_to_sparse
 
+- func: quantize_linear(Tensor self, float scale, int zero_point) -> Tensor
+  matches_jit_signature: True
+  variants: function, method
+  requires_tensor: True
+
+- func: dequantize(Tensor self) -> Tensor
+  matches_jit_signature: True
+  variants: function, method
+  requires_tensor: True
+
+- func: q_scale(Tensor self) -> Scalar
+  matches_jit_signature: True
+  variants: function, method
+  requires_tensor: True
+
+- func: q_zero_point(Tensor self) -> Scalar
+  matches_jit_signature: True
+  variants: function, method
+  requires_tensor: True
+
 # to(Device) must not exist because all constructors of Device also works for
 # TensorOptions. Otherwise, an ambiguity error is thrown.
 # See NOTE [ TensorOptions Constructors ].
diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp
new file mode 100644 (file)
index 0000000..ef8e78b
--- /dev/null
@@ -0,0 +1,36 @@
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/quantized/Quantizer.h>
+#include <ATen/quantized/QTensorImpl.h>
+
+
+namespace at {
+namespace native {
+
+QTensor quantize_linear(const RealTensor& self, double scale, int64_t zero_point) {
+  auto quantizer = make_per_tensor_affine_quantizer(scale, zero_point);
+  return quantizer->quantize(self);
+}
+
+RealTensor dequantize(const QTensor& self) {
+  return get_qtensorimpl(self)->quantizer()->dequantize(self);
+}
+
+Scalar q_scale(const QTensor& self) {
+  auto quantizer = get_qtensorimpl(self)->quantizer();
+  AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);
+  return Scalar(static_cast<PerTensorAffineQuantizer*>(quantizer.get())->scale());
+}
+
+Scalar q_zero_point(const QTensor& self) {
+  auto quantizer = get_qtensorimpl(self)->quantizer();
+  AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);
+  return Scalar(static_cast<PerTensorAffineQuantizer*>(quantizer.get())->zero_point());
+}
+
+Quantizer* quantizer(const QTensor& self) {
+  return get_qtensorimpl(self)->quantizer().get();
+}
+
+} // namespace native
+} // namespace at
index 6217769..fa946b7 100644 (file)
@@ -343,7 +343,6 @@ def has_sparse_dispatches(dispatches):
             return True
     return False
 
-
 def parse_native_yaml(path):
     with open(path, 'r') as f:
         return yaml.load(f, Loader=Loader)
index 59dcf6b..5dac87d 100644 (file)
@@ -15,7 +15,8 @@ type_map = {
         'Short',
         'Int',
         'Long',
-        'Bool'
+        'Bool',
+        'QInt8',
     ],
 }
 
diff --git a/aten/src/ATen/quantized/CMakeLists.txt b/aten/src/ATen/quantized/CMakeLists.txt
new file mode 100644 (file)
index 0000000..1620919
--- /dev/null
@@ -0,0 +1,9 @@
+FILE(GLOB_RECURSE ATen_QUANTIZED_HEADERS "*.h")
+FILE(GLOB_RECURSE ATen_QUANTIZED_SRCS "*.cpp")
+FILE(GLOB_RECURSE ATen_QUANTIZED_TEST_SRCS "*_test.cpp")
+EXCLUDE(ATen_QUANTIZED_SRCS "${ATen_QUANTIZED_SRCS}" ${ATen_QUANTIZED_TEST_SRCS})
+
+# Pass to parent
+set(ATen_QUANTIZED_HEADERS ${ATen_QUANTIZED_HEADERS} PARENT_SCOPE)
+set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE)
+set(ATen_QUANTIZED_TEST_SRCS ${ATen_QUANTIZED_TEST_SRCS} PARENT_SCOPE)
diff --git a/aten/src/ATen/quantized/QTensorImpl.cpp b/aten/src/ATen/quantized/QTensorImpl.cpp
new file mode 100644 (file)
index 0000000..83f585e
--- /dev/null
@@ -0,0 +1,13 @@
+#include <ATen/quantized/QTensorImpl.h>
+
+namespace at {
+
+QTensorImpl::QTensorImpl(
+    Storage&& storage,
+    TensorTypeId type_id,
+    bool is_variable,
+    QuantizerPtr quantizer)
+    : TensorImpl(std::move(storage), type_id, is_variable),
+      quantizer_(quantizer) {}
+
+} // namespace at
diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h
new file mode 100644 (file)
index 0000000..9f01246
--- /dev/null
@@ -0,0 +1,38 @@
+#pragma once
+
+#include <ATen/quantized/Quantizer.h>
+#include <c10/core/TensorImpl.h>
+#include <c10/util/Exception.h>
+
+namespace at {
+
+struct CAFFE2_API QTensorImpl : public c10::TensorImpl {
+ public:
+  QTensorImpl(
+      Storage&& storage,
+      TensorTypeId type_id,
+      bool is_variable,
+      QuantizerPtr quantizer);
+
+  // TODO: Expose in PyTorch Frontend
+  QuantizerPtr quantizer() {
+    return quantizer_;
+  }
+
+  c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach() const override {
+    auto impl = c10::make_intrusive<QTensorImpl>(
+        Storage(storage()), type_id(), is_variable(), quantizer_);
+    impl->set_sizes_and_strides(sizes(), strides());
+    impl->storage_offset_ = storage_offset_;
+    impl->is_wrapped_number_ = is_wrapped_number_;
+    impl->reserved_ = reserved_;
+    impl->refresh_numel();
+    impl->refresh_contiguous();
+    return impl;
+  }
+
+ private:
+  QuantizerPtr quantizer_;
+};
+
+} // namespace at
diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp
new file mode 100644 (file)
index 0000000..eabc697
--- /dev/null
@@ -0,0 +1,106 @@
+#include <ATen/quantized/Quantizer.h>
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Type.h>
+#include <ATen/native/TensorFactories.h>
+#include <ATen/quantized/QTensorImpl.h>
+
+namespace at {
+
+QuantizerPtr make_per_tensor_affine_quantizer(
+    double scale,
+    int64_t zero_point) {
+  return c10::make_intrusive<PerTensorAffineQuantizer>(
+      static_cast<float>(scale), static_cast<uint8_t>(zero_point));
+}
+
+QTensorImpl* get_qtensorimpl(const QTensor& self) {
+  // TODO: remove this when Variable and Tensor are merged
+  AT_ASSERTM(
+      !self.is_variable(),
+      "_internal_get_QTensorImpl: should not be a variable");
+  // TODO: uncomment after is_quantized() is implemented
+  // AT_ASSERTM(self.is_quantized(), "_internal_get_QTensorImpl: not a quantized
+  // tensor");
+  return static_cast<QTensorImpl*>(self.unsafeGetTensorImpl());
+}
+
+inline QTensor new_qtensor(
+    IntArrayRef sizes,
+    const TensorOptions& options,
+    bool is_variable,
+    QuantizerPtr quantizer) {
+  AT_ASSERT(options.device().is_cpu());
+
+  native::check_size_nonnegative(sizes);
+  auto* allocator = at::getCPUAllocator();
+  int64_t nelements = at::prod_intlist(sizes);
+  auto dtype = options.dtype();
+  AT_ASSERT(isQIntType(typeMetaToScalarType(dtype)));
+  auto storage_impl = c10::make_intrusive<StorageImpl>(
+      dtype,
+      nelements,
+      allocator->allocate(nelements * dtype.itemsize()),
+      allocator,
+      /*resizable=*/true);
+  auto tensor = detail::make_tensor<QTensorImpl>(
+      storage_impl, at::CPUTensorId(), is_variable, quantizer);
+  get_qtensorimpl(tensor)->set_sizes_contiguous(sizes);
+  return tensor;
+}
+
+qint8 quantize_uint8(float scale, uint8_t zero_point, float value) {
+  const int32_t qmin = std::numeric_limits<uint8_t>::min();
+  const int32_t qmax = std::numeric_limits<uint8_t>::max();
+
+  // std::nearbyint results in nearest integer value according to the current
+  // rounding mode and the default rounding mode is rounds to even in half-way
+  // cases in most popular processor architectures like x86 and ARM. This is
+  // typically faster than an alternatives like std::round that rounds half-way
+  // cases away from zero, and can be consistent with SIMD implementations for
+  // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
+  // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
+  int32_t r = zero_point + static_cast<int32_t>(std::nearbyint(value / scale));
+  r = std::max(r, qmin);
+  r = std::min(r, qmax);
+  return static_cast<qint8>(r);
+}
+
+QTensor PerTensorAffineQuantizer::quantize(RealTensor tensor) {
+  IntArrayRef sizes = tensor.sizes();
+  // Here we need a std::intrusive_ptr<Quantizer>.. but actually "this" is the
+  // quantizer that can be reused, so I'm using intrusive_from_this here
+  AT_CHECK(
+      tensor.options().device() == kCPU,
+      "quantize only works for CPU backend right now.");
+  QTensor qv = new_qtensor(
+      sizes,
+      tensor.options().dtype(at::kQInt8),
+      tensor.is_variable(),
+      intrusive_from_this());
+  auto qvd = qv.data<qint8>();
+  tensor.contiguous();
+  const float* svd = tensor.data<float>();
+  for (int i = 0; i < tensor.numel(); ++i) {
+    qvd[i] = quantize_uint8(scale_, zero_point_, svd[i]);
+  }
+  return qv;
+}
+
+RealTensor PerTensorAffineQuantizer::dequantize(QTensor tensor) {
+  std::vector<int64_t> sizes = tensor.sizes().vec();
+  at::TensorOptions real_options = tensor.options().dtype(at::kFloat);
+
+  RealTensor rv = at::empty(sizes, real_options);
+  tensor.contiguous();
+  const auto* qvd = tensor.data<qint8>();
+  float* rvd = rv.data<float>();
+  for (auto i = 0; i < tensor.numel(); ++i) {
+    rvd[i] = (static_cast<uint32_t>(qvd[i].val_) - zero_point_) * scale_;
+  }
+  return rv;
+}
+
+Quantizer::~Quantizer() {}
+
+} // namespace at
diff --git a/aten/src/ATen/quantized/Quantizer.h b/aten/src/ATen/quantized/Quantizer.h
new file mode 100644 (file)
index 0000000..931d489
--- /dev/null
@@ -0,0 +1,251 @@
+#pragma once
+
+#include <ATen/core/Tensor.h>
+#include <c10/core/QScheme.h>
+#include <c10/macros/Macros.h>
+#include <c10/util/Exception.h>
+
+#include <cmath>
+#include <memory>
+
+// TODO: move to c10 namespace after we
+// unified caffe2::Tensor and at::Tensor
+
+namespace at {
+
+struct QTensorImpl;
+
+using QTensor = Tensor;
+using RealTensor = Tensor;
+
+struct Quantizer;
+using QuantizerPtr = c10::intrusive_ptr<Quantizer>;
+
+/**
+ * Quantizer is the class for storing all the information
+ * that's necessary to perform quantize and dequantize
+ * operation.
+ *
+ * We might have different types of quantization schemes and this is
+ * the base class for all quantizers.
+ *
+ * QTensorImpl will hold a pointer to Quantizer so that we can support
+ * different quantization schemes on Tensor.
+ *
+ * For example, the most common quantization scheme, Affine Quantization,
+ * requires scale and zero_point as parameters, we'll store scale and zero_point
+ * inside the instance and we can use it to quantize a float Tensor or
+ * dequantize a quantized Tensor.
+ *
+ * When you add new types of leaf Quantizer class, please also
+ * make sure to add a corresponding QScheme enum since
+ * they should have one to one mapping.
+ *
+ * Note about intrusive_ptr:
+ * QTensor holds an intrusive_ptr to Quantizer, and multiple Tensor can
+ * share the same Quantizer. Quantizer should be immutable.
+ */
+struct CAFFE2_API Quantizer : public c10::intrusive_ptr_target {
+  const QScheme qscheme_;
+  explicit Quantizer(QScheme qscheme) : qscheme_(qscheme) {}
+  virtual ~Quantizer();
+
+  // Copied from torch/csrc/jit/scope.h
+  QuantizerPtr intrusive_from_this() {
+    c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
+                                           // from a raw `this` pointer
+                                           // so we need to bump the refcount
+                                           // to account for this ownership
+    return c10::intrusive_ptr<Quantizer>::reclaim(this);
+  }
+
+  virtual QScheme qscheme() {
+    return qscheme_;
+  }
+
+  /**
+   * quantize a float Tensor into a quantized Tensor.
+   */
+  virtual QTensor quantize(RealTensor t) = 0;
+
+  /**
+   * dequantize a quantized Tensor into a float Tensor.
+   */
+  virtual RealTensor dequantize(QTensor t) = 0;
+};
+
+/**
+ * UniformQuantizer is the parent class for all uniform quantizers.
+ * These quantization scheme will map float value uniformly to
+ * the quantized value. For example, affine quantizer is
+ * the most commonly used scheme in this category.
+ */
+struct CAFFE2_API UniformQuantizer : public Quantizer {
+  explicit UniformQuantizer(QScheme qscheme) : Quantizer(qscheme) {}
+};
+
+/**
+ * NonUniformQuantizer is the parent class for all non-uniform quantizers.
+ * These quantization scheme may map float value non-uniformly to the quantized
+ * value. K-means quantization is a representative example in this category.
+ */
+struct CAFFE2_API NonUniformQuantizer : public Quantizer {
+  explicit NonUniformQuantizer(QScheme qscheme) : Quantizer(qscheme) {}
+};
+
+// There is also StochasticQuantizer which is uniform but not affine
+
+/**
+ * AffineQuantizer uses affine transformation to do quantization.
+ *
+ * For quantize:
+ * Y = clamp((X * scale + zero_point, min, max)
+ * For dequantize:
+ * X = (Y - zero_point) / scale
+ */
+struct CAFFE2_API AffineQuantizer : public UniformQuantizer {
+  explicit AffineQuantizer(QScheme qscheme) : UniformQuantizer(qscheme) {}
+};
+
+/**
+ * SymmetricQuantizer is similar to AffineQuantizer except that it
+ * does not have zero_point
+ *
+ * For quantize:
+ * Y = clamp(X * scale, min, max)
+ * For dequantize:
+ * X = Y / scale
+ */
+struct CAFFE2_API SymmetricQuantizer : public UniformQuantizer {
+  explicit SymmetricQuantizer(QScheme qscheme) : UniformQuantizer(qscheme) {}
+};
+
+/**
+ * PerTensorSymmetricQuantizer stores a single scale number which is
+ * used for quantizing all the values in the given Tensor
+ */
+struct CAFFE2_API PerTensorSymmetricQuantizer : public SymmetricQuantizer {
+  explicit PerTensorSymmetricQuantizer(float scale)
+      : SymmetricQuantizer(kPerTensorSymmetric), scale_(scale) {}
+  float scale_{1.0};
+};
+
+/**
+ * PerChannelSymmetricQuantizer stores a vector of scale number and
+ * applys symmetric quantization using different scales on each channel.
+ *
+ * Also note that per channel quantization is mostly applied to output channels
+ * of weights since per-input channel of weight quantization or per-channel
+ * quantization for activations can't be efficiently supported in most of
+ * processors since it requires each multiplication result within a single
+ * dot-product to have a different scale.
+ */
+struct CAFFE2_API PerChannelSymmetricQuantizer : public SymmetricQuantizer {
+  explicit PerChannelSymmetricQuantizer(
+      const std::vector<float>& scales,
+      const std::vector<int64_t>& axis)
+      : SymmetricQuantizer(kPerChannelSymmetric), scales_(scales), axis_(axis) {
+    AT_CHECK(
+        axis_.size() == 1,
+        "Per channel symmetric quantization in multiple axis is not supported yet.");
+  }
+
+  std::vector<float> scales() const {
+    return scales_;
+  }
+
+  std::vector<int64_t> axis() const {
+    return axis_;
+  }
+
+ private:
+  const std::vector<float> scales_;
+  const std::vector<int64_t> axis_;
+};
+
+/**
+ * PerTensorAffineQuantizer stores a scale and a zero_point, which is used for
+ * all the values in the Tensor.
+ */
+struct CAFFE2_API PerTensorAffineQuantizer : public AffineQuantizer {
+  explicit PerTensorAffineQuantizer(float scale, uint8_t zero_point)
+      : AffineQuantizer(kPerTensorAffine),
+        scale_(scale),
+        zero_point_(zero_point) {}
+
+  QTensor quantize(RealTensor tensor) override;
+  RealTensor dequantize(QTensor tensor) override;
+
+  float scale() const {
+    return scale_;
+  }
+
+  uint8_t zero_point() const {
+    return zero_point_;
+  }
+
+ private:
+  const float scale_;
+  const uint8_t zero_point_;
+};
+
+/**
+ * PerChannelAffineQuantizer is the same as PerTensorAffineQuantizer
+ * except that we have an independent scale and zero_point parameter
+ * for each channel.
+ */
+struct CAFFE2_API PerChannelAffineQuantizer : public AffineQuantizer {
+  explicit PerChannelAffineQuantizer(
+      const std::vector<float>& scales,
+      const std::vector<uint8_t>& zero_points,
+      const std::vector<int64_t>& axis)
+      : AffineQuantizer(kPerChannelAffine),
+        scales_(scales),
+        zero_points_(zero_points),
+        axis_(axis) {
+    AT_CHECK(
+        axis_.size() == 1,
+        "Per channel affine quantization in multiple axis is not supported yet.");
+  }
+
+  std::vector<float> scales() const {
+    return scales_;
+  }
+
+  std::vector<uint8_t> zero_points() const {
+    return zero_points_;
+  }
+
+  std::vector<int64_t> axis() const {
+    return axis_;
+  }
+
+ private:
+  const std::vector<float> scales_;
+  const std::vector<uint8_t> zero_points_;
+  const std::vector<int64_t> axis_;
+};
+
+// This is an internal utility function for getting at the QTensorImpl,
+// You should only use this for writing low level
+// setters/getters for QTensorImpl fields; otherwise, you should use
+// the low level setters/getters that were implemented using this.
+// This may be called repeatedly, so make sure it's pretty cheap.
+CAFFE2_API QTensorImpl* get_qtensorimpl(const QTensor& self);
+
+// Quantize a float value into a uint8 value given scale and zero_point
+CAFFE2_API qint8 quantize_uint8(float scale, uint8_t zero_point, float value);
+
+// double and int64_t are because of the native function API, we only have these
+// argument types right now in native functions
+CAFFE2_API QuantizerPtr
+make_per_tensor_affine_quantizer(double scale, int64_t zero_point);
+
+// Create a QTensor given arguments for normal Tensor and a quantizer
+QTensor new_qtensor(
+    IntArrayRef sizes,
+    const TensorOptions& options,
+    bool is_variable,
+    QuantizerPtr quantizer);
+
+} // namespace at
index 5dea983..bbaa19b 100644 (file)
@@ -133,7 +133,10 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST)
     return item().to##name();     \
   }
 
-AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ITEM)
+AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_ITEM)
 #undef DEFINE_ITEM
 
+// TODO: after is_quantized() is implemented,
+// implement item() (returnning a float) for quantized Tensor
+
 } //namespace at
index 52f9eb2..d1d2a28 100644 (file)
@@ -40,6 +40,7 @@ using TensorList = ArrayRef<Tensor>;
 
 class Context;
 struct Generator;
+struct Quantizer;
 
 static inline void noop_deleter(void*) {}
 
index 4e8b955..bcee7ef 100644 (file)
@@ -21,6 +21,7 @@ list(APPEND ATen_CPU_TEST_SRCS
   ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp)
 
diff --git a/aten/src/ATen/test/quantized_test.cpp b/aten/src/ATen/test/quantized_test.cpp
new file mode 100644 (file)
index 0000000..6561448
--- /dev/null
@@ -0,0 +1,43 @@
+#include <gtest/gtest.h>
+
+#include <ATen/ATen.h>
+#include <ATen/test/test_assert.h>
+#include <cmath>
+#include <iostream>
+#include <limits>
+#include <sstream>
+#include <type_traits>
+// For quantize_uint8
+#include <ATen/quantized/Quantizer.h>
+
+using namespace at;
+
+TEST(TestQTensor, QuantDequantAPIs) {
+  auto num_elements = 10;
+  Tensor r = at::ones({num_elements});
+  const float scale = 1.0;
+  const int32_t zero_point = 2;
+  Tensor qr = r.quantize_linear(scale, zero_point);
+  ASSERT_EQ(qr.q_scale().to<float>(), scale);
+  ASSERT_EQ(qr.q_zero_point().to<int32_t>(), zero_point);
+
+  // TODO: Uncomment when quantizer is ready
+  // auto* quantizer = static_cast<PerTensorAffineQuantizer*>(qr.quantizer());
+  // ASSERT_EQ(quantizer->scale(), scale);
+  // ASSERT_EQ(quantizer->zero_point(), zero_point);
+
+  // Check for correct quantization
+  auto r_data = r.data<float>();
+  auto qr_data = qr.data<qint8>();
+  for (auto i = 0; i < num_elements; ++i) {
+    ASSERT_EQ(
+        quantize_uint8(scale, zero_point, r_data[i]).val_, qr_data[i].val_);
+  }
+
+  // Check for correct dequantization
+  Tensor rqr = qr.dequantize();
+  auto rqr_data = rqr.data<float>();
+  for (auto i = 0; i < num_elements; ++i) {
+    ASSERT_EQ(r_data[i], rqr_data[i]);
+  }
+}
index de9ddea..86efc40 100644 (file)
@@ -16,7 +16,7 @@ inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy);
     TH ## name ## Blas_axpy(n, a, x, incx, y, incy); \
   }
 
-AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(AXPY_SPECIALIZATION)
+AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(AXPY_SPECIALIZATION)
 
 
 template<typename T>
@@ -29,4 +29,4 @@ inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy);
     TH ## name ## Blas_copy(n, x, incx, y, incy); \
   }
 
-AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(COPY_SPECIALIZATION)
+AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(COPY_SPECIALIZATION)
diff --git a/c10/core/QScheme.h b/c10/core/QScheme.h
new file mode 100644 (file)
index 0000000..4757a77
--- /dev/null
@@ -0,0 +1,24 @@
+#pragma once
+
+namespace c10 {
+
+/**
+ * QScheme is an enum that specifies the type of quantization. This has a one
+ * to one correspondence with Quantizer
+ * Please refer to ATen/core/Quantizer.h to see the Quantizers classes.
+ */
+enum class QScheme : uint8_t {
+  NO_QUANT,
+  PER_TENSOR_AFFINE,
+  PER_CHANNEL_AFFINE,
+  PER_TENSOR_SYMMETRIC,
+  PER_CHANNEL_SYMMETRIC
+};
+
+constexpr auto kNoQuant = QScheme::NO_QUANT;
+constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE;
+constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE;
+constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC;
+constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC;
+
+} // namespace c10
index 7952e86..23cd5a0 100644 (file)
@@ -31,7 +31,7 @@ class C10_API Scalar {
   // We can't set v in the initializer list using the
   // syntax v{ .member = ... } because it doesn't work on MSVC
 
-  AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR)
+  AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(DEFINE_IMPLICIT_CTOR)
 
 #undef DEFINE_IMPLICIT_CTOR
 
@@ -71,7 +71,8 @@ class C10_API Scalar {
   }
 
   // TODO: Support ComplexHalf accessor
-  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR)
+  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(
+      DEFINE_ACCESSOR)
 
   // also support scalar.to<int64_t>();
   template <typename T>
@@ -114,6 +115,6 @@ inline T Scalar::to() {
   inline T Scalar::to<T>() {  \
     return to##name();        \
   }
-AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO)
+AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_TO)
 #undef DEFINE_TO
 } // namespace c10
index 11549e1..99e7975 100644 (file)
@@ -11,6 +11,9 @@
 
 namespace c10 {
 
+// TODO: check all usages of these macro and make sure
+// the use case makes sense for qint
+
 // NB: Order matters for this macro; it is relied upon in
 // _promoteTypesLookup and the serialization format.
 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_)       \
@@ -25,7 +28,8 @@ namespace c10 {
   _(at::ComplexHalf, ComplexHalf, z) /* 8 */         \
   _(std::complex<float>, ComplexFloat, z) /* 9 */    \
   _(std::complex<double>, ComplexDouble, z) /* 10 */ \
-  _(bool, Bool, i) /* 11 */
+  _(bool, Bool, i) /* 11 */                          \
+  _(c10::qint8, QInt8, i) /* 12 */
 
 // If you want to support ComplexHalf for real, replace occurrences
 // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX.  But
@@ -41,6 +45,20 @@ namespace c10 {
   _(double, Double, d)                                             \
   _(std::complex<float>, ComplexFloat, z)                          \
   _(std::complex<double>, ComplexDouble, z)                        \
+  _(bool, Bool, i)                                                 \
+  _(c10::qint8, QInt8, i)
+
+#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(_) \
+  _(uint8_t, Byte, i)                                                       \
+  _(int8_t, Char, i)                                                        \
+  _(int16_t, Short, i)                                                      \
+  _(int, Int, i)                                                            \
+  _(int64_t, Long, i)                                                       \
+  _(at::Half, Half, d)                                                      \
+  _(float, Float, d)                                                        \
+  _(double, Double, d)                                                      \
+  _(std::complex<float>, ComplexFloat, z)                                   \
+  _(std::complex<double>, ComplexDouble, z)                                 \
   _(bool, Bool, i)
 
 #define AT_FORALL_SCALAR_TYPES(_) \
@@ -51,17 +69,28 @@ namespace c10 {
   _(int64_t, Long, i)             \
   _(at::Half, Half, d)            \
   _(float, Float, d)              \
+  _(double, Double, d)            \
+  _(c10::qint8, QInt8, i)
+
+#define AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(_) \
+  _(uint8_t, Byte, i)                         \
+  _(int8_t, Char, i)                          \
+  _(int16_t, Short, i)                        \
+  _(int, Int, i)                              \
+  _(int64_t, Long, i)                         \
+  _(at::Half, Half, d)                        \
+  _(float, Float, d)                          \
   _(double, Double, d)
 
-#define AT_FORALL_SCALAR_TYPES_AND_BOOL(_) \
-  _(uint8_t, Byte, i)                      \
-  _(int8_t, Char, i)                       \
-  _(int16_t, Short, i)                     \
-  _(int, Int, i)                           \
-  _(int64_t, Long, i)                      \
-  _(at::Half, Half, d)                     \
-  _(float, Float, d)                       \
-  _(double, Double, d)                     \
+#define AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(_) \
+  _(uint8_t, Byte, i)                                  \
+  _(int8_t, Char, i)                                   \
+  _(int16_t, Short, i)                                 \
+  _(int, Int, i)                                       \
+  _(int64_t, Long, i)                                  \
+  _(at::Half, Half, d)                                 \
+  _(float, Float, d)                                   \
+  _(double, Double, d)                                 \
   _(bool, Bool, i)
 
 #define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
@@ -71,6 +100,16 @@ namespace c10 {
   _(int, Int, i)                              \
   _(int64_t, Long, i)                         \
   _(float, Float, d)                          \
+  _(double, Double, d)                        \
+  _(c10::qint8, QInt8, i)
+
+#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(_) \
+  _(uint8_t, Byte, i)                                  \
+  _(int8_t, Char, i)                                   \
+  _(int16_t, Short, i)                                 \
+  _(int, Int, i)                                       \
+  _(int64_t, Long, i)                                  \
+  _(float, Float, d)                                   \
   _(double, Double, d)
 
 enum class ScalarType : int8_t {
@@ -182,6 +221,11 @@ static inline bool isComplexType(ScalarType t) {
       t == ScalarType::ComplexDouble);
 }
 
+static inline bool isQIntType(ScalarType t) {
+  // Don't forget to extend this when adding new QInt types
+  return t == ScalarType::QInt8;
+}
+
 static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
   // This is generated according to NumPy's promote_types
   constexpr auto u1 = ScalarType::Byte;
@@ -202,6 +246,11 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
         "promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
   }
 
+  if (isQIntType(a) || isQIntType(b)) {
+    AT_ERROR(
+        "promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be");
+  }
+
   // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
   // so that's why we have to add undefined as we are not sure what is the
   // corrent values for the type promotions in complex type cases.
index 76e36ea..24b7576 100644 (file)
@@ -1,5 +1,4 @@
 #include <c10/util/Half.h>
-
 #include <iostream>
 
 namespace c10 {
@@ -12,5 +11,4 @@ std::ostream& operator<<(std::ostream& out, const Half& value) {
   out << (float)value;
   return out;
 }
-
 } // namespace c10
diff --git a/c10/util/qint8.h b/c10/util/qint8.h
new file mode 100644 (file)
index 0000000..e6c3721
--- /dev/null
@@ -0,0 +1,16 @@
+#pragma once
+#include <cstdint>
+
+namespace c10 {
+
+/**
+ * This is the data type for quantized Tensors. Right now we only have
+ * qint8 which is for 8 bit Tensors, we might have 4 bit, 2 bit or 1 bit
+ * data types in the future.
+ */
+struct alignas(1) qint8 {
+  uint8_t val_;
+  explicit qint8(uint8_t val) : val_(val) {}
+};
+
+} // namespace c10
index 6253c83..ad62df5 100644 (file)
@@ -80,6 +80,8 @@ CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(
     26,
     detail::_guard_long_unique<std::vector<long>>);
 
-CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, _CaffeHighestPreallocatedTypeId)
+CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, c10::qint8);
+
+CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId)
 
 } // namespace caffe2
index 9cf9d92..8019714 100644 (file)
@@ -23,6 +23,7 @@
 #include "c10/util/Exception.h"
 #include "c10/util/Half.h"
 #include "c10/util/IdWrapper.h"
+#include "c10/util/qint8.h"
 
 #include "c10/util/Type.h"
 
@@ -622,5 +623,7 @@ CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(
     26,
     detail::_guard_long_unique<std::vector<long>>)
 
-CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, _CaffeHighestPreallocatedTypeId)
+CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, c10::qint8);
+
+CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId)
 } // namespace caffe2
index 9354cd7..b597084 100644 (file)
@@ -15,7 +15,7 @@ static std::unordered_map<std::string, int> op_to_key = {
 
 namespace caffe2 {
 
-using at::Half; // for AT_FORALL_SCALAR_TYPES_AND_BOOL
+using at::Half; // for AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT
 
 template <class Context>
 class ATenOp : public Operator<Context> {
@@ -47,7 +47,7 @@ private:
       case at::k##aten_name: \
         return TypeMeta::Make<ctype>();
     switch(st) {
-      AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE)
+      AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(DEFINE_CASE)
       default:
         CAFFE_THROW("Unknown ATen Type");
     }
@@ -114,10 +114,10 @@ private:
     }
   }
 
-  // the AT_FORALL_SCALAR_TYPES_AND_BOOL macro just gives a 'i' or 'd' argument
-  // for each type to specify if it is stored as a integer or a double.
-  // We need this workaround here to extract the value in the scalar losslessly
-  // because in some cases like 'sum' Torch promotes float to double
+  // the AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT macro just gives a 'i' or
+  // 'd' argument for each type to specify if it is stored as a integer or a
+  // double. We need this workaround here to extract the value in the scalar
+  // losslessly because in some cases like 'sum' Torch promotes float to double
   // and will complain if we downcast it with toFloat, causing it
   // to lose precision
   double extract_d(const at::Scalar & s) {
@@ -134,8 +134,8 @@ private:
           auto value = extract_##native(scalar); \
           assignToValue<ctype>(dst, at::convert<ctype,decltype(value)>(value)); \
         } break;
-      AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE)
-      #undef DEFINE_CASE
+      AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(DEFINE_CASE)
+#undef DEFINE_CASE
       default:
         CAFFE_THROW("Unknown ATen Type");
     }
index 1fa560e..3f0b2f3 100644 (file)
@@ -80,7 +80,7 @@ void cast_op_cpu(
     int64_t to) {
   switch (input.scalar_type()) {
 #define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl<ctype>(input, output, to);
-    AT_FORALL_SCALAR_TYPES_AND_BOOL(CASE)
+    AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(CASE)
 #undef CASE
     default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type()));
   }
index b9ad547..68d6997 100644 (file)
@@ -205,6 +205,7 @@ view of a storage and defines numeric operations on it.
    .. automethod:: cumprod
    .. automethod:: cumsum
    .. automethod:: data_ptr
+   .. automethod:: dequantize
    .. automethod:: det
    .. automethod:: dense_dim
    .. automethod:: detach
@@ -354,6 +355,9 @@ view of a storage and defines numeric operations on it.
    .. automethod:: pstrf
    .. automethod:: put_
    .. automethod:: qr
+   .. automethod:: quantize_linear
+   .. automethod:: q_scale
+   .. automethod:: q_zero_point
    .. automethod:: random_
    .. automethod:: reciprocal
    .. automethod:: reciprocal_
index ffc8943..ef21ea6 100644 (file)
@@ -2653,6 +2653,18 @@ class _TestTorchMixin(object):
                 if TEST_NUMPY:
                     assertEqual('cuda:1', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:1'))
 
+    def test_qtensor(self):
+        num_elements = 10
+        r = torch.ones(num_elements, dtype=torch.float)
+        scale = 1.0
+        zero_point = 2
+        qr = r.quantize_linear(scale, zero_point)
+        self.assertEqual(qr.q_scale(), scale)
+        self.assertEqual(qr.q_zero_point(), zero_point)
+        rqr = qr.dequantize()
+        for i in range(num_elements):
+            self.assertEqual(r[i], rqr[i])
+
     def test_to(self):
         def test_copy_behavior(t, non_blocking=False):
             self.assertIs(t, t.to(t, non_blocking=non_blocking))
index 9f6ccf8..58fb3b2 100644 (file)
@@ -149,7 +149,7 @@ SUPPORTED_RETURN_TYPES = {
     'std::tuple<Tensor,Tensor,Tensor,int64_t>',
     'std::tuple<Tensor,Tensor,double,int64_t>',
     'std::vector<Tensor>',
-    'Scalar', 'bool', 'int64_t', 'void*', 'void'
+    'Scalar', 'bool', 'int64_t', 'void*', 'void',
 }
 
 TENSOR_OPTIONS = CodeTemplate("""\
index 7fbcea2..29b1e55 100644 (file)
@@ -4,6 +4,7 @@
 
 #include <ATen/ATen.h>
 #include <c10/util/ArrayRef.h>
+#include <c10/util/qint8.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/tracer.h>
 
index 55ff438..1b79a44 100644 (file)
@@ -671,6 +671,13 @@ data_ptr() -> int
 Returns the address of the first element of :attr:`self` tensor.
 """)
 
+add_docstr_all('dequantize',
+               r"""
+dequantize() -> Tensor
+
+Given a quantized Tensor, dequantize it and return the dequantized float Tensor.
+""")
+
 add_docstr_all('dense_dim',
                r"""
 dense_dim() -> int
@@ -1775,6 +1782,31 @@ qr() -> (Tensor, Tensor)
 See :func:`torch.qr`
 """)
 
+add_docstr_all('quantize_linear',
+               r"""
+quantize_linear(scale, zero_point) -> Tensor
+
+Quantize a float Tensor using affine quantization scheme with given scale and
+zero_point.
+returns the quantized Tensor.
+""")
+
+add_docstr_all('q_scale',
+               r"""
+q_scale() -> float
+
+Given a Tensor quantized by linear(affine) quantization,
+returns the scale of the underlying quantizer().
+""")
+
+add_docstr_all('q_zero_point',
+               r"""
+q_zero_point() -> int
+
+Given a Tensor quantized by linear(affine) quantization,
+returns the zero_point of the underlying quantizer().
+""")
+
 add_docstr_all('random_',
                r"""
 random_(from=0, to=None, *, generator=None) -> Tensor
index b165551..51c8dc5 100644 (file)
@@ -40,6 +40,8 @@ static std::pair<std::string, std::string> getDtypeNames(
       return std::make_pair("complex128", "");
     case at::ScalarType::Bool:
       return std::make_pair("bool", "");
+    case at::ScalarType::QInt8:
+      return std::make_pair("qint8", "");
     default:
       throw std::runtime_error("Unimplemented scalar type");
   }