From 1c836e7bb9eafdb743376d09ccbd88ceaad76b05 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 12 Apr 2019 12:47:39 -0700 Subject: [PATCH] Add Quantized Backend (#18546) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18546 We'll expose all combinations of various ways of quantization in the top level dispatch key, that is we have AffineCPUTensor, PerChannelAffineCUDATensor, etc. QTensor method added: - is_quantized() - item() Differential Revision: D14637671 fbshipit-source-id: 346bc6ef404a570f0efd34e8793056ad3c7855f5 --- aten/src/ATen/core/Tensor.h | 3 +++ aten/src/ATen/core/TensorMethods.h | 12 ++++++--- aten/src/ATen/core/Type.h | 2 ++ aten/src/ATen/function_wrapper.py | 1 - aten/src/ATen/gen.py | 37 +++++++++++++++++++-------- aten/src/ATen/native/Scalar.cpp | 2 ++ aten/src/ATen/native/TypeProperties.cpp | 4 +++ aten/src/ATen/native/native_functions.yaml | 12 ++++++--- aten/src/ATen/native/quantized/QTensor.cpp | 8 +++--- aten/src/ATen/preprocess_declarations.py | 2 +- aten/src/ATen/quantized/Quantizer.cpp | 3 ++- aten/src/ATen/templates/SparseTypeDerived.cpp | 2 +- aten/src/ATen/templates/Tensor.h | 3 +++ aten/src/ATen/templates/TensorMethods.h | 12 ++++++--- aten/src/ATen/templates/Type.h | 1 + aten/src/ATen/templates/TypeDefault.h | 3 +++ aten/src/ATen/templates/TypeDerived.cpp | 2 +- aten/src/ATen/templates/TypeDerived.h | 2 +- aten/src/ATen/test/quantized_test.cpp | 14 ++++++---- c10/core/Backend.h | 13 +++++++++- c10/core/TensorImpl.h | 8 ++++++ c10/core/TensorTypeIdRegistration.cpp | 1 + c10/core/TensorTypeIdRegistration.h | 1 + test/test_torch.py | 6 +++++ torch/csrc/autograd/python_variable.cpp | 9 +++++++ 25 files changed, 127 insertions(+), 36 deletions(-) diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 3ba8a57..d37722e 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -252,6 +252,9 @@ class CAFFE2_API Tensor { /// Returns if a `Tensor` has sparse backend. bool is_sparse() const; + /// Returns if a `Tensor` has quantized backend. + bool is_quantized() const; + /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in /// TensorOptions.h. TensorOptions options() const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 367839d..0b3249d 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -1351,6 +1351,15 @@ inline bool is_sparse(Tensor self) { return self.is_sparse(); } +inline bool Tensor::is_quantized() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_quantized(); +} + +inline bool is_quantized(Tensor self) { + return self.is_quantized(); +} + #define DEFINE_CAST(T, name, _) \ template <> \ inline T* Tensor::data() const { \ @@ -1375,7 +1384,4 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST) AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_ITEM) #undef DEFINE_ITEM -// TODO: after is_quantized() is implemented, -// implement item() (returnning a float) for quantized Tensor - } //namespace at diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 15d5efd..59f29fa 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -84,6 +84,7 @@ enum class TypeID { SparseCUDALong, SparseCUDAShort, SparseCUDAQInt8, + QuantizedCPUQInt8, MSNPUBool, MSNPUByte, MSNPUChar, @@ -124,6 +125,7 @@ struct CAFFE2_API Type { virtual bool is_cuda() const = 0; virtual bool is_hip() const = 0; virtual bool is_sparse() const = 0; + virtual bool is_quantized() const = 0; virtual bool is_distributed() const = 0; bool is_variable() const noexcept { return is_variable_; } bool is_undefined() const noexcept { return is_undefined_; } diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 307ff06..bb6003a 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -1196,7 +1196,6 @@ def create_derived(backend_type_env, declarations): # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]] type_object_declarations = [] type_object_definitions = [] - is_cuda = 'CUDA' in backend_type_env['Backend'] def replace_with_null(argument): diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index d8a3e50..08da16b 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -174,11 +174,22 @@ generators = { }, } -backends = ['CPU', 'CUDA'] +backends = ['CPU', 'CUDA', 'QuantizedCPU'] densities = ['Dense', 'Sparse', 'Mkldnn'] # TODO: layout instead of densities? +def backend_to_devicetype(backend): + if backend == 'QuantizedCPU': + return 'CPU' + return backend + +quantized_backends = ['QuantizedCPU'] + extension_backends = ['MSNPU', 'XLA'] # scalar_name, c_type, accreal, is_floating_type +quantized_scalar_types = [ + ('QInt8', 'qint8', 'QInt8AccrealNotDefined', 'Qint8IsFloatingTypeNotDefined'), +] + scalar_types = [ ('Bool', 'bool', 'BoolAccrealNotDefined', False), ('Byte', 'uint8_t', 'Long', False), @@ -189,8 +200,9 @@ scalar_types = [ ('Long', 'int64_t', 'Long', False), ('Short', 'int16_t', 'Long', False), ('Half', 'Half', 'Double', True), - ('QInt8', 'qint8', 'Long', False), -] +] + quantized_scalar_types + + # shared environment for non-derived base classes Type.h Tensor.h Storage.h top_env = { @@ -267,9 +279,8 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations env['isFloatingType'] = is_floating_type env['isIntegralType'] = not is_floating_type env['Type'] = "{}{}{}Type".format(density_tag, backend, scalar_name) - env['DenseTensor'] = "{}{}Tensor".format(backend, scalar_name) + env['DeviceType'] = backend_to_devicetype(backend) env['Backend'] = density_tag + backend - env['DenseBackend'] = backend env['storage_tensor_headers'] = [] if density != 'Sparse': env['storage_tensor_headers'] = ['#include '] @@ -341,7 +352,7 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations env['type_derived_method_definitions'] = definitions fm = file_manager - if env['DenseBackend'] == 'CUDA': + if env['DeviceType'] == 'CUDA': fm = cuda_file_manager if density != 'Sparse': @@ -351,12 +362,12 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations fm.write(env['Type'] + ".h", TYPE_DERIVED_H, env) type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type']) - if env['DenseBackend'] == 'CPU': + if env['DeviceType'] == 'CPU': top_env['cpu_type_registrations'].append(type_register) top_env['cpu_type_headers'].append( '#include "ATen/{}.h"'.format(env['Type'])) else: - assert env['DenseBackend'] == 'CUDA' + assert env['DeviceType'] == 'CUDA' top_env['cuda_type_registrations'].append(type_register) top_env['cuda_type_headers'].append( '#include "ATen/{}.h"'.format(env['Type'])) @@ -366,7 +377,7 @@ def generate_type_extension_backend(backend, declarations): env = {} env['Type'] = "{}Type".format(backend) env['Backend'] = backend - env['DeviceType'] = backend + env['DeviceType'] = backend_to_devicetype(backend) declarations, definitions = function_wrapper.create_extension_backend( env, declarations) @@ -426,7 +437,13 @@ def iterate_types(): if density == 'Sparse' and scalar_type[0] == 'Half': # THS does not do half type yet. continue - yield (backend, density, scalar_type) + if backend in quantized_backends: + if density == 'Dense' and scalar_type in quantized_scalar_types: + yield (backend, density, scalar_type) + else: + continue + else: + yield (backend, density, scalar_type) ################### diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index 918f4c3..c018b3a 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -11,6 +11,8 @@ Scalar item(const Tensor& self) { if (self._nnz() == 0) return Scalar(0); if (self.is_coalesced()) return at::_local_scalar_dense(self._values()); return at::_local_scalar_dense(self._values().sum()); + } else if (self.is_quantized()) { + return self.dequantize().item(); } else { return _local_scalar_dense(self); } diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index ae1910f..1a3f7bd 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -34,6 +34,10 @@ bool is_sparse(const Tensor& self) { return self.is_sparse(); } +bool is_quantized(const Tensor& self) { + return self.is_quantized(); +} + Tensor type_as(const Tensor& self, const Tensor& other) { return self.toType(other.type()); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a4b6469..609e179 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2489,19 +2489,23 @@ - func: quantize_linear(Tensor self, float scale, int zero_point) -> Tensor variants: function, method - requires_tensor: True + dispatch: + CPU: quantize_linear_cpu - func: dequantize(Tensor self) -> Tensor variants: function, method - requires_tensor: True + dispatch: + QuantizedCPU: dequantize_quant - func: q_scale(Tensor self) -> Scalar variants: function, method - requires_tensor: True + dispatch: + QuantizedCPU: q_scale_quant - func: q_zero_point(Tensor self) -> Scalar variants: function, method - requires_tensor: True + dispatch: + QuantizedCPU: q_zero_point_quant # to(Device) must not exist because all constructors of Device also works for # TensorOptions. Otherwise, an ambiguity error is thrown. diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp index ef8e78b..5883e16 100644 --- a/aten/src/ATen/native/quantized/QTensor.cpp +++ b/aten/src/ATen/native/quantized/QTensor.cpp @@ -7,22 +7,22 @@ namespace at { namespace native { -QTensor quantize_linear(const RealTensor& self, double scale, int64_t zero_point) { +QTensor quantize_linear_cpu(const RealTensor& self, double scale, int64_t zero_point) { auto quantizer = make_per_tensor_affine_quantizer(scale, zero_point); return quantizer->quantize(self); } -RealTensor dequantize(const QTensor& self) { +RealTensor dequantize_quant(const QTensor& self) { return get_qtensorimpl(self)->quantizer()->dequantize(self); } -Scalar q_scale(const QTensor& self) { +Scalar q_scale_quant(const QTensor& self) { auto quantizer = get_qtensorimpl(self)->quantizer(); AT_ASSERT(quantizer->qscheme() == kPerTensorAffine); return Scalar(static_cast(quantizer.get())->scale()); } -Scalar q_zero_point(const QTensor& self) { +Scalar q_zero_point_quant(const QTensor& self) { auto quantizer = get_qtensorimpl(self)->quantizer(); AT_ASSERT(quantizer->qscheme() == kPerTensorAffine); return Scalar(static_cast(quantizer.get())->zero_point()); diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py index ed1a2ac..7953612 100644 --- a/aten/src/ATen/preprocess_declarations.py +++ b/aten/src/ATen/preprocess_declarations.py @@ -23,7 +23,7 @@ type_map = { all_types = type_map['floating_point'] + type_map['integral'] type_map['all'] = all_types -all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU'] +all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU'] default_backends = ['CPU', 'CUDA'] diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index 1f30ae2..b4c6936 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -42,8 +42,9 @@ inline QTensor new_qtensor( allocator->allocate(nelements * dtype.itemsize()), allocator, /*resizable=*/true); + // TODO: get TensorTypeId from quantizer auto tensor = detail::make_tensor( - storage_impl, at::CPUTensorId(), quantizer); + storage_impl, at::QuantizedCPUTensorId(), quantizer); get_qtensorimpl(tensor)->set_sizes_contiguous(sizes); return tensor; } diff --git a/aten/src/ATen/templates/SparseTypeDerived.cpp b/aten/src/ATen/templates/SparseTypeDerived.cpp index 43d9b86..322ce91 100644 --- a/aten/src/ATen/templates/SparseTypeDerived.cpp +++ b/aten/src/ATen/templates/SparseTypeDerived.cpp @@ -28,7 +28,7 @@ $extra_cuda_headers namespace at { ${Type}::${Type}() - : ${DenseBackend}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {} + : ${DeviceType}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {} ScalarType ${Type}::scalarType() const { return ScalarType::${ScalarName}; } diff --git a/aten/src/ATen/templates/Tensor.h b/aten/src/ATen/templates/Tensor.h index ab9917e..b2cac27 100644 --- a/aten/src/ATen/templates/Tensor.h +++ b/aten/src/ATen/templates/Tensor.h @@ -252,6 +252,9 @@ class CAFFE2_API Tensor { /// Returns if a `Tensor` has sparse backend. bool is_sparse() const; + /// Returns if a `Tensor` has quantized backend. + bool is_quantized() const; + /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in /// TensorOptions.h. TensorOptions options() const; diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h index 18b5e53..365b80c 100644 --- a/aten/src/ATen/templates/TensorMethods.h +++ b/aten/src/ATen/templates/TensorMethods.h @@ -116,6 +116,15 @@ inline bool is_sparse(Tensor self) { return self.is_sparse(); } +inline bool Tensor::is_quantized() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_quantized(); +} + +inline bool is_quantized(Tensor self) { + return self.is_quantized(); +} + #define DEFINE_CAST(T, name, _) \ template <> \ inline T* Tensor::data() const { \ @@ -140,7 +149,4 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST) AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_ITEM) #undef DEFINE_ITEM -// TODO: after is_quantized() is implemented, -// implement item() (returnning a float) for quantized Tensor - } //namespace at diff --git a/aten/src/ATen/templates/Type.h b/aten/src/ATen/templates/Type.h index ed965ca..65c3cb1 100644 --- a/aten/src/ATen/templates/Type.h +++ b/aten/src/ATen/templates/Type.h @@ -66,6 +66,7 @@ struct CAFFE2_API Type { virtual bool is_cuda() const = 0; virtual bool is_hip() const = 0; virtual bool is_sparse() const = 0; + virtual bool is_quantized() const = 0; virtual bool is_distributed() const = 0; bool is_variable() const noexcept { return is_variable_; } bool is_undefined() const noexcept { return is_undefined_; } diff --git a/aten/src/ATen/templates/TypeDefault.h b/aten/src/ATen/templates/TypeDefault.h index 4ffc1c2..5235190 100644 --- a/aten/src/ATen/templates/TypeDefault.h +++ b/aten/src/ATen/templates/TypeDefault.h @@ -21,6 +21,9 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface { bool is_sparse() const override { return backend() == Backend::SparseCPU || backend() == Backend::SparseCUDA || backend() == Backend::SparseHIP; } + bool is_quantized() const override { + return backend() == Backend::QuantizedCPU; + } bool is_distributed() const override { return false; } diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/TypeDerived.cpp index cd30ba1..ca675a9 100644 --- a/aten/src/ATen/templates/TypeDerived.cpp +++ b/aten/src/ATen/templates/TypeDerived.cpp @@ -31,7 +31,7 @@ $extra_cuda_headers namespace at { ${Type}::${Type}() - : ${DenseBackend}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {} + : ${DeviceType}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {} ScalarType ${Type}::scalarType() const { return ScalarType::${ScalarName}; diff --git a/aten/src/ATen/templates/TypeDerived.h b/aten/src/ATen/templates/TypeDerived.h index f0b8ed6..b47f53c 100644 --- a/aten/src/ATen/templates/TypeDerived.h +++ b/aten/src/ATen/templates/TypeDerived.h @@ -16,7 +16,7 @@ $extra_cuda_headers namespace at { -struct ${Type} final : public ${DenseBackend}TypeDefault { +struct ${Type} final : public ${DeviceType}TypeDefault { explicit ${Type}(); virtual ScalarType scalarType() const override; virtual caffe2::TypeMeta typeMeta() const override; diff --git a/aten/src/ATen/test/quantized_test.cpp b/aten/src/ATen/test/quantized_test.cpp index 6561448..dd99d22 100644 --- a/aten/src/ATen/test/quantized_test.cpp +++ b/aten/src/ATen/test/quantized_test.cpp @@ -20,11 +20,7 @@ TEST(TestQTensor, QuantDequantAPIs) { Tensor qr = r.quantize_linear(scale, zero_point); ASSERT_EQ(qr.q_scale().to(), scale); ASSERT_EQ(qr.q_zero_point().to(), zero_point); - - // TODO: Uncomment when quantizer is ready - // auto* quantizer = static_cast(qr.quantizer()); - // ASSERT_EQ(quantizer->scale(), scale); - // ASSERT_EQ(quantizer->zero_point(), zero_point); + ASSERT_TRUE(qr.is_quantized()); // Check for correct quantization auto r_data = r.data(); @@ -41,3 +37,11 @@ TEST(TestQTensor, QuantDequantAPIs) { ASSERT_EQ(r_data[i], rqr_data[i]); } } + +TEST(TestQTensor, Item) { + Tensor r = at::ones({1}); + const float scale = 1; + const int32_t zero_point = 2; + Tensor qr = r.quantize_linear(scale, zero_point); + ASSERT_EQ(r.item().to(), qr.item().to()); +} diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 194c541..2fa5c47 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -26,7 +26,7 @@ namespace c10 { * or "SparseCUDA"; backend in torch.backends is something like "MKL" or * "CUDNN". */ -enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, Undefined, MkldnnCPU, NumOptions }; +enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, QuantizedCPU, Undefined, MkldnnCPU, NumOptions }; static inline Backend toSparse(Backend b) { switch (b) { @@ -65,6 +65,8 @@ static inline Backend toDense(Backend b) { return Backend::CUDA; case Backend::SparseHIP: return Backend::HIP; + case Backend::QuantizedCPU: + return Backend::QuantizedCPU; default: throw std::runtime_error("Unknown backend"); } @@ -89,6 +91,8 @@ static inline Backend tensorTypeIdToBackend(TensorTypeId t) { return Backend::SparseHIP; } else if (t == MkldnnCPUTensorId()) { return Backend::MkldnnCPU; + } else if (t == QuantizedCPUTensorId()) { + return Backend::QuantizedCPU; } else if (t == UndefinedTensorId()) { return Backend::Undefined; } else { @@ -116,6 +120,8 @@ static inline TensorTypeId backendToTensorTypeId(Backend b) { return SparseHIPTensorId(); case Backend::MkldnnCPU: return MkldnnCPUTensorId(); + case Backend::QuantizedCPU: + return QuantizedCPUTensorId(); case Backend::Undefined: return UndefinedTensorId(); default: @@ -142,6 +148,7 @@ static inline DeviceType backendToDeviceType(Backend b) { case Backend::SparseHIP: return DeviceType::HIP; case Backend::MkldnnCPU: + case Backend::QuantizedCPU: return DeviceType::CPU; case Backend::Undefined: AT_ERROR("Undefined backend is not a valid device type"); @@ -169,6 +176,8 @@ static inline Backend backendToCPU(Backend b) { return Backend::CPU; case Backend::MkldnnCPU: return Backend::MkldnnCPU; + case Backend::QuantizedCPU: + return Backend::QuantizedCPU; case Backend::Undefined: return Backend::Undefined; default: @@ -240,6 +249,8 @@ static inline const char* toString(Backend b) { return "SparseHIP"; case Backend::MkldnnCPU: return "MkldnnCPU"; + case Backend::QuantizedCPU: + return "QuantizedCPU"; default: return "UNKNOWN_BACKEND"; } diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 4600586..3a80a79 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -397,6 +397,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return tid == SparseCPUTensorId() || tid == SparseCUDATensorId() || tid == SparseHIPTensorId(); } + bool is_quantized() const { + // NB: This method is not virtual and avoid dispatches for performance reasons. + auto tid = type_id(); + // NB: At the moment, variables have the same TensorTypeId as their + // corresponding tensor, but if this ever changes, we need to modify this. + return tid == QuantizedCPUTensorId(); + } + bool is_cuda() const { // NB: This method is not virtual and avoid dispatches for performance reasons. auto tid = type_id(); diff --git a/c10/core/TensorTypeIdRegistration.cpp b/c10/core/TensorTypeIdRegistration.cpp index 968d9a4..559aa0d 100644 --- a/c10/core/TensorTypeIdRegistration.cpp +++ b/c10/core/TensorTypeIdRegistration.cpp @@ -71,5 +71,6 @@ C10_DEFINE_TENSOR_TYPE(SparseHIPTensorId); C10_DEFINE_TENSOR_TYPE(MSNPUTensorId); C10_DEFINE_TENSOR_TYPE(XLATensorId); C10_DEFINE_TENSOR_TYPE(MkldnnCPUTensorId); +C10_DEFINE_TENSOR_TYPE(QuantizedCPUTensorId); } // namespace c10 diff --git a/c10/core/TensorTypeIdRegistration.h b/c10/core/TensorTypeIdRegistration.h index a7fc314..7726e37 100644 --- a/c10/core/TensorTypeIdRegistration.h +++ b/c10/core/TensorTypeIdRegistration.h @@ -110,6 +110,7 @@ C10_DECLARE_TENSOR_TYPE(SparseHIPTensorId); // PyTorch only C10_DECLARE_TENSOR_TYPE(MSNPUTensorId); // PyTorch only C10_DECLARE_TENSOR_TYPE(XLATensorId); // PyTorch only C10_DECLARE_TENSOR_TYPE(MkldnnCPUTensorId); +C10_DECLARE_TENSOR_TYPE(QuantizedCPUTensorId); // PyTorch only } // namespace c10 diff --git a/test/test_torch.py b/test/test_torch.py index de52cb7..d7e1ddb 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2661,9 +2661,15 @@ class _TestTorchMixin(object): qr = r.quantize_linear(scale, zero_point) self.assertEqual(qr.q_scale(), scale) self.assertEqual(qr.q_zero_point(), zero_point) + self.assertTrue(qr.is_quantized) + self.assertFalse(r.is_quantized) rqr = qr.dequantize() for i in range(num_elements): self.assertEqual(r[i], rqr[i]) + # Testing item + r = torch.ones(1, dtype=torch.float) + qr = r.quantize_linear(scale, zero_point) + self.assertEqual(qr.item(), 1) @unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected') def test_device_guard(self): diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 54c8134..3357ca7 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -392,6 +392,14 @@ PyObject *THPVariable_is_sparse(THPVariable *self) END_HANDLE_TH_ERRORS } +PyObject *THPVariable_is_quantized(THPVariable *self) +{ + HANDLE_TH_ERRORS + auto& self_ = self->cdata; + return torch::autograd::utils::wrap(self_.is_quantized()); + END_HANDLE_TH_ERRORS +} + static PyObject *THPVariable_dtype(THPVariable *self) { HANDLE_TH_ERRORS @@ -431,6 +439,7 @@ static struct PyGetSetDef THPVariable_properties[] = { {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr}, {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr}, {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr}, + {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr}, {"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr}, {"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr}, {"device", (getter)THPVariable_device, nullptr, nullptr, nullptr}, -- 2.7.4