Add Quantized Backend (#18546)
authorJerry Zhang <jerryzh@fb.com>
Fri, 12 Apr 2019 19:47:39 +0000 (12:47 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 12 Apr 2019 19:55:49 +0000 (12:55 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18546

We'll expose all combinations of various ways of quantization in the top level dispatch key, that is we have AffineCPUTensor, PerChannelAffineCUDATensor, etc.

QTensor method added:
- is_quantized()
- item()

Differential Revision: D14637671

fbshipit-source-id: 346bc6ef404a570f0efd34e8793056ad3c7855f5

25 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/function_wrapper.py
aten/src/ATen/gen.py
aten/src/ATen/native/Scalar.cpp
aten/src/ATen/native/TypeProperties.cpp
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/quantized/QTensor.cpp
aten/src/ATen/preprocess_declarations.py
aten/src/ATen/quantized/Quantizer.cpp
aten/src/ATen/templates/SparseTypeDerived.cpp
aten/src/ATen/templates/Tensor.h
aten/src/ATen/templates/TensorMethods.h
aten/src/ATen/templates/Type.h
aten/src/ATen/templates/TypeDefault.h
aten/src/ATen/templates/TypeDerived.cpp
aten/src/ATen/templates/TypeDerived.h
aten/src/ATen/test/quantized_test.cpp
c10/core/Backend.h
c10/core/TensorImpl.h
c10/core/TensorTypeIdRegistration.cpp
c10/core/TensorTypeIdRegistration.h
test/test_torch.py
torch/csrc/autograd/python_variable.cpp

index 3ba8a57..d37722e 100644 (file)
@@ -252,6 +252,9 @@ class CAFFE2_API Tensor {
   /// Returns if a `Tensor` has sparse backend.
   bool is_sparse() const;
 
+  /// Returns if a `Tensor` has quantized backend.
+  bool is_quantized() const;
+
   /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
   /// TensorOptions.h.
   TensorOptions options() const;
index 367839d..0b3249d 100644 (file)
@@ -1351,6 +1351,15 @@ inline bool is_sparse(Tensor self) {
   return self.is_sparse();
 }
 
+inline bool Tensor::is_quantized() const {
+  // NB: this is not a native function to avoid dispatching overhead.
+  return impl_->is_quantized();
+}
+
+inline bool is_quantized(Tensor self) {
+  return self.is_quantized();
+}
+
 #define DEFINE_CAST(T, name, _)                  \
   template <>                                    \
   inline T* Tensor::data() const {               \
@@ -1375,7 +1384,4 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST)
 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_ITEM)
 #undef DEFINE_ITEM
 
-// TODO: after is_quantized() is implemented,
-// implement item() (returnning a float) for quantized Tensor
-
 } //namespace at
index 15d5efd..59f29fa 100644 (file)
@@ -84,6 +84,7 @@ enum class TypeID {
   SparseCUDALong,
   SparseCUDAShort,
   SparseCUDAQInt8,
+  QuantizedCPUQInt8,
   MSNPUBool,
   MSNPUByte,
   MSNPUChar,
@@ -124,6 +125,7 @@ struct CAFFE2_API Type {
   virtual bool is_cuda() const = 0;
   virtual bool is_hip() const = 0;
   virtual bool is_sparse() const = 0;
+  virtual bool is_quantized() const = 0;
   virtual bool is_distributed() const = 0;
   bool is_variable() const noexcept { return is_variable_; }
   bool is_undefined() const noexcept { return is_undefined_; }
index 307ff06..bb6003a 100644 (file)
@@ -1196,7 +1196,6 @@ def create_derived(backend_type_env, declarations):
     # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
     type_object_declarations = []
     type_object_definitions = []
-
     is_cuda = 'CUDA' in backend_type_env['Backend']
 
     def replace_with_null(argument):
index d8a3e50..08da16b 100644 (file)
@@ -174,11 +174,22 @@ generators = {
     },
 }
 
-backends = ['CPU', 'CUDA']
+backends = ['CPU', 'CUDA', 'QuantizedCPU']
 densities = ['Dense', 'Sparse', 'Mkldnn']  # TODO: layout instead of densities?
+def backend_to_devicetype(backend):
+    if backend == 'QuantizedCPU':
+        return 'CPU'
+    return backend
+
+quantized_backends = ['QuantizedCPU']
+
 extension_backends = ['MSNPU', 'XLA']
 
 # scalar_name, c_type, accreal, is_floating_type
+quantized_scalar_types = [
+    ('QInt8', 'qint8', 'QInt8AccrealNotDefined', 'Qint8IsFloatingTypeNotDefined'),
+]
+
 scalar_types = [
     ('Bool', 'bool', 'BoolAccrealNotDefined', False),
     ('Byte', 'uint8_t', 'Long', False),
@@ -189,8 +200,9 @@ scalar_types = [
     ('Long', 'int64_t', 'Long', False),
     ('Short', 'int16_t', 'Long', False),
     ('Half', 'Half', 'Double', True),
-    ('QInt8', 'qint8', 'Long', False),
-]
+] + quantized_scalar_types
+
+
 
 # shared environment for non-derived base classes Type.h Tensor.h Storage.h
 top_env = {
@@ -267,9 +279,8 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
     env['isFloatingType'] = is_floating_type
     env['isIntegralType'] = not is_floating_type
     env['Type'] = "{}{}{}Type".format(density_tag, backend, scalar_name)
-    env['DenseTensor'] = "{}{}Tensor".format(backend, scalar_name)
+    env['DeviceType'] = backend_to_devicetype(backend)
     env['Backend'] = density_tag + backend
-    env['DenseBackend'] = backend
     env['storage_tensor_headers'] = []
     if density != 'Sparse':
         env['storage_tensor_headers'] = ['#include <c10/core/TensorImpl.h>']
@@ -341,7 +352,7 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
     env['type_derived_method_definitions'] = definitions
 
     fm = file_manager
-    if env['DenseBackend'] == 'CUDA':
+    if env['DeviceType'] == 'CUDA':
         fm = cuda_file_manager
 
     if density != 'Sparse':
@@ -351,12 +362,12 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
     fm.write(env['Type'] + ".h", TYPE_DERIVED_H, env)
 
     type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type'])
-    if env['DenseBackend'] == 'CPU':
+    if env['DeviceType'] == 'CPU':
         top_env['cpu_type_registrations'].append(type_register)
         top_env['cpu_type_headers'].append(
             '#include "ATen/{}.h"'.format(env['Type']))
     else:
-        assert env['DenseBackend'] == 'CUDA'
+        assert env['DeviceType'] == 'CUDA'
         top_env['cuda_type_registrations'].append(type_register)
         top_env['cuda_type_headers'].append(
             '#include "ATen/{}.h"'.format(env['Type']))
@@ -366,7 +377,7 @@ def generate_type_extension_backend(backend, declarations):
     env = {}
     env['Type'] = "{}Type".format(backend)
     env['Backend'] = backend
-    env['DeviceType'] = backend
+    env['DeviceType'] = backend_to_devicetype(backend)
 
     declarations, definitions = function_wrapper.create_extension_backend(
         env, declarations)
@@ -426,7 +437,13 @@ def iterate_types():
                 if density == 'Sparse' and scalar_type[0] == 'Half':
                     # THS does not do half type yet.
                     continue
-                yield (backend, density, scalar_type)
+                if backend in quantized_backends:
+                    if density == 'Dense' and scalar_type in quantized_scalar_types:
+                            yield (backend, density, scalar_type)
+                    else:
+                        continue
+                else:
+                    yield (backend, density, scalar_type)
 
 
 ###################
index 918f4c3..c018b3a 100644 (file)
@@ -11,6 +11,8 @@ Scalar item(const Tensor& self) {
     if (self._nnz() == 0) return Scalar(0);
     if (self.is_coalesced()) return at::_local_scalar_dense(self._values());
     return at::_local_scalar_dense(self._values().sum());
+  } else if (self.is_quantized()) {
+    return self.dequantize().item();
   } else {
     return _local_scalar_dense(self);
   }
index ae1910f..1a3f7bd 100644 (file)
@@ -34,6 +34,10 @@ bool is_sparse(const Tensor& self) {
   return self.is_sparse();
 }
 
+bool is_quantized(const Tensor& self) {
+  return self.is_quantized();
+}
+
 Tensor type_as(const Tensor& self, const Tensor& other) {
   return self.toType(other.type());
 }
index a4b6469..609e179 100644 (file)
 
 - func: quantize_linear(Tensor self, float scale, int zero_point) -> Tensor
   variants: function, method
-  requires_tensor: True
+  dispatch:
+    CPU: quantize_linear_cpu
 
 - func: dequantize(Tensor self) -> Tensor
   variants: function, method
-  requires_tensor: True
+  dispatch:
+    QuantizedCPU: dequantize_quant
 
 - func: q_scale(Tensor self) -> Scalar
   variants: function, method
-  requires_tensor: True
+  dispatch:
+    QuantizedCPU: q_scale_quant
 
 - func: q_zero_point(Tensor self) -> Scalar
   variants: function, method
-  requires_tensor: True
+  dispatch:
+    QuantizedCPU: q_zero_point_quant
 
 # to(Device) must not exist because all constructors of Device also works for
 # TensorOptions. Otherwise, an ambiguity error is thrown.
index ef8e78b..5883e16 100644 (file)
@@ -7,22 +7,22 @@
 namespace at {
 namespace native {
 
-QTensor quantize_linear(const RealTensor& self, double scale, int64_t zero_point) {
+QTensor quantize_linear_cpu(const RealTensor& self, double scale, int64_t zero_point) {
   auto quantizer = make_per_tensor_affine_quantizer(scale, zero_point);
   return quantizer->quantize(self);
 }
 
-RealTensor dequantize(const QTensor& self) {
+RealTensor dequantize_quant(const QTensor& self) {
   return get_qtensorimpl(self)->quantizer()->dequantize(self);
 }
 
-Scalar q_scale(const QTensor& self) {
+Scalar q_scale_quant(const QTensor& self) {
   auto quantizer = get_qtensorimpl(self)->quantizer();
   AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);
   return Scalar(static_cast<PerTensorAffineQuantizer*>(quantizer.get())->scale());
 }
 
-Scalar q_zero_point(const QTensor& self) {
+Scalar q_zero_point_quant(const QTensor& self) {
   auto quantizer = get_qtensorimpl(self)->quantizer();
   AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);
   return Scalar(static_cast<PerTensorAffineQuantizer*>(quantizer.get())->zero_point());
index ed1a2ac..7953612 100644 (file)
@@ -23,7 +23,7 @@ type_map = {
 all_types = type_map['floating_point'] + type_map['integral']
 type_map['all'] = all_types
 
-all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU']
+all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU']
 default_backends = ['CPU', 'CUDA']
 
 
index 1f30ae2..b4c6936 100644 (file)
@@ -42,8 +42,9 @@ inline QTensor new_qtensor(
       allocator->allocate(nelements * dtype.itemsize()),
       allocator,
       /*resizable=*/true);
+  // TODO: get TensorTypeId from quantizer
   auto tensor = detail::make_tensor<QTensorImpl>(
-      storage_impl, at::CPUTensorId(), quantizer);
+      storage_impl, at::QuantizedCPUTensorId(), quantizer);
   get_qtensorimpl(tensor)->set_sizes_contiguous(sizes);
   return tensor;
 }
index 43d9b86..322ce91 100644 (file)
@@ -28,7 +28,7 @@ $extra_cuda_headers
 namespace at {
 
 ${Type}::${Type}()
-  : ${DenseBackend}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
+  : ${DeviceType}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
 ScalarType ${Type}::scalarType() const {
   return ScalarType::${ScalarName};
 }
index ab9917e..b2cac27 100644 (file)
@@ -252,6 +252,9 @@ class CAFFE2_API Tensor {
   /// Returns if a `Tensor` has sparse backend.
   bool is_sparse() const;
 
+  /// Returns if a `Tensor` has quantized backend.
+  bool is_quantized() const;
+
   /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
   /// TensorOptions.h.
   TensorOptions options() const;
index 18b5e53..365b80c 100644 (file)
@@ -116,6 +116,15 @@ inline bool is_sparse(Tensor self) {
   return self.is_sparse();
 }
 
+inline bool Tensor::is_quantized() const {
+  // NB: this is not a native function to avoid dispatching overhead.
+  return impl_->is_quantized();
+}
+
+inline bool is_quantized(Tensor self) {
+  return self.is_quantized();
+}
+
 #define DEFINE_CAST(T, name, _)                  \
   template <>                                    \
   inline T* Tensor::data() const {               \
@@ -140,7 +149,4 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST)
 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_ITEM)
 #undef DEFINE_ITEM
 
-// TODO: after is_quantized() is implemented,
-// implement item() (returnning a float) for quantized Tensor
-
 } //namespace at
index ed965ca..65c3cb1 100644 (file)
@@ -66,6 +66,7 @@ struct CAFFE2_API Type {
   virtual bool is_cuda() const = 0;
   virtual bool is_hip() const = 0;
   virtual bool is_sparse() const = 0;
+  virtual bool is_quantized() const = 0;
   virtual bool is_distributed() const = 0;
   bool is_variable() const noexcept { return is_variable_; }
   bool is_undefined() const noexcept { return is_undefined_; }
index 4ffc1c2..5235190 100644 (file)
@@ -21,6 +21,9 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface {
   bool is_sparse() const override {
     return backend() == Backend::SparseCPU || backend() == Backend::SparseCUDA || backend() == Backend::SparseHIP;
   }
+  bool is_quantized() const override {
+    return backend() == Backend::QuantizedCPU;
+  }
   bool is_distributed() const override {
     return false;
   }
index cd30ba1..ca675a9 100644 (file)
@@ -31,7 +31,7 @@ $extra_cuda_headers
 namespace at {
 
 ${Type}::${Type}()
-  : ${DenseBackend}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
+  : ${DeviceType}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
 
 ScalarType ${Type}::scalarType() const {
   return ScalarType::${ScalarName};
index f0b8ed6..b47f53c 100644 (file)
@@ -16,7 +16,7 @@ $extra_cuda_headers
 
 namespace at {
 
-struct ${Type} final : public ${DenseBackend}TypeDefault {
+struct ${Type} final : public ${DeviceType}TypeDefault {
   explicit ${Type}();
   virtual ScalarType scalarType() const override;
   virtual caffe2::TypeMeta typeMeta() const override;
index 6561448..dd99d22 100644 (file)
@@ -20,11 +20,7 @@ TEST(TestQTensor, QuantDequantAPIs) {
   Tensor qr = r.quantize_linear(scale, zero_point);
   ASSERT_EQ(qr.q_scale().to<float>(), scale);
   ASSERT_EQ(qr.q_zero_point().to<int32_t>(), zero_point);
-
-  // TODO: Uncomment when quantizer is ready
-  // auto* quantizer = static_cast<PerTensorAffineQuantizer*>(qr.quantizer());
-  // ASSERT_EQ(quantizer->scale(), scale);
-  // ASSERT_EQ(quantizer->zero_point(), zero_point);
+  ASSERT_TRUE(qr.is_quantized());
 
   // Check for correct quantization
   auto r_data = r.data<float>();
@@ -41,3 +37,11 @@ TEST(TestQTensor, QuantDequantAPIs) {
     ASSERT_EQ(r_data[i], rqr_data[i]);
   }
 }
+
+TEST(TestQTensor, Item) {
+  Tensor r = at::ones({1});
+  const float scale = 1;
+  const int32_t zero_point = 2;
+  Tensor qr = r.quantize_linear(scale, zero_point);
+  ASSERT_EQ(r.item().to<float>(), qr.item().to<float>());
+}
index 194c541..2fa5c47 100644 (file)
@@ -26,7 +26,7 @@ namespace c10 {
  * or "SparseCUDA"; backend in torch.backends is something like "MKL" or
  * "CUDNN".
  */
-enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, Undefined, MkldnnCPU, NumOptions };
+enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, QuantizedCPU, Undefined, MkldnnCPU, NumOptions };
 
 static inline Backend toSparse(Backend b) {
   switch (b) {
@@ -65,6 +65,8 @@ static inline Backend toDense(Backend b) {
       return Backend::CUDA;
     case Backend::SparseHIP:
       return Backend::HIP;
+    case Backend::QuantizedCPU:
+      return Backend::QuantizedCPU;
     default:
       throw std::runtime_error("Unknown backend");
   }
@@ -89,6 +91,8 @@ static inline Backend tensorTypeIdToBackend(TensorTypeId t) {
     return Backend::SparseHIP;
   } else if (t == MkldnnCPUTensorId()) {
     return Backend::MkldnnCPU;
+  } else if (t == QuantizedCPUTensorId()) {
+    return Backend::QuantizedCPU;
   } else if (t == UndefinedTensorId()) {
     return Backend::Undefined;
   } else {
@@ -116,6 +120,8 @@ static inline TensorTypeId backendToTensorTypeId(Backend b) {
       return SparseHIPTensorId();
     case Backend::MkldnnCPU:
       return MkldnnCPUTensorId();
+    case Backend::QuantizedCPU:
+      return QuantizedCPUTensorId();
     case Backend::Undefined:
       return UndefinedTensorId();
     default:
@@ -142,6 +148,7 @@ static inline DeviceType backendToDeviceType(Backend b) {
     case Backend::SparseHIP:
       return DeviceType::HIP;
     case Backend::MkldnnCPU:
+    case Backend::QuantizedCPU:
       return DeviceType::CPU;
     case Backend::Undefined:
       AT_ERROR("Undefined backend is not a valid device type");
@@ -169,6 +176,8 @@ static inline Backend backendToCPU(Backend b) {
       return Backend::CPU;
     case Backend::MkldnnCPU:
       return Backend::MkldnnCPU;
+    case Backend::QuantizedCPU:
+      return Backend::QuantizedCPU;
     case Backend::Undefined:
       return Backend::Undefined;
     default:
@@ -240,6 +249,8 @@ static inline const char* toString(Backend b) {
       return "SparseHIP";
     case Backend::MkldnnCPU:
       return "MkldnnCPU";
+    case Backend::QuantizedCPU:
+      return "QuantizedCPU";
     default:
       return "UNKNOWN_BACKEND";
   }
index 4600586..3a80a79 100644 (file)
@@ -397,6 +397,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
     return tid == SparseCPUTensorId() || tid == SparseCUDATensorId() || tid == SparseHIPTensorId();
   }
 
+  bool is_quantized() const {
+    // NB: This method is not virtual and avoid dispatches for performance reasons.
+    auto tid = type_id();
+    // NB: At the moment, variables have the same TensorTypeId as their
+    // corresponding tensor, but if this ever changes, we need to modify this.
+    return tid == QuantizedCPUTensorId();
+  }
+
   bool is_cuda() const {
     // NB: This method is not virtual and avoid dispatches for performance reasons.
     auto tid = type_id();
index 968d9a4..559aa0d 100644 (file)
@@ -71,5 +71,6 @@ C10_DEFINE_TENSOR_TYPE(SparseHIPTensorId);
 C10_DEFINE_TENSOR_TYPE(MSNPUTensorId);
 C10_DEFINE_TENSOR_TYPE(XLATensorId);
 C10_DEFINE_TENSOR_TYPE(MkldnnCPUTensorId);
+C10_DEFINE_TENSOR_TYPE(QuantizedCPUTensorId);
 
 } // namespace c10
index a7fc314..7726e37 100644 (file)
@@ -110,6 +110,7 @@ C10_DECLARE_TENSOR_TYPE(SparseHIPTensorId); // PyTorch only
 C10_DECLARE_TENSOR_TYPE(MSNPUTensorId); // PyTorch only
 C10_DECLARE_TENSOR_TYPE(XLATensorId); // PyTorch only
 C10_DECLARE_TENSOR_TYPE(MkldnnCPUTensorId);
+C10_DECLARE_TENSOR_TYPE(QuantizedCPUTensorId); // PyTorch only
 
 } // namespace c10
 
index de52cb7..d7e1ddb 100644 (file)
@@ -2661,9 +2661,15 @@ class _TestTorchMixin(object):
         qr = r.quantize_linear(scale, zero_point)
         self.assertEqual(qr.q_scale(), scale)
         self.assertEqual(qr.q_zero_point(), zero_point)
+        self.assertTrue(qr.is_quantized)
+        self.assertFalse(r.is_quantized)
         rqr = qr.dequantize()
         for i in range(num_elements):
             self.assertEqual(r[i], rqr[i])
+        # Testing item
+        r = torch.ones(1, dtype=torch.float)
+        qr = r.quantize_linear(scale, zero_point)
+        self.assertEqual(qr.item(), 1)
 
     @unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected')
     def test_device_guard(self):
index 54c8134..3357ca7 100644 (file)
@@ -392,6 +392,14 @@ PyObject *THPVariable_is_sparse(THPVariable *self)
   END_HANDLE_TH_ERRORS
 }
 
+PyObject *THPVariable_is_quantized(THPVariable *self)
+{
+  HANDLE_TH_ERRORS
+  auto& self_ = self->cdata;
+  return torch::autograd::utils::wrap(self_.is_quantized());
+  END_HANDLE_TH_ERRORS
+}
+
 static PyObject *THPVariable_dtype(THPVariable *self)
 {
   HANDLE_TH_ERRORS
@@ -431,6 +439,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
   {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
   {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
   {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
+  {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
   {"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr},
   {"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
   {"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},