From: Jerry Zhang Date: Wed, 3 Apr 2019 20:13:26 +0000 (-0700) Subject: QTensor (#18230) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~440 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dfcd7b0185d479f186ddb100c761f4df3495f8e8;p=platform%2Fupstream%2Fpytorch.git QTensor (#18230) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18230 Implementing minimum qtensor API to unblock other workstreams in quantization Changes: - Added Quantizer which represents different quantization schemes - Added qint8 as a data type for QTensor - Added a new ScalarType QInt8 - Added QTensorImpl for QTensor - Added following user facing APIs - quantize_linear(scale, zero_point) - dequantize() - q_scale() - q_zero_point() Reviewed By: dzhulgakov Differential Revision: D14524641 fbshipit-source-id: c1c0ae0978fb500d47cdb23fb15b747773429e6c --- diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index a1316e1..98af97d 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -56,6 +56,7 @@ FILE(GLOB native_cpp "native/*.cpp") FILE(GLOB native_mkl_cpp "native/mkl/*.cpp") FILE(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") FILE(GLOB native_sparse_cpp "native/sparse/*.cpp") +FILE(GLOB native_quantized_cpp "native/quantized/*.cpp") FILE(GLOB native_cuda_cu "native/cuda/*.cu") FILE(GLOB native_cuda_cpp "native/cuda/*.cpp") @@ -70,7 +71,8 @@ FILE(GLOB native_cudnn_hip_cpp "native/cudnn/hip/*.cpp") FILE(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip") FILE(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp") -set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${cpu_kernel_cpp}) +add_subdirectory(quantized) +set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp}) if(AT_MKL_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp}) endif() @@ -435,10 +437,12 @@ set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE) set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) +set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE) set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE) set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE) set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE) set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE) +set(ATen_QUANTIZED_TEST_SRCS ${ATen_QUANTIZED_TEST_SRCS} PARENT_SCOPE) set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE) set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index de0b1b0..e2983eb 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -39,6 +39,9 @@ static DLDataType getDLDataType(const Tensor& t) { case ScalarType::Bool: dtype.code = DLDataTypeCode::kDLUInt; break; + case ScalarType::QInt8: + throw std::logic_error("QInt8 is not supported by dlpack"); + break; case ScalarType::ComplexHalf: throw std::logic_error("ComplexHalf is not supported by dlpack"); case ScalarType::ComplexFloat: diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index add127c..8dbfb81 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -566,6 +566,10 @@ class CAFFE2_API Tensor { std::vector unbind(int64_t dim=0) const; Tensor to_sparse(int64_t sparse_dim) const; Tensor to_sparse() const; + Tensor quantize_linear(double scale, int64_t zero_point) const; + Tensor dequantize() const; + Scalar q_scale() const; + Scalar q_zero_point() const; Tensor to(const TensorOptions & options, bool non_blocking=false, bool copy=false) const; Tensor to(Device device, ScalarType dtype, bool non_blocking=false, bool copy=false) const; Tensor to(ScalarType dtype, bool non_blocking=false, bool copy=false) const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 1605aff..9e50931 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -769,6 +769,18 @@ inline Tensor Tensor::to_sparse(int64_t sparse_dim) const { inline Tensor Tensor::to_sparse() const { return type().to_sparse(*this); } +inline Tensor Tensor::quantize_linear(double scale, int64_t zero_point) const { + return type().quantize_linear(*this, scale, zero_point); +} +inline Tensor Tensor::dequantize() const { + return type().dequantize(*this); +} +inline Scalar Tensor::q_scale() const { + return type().q_scale(*this); +} +inline Scalar Tensor::q_zero_point() const { + return type().q_zero_point(*this); +} inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool copy) const { return type().to(*this, options, non_blocking, copy); } @@ -1347,7 +1359,10 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST) return item().to##name(); \ } -AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ITEM) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_ITEM) #undef DEFINE_ITEM +// TODO: after is_quantized() is implemented, +// implement item() (returnning a float) for quantized Tensor + } //namespace at diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index e50af83..9a1d6fc 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -40,6 +40,7 @@ using TensorList = ArrayRef; class Context; struct Generator; +struct Quantizer; static inline void noop_deleter(void*) {} @@ -53,6 +54,7 @@ enum class TypeID { CPULong, CPUShort, CPUHalf, + CPUQInt8, SparseCPUBool, SparseCPUByte, SparseCPUChar, @@ -61,6 +63,7 @@ enum class TypeID { SparseCPUInt, SparseCPULong, SparseCPUShort, + SparseCPUQInt8, CUDABool, CUDAByte, CUDAChar, @@ -70,6 +73,7 @@ enum class TypeID { CUDALong, CUDAShort, CUDAHalf, + CUDAQInt8, SparseCUDABool, SparseCUDAByte, SparseCUDAChar, @@ -78,6 +82,7 @@ enum class TypeID { SparseCUDAInt, SparseCUDALong, SparseCUDAShort, + SparseCUDAQInt8, MSNPUBool, MSNPUByte, MSNPUChar, @@ -87,6 +92,7 @@ enum class TypeID { MSNPULong, MSNPUShort, MSNPUHalf, + MSNPUQInt8, XLABool, XLAByte, XLAChar, @@ -96,6 +102,7 @@ enum class TypeID { XLALong, XLAShort, XLAHalf, + XLAQInt8, CPUComplexFloat, CPUComplexDouble, CUDAComplexFloat, @@ -444,6 +451,10 @@ struct CAFFE2_API Type { virtual std::vector unbind(const Tensor & self, int64_t dim) const = 0; virtual Tensor to_sparse(const Tensor & self, int64_t sparse_dim) const = 0; virtual Tensor to_sparse(const Tensor & self) const = 0; + virtual Tensor quantize_linear(const Tensor & self, double scale, int64_t zero_point) const = 0; + virtual Tensor dequantize(const Tensor & self) const = 0; + virtual Scalar q_scale(const Tensor & self) const = 0; + virtual Scalar q_zero_point(const Tensor & self) const = 0; virtual Tensor to(const Tensor & self, const TensorOptions & options, bool non_blocking, bool copy) const = 0; virtual Tensor to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy) const = 0; virtual Tensor to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy) const = 0; diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 3694803..f50537b 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -1564,6 +1564,9 @@ def create_derived(backend_type_env, declarations): # type: (FunctionOption) -> None pair = (backend_type_env['Backend'], backend_type_env['ScalarName']) + # Skip generating TH code for QInt8 + if pair[1] == 'QInt8': + return if pair in option['backend_type_pairs']: env = nested_dict(option, backend_type_env) body = emit_body(env, option) # type: ignore diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index 356e43f..9cb866b 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -189,6 +189,7 @@ scalar_types = [ ('Long', 'int64_t', 'Long', False), ('Short', 'int16_t', 'Long', False), ('Half', 'Half', 'Double', True), + ('QInt8', 'qint8', 'Long', False), ] # shared environment for non-derived base classes Type.h Tensor.h Storage.h diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1cdcb42..3fc0ee2 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -148,7 +148,7 @@ Tensor& empty_out(Tensor& result, IntArrayRef size) { return target_type.copy(self, non_blocking); \ } -AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CAST_OP) +AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(DEFINE_CAST_OP) #undef DEFINE_CAST_OP @@ -732,7 +732,7 @@ Tensor tensor_cuda(ArrayRef values, const TensorOptions& options) { return tensor_cpu(values, options); \ } \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(TENSOR) +AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(TENSOR) #undef TENSOR } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e1cd053..a72d9c4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2990,6 +2990,26 @@ CPU: dense_to_sparse CUDA: dense_to_sparse +- func: quantize_linear(Tensor self, float scale, int zero_point) -> Tensor + matches_jit_signature: True + variants: function, method + requires_tensor: True + +- func: dequantize(Tensor self) -> Tensor + matches_jit_signature: True + variants: function, method + requires_tensor: True + +- func: q_scale(Tensor self) -> Scalar + matches_jit_signature: True + variants: function, method + requires_tensor: True + +- func: q_zero_point(Tensor self) -> Scalar + matches_jit_signature: True + variants: function, method + requires_tensor: True + # to(Device) must not exist because all constructors of Device also works for # TensorOptions. Otherwise, an ambiguity error is thrown. # See NOTE [ TensorOptions Constructors ]. diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp new file mode 100644 index 0000000..ef8e78b --- /dev/null +++ b/aten/src/ATen/native/quantized/QTensor.cpp @@ -0,0 +1,36 @@ +#include +#include +#include +#include + + +namespace at { +namespace native { + +QTensor quantize_linear(const RealTensor& self, double scale, int64_t zero_point) { + auto quantizer = make_per_tensor_affine_quantizer(scale, zero_point); + return quantizer->quantize(self); +} + +RealTensor dequantize(const QTensor& self) { + return get_qtensorimpl(self)->quantizer()->dequantize(self); +} + +Scalar q_scale(const QTensor& self) { + auto quantizer = get_qtensorimpl(self)->quantizer(); + AT_ASSERT(quantizer->qscheme() == kPerTensorAffine); + return Scalar(static_cast(quantizer.get())->scale()); +} + +Scalar q_zero_point(const QTensor& self) { + auto quantizer = get_qtensorimpl(self)->quantizer(); + AT_ASSERT(quantizer->qscheme() == kPerTensorAffine); + return Scalar(static_cast(quantizer.get())->zero_point()); +} + +Quantizer* quantizer(const QTensor& self) { + return get_qtensorimpl(self)->quantizer().get(); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index 6217769..fa946b7 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -343,7 +343,6 @@ def has_sparse_dispatches(dispatches): return True return False - def parse_native_yaml(path): with open(path, 'r') as f: return yaml.load(f, Loader=Loader) diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py index 59dcf6b..5dac87d 100644 --- a/aten/src/ATen/preprocess_declarations.py +++ b/aten/src/ATen/preprocess_declarations.py @@ -15,7 +15,8 @@ type_map = { 'Short', 'Int', 'Long', - 'Bool' + 'Bool', + 'QInt8', ], } diff --git a/aten/src/ATen/quantized/CMakeLists.txt b/aten/src/ATen/quantized/CMakeLists.txt new file mode 100644 index 0000000..1620919 --- /dev/null +++ b/aten/src/ATen/quantized/CMakeLists.txt @@ -0,0 +1,9 @@ +FILE(GLOB_RECURSE ATen_QUANTIZED_HEADERS "*.h") +FILE(GLOB_RECURSE ATen_QUANTIZED_SRCS "*.cpp") +FILE(GLOB_RECURSE ATen_QUANTIZED_TEST_SRCS "*_test.cpp") +EXCLUDE(ATen_QUANTIZED_SRCS "${ATen_QUANTIZED_SRCS}" ${ATen_QUANTIZED_TEST_SRCS}) + +# Pass to parent +set(ATen_QUANTIZED_HEADERS ${ATen_QUANTIZED_HEADERS} PARENT_SCOPE) +set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE) +set(ATen_QUANTIZED_TEST_SRCS ${ATen_QUANTIZED_TEST_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/quantized/QTensorImpl.cpp b/aten/src/ATen/quantized/QTensorImpl.cpp new file mode 100644 index 0000000..83f585e --- /dev/null +++ b/aten/src/ATen/quantized/QTensorImpl.cpp @@ -0,0 +1,13 @@ +#include + +namespace at { + +QTensorImpl::QTensorImpl( + Storage&& storage, + TensorTypeId type_id, + bool is_variable, + QuantizerPtr quantizer) + : TensorImpl(std::move(storage), type_id, is_variable), + quantizer_(quantizer) {} + +} // namespace at diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h new file mode 100644 index 0000000..9f01246 --- /dev/null +++ b/aten/src/ATen/quantized/QTensorImpl.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +namespace at { + +struct CAFFE2_API QTensorImpl : public c10::TensorImpl { + public: + QTensorImpl( + Storage&& storage, + TensorTypeId type_id, + bool is_variable, + QuantizerPtr quantizer); + + // TODO: Expose in PyTorch Frontend + QuantizerPtr quantizer() { + return quantizer_; + } + + c10::intrusive_ptr shallow_copy_and_detach() const override { + auto impl = c10::make_intrusive( + Storage(storage()), type_id(), is_variable(), quantizer_); + impl->set_sizes_and_strides(sizes(), strides()); + impl->storage_offset_ = storage_offset_; + impl->is_wrapped_number_ = is_wrapped_number_; + impl->reserved_ = reserved_; + impl->refresh_numel(); + impl->refresh_contiguous(); + return impl; + } + + private: + QuantizerPtr quantizer_; +}; + +} // namespace at diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp new file mode 100644 index 0000000..eabc697 --- /dev/null +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -0,0 +1,106 @@ +#include +#include +#include +#include +#include +#include + +namespace at { + +QuantizerPtr make_per_tensor_affine_quantizer( + double scale, + int64_t zero_point) { + return c10::make_intrusive( + static_cast(scale), static_cast(zero_point)); +} + +QTensorImpl* get_qtensorimpl(const QTensor& self) { + // TODO: remove this when Variable and Tensor are merged + AT_ASSERTM( + !self.is_variable(), + "_internal_get_QTensorImpl: should not be a variable"); + // TODO: uncomment after is_quantized() is implemented + // AT_ASSERTM(self.is_quantized(), "_internal_get_QTensorImpl: not a quantized + // tensor"); + return static_cast(self.unsafeGetTensorImpl()); +} + +inline QTensor new_qtensor( + IntArrayRef sizes, + const TensorOptions& options, + bool is_variable, + QuantizerPtr quantizer) { + AT_ASSERT(options.device().is_cpu()); + + native::check_size_nonnegative(sizes); + auto* allocator = at::getCPUAllocator(); + int64_t nelements = at::prod_intlist(sizes); + auto dtype = options.dtype(); + AT_ASSERT(isQIntType(typeMetaToScalarType(dtype))); + auto storage_impl = c10::make_intrusive( + dtype, + nelements, + allocator->allocate(nelements * dtype.itemsize()), + allocator, + /*resizable=*/true); + auto tensor = detail::make_tensor( + storage_impl, at::CPUTensorId(), is_variable, quantizer); + get_qtensorimpl(tensor)->set_sizes_contiguous(sizes); + return tensor; +} + +qint8 quantize_uint8(float scale, uint8_t zero_point, float value) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + // std::nearbyint results in nearest integer value according to the current + // rounding mode and the default rounding mode is rounds to even in half-way + // cases in most popular processor architectures like x86 and ARM. This is + // typically faster than an alternatives like std::round that rounds half-way + // cases away from zero, and can be consistent with SIMD implementations for + // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with + // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode. + int32_t r = zero_point + static_cast(std::nearbyint(value / scale)); + r = std::max(r, qmin); + r = std::min(r, qmax); + return static_cast(r); +} + +QTensor PerTensorAffineQuantizer::quantize(RealTensor tensor) { + IntArrayRef sizes = tensor.sizes(); + // Here we need a std::intrusive_ptr.. but actually "this" is the + // quantizer that can be reused, so I'm using intrusive_from_this here + AT_CHECK( + tensor.options().device() == kCPU, + "quantize only works for CPU backend right now."); + QTensor qv = new_qtensor( + sizes, + tensor.options().dtype(at::kQInt8), + tensor.is_variable(), + intrusive_from_this()); + auto qvd = qv.data(); + tensor.contiguous(); + const float* svd = tensor.data(); + for (int i = 0; i < tensor.numel(); ++i) { + qvd[i] = quantize_uint8(scale_, zero_point_, svd[i]); + } + return qv; +} + +RealTensor PerTensorAffineQuantizer::dequantize(QTensor tensor) { + std::vector sizes = tensor.sizes().vec(); + at::TensorOptions real_options = tensor.options().dtype(at::kFloat); + + RealTensor rv = at::empty(sizes, real_options); + tensor.contiguous(); + const auto* qvd = tensor.data(); + float* rvd = rv.data(); + for (auto i = 0; i < tensor.numel(); ++i) { + rvd[i] = (static_cast(qvd[i].val_) - zero_point_) * scale_; + } + return rv; +} + +Quantizer::~Quantizer() {} + +} // namespace at diff --git a/aten/src/ATen/quantized/Quantizer.h b/aten/src/ATen/quantized/Quantizer.h new file mode 100644 index 0000000..931d489 --- /dev/null +++ b/aten/src/ATen/quantized/Quantizer.h @@ -0,0 +1,251 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +// TODO: move to c10 namespace after we +// unified caffe2::Tensor and at::Tensor + +namespace at { + +struct QTensorImpl; + +using QTensor = Tensor; +using RealTensor = Tensor; + +struct Quantizer; +using QuantizerPtr = c10::intrusive_ptr; + +/** + * Quantizer is the class for storing all the information + * that's necessary to perform quantize and dequantize + * operation. + * + * We might have different types of quantization schemes and this is + * the base class for all quantizers. + * + * QTensorImpl will hold a pointer to Quantizer so that we can support + * different quantization schemes on Tensor. + * + * For example, the most common quantization scheme, Affine Quantization, + * requires scale and zero_point as parameters, we'll store scale and zero_point + * inside the instance and we can use it to quantize a float Tensor or + * dequantize a quantized Tensor. + * + * When you add new types of leaf Quantizer class, please also + * make sure to add a corresponding QScheme enum since + * they should have one to one mapping. + * + * Note about intrusive_ptr: + * QTensor holds an intrusive_ptr to Quantizer, and multiple Tensor can + * share the same Quantizer. Quantizer should be immutable. + */ +struct CAFFE2_API Quantizer : public c10::intrusive_ptr_target { + const QScheme qscheme_; + explicit Quantizer(QScheme qscheme) : qscheme_(qscheme) {} + virtual ~Quantizer(); + + // Copied from torch/csrc/jit/scope.h + QuantizerPtr intrusive_from_this() { + c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer + // from a raw `this` pointer + // so we need to bump the refcount + // to account for this ownership + return c10::intrusive_ptr::reclaim(this); + } + + virtual QScheme qscheme() { + return qscheme_; + } + + /** + * quantize a float Tensor into a quantized Tensor. + */ + virtual QTensor quantize(RealTensor t) = 0; + + /** + * dequantize a quantized Tensor into a float Tensor. + */ + virtual RealTensor dequantize(QTensor t) = 0; +}; + +/** + * UniformQuantizer is the parent class for all uniform quantizers. + * These quantization scheme will map float value uniformly to + * the quantized value. For example, affine quantizer is + * the most commonly used scheme in this category. + */ +struct CAFFE2_API UniformQuantizer : public Quantizer { + explicit UniformQuantizer(QScheme qscheme) : Quantizer(qscheme) {} +}; + +/** + * NonUniformQuantizer is the parent class for all non-uniform quantizers. + * These quantization scheme may map float value non-uniformly to the quantized + * value. K-means quantization is a representative example in this category. + */ +struct CAFFE2_API NonUniformQuantizer : public Quantizer { + explicit NonUniformQuantizer(QScheme qscheme) : Quantizer(qscheme) {} +}; + +// There is also StochasticQuantizer which is uniform but not affine + +/** + * AffineQuantizer uses affine transformation to do quantization. + * + * For quantize: + * Y = clamp((X * scale + zero_point, min, max) + * For dequantize: + * X = (Y - zero_point) / scale + */ +struct CAFFE2_API AffineQuantizer : public UniformQuantizer { + explicit AffineQuantizer(QScheme qscheme) : UniformQuantizer(qscheme) {} +}; + +/** + * SymmetricQuantizer is similar to AffineQuantizer except that it + * does not have zero_point + * + * For quantize: + * Y = clamp(X * scale, min, max) + * For dequantize: + * X = Y / scale + */ +struct CAFFE2_API SymmetricQuantizer : public UniformQuantizer { + explicit SymmetricQuantizer(QScheme qscheme) : UniformQuantizer(qscheme) {} +}; + +/** + * PerTensorSymmetricQuantizer stores a single scale number which is + * used for quantizing all the values in the given Tensor + */ +struct CAFFE2_API PerTensorSymmetricQuantizer : public SymmetricQuantizer { + explicit PerTensorSymmetricQuantizer(float scale) + : SymmetricQuantizer(kPerTensorSymmetric), scale_(scale) {} + float scale_{1.0}; +}; + +/** + * PerChannelSymmetricQuantizer stores a vector of scale number and + * applys symmetric quantization using different scales on each channel. + * + * Also note that per channel quantization is mostly applied to output channels + * of weights since per-input channel of weight quantization or per-channel + * quantization for activations can't be efficiently supported in most of + * processors since it requires each multiplication result within a single + * dot-product to have a different scale. + */ +struct CAFFE2_API PerChannelSymmetricQuantizer : public SymmetricQuantizer { + explicit PerChannelSymmetricQuantizer( + const std::vector& scales, + const std::vector& axis) + : SymmetricQuantizer(kPerChannelSymmetric), scales_(scales), axis_(axis) { + AT_CHECK( + axis_.size() == 1, + "Per channel symmetric quantization in multiple axis is not supported yet."); + } + + std::vector scales() const { + return scales_; + } + + std::vector axis() const { + return axis_; + } + + private: + const std::vector scales_; + const std::vector axis_; +}; + +/** + * PerTensorAffineQuantizer stores a scale and a zero_point, which is used for + * all the values in the Tensor. + */ +struct CAFFE2_API PerTensorAffineQuantizer : public AffineQuantizer { + explicit PerTensorAffineQuantizer(float scale, uint8_t zero_point) + : AffineQuantizer(kPerTensorAffine), + scale_(scale), + zero_point_(zero_point) {} + + QTensor quantize(RealTensor tensor) override; + RealTensor dequantize(QTensor tensor) override; + + float scale() const { + return scale_; + } + + uint8_t zero_point() const { + return zero_point_; + } + + private: + const float scale_; + const uint8_t zero_point_; +}; + +/** + * PerChannelAffineQuantizer is the same as PerTensorAffineQuantizer + * except that we have an independent scale and zero_point parameter + * for each channel. + */ +struct CAFFE2_API PerChannelAffineQuantizer : public AffineQuantizer { + explicit PerChannelAffineQuantizer( + const std::vector& scales, + const std::vector& zero_points, + const std::vector& axis) + : AffineQuantizer(kPerChannelAffine), + scales_(scales), + zero_points_(zero_points), + axis_(axis) { + AT_CHECK( + axis_.size() == 1, + "Per channel affine quantization in multiple axis is not supported yet."); + } + + std::vector scales() const { + return scales_; + } + + std::vector zero_points() const { + return zero_points_; + } + + std::vector axis() const { + return axis_; + } + + private: + const std::vector scales_; + const std::vector zero_points_; + const std::vector axis_; +}; + +// This is an internal utility function for getting at the QTensorImpl, +// You should only use this for writing low level +// setters/getters for QTensorImpl fields; otherwise, you should use +// the low level setters/getters that were implemented using this. +// This may be called repeatedly, so make sure it's pretty cheap. +CAFFE2_API QTensorImpl* get_qtensorimpl(const QTensor& self); + +// Quantize a float value into a uint8 value given scale and zero_point +CAFFE2_API qint8 quantize_uint8(float scale, uint8_t zero_point, float value); + +// double and int64_t are because of the native function API, we only have these +// argument types right now in native functions +CAFFE2_API QuantizerPtr +make_per_tensor_affine_quantizer(double scale, int64_t zero_point); + +// Create a QTensor given arguments for normal Tensor and a quantizer +QTensor new_qtensor( + IntArrayRef sizes, + const TensorOptions& options, + bool is_variable, + QuantizerPtr quantizer); + +} // namespace at diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h index 5dea983..bbaa19b 100644 --- a/aten/src/ATen/templates/TensorMethods.h +++ b/aten/src/ATen/templates/TensorMethods.h @@ -133,7 +133,10 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST) return item().to##name(); \ } -AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ITEM) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_ITEM) #undef DEFINE_ITEM +// TODO: after is_quantized() is implemented, +// implement item() (returnning a float) for quantized Tensor + } //namespace at diff --git a/aten/src/ATen/templates/Type.h b/aten/src/ATen/templates/Type.h index 52f9eb2..d1d2a28 100644 --- a/aten/src/ATen/templates/Type.h +++ b/aten/src/ATen/templates/Type.h @@ -40,6 +40,7 @@ using TensorList = ArrayRef; class Context; struct Generator; +struct Quantizer; static inline void noop_deleter(void*) {} diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 4e8b955..bcee7ef 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -21,6 +21,7 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp) diff --git a/aten/src/ATen/test/quantized_test.cpp b/aten/src/ATen/test/quantized_test.cpp new file mode 100644 index 0000000..6561448 --- /dev/null +++ b/aten/src/ATen/test/quantized_test.cpp @@ -0,0 +1,43 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +// For quantize_uint8 +#include + +using namespace at; + +TEST(TestQTensor, QuantDequantAPIs) { + auto num_elements = 10; + Tensor r = at::ones({num_elements}); + const float scale = 1.0; + const int32_t zero_point = 2; + Tensor qr = r.quantize_linear(scale, zero_point); + ASSERT_EQ(qr.q_scale().to(), scale); + ASSERT_EQ(qr.q_zero_point().to(), zero_point); + + // TODO: Uncomment when quantizer is ready + // auto* quantizer = static_cast(qr.quantizer()); + // ASSERT_EQ(quantizer->scale(), scale); + // ASSERT_EQ(quantizer->zero_point(), zero_point); + + // Check for correct quantization + auto r_data = r.data(); + auto qr_data = qr.data(); + for (auto i = 0; i < num_elements; ++i) { + ASSERT_EQ( + quantize_uint8(scale, zero_point, r_data[i]).val_, qr_data[i].val_); + } + + // Check for correct dequantization + Tensor rqr = qr.dequantize(); + auto rqr_data = rqr.data(); + for (auto i = 0; i < num_elements; ++i) { + ASSERT_EQ(r_data[i], rqr_data[i]); + } +} diff --git a/aten/src/TH/THBlasUtils.h b/aten/src/TH/THBlasUtils.h index de9ddea..86efc40 100644 --- a/aten/src/TH/THBlasUtils.h +++ b/aten/src/TH/THBlasUtils.h @@ -16,7 +16,7 @@ inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy); TH ## name ## Blas_axpy(n, a, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(AXPY_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(AXPY_SPECIALIZATION) template @@ -29,4 +29,4 @@ inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy); TH ## name ## Blas_copy(n, x, incx, y, incy); \ } -AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(COPY_SPECIALIZATION) +AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(COPY_SPECIALIZATION) diff --git a/c10/core/QScheme.h b/c10/core/QScheme.h new file mode 100644 index 0000000..4757a77 --- /dev/null +++ b/c10/core/QScheme.h @@ -0,0 +1,24 @@ +#pragma once + +namespace c10 { + +/** + * QScheme is an enum that specifies the type of quantization. This has a one + * to one correspondence with Quantizer + * Please refer to ATen/core/Quantizer.h to see the Quantizers classes. + */ +enum class QScheme : uint8_t { + NO_QUANT, + PER_TENSOR_AFFINE, + PER_CHANNEL_AFFINE, + PER_TENSOR_SYMMETRIC, + PER_CHANNEL_SYMMETRIC +}; + +constexpr auto kNoQuant = QScheme::NO_QUANT; +constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE; +constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE; +constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC; +constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC; + +} // namespace c10 diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 7952e86..23cd5a0 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -31,7 +31,7 @@ class C10_API Scalar { // We can't set v in the initializer list using the // syntax v{ .member = ... } because it doesn't work on MSVC - AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR) + AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(DEFINE_IMPLICIT_CTOR) #undef DEFINE_IMPLICIT_CTOR @@ -71,7 +71,8 @@ class C10_API Scalar { } // TODO: Support ComplexHalf accessor - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR) + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT( + DEFINE_ACCESSOR) // also support scalar.to(); template @@ -114,6 +115,6 @@ inline T Scalar::to() { inline T Scalar::to() { \ return to##name(); \ } -AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(DEFINE_TO) #undef DEFINE_TO } // namespace c10 diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 11549e1..99e7975 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -11,6 +11,9 @@ namespace c10 { +// TODO: check all usages of these macro and make sure +// the use case makes sense for qint + // NB: Order matters for this macro; it is relied upon in // _promoteTypesLookup and the serialization format. #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ @@ -25,7 +28,8 @@ namespace c10 { _(at::ComplexHalf, ComplexHalf, z) /* 8 */ \ _(std::complex, ComplexFloat, z) /* 9 */ \ _(std::complex, ComplexDouble, z) /* 10 */ \ - _(bool, Bool, i) /* 11 */ + _(bool, Bool, i) /* 11 */ \ + _(c10::qint8, QInt8, i) /* 12 */ // If you want to support ComplexHalf for real, replace occurrences // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But @@ -41,6 +45,20 @@ namespace c10 { _(double, Double, d) \ _(std::complex, ComplexFloat, z) \ _(std::complex, ComplexDouble, z) \ + _(bool, Bool, i) \ + _(c10::qint8, QInt8, i) + +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_AND_QINT(_) \ + _(uint8_t, Byte, i) \ + _(int8_t, Char, i) \ + _(int16_t, Short, i) \ + _(int, Int, i) \ + _(int64_t, Long, i) \ + _(at::Half, Half, d) \ + _(float, Float, d) \ + _(double, Double, d) \ + _(std::complex, ComplexFloat, z) \ + _(std::complex, ComplexDouble, z) \ _(bool, Bool, i) #define AT_FORALL_SCALAR_TYPES(_) \ @@ -51,17 +69,28 @@ namespace c10 { _(int64_t, Long, i) \ _(at::Half, Half, d) \ _(float, Float, d) \ + _(double, Double, d) \ + _(c10::qint8, QInt8, i) + +#define AT_FORALL_SCALAR_TYPES_EXCEPT_QINT(_) \ + _(uint8_t, Byte, i) \ + _(int8_t, Char, i) \ + _(int16_t, Short, i) \ + _(int, Int, i) \ + _(int64_t, Long, i) \ + _(at::Half, Half, d) \ + _(float, Float, d) \ _(double, Double, d) -#define AT_FORALL_SCALAR_TYPES_AND_BOOL(_) \ - _(uint8_t, Byte, i) \ - _(int8_t, Char, i) \ - _(int16_t, Short, i) \ - _(int, Int, i) \ - _(int64_t, Long, i) \ - _(at::Half, Half, d) \ - _(float, Float, d) \ - _(double, Double, d) \ +#define AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(_) \ + _(uint8_t, Byte, i) \ + _(int8_t, Char, i) \ + _(int16_t, Short, i) \ + _(int, Int, i) \ + _(int64_t, Long, i) \ + _(at::Half, Half, d) \ + _(float, Float, d) \ + _(double, Double, d) \ _(bool, Bool, i) #define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \ @@ -71,6 +100,16 @@ namespace c10 { _(int, Int, i) \ _(int64_t, Long, i) \ _(float, Float, d) \ + _(double, Double, d) \ + _(c10::qint8, QInt8, i) + +#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(_) \ + _(uint8_t, Byte, i) \ + _(int8_t, Char, i) \ + _(int16_t, Short, i) \ + _(int, Int, i) \ + _(int64_t, Long, i) \ + _(float, Float, d) \ _(double, Double, d) enum class ScalarType : int8_t { @@ -182,6 +221,11 @@ static inline bool isComplexType(ScalarType t) { t == ScalarType::ComplexDouble); } +static inline bool isQIntType(ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ScalarType::QInt8; +} + static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { // This is generated according to NumPy's promote_types constexpr auto u1 = ScalarType::Byte; @@ -202,6 +246,11 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { "promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be"); } + if (isQIntType(a) || isQIntType(b)) { + AT_ERROR( + "promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be"); + } + // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX // so that's why we have to add undefined as we are not sure what is the // corrent values for the type promotions in complex type cases. diff --git a/c10/util/Half.cpp b/c10/util/Half.cpp index 76e36ea..24b7576 100644 --- a/c10/util/Half.cpp +++ b/c10/util/Half.cpp @@ -1,5 +1,4 @@ #include - #include namespace c10 { @@ -12,5 +11,4 @@ std::ostream& operator<<(std::ostream& out, const Half& value) { out << (float)value; return out; } - } // namespace c10 diff --git a/c10/util/qint8.h b/c10/util/qint8.h new file mode 100644 index 0000000..e6c3721 --- /dev/null +++ b/c10/util/qint8.h @@ -0,0 +1,16 @@ +#pragma once +#include + +namespace c10 { + +/** + * This is the data type for quantized Tensors. Right now we only have + * qint8 which is for 8 bit Tensors, we might have 4 bit, 2 bit or 1 bit + * data types in the future. + */ +struct alignas(1) qint8 { + uint8_t val_; + explicit qint8(uint8_t val) : val_(val) {} +}; + +} // namespace c10 diff --git a/c10/util/typeid.cpp b/c10/util/typeid.cpp index 6253c83..ad62df5 100644 --- a/c10/util/typeid.cpp +++ b/c10/util/typeid.cpp @@ -80,6 +80,8 @@ CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE( 26, detail::_guard_long_unique>); -CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, _CaffeHighestPreallocatedTypeId) +CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, c10::qint8); + +CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId) } // namespace caffe2 diff --git a/c10/util/typeid.h b/c10/util/typeid.h index 9cf9d92..8019714 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -23,6 +23,7 @@ #include "c10/util/Exception.h" #include "c10/util/Half.h" #include "c10/util/IdWrapper.h" +#include "c10/util/qint8.h" #include "c10/util/Type.h" @@ -622,5 +623,7 @@ CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE( 26, detail::_guard_long_unique>) -CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, _CaffeHighestPreallocatedTypeId) +CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, c10::qint8); + +CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId) } // namespace caffe2 diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index 9354cd7..b597084 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -15,7 +15,7 @@ static std::unordered_map op_to_key = { namespace caffe2 { -using at::Half; // for AT_FORALL_SCALAR_TYPES_AND_BOOL +using at::Half; // for AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT template class ATenOp : public Operator { @@ -47,7 +47,7 @@ private: case at::k##aten_name: \ return TypeMeta::Make(); switch(st) { - AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE) + AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(DEFINE_CASE) default: CAFFE_THROW("Unknown ATen Type"); } @@ -114,10 +114,10 @@ private: } } - // the AT_FORALL_SCALAR_TYPES_AND_BOOL macro just gives a 'i' or 'd' argument - // for each type to specify if it is stored as a integer or a double. - // We need this workaround here to extract the value in the scalar losslessly - // because in some cases like 'sum' Torch promotes float to double + // the AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT macro just gives a 'i' or + // 'd' argument for each type to specify if it is stored as a integer or a + // double. We need this workaround here to extract the value in the scalar + // losslessly because in some cases like 'sum' Torch promotes float to double // and will complain if we downcast it with toFloat, causing it // to lose precision double extract_d(const at::Scalar & s) { @@ -134,8 +134,8 @@ private: auto value = extract_##native(scalar); \ assignToValue(dst, at::convert(value)); \ } break; - AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE) - #undef DEFINE_CASE + AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(DEFINE_CASE) +#undef DEFINE_CASE default: CAFFE_THROW("Unknown ATen Type"); } diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc index 1fa560e..3f0b2f3 100644 --- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc @@ -80,7 +80,7 @@ void cast_op_cpu( int64_t to) { switch (input.scalar_type()) { #define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl(input, output, to); - AT_FORALL_SCALAR_TYPES_AND_BOOL(CASE) + AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(CASE) #undef CASE default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type())); } diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index b9ad547..68d6997 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -205,6 +205,7 @@ view of a storage and defines numeric operations on it. .. automethod:: cumprod .. automethod:: cumsum .. automethod:: data_ptr + .. automethod:: dequantize .. automethod:: det .. automethod:: dense_dim .. automethod:: detach @@ -354,6 +355,9 @@ view of a storage and defines numeric operations on it. .. automethod:: pstrf .. automethod:: put_ .. automethod:: qr + .. automethod:: quantize_linear + .. automethod:: q_scale + .. automethod:: q_zero_point .. automethod:: random_ .. automethod:: reciprocal .. automethod:: reciprocal_ diff --git a/test/test_torch.py b/test/test_torch.py index ffc8943..ef21ea6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2653,6 +2653,18 @@ class _TestTorchMixin(object): if TEST_NUMPY: assertEqual('cuda:1', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:1')) + def test_qtensor(self): + num_elements = 10 + r = torch.ones(num_elements, dtype=torch.float) + scale = 1.0 + zero_point = 2 + qr = r.quantize_linear(scale, zero_point) + self.assertEqual(qr.q_scale(), scale) + self.assertEqual(qr.q_zero_point(), zero_point) + rqr = qr.dequantize() + for i in range(num_elements): + self.assertEqual(r[i], rqr[i]) + def test_to(self): def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(t, non_blocking=non_blocking)) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 9f6ccf8..58fb3b2 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -149,7 +149,7 @@ SUPPORTED_RETURN_TYPES = { 'std::tuple', 'std::tuple', 'std::vector', - 'Scalar', 'bool', 'int64_t', 'void*', 'void' + 'Scalar', 'bool', 'int64_t', 'void*', 'void', } TENSOR_OPTIONS = CodeTemplate("""\ diff --git a/tools/autograd/templates/variable_factories.h b/tools/autograd/templates/variable_factories.h index 7fbcea2..29b1e55 100644 --- a/tools/autograd/templates/variable_factories.h +++ b/tools/autograd/templates/variable_factories.h @@ -4,6 +4,7 @@ #include #include +#include #include #include diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 55ff438..1b79a44 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -671,6 +671,13 @@ data_ptr() -> int Returns the address of the first element of :attr:`self` tensor. """) +add_docstr_all('dequantize', + r""" +dequantize() -> Tensor + +Given a quantized Tensor, dequantize it and return the dequantized float Tensor. +""") + add_docstr_all('dense_dim', r""" dense_dim() -> int @@ -1775,6 +1782,31 @@ qr() -> (Tensor, Tensor) See :func:`torch.qr` """) +add_docstr_all('quantize_linear', + r""" +quantize_linear(scale, zero_point) -> Tensor + +Quantize a float Tensor using affine quantization scheme with given scale and +zero_point. +returns the quantized Tensor. +""") + +add_docstr_all('q_scale', + r""" +q_scale() -> float + +Given a Tensor quantized by linear(affine) quantization, +returns the scale of the underlying quantizer(). +""") + +add_docstr_all('q_zero_point', + r""" +q_zero_point() -> int + +Given a Tensor quantized by linear(affine) quantization, +returns the zero_point of the underlying quantizer(). +""") + add_docstr_all('random_', r""" random_(from=0, to=None, *, generator=None) -> Tensor diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index b165551..51c8dc5 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -40,6 +40,8 @@ static std::pair getDtypeNames( return std::make_pair("complex128", ""); case at::ScalarType::Bool: return std::make_pair("bool", ""); + case at::ScalarType::QInt8: + return std::make_pair("qint8", ""); default: throw std::runtime_error("Unimplemented scalar type"); }