class TestAOMigrationQuantizePy(AOMigrationTestCase):
- def test_package_import(self):
+ def test_package_import_quantize(self):
self._test_package_import('quantize')
- def test_function_import(self):
+ def test_function_import_quantize(self):
function_list = [
'_convert',
'_observer_forward_hook',
'quantize_dynamic_jit',
]
self._test_function_import('quantize_jit', function_list)
+
+ def test_package_import_fake_quantize(self):
+ self._test_package_import('fake_quantize')
+
+ def test_function_import_fake_quantize(self):
+ function_list = [
+ '_is_per_channel',
+ '_is_per_tensor',
+ '_is_symmetric_quant',
+ 'FakeQuantizeBase',
+ 'FakeQuantize',
+ 'FixedQParamsFakeQuantize',
+ 'FusedMovingAvgObsFakeQuantize',
+ 'default_fake_quant',
+ 'default_weight_fake_quant',
+ 'default_symmetric_fixed_qparams_fake_quant',
+ 'default_affine_fixed_qparams_fake_quant',
+ 'default_per_channel_weight_fake_quant',
+ 'default_histogram_fake_quant',
+ 'default_fused_act_fake_quant',
+ 'default_fused_wt_fake_quant',
+ 'default_fused_per_channel_wt_fake_quant',
+ '_is_fake_quant_script_module',
+ 'disable_fake_quant',
+ 'enable_fake_quant',
+ 'disable_observer',
+ 'enable_observer',
+ ]
+ self._test_function_import('fake_quantize', function_list)
+from .fake_quantize import * # noqa: F403
+
+# TODO(future PR): fix the typo, should be `__all__`
+_all__ = [
+ # FakeQuantize (for qat)
+ 'default_fake_quant', 'default_weight_fake_quant',
+ 'default_symmetric_fixed_qparams_fake_quant',
+ 'default_affine_fixed_qparams_fake_quant',
+ 'default_per_channel_weight_fake_quant',
+ 'default_histogram_fake_quant',
+]
--- /dev/null
+import torch
+from torch.nn import Module
+from torch.quantization.observer import (
+ MovingAverageMinMaxObserver,
+ HistogramObserver,
+ MovingAveragePerChannelMinMaxObserver,
+ _with_args,
+)
+import re
+from abc import ABC, abstractmethod
+from typing import Any, Tuple
+
+def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
+ return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
+
+def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
+ return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
+
+def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
+ return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
+
+class FakeQuantizeBase(ABC, Module):
+ r""" Base fake quantize module
+ Any fake quantize implementation should derive from this class.
+
+ Concrete fake quantize module should follow the same API. In forward, they will update
+ the statistics of the observed Tensor and fake quantize the input. They should also provide a
+ `calculate_qparams` function that computes the quantization parameters given
+ the collected statistics.
+
+ """
+
+ fake_quant_enabled: torch.Tensor
+ observer_enabled: torch.Tensor
+
+ def __init__(self):
+ super().__init__()
+ # fake_quant_enabled and observer_enabled are buffers to support their
+ # replication in DDP. Data type is uint8 because NCCL does not support
+ # bool tensors.
+ self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
+ self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
+
+ @abstractmethod
+ def forward(self, x):
+ pass
+
+ @abstractmethod
+ def calculate_qparams(self, **kwargs):
+ pass
+
+ @torch.jit.export
+ def enable_fake_quant(self, enabled: bool = True) -> None:
+ self.fake_quant_enabled[0] = 1 if enabled else 0
+
+ @torch.jit.export
+ def disable_fake_quant(self):
+ self.enable_fake_quant(False)
+
+ @torch.jit.export
+ def enable_observer(self, enabled: bool = True) -> None:
+ self.observer_enabled[0] = 1 if enabled else 0
+
+ @torch.jit.export
+ def disable_observer(self):
+ self.enable_observer(False)
+
+ with_args = classmethod(_with_args)
+
+class FakeQuantize(FakeQuantizeBase):
+ r""" Simulate the quantize and dequantize operations in training time.
+ The output of this module is given by
+
+ x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
+
+
+
+ * :attr:`scale` defines the scale factor used for quantization.
+
+ * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
+
+ * :attr:`quant_min` specifies the minimum allowable quantized value.
+
+ * :attr:`quant_max` specifies the maximum allowable quantized value.
+
+ * :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
+ statistics can still be updated.
+
+ * :attr:`observer_enable` controls statistics collection on tensors
+
+ * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
+ allowable values are torch.qint8 and torch.quint8. The values of quant_min and
+ quant_max should be chosen to be consistent with the dtype
+
+
+ Args:
+ observer (module): Module for observing statistics on input tensors and calculating scale
+ and zero-point.
+ quant_min (int): The minimum allowable quantized value.
+ quant_max (int): The maximum allowable quantized value.
+ observer_kwargs (optional): Arguments for the observer module
+
+ Attributes:
+ observer (Module): User provided module that collects statistics on the input tensor and
+ provides a method to calculate scale and zero-point.
+
+ """
+
+ scale: torch.Tensor
+ zero_point: torch.Tensor
+
+ def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
+ super().__init__()
+ assert quant_min <= quant_max, \
+ 'quant_min must be less than or equal to quant_max'
+ self.quant_min = quant_min
+ self.quant_max = quant_max
+ self.activation_post_process = observer(**observer_kwargs)
+ assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
+ assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
+ self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
+ self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
+ self.dtype = self.activation_post_process.dtype
+ self.qscheme = self.activation_post_process.qscheme
+ self.ch_axis = self.activation_post_process.ch_axis \
+ if hasattr(self.activation_post_process, 'ch_axis') else -1
+ assert _is_per_channel(self.qscheme) or \
+ _is_per_tensor(self.qscheme), \
+ 'Only per channel and per tensor quantization are supported in fake quantize' + \
+ ' got qscheme: ' + str(self.qscheme)
+ self.is_per_channel = _is_per_channel(self.qscheme)
+
+ @torch.jit.export
+ def calculate_qparams(self):
+ return self.activation_post_process.calculate_qparams()
+
+ def forward(self, X):
+ if self.observer_enabled[0] == 1:
+ self.activation_post_process(X.detach())
+ _scale, _zero_point = self.calculate_qparams()
+ _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
+ if self.scale.shape != _scale.shape:
+ self.scale.resize_(_scale.shape)
+ self.zero_point.resize_(_zero_point.shape)
+ self.scale.copy_(_scale)
+ self.zero_point.copy_(_zero_point)
+
+ if self.fake_quant_enabled[0] == 1:
+ if self.is_per_channel:
+ X = torch.fake_quantize_per_channel_affine(
+ X, self.scale, self.zero_point,
+ self.ch_axis, self.quant_min, self.quant_max)
+ else:
+ X = torch.fake_quantize_per_tensor_affine(
+ X, self.scale, self.zero_point,
+ self.quant_min, self.quant_max)
+ return X
+
+ @torch.jit.export
+ def extra_repr(self):
+ return 'fake_quant_enabled={}, observer_enabled={}, ' \
+ 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
+ 'scale={}, zero_point={}'.format(
+ self.fake_quant_enabled, self.observer_enabled,
+ self.quant_min, self.quant_max,
+ self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ # We cannot currently register scalar values as buffers, so need to manually
+ # specify serialization here.
+ super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
+ destination[prefix + 'scale'] = self.scale
+ destination[prefix + 'zero_point'] = self.zero_point
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ # Removing this function throws an error that the the size of the loaded tensor does not match the original size
+ # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
+ local_state = ['scale', 'zero_point']
+ for name in local_state:
+ key = prefix + name
+ if key in state_dict:
+ val = state_dict[key]
+ # Custom handling to allow loading scale and zero_point
+ # of size N into uninitialized buffers of size 0. The
+ # buffers are resized here, and the values are copied in
+ # the default state_dict loading code of the parent.
+ if name == 'scale':
+ self.scale.resize_(val.shape)
+ else:
+ assert name == 'zero_point'
+ self.zero_point.resize_(val.shape)
+ # For torchscript module we need to update the attributes here since we do not
+ # call the `_load_from_state_dict` function defined module.py
+ if torch.jit.is_scripting():
+ if name == 'scale':
+ self.scale.copy_(val)
+ else:
+ assert name == 'zero_point'
+ self.zero_point.copy_(val)
+ elif strict:
+ missing_keys.append(key)
+ super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+class FixedQParamsFakeQuantize(FakeQuantizeBase):
+ """ Simulate quantize and dequantize with fixed quantization
+ parameters in training time. Only per tensor quantization
+ is supported.
+ Args:
+ `scale` (float): fixed scale for the fake quantize module
+ `zero_point` (int): fixed zero point for the fake quantize module
+ `dtype`, `qscheme`, `quant_min`, `quant_max`
+ """
+
+ scale: torch.Tensor
+ zero_point: torch.Tensor
+
+ def __init__(self,
+ scale,
+ zero_point,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ quant_min=0,
+ quant_max=255):
+ super().__init__()
+ assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max'
+ self.quant_min = quant_min
+ self.quant_max = quant_max
+ self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
+ self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
+ self.dtype = dtype
+ self.qscheme = qscheme
+ assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
+ ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
+
+ def forward(self, X):
+ if self.fake_quant_enabled[0] == 1:
+ X = torch.fake_quantize_per_tensor_affine(X, self.scale,
+ self.zero_point, self.quant_min,
+ self.quant_max)
+ return X
+
+ @torch.jit.export
+ def calculate_qparams(self):
+ return self.scale, self.zero_point
+
+ @torch.jit.export
+ def extra_repr(self):
+ return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
+ 'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
+ self.fake_quant_enabled, self.observer_enabled,
+ self.scale, self.zero_point, self.dtype,
+ self.quant_min, self.quant_max, self.qscheme)
+
+class FusedMovingAvgObsFakeQuantize(FakeQuantize):
+ r"""Fused module that is used to observe the input tensor (compute min/max), compute
+ scale/zero_point and fake_quantize the tensor.
+ This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
+ to compute the min/max values in order to compute the scale/zero_point.
+ The qscheme input in the observer is used to differentiate between symmetric/affine
+ quantization scheme.
+
+ The output of this module is given by
+ x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
+
+ Similar to :class:`~torch.quantization.FakeQuantize`, and accepts the same attributes as the
+ base class.
+
+ """
+
+ def __init__(
+ self,
+ observer: Any = MovingAverageMinMaxObserver,
+ quant_min: int = 0,
+ quant_max: int = 255,
+ **observer_kwargs: Any
+ ) -> None:
+ super().__init__(observer, quant_min, quant_max, **observer_kwargs)
+ assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\
+ "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
+ self.quant_min: int = quant_min
+ self.quant_max: int = quant_max
+ self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
+ self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
+ self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
+
+ self.quant_min, self.quant_max = self.activation_post_process.quant_min, self.activation_post_process.quant_max
+
+ @torch.jit.export
+ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ return self.activation_post_process.calculate_qparams()
+
+ @torch.jit.export
+ def extra_repr(self) -> str:
+ return (
+ "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
+ "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
+ self.fake_quant_enabled,
+ self.observer_enabled,
+ self.scale,
+ self.zero_point,
+ self.dtype,
+ self.quant_min,
+ self.quant_max,
+ self.qscheme,
+ self.activation_post_process.reduce_range,
+ )
+ )
+
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
+ return torch.fused_moving_avg_obs_fake_quant(
+ X,
+ self.observer_enabled,
+ self.fake_quant_enabled,
+ self.activation_post_process.min_val,
+ self.activation_post_process.max_val,
+ self.scale,
+ self.zero_point,
+ self.activation_post_process.averaging_constant,
+ self.quant_min,
+ self.quant_max,
+ self.ch_axis,
+ self.is_per_channel,
+ self.is_symmetric_quant,
+ )
+
+default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
+ dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
+default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
+ dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
+
+# TODO(future PR): remove these defaults and enforce activation functions
+# to explicitly specify their output range
+default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
+ scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
+default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
+ scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
+
+default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+ quant_min=-128,
+ quant_max=127,
+ dtype=torch.qint8,
+ qscheme=torch.per_channel_symmetric,
+ reduce_range=False,
+ ch_axis=0)
+default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
+ quant_min=0,
+ quant_max=255,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ reduce_range=True)
+
+default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+ quant_min=0,
+ quant_max=255,
+ dtype=torch.quint8,)
+
+
+default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+ quant_min=-128,
+ quant_max=127,
+ dtype=torch.qint8,
+ qscheme=torch.per_tensor_symmetric)
+
+default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+ quant_min=-128,
+ quant_max=127,
+ dtype=torch.qint8,
+ qscheme=torch.per_channel_symmetric)
+
+def _is_fake_quant_script_module(mod):
+ ''' Returns true if given mod is an instance of FakeQuantize script module.
+ '''
+ if isinstance(mod, torch.jit.RecursiveScriptModule):
+ # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
+ suffix = mod._c.qualified_name.split('.', 1)[1]
+ name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
+ return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \
+ name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
+ return False
+
+def disable_fake_quant(mod):
+ if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+ mod.disable_fake_quant()
+
+def enable_fake_quant(mod):
+ if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+ mod.enable_fake_quant()
+
+def disable_observer(mod):
+ if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+ mod.disable_observer()
+
+def enable_observer(mod):
+ if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+ mod.enable_observer()
for data, target in calib_data:
model(data)
+# TODO(future PR): fix the typo, should be `__all__`
_all__ = [
'QuantWrapper', 'QuantStub', 'DeQuantStub',
# Top level API for eager mode quantization
-import torch
-from torch.nn import Module
-from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
-import re
-from abc import ABC, abstractmethod
-from typing import Any, Tuple
-
-def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
- return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
-
-def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
- return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
-
-def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
- return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
-
-class FakeQuantizeBase(ABC, Module):
- r""" Base fake quantize module
- Any fake quantize implementation should derive from this class.
-
- Concrete fake quantize module should follow the same API. In forward, they will update
- the statistics of the observed Tensor and fake quantize the input. They should also provide a
- `calculate_qparams` function that computes the quantization parameters given
- the collected statistics.
-
- """
-
- fake_quant_enabled: torch.Tensor
- observer_enabled: torch.Tensor
-
- def __init__(self):
- super().__init__()
- # fake_quant_enabled and observer_enabled are buffers to support their
- # replication in DDP. Data type is uint8 because NCCL does not support
- # bool tensors.
- self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
- self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
-
- @abstractmethod
- def forward(self, x):
- pass
-
- @abstractmethod
- def calculate_qparams(self, **kwargs):
- pass
-
- @torch.jit.export
- def enable_fake_quant(self, enabled: bool = True) -> None:
- self.fake_quant_enabled[0] = 1 if enabled else 0
-
- @torch.jit.export
- def disable_fake_quant(self):
- self.enable_fake_quant(False)
-
- @torch.jit.export
- def enable_observer(self, enabled: bool = True) -> None:
- self.observer_enabled[0] = 1 if enabled else 0
-
- @torch.jit.export
- def disable_observer(self):
- self.enable_observer(False)
-
- with_args = classmethod(_with_args)
-
-class FakeQuantize(FakeQuantizeBase):
- r""" Simulate the quantize and dequantize operations in training time.
- The output of this module is given by
-
- x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
-
-
-
- * :attr:`scale` defines the scale factor used for quantization.
-
- * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
-
- * :attr:`quant_min` specifies the minimum allowable quantized value.
-
- * :attr:`quant_max` specifies the maximum allowable quantized value.
-
- * :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
- statistics can still be updated.
-
- * :attr:`observer_enable` controls statistics collection on tensors
-
- * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
- allowable values are torch.qint8 and torch.quint8. The values of quant_min and
- quant_max should be chosen to be consistent with the dtype
-
-
- Args:
- observer (module): Module for observing statistics on input tensors and calculating scale
- and zero-point.
- quant_min (int): The minimum allowable quantized value.
- quant_max (int): The maximum allowable quantized value.
- observer_kwargs (optional): Arguments for the observer module
-
- Attributes:
- observer (Module): User provided module that collects statistics on the input tensor and
- provides a method to calculate scale and zero-point.
-
- """
-
- scale: torch.Tensor
- zero_point: torch.Tensor
-
- def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
- super().__init__()
- assert quant_min <= quant_max, \
- 'quant_min must be less than or equal to quant_max'
- self.quant_min = quant_min
- self.quant_max = quant_max
- self.activation_post_process = observer(**observer_kwargs)
- assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
- assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
- self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
- self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
- self.dtype = self.activation_post_process.dtype
- self.qscheme = self.activation_post_process.qscheme
- self.ch_axis = self.activation_post_process.ch_axis \
- if hasattr(self.activation_post_process, 'ch_axis') else -1
- assert _is_per_channel(self.qscheme) or \
- _is_per_tensor(self.qscheme), \
- 'Only per channel and per tensor quantization are supported in fake quantize' + \
- ' got qscheme: ' + str(self.qscheme)
- self.is_per_channel = _is_per_channel(self.qscheme)
-
- @torch.jit.export
- def calculate_qparams(self):
- return self.activation_post_process.calculate_qparams()
-
- def forward(self, X):
- if self.observer_enabled[0] == 1:
- self.activation_post_process(X.detach())
- _scale, _zero_point = self.calculate_qparams()
- _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
- if self.scale.shape != _scale.shape:
- self.scale.resize_(_scale.shape)
- self.zero_point.resize_(_zero_point.shape)
- self.scale.copy_(_scale)
- self.zero_point.copy_(_zero_point)
-
- if self.fake_quant_enabled[0] == 1:
- if self.is_per_channel:
- X = torch.fake_quantize_per_channel_affine(
- X, self.scale, self.zero_point,
- self.ch_axis, self.quant_min, self.quant_max)
- else:
- X = torch.fake_quantize_per_tensor_affine(
- X, self.scale, self.zero_point,
- self.quant_min, self.quant_max)
- return X
-
- @torch.jit.export
- def extra_repr(self):
- return 'fake_quant_enabled={}, observer_enabled={}, ' \
- 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
- 'scale={}, zero_point={}'.format(
- self.fake_quant_enabled, self.observer_enabled,
- self.quant_min, self.quant_max,
- self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
-
- def _save_to_state_dict(self, destination, prefix, keep_vars):
- # We cannot currently register scalar values as buffers, so need to manually
- # specify serialization here.
- super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
- destination[prefix + 'scale'] = self.scale
- destination[prefix + 'zero_point'] = self.zero_point
-
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- # Removing this function throws an error that the the size of the loaded tensor does not match the original size
- # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
- local_state = ['scale', 'zero_point']
- for name in local_state:
- key = prefix + name
- if key in state_dict:
- val = state_dict[key]
- # Custom handling to allow loading scale and zero_point
- # of size N into uninitialized buffers of size 0. The
- # buffers are resized here, and the values are copied in
- # the default state_dict loading code of the parent.
- if name == 'scale':
- self.scale.resize_(val.shape)
- else:
- assert name == 'zero_point'
- self.zero_point.resize_(val.shape)
- # For torchscript module we need to update the attributes here since we do not
- # call the `_load_from_state_dict` function defined module.py
- if torch.jit.is_scripting():
- if name == 'scale':
- self.scale.copy_(val)
- else:
- assert name == 'zero_point'
- self.zero_point.copy_(val)
- elif strict:
- missing_keys.append(key)
- super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs)
-
-class FixedQParamsFakeQuantize(FakeQuantizeBase):
- """ Simulate quantize and dequantize with fixed quantization
- parameters in training time. Only per tensor quantization
- is supported.
- Args:
- `scale` (float): fixed scale for the fake quantize module
- `zero_point` (int): fixed zero point for the fake quantize module
- `dtype`, `qscheme`, `quant_min`, `quant_max`
- """
-
- scale: torch.Tensor
- zero_point: torch.Tensor
-
- def __init__(self,
- scale,
- zero_point,
- dtype=torch.quint8,
- qscheme=torch.per_tensor_affine,
- quant_min=0,
- quant_max=255):
- super().__init__()
- assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max'
- self.quant_min = quant_min
- self.quant_max = quant_max
- self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
- self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
- self.dtype = dtype
- self.qscheme = qscheme
- assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
- ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
-
- def forward(self, X):
- if self.fake_quant_enabled[0] == 1:
- X = torch.fake_quantize_per_tensor_affine(X, self.scale,
- self.zero_point, self.quant_min,
- self.quant_max)
- return X
-
- @torch.jit.export
- def calculate_qparams(self):
- return self.scale, self.zero_point
-
- @torch.jit.export
- def extra_repr(self):
- return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
- 'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
- self.fake_quant_enabled, self.observer_enabled,
- self.scale, self.zero_point, self.dtype,
- self.quant_min, self.quant_max, self.qscheme)
-
-class FusedMovingAvgObsFakeQuantize(FakeQuantize):
- r"""Fused module that is used to observe the input tensor (compute min/max), compute
- scale/zero_point and fake_quantize the tensor.
- This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
- to compute the min/max values in order to compute the scale/zero_point.
- The qscheme input in the observer is used to differentiate between symmetric/affine
- quantization scheme.
-
- The output of this module is given by
- x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
-
- Similar to :class:`~torch.quantization.FakeQuantize`, and accepts the same attributes as the
- base class.
-
- """
-
- def __init__(
- self,
- observer: Any = MovingAverageMinMaxObserver,
- quant_min: int = 0,
- quant_max: int = 255,
- **observer_kwargs: Any
- ) -> None:
- super().__init__(observer, quant_min, quant_max, **observer_kwargs)
- assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\
- "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
- self.quant_min: int = quant_min
- self.quant_max: int = quant_max
- self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
- self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
- self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
-
- self.quant_min, self.quant_max = self.activation_post_process.quant_min, self.activation_post_process.quant_max
-
- @torch.jit.export
- def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
- return self.activation_post_process.calculate_qparams()
-
- @torch.jit.export
- def extra_repr(self) -> str:
- return (
- "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
- "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
- self.fake_quant_enabled,
- self.observer_enabled,
- self.scale,
- self.zero_point,
- self.dtype,
- self.quant_min,
- self.quant_max,
- self.qscheme,
- self.activation_post_process.reduce_range,
- )
- )
-
- def forward(self, X: torch.Tensor) -> torch.Tensor:
- return torch.fused_moving_avg_obs_fake_quant(
- X,
- self.observer_enabled,
- self.fake_quant_enabled,
- self.activation_post_process.min_val,
- self.activation_post_process.max_val,
- self.scale,
- self.zero_point,
- self.activation_post_process.averaging_constant,
- self.quant_min,
- self.quant_max,
- self.ch_axis,
- self.is_per_channel,
- self.is_symmetric_quant,
- )
-
-default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
- dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
-default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
- dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
-
-# TODO(future PR): remove these defaults and enforce activation functions
-# to explicitly specify their output range
-default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
- scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
-default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
- scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
-
-default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
- quant_min=-128,
- quant_max=127,
- dtype=torch.qint8,
- qscheme=torch.per_channel_symmetric,
- reduce_range=False,
- ch_axis=0)
-default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
- quant_min=0,
- quant_max=255,
- dtype=torch.quint8,
- qscheme=torch.per_tensor_affine,
- reduce_range=True)
-
-default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
- quant_min=0,
- quant_max=255,
- dtype=torch.quint8,)
-
-
-default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
- quant_min=-128,
- quant_max=127,
- dtype=torch.qint8,
- qscheme=torch.per_tensor_symmetric)
-
-default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
- quant_min=-128,
- quant_max=127,
- dtype=torch.qint8,
- qscheme=torch.per_channel_symmetric)
-
-def _is_fake_quant_script_module(mod):
- ''' Returns true if given mod is an instance of FakeQuantize script module.
- '''
- if isinstance(mod, torch.jit.RecursiveScriptModule):
- # qualified name looks like '__torch__.torch.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
- suffix = mod._c.qualified_name.split('.', 1)[1]
- name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
- return name == 'torch.quantization.fake_quantize.FakeQuantize' or \
- name == 'torch.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
- return False
-
-def disable_fake_quant(mod):
- if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
- mod.disable_fake_quant()
-
-def enable_fake_quant(mod):
- if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
- mod.enable_fake_quant()
-
-def disable_observer(mod):
- if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
- mod.disable_observer()
-
-def enable_observer(mod):
- if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
- mod.enable_observer()
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/quantization/fake_quantize.py`, while adding an import statement
+here.
+"""
+
+from torch.ao.quantization.fake_quantize import (
+ _is_per_channel,
+ _is_per_tensor,
+ _is_symmetric_quant,
+ FakeQuantizeBase,
+ FakeQuantize,
+ FixedQParamsFakeQuantize,
+ FusedMovingAvgObsFakeQuantize,
+ default_fake_quant,
+ default_weight_fake_quant,
+ default_symmetric_fixed_qparams_fake_quant,
+ default_affine_fixed_qparams_fake_quant,
+ default_per_channel_weight_fake_quant,
+ default_histogram_fake_quant,
+ default_fused_act_fake_quant,
+ default_fused_wt_fake_quant,
+ default_fused_per_channel_wt_fake_quant,
+ _is_fake_quant_script_module,
+ disable_fake_quant,
+ enable_fake_quant,
+ disable_observer,
+ enable_observer,
+)
default_float_qparams_observer, default_observer,
default_per_channel_weight_observer,
default_placeholder_observer, default_weight_observer)
-from .fake_quantize import (FakeQuantize, default_fake_quant,
- default_per_channel_weight_fake_quant,
- default_weight_fake_quant, default_fused_act_fake_quant, default_fused_wt_fake_quant,
- FusedMovingAvgObsFakeQuantize, default_fused_per_channel_wt_fake_quant)
+from torch.ao.quantization.fake_quantize import (
+ FakeQuantize,
+ default_fake_quant,
+ default_per_channel_weight_fake_quant,
+ default_weight_fake_quant,
+ default_fused_act_fake_quant,
+ default_fused_wt_fake_quant,
+ FusedMovingAvgObsFakeQuantize,
+ default_fused_per_channel_wt_fake_quant,
+)
import torch
import torch.nn as nn
from typing import Optional, Union, Dict, Set, Callable, Any
from torch.ao.quantization.stubs import QuantStub, DeQuantStub
-from .fake_quantize import (
+from torch.ao.quantization.fake_quantize import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
)