torch.ao migration: fake_quantize.py, phase 1 (#64814)
authorVasiliy Kuznetsov <vasiliy@fb.com>
Mon, 13 Sep 2021 22:20:44 +0000 (15:20 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 22:22:28 +0000 (15:22 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64814

1. move the file
```
hg mv caffe2/torch/quantization/fake_quantize.py caffe2/torch/ao/quantization/
```

2. create a new file in the old location and copy the imports
3. fix all callsites inside `torch`

Test Plan:
```
buck test mode/dev //caffe2/test:quantization
```

Reviewed By: z-a-f

Differential Revision: D30866792

fbshipit-source-id: 7a221cb46c0ab01f1c5de9be061f09ecc83ce23e

test/quantization/ao_migration/test_quantize.py
torch/ao/quantization/__init__.py
torch/ao/quantization/fake_quantize.py [new file with mode: 0644]
torch/quantization/__init__.py
torch/quantization/fake_quantize.py
torch/quantization/qconfig.py
torch/quantization/quantization_mappings.py

index 9ada537..d6e6109 100644 (file)
@@ -35,10 +35,10 @@ class AOMigrationTestCase(TestCase):
 
 
 class TestAOMigrationQuantizePy(AOMigrationTestCase):
-    def test_package_import(self):
+    def test_package_import_quantize(self):
         self._test_package_import('quantize')
 
-    def test_function_import(self):
+    def test_function_import_quantize(self):
         function_list = [
             '_convert',
             '_observer_forward_hook',
@@ -94,3 +94,32 @@ class TestAOMigrationQuantizePy(AOMigrationTestCase):
             'quantize_dynamic_jit',
         ]
         self._test_function_import('quantize_jit', function_list)
+
+    def test_package_import_fake_quantize(self):
+        self._test_package_import('fake_quantize')
+
+    def test_function_import_fake_quantize(self):
+        function_list = [
+            '_is_per_channel',
+            '_is_per_tensor',
+            '_is_symmetric_quant',
+            'FakeQuantizeBase',
+            'FakeQuantize',
+            'FixedQParamsFakeQuantize',
+            'FusedMovingAvgObsFakeQuantize',
+            'default_fake_quant',
+            'default_weight_fake_quant',
+            'default_symmetric_fixed_qparams_fake_quant',
+            'default_affine_fixed_qparams_fake_quant',
+            'default_per_channel_weight_fake_quant',
+            'default_histogram_fake_quant',
+            'default_fused_act_fake_quant',
+            'default_fused_wt_fake_quant',
+            'default_fused_per_channel_wt_fake_quant',
+            '_is_fake_quant_script_module',
+            'disable_fake_quant',
+            'enable_fake_quant',
+            'disable_observer',
+            'enable_observer',
+        ]
+        self._test_function_import('fake_quantize', function_list)
index e69de29..51029f9 100644 (file)
@@ -0,0 +1,11 @@
+from .fake_quantize import *  # noqa: F403
+
+# TODO(future PR): fix the typo, should be `__all__`
+_all__ = [
+    # FakeQuantize (for qat)
+    'default_fake_quant', 'default_weight_fake_quant',
+    'default_symmetric_fixed_qparams_fake_quant',
+    'default_affine_fixed_qparams_fake_quant',
+    'default_per_channel_weight_fake_quant',
+    'default_histogram_fake_quant',
+]
diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py
new file mode 100644 (file)
index 0000000..2c4d997
--- /dev/null
@@ -0,0 +1,397 @@
+import torch
+from torch.nn import Module
+from torch.quantization.observer import (
+    MovingAverageMinMaxObserver,
+    HistogramObserver,
+    MovingAveragePerChannelMinMaxObserver,
+    _with_args,
+)
+import re
+from abc import ABC, abstractmethod
+from typing import Any, Tuple
+
+def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
+    return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
+
+def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
+    return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
+
+def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
+    return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
+
+class FakeQuantizeBase(ABC, Module):
+    r""" Base fake quantize module
+    Any fake quantize implementation should derive from this class.
+
+    Concrete fake quantize module should follow the same API. In forward, they will update
+    the statistics of the observed Tensor and fake quantize the input. They should also provide a
+    `calculate_qparams` function that computes the quantization parameters given
+    the collected statistics.
+
+    """
+
+    fake_quant_enabled: torch.Tensor
+    observer_enabled: torch.Tensor
+
+    def __init__(self):
+        super().__init__()
+        # fake_quant_enabled and observer_enabled are buffers to support their
+        # replication in DDP. Data type is uint8 because NCCL does not support
+        # bool tensors.
+        self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
+        self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
+
+    @abstractmethod
+    def forward(self, x):
+        pass
+
+    @abstractmethod
+    def calculate_qparams(self, **kwargs):
+        pass
+
+    @torch.jit.export
+    def enable_fake_quant(self, enabled: bool = True) -> None:
+        self.fake_quant_enabled[0] = 1 if enabled else 0
+
+    @torch.jit.export
+    def disable_fake_quant(self):
+        self.enable_fake_quant(False)
+
+    @torch.jit.export
+    def enable_observer(self, enabled: bool = True) -> None:
+        self.observer_enabled[0] = 1 if enabled else 0
+
+    @torch.jit.export
+    def disable_observer(self):
+        self.enable_observer(False)
+
+    with_args = classmethod(_with_args)
+
+class FakeQuantize(FakeQuantizeBase):
+    r""" Simulate the quantize and dequantize operations in training time.
+    The output of this module is given by
+
+    x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
+
+
+
+    * :attr:`scale` defines the scale factor used for quantization.
+
+    * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
+
+    * :attr:`quant_min` specifies the minimum allowable quantized value.
+
+    * :attr:`quant_max` specifies the maximum allowable quantized value.
+
+    * :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
+      statistics can still be updated.
+
+    * :attr:`observer_enable` controls statistics collection on tensors
+
+    * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
+                    allowable values are torch.qint8 and torch.quint8. The values of quant_min and
+                    quant_max should be chosen to be consistent with the dtype
+
+
+    Args:
+        observer (module): Module for observing statistics on input tensors and calculating scale
+                           and zero-point.
+        quant_min (int): The minimum allowable quantized value.
+        quant_max (int): The maximum allowable quantized value.
+        observer_kwargs (optional): Arguments for the observer module
+
+    Attributes:
+        observer (Module): User provided module that collects statistics on the input tensor and
+                           provides a method to calculate scale and zero-point.
+
+    """
+
+    scale: torch.Tensor
+    zero_point: torch.Tensor
+
+    def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
+        super().__init__()
+        assert quant_min <= quant_max, \
+            'quant_min must be less than or equal to quant_max'
+        self.quant_min = quant_min
+        self.quant_max = quant_max
+        self.activation_post_process = observer(**observer_kwargs)
+        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
+        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
+        self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
+        self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
+        self.dtype = self.activation_post_process.dtype
+        self.qscheme = self.activation_post_process.qscheme
+        self.ch_axis = self.activation_post_process.ch_axis \
+            if hasattr(self.activation_post_process, 'ch_axis') else -1
+        assert _is_per_channel(self.qscheme) or \
+            _is_per_tensor(self.qscheme), \
+            'Only per channel and per tensor quantization are supported in fake quantize' + \
+            ' got qscheme: ' + str(self.qscheme)
+        self.is_per_channel = _is_per_channel(self.qscheme)
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        return self.activation_post_process.calculate_qparams()
+
+    def forward(self, X):
+        if self.observer_enabled[0] == 1:
+            self.activation_post_process(X.detach())
+            _scale, _zero_point = self.calculate_qparams()
+            _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
+            if self.scale.shape != _scale.shape:
+                self.scale.resize_(_scale.shape)
+                self.zero_point.resize_(_zero_point.shape)
+            self.scale.copy_(_scale)
+            self.zero_point.copy_(_zero_point)
+
+        if self.fake_quant_enabled[0] == 1:
+            if self.is_per_channel:
+                X = torch.fake_quantize_per_channel_affine(
+                    X, self.scale, self.zero_point,
+                    self.ch_axis, self.quant_min, self.quant_max)
+            else:
+                X = torch.fake_quantize_per_tensor_affine(
+                    X, self.scale, self.zero_point,
+                    self.quant_min, self.quant_max)
+        return X
+
+    @torch.jit.export
+    def extra_repr(self):
+        return 'fake_quant_enabled={}, observer_enabled={}, ' \
+               'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
+               'scale={}, zero_point={}'.format(
+                   self.fake_quant_enabled, self.observer_enabled,
+                   self.quant_min, self.quant_max,
+                   self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        # We cannot currently register scalar values as buffers, so need to manually
+        # specify serialization here.
+        super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'scale'] = self.scale
+        destination[prefix + 'zero_point'] = self.zero_point
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        # Removing this function throws an error that the the size of the loaded tensor does not match the original size
+        # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
+        local_state = ['scale', 'zero_point']
+        for name in local_state:
+            key = prefix + name
+            if key in state_dict:
+                val = state_dict[key]
+                # Custom handling to allow loading scale and zero_point
+                # of size N into uninitialized buffers of size 0. The
+                # buffers are resized here, and the values are copied in
+                # the default state_dict loading code of the parent.
+                if name == 'scale':
+                    self.scale.resize_(val.shape)
+                else:
+                    assert name == 'zero_point'
+                    self.zero_point.resize_(val.shape)
+                # For torchscript module we need to update the attributes here since we do not
+                # call the `_load_from_state_dict` function defined module.py
+                if torch.jit.is_scripting():
+                    if name == 'scale':
+                        self.scale.copy_(val)
+                    else:
+                        assert name == 'zero_point'
+                        self.zero_point.copy_(val)
+            elif strict:
+                missing_keys.append(key)
+        super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
+                                                        missing_keys, unexpected_keys, error_msgs)
+
+class FixedQParamsFakeQuantize(FakeQuantizeBase):
+    """ Simulate quantize and dequantize with fixed quantization
+    parameters in training time. Only per tensor quantization
+    is supported.
+    Args:
+        `scale` (float): fixed scale for the fake quantize module
+        `zero_point` (int): fixed zero point for the fake quantize module
+        `dtype`, `qscheme`, `quant_min`, `quant_max`
+    """
+
+    scale: torch.Tensor
+    zero_point: torch.Tensor
+
+    def __init__(self,
+                 scale,
+                 zero_point,
+                 dtype=torch.quint8,
+                 qscheme=torch.per_tensor_affine,
+                 quant_min=0,
+                 quant_max=255):
+        super().__init__()
+        assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max'
+        self.quant_min = quant_min
+        self.quant_max = quant_max
+        self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
+        self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
+        self.dtype = dtype
+        self.qscheme = qscheme
+        assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
+            ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
+
+    def forward(self, X):
+        if self.fake_quant_enabled[0] == 1:
+            X = torch.fake_quantize_per_tensor_affine(X, self.scale,
+                                                      self.zero_point, self.quant_min,
+                                                      self.quant_max)
+        return X
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        return self.scale, self.zero_point
+
+    @torch.jit.export
+    def extra_repr(self):
+        return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
+               'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
+                   self.fake_quant_enabled, self.observer_enabled,
+                   self.scale, self.zero_point, self.dtype,
+                   self.quant_min, self.quant_max, self.qscheme)
+
+class FusedMovingAvgObsFakeQuantize(FakeQuantize):
+    r"""Fused module that is used to observe the input tensor (compute min/max), compute
+    scale/zero_point and fake_quantize the tensor.
+    This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
+    to compute the min/max values in order to compute the scale/zero_point.
+    The qscheme input in the observer is used to differentiate between symmetric/affine
+    quantization scheme.
+
+    The output of this module is given by
+    x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
+
+    Similar to :class:`~torch.quantization.FakeQuantize`, and accepts the same attributes as the
+    base class.
+
+    """
+
+    def __init__(
+        self,
+        observer: Any = MovingAverageMinMaxObserver,
+        quant_min: int = 0,
+        quant_max: int = 255,
+        **observer_kwargs: Any
+    ) -> None:
+        super().__init__(observer, quant_min, quant_max, **observer_kwargs)
+        assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\
+            "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
+        self.quant_min: int = quant_min
+        self.quant_max: int = quant_max
+        self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
+        self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
+        self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
+
+        self.quant_min, self.quant_max = self.activation_post_process.quant_min, self.activation_post_process.quant_max
+
+    @torch.jit.export
+    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
+        return self.activation_post_process.calculate_qparams()
+
+    @torch.jit.export
+    def extra_repr(self) -> str:
+        return (
+            "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
+            "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
+                self.fake_quant_enabled,
+                self.observer_enabled,
+                self.scale,
+                self.zero_point,
+                self.dtype,
+                self.quant_min,
+                self.quant_max,
+                self.qscheme,
+                self.activation_post_process.reduce_range,
+            )
+        )
+
+    def forward(self, X: torch.Tensor) -> torch.Tensor:
+        return torch.fused_moving_avg_obs_fake_quant(
+            X,
+            self.observer_enabled,
+            self.fake_quant_enabled,
+            self.activation_post_process.min_val,
+            self.activation_post_process.max_val,
+            self.scale,
+            self.zero_point,
+            self.activation_post_process.averaging_constant,
+            self.quant_min,
+            self.quant_max,
+            self.ch_axis,
+            self.is_per_channel,
+            self.is_symmetric_quant,
+        )
+
+default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
+                                            dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
+default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
+                                                   dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
+
+# TODO(future PR): remove these defaults and enforce activation functions
+# to explicitly specify their output range
+default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
+    scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
+default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
+    scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
+
+default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+                                                               quant_min=-128,
+                                                               quant_max=127,
+                                                               dtype=torch.qint8,
+                                                               qscheme=torch.per_channel_symmetric,
+                                                               reduce_range=False,
+                                                               ch_axis=0)
+default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
+                                                      quant_min=0,
+                                                      quant_max=255,
+                                                      dtype=torch.quint8,
+                                                      qscheme=torch.per_tensor_affine,
+                                                      reduce_range=True)
+
+default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                       quant_min=0,
+                                                                       quant_max=255,
+                                                                       dtype=torch.quint8,)
+
+
+default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                      quant_min=-128,
+                                                                      quant_max=127,
+                                                                      dtype=torch.qint8,
+                                                                      qscheme=torch.per_tensor_symmetric)
+
+default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+                                                                                  quant_min=-128,
+                                                                                  quant_max=127,
+                                                                                  dtype=torch.qint8,
+                                                                                  qscheme=torch.per_channel_symmetric)
+
+def _is_fake_quant_script_module(mod):
+    ''' Returns true if given mod is an instance of FakeQuantize script module.
+    '''
+    if isinstance(mod, torch.jit.RecursiveScriptModule):
+        # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
+        suffix = mod._c.qualified_name.split('.', 1)[1]
+        name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
+        return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \
+            name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
+    return False
+
+def disable_fake_quant(mod):
+    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+        mod.disable_fake_quant()
+
+def enable_fake_quant(mod):
+    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+        mod.enable_fake_quant()
+
+def disable_observer(mod):
+    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+        mod.disable_observer()
+
+def enable_observer(mod):
+    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+        mod.enable_observer()
index 1b66d02..5897753 100644 (file)
@@ -18,6 +18,7 @@ def default_eval_fn(model, calib_data):
     for data, target in calib_data:
         model(data)
 
+# TODO(future PR): fix the typo, should be `__all__`
 _all__ = [
     'QuantWrapper', 'QuantStub', 'DeQuantStub',
     # Top level API for eager mode quantization
index 7d470b8..3b28b75 100644 (file)
-import torch
-from torch.nn import Module
-from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
-import re
-from abc import ABC, abstractmethod
-from typing import Any, Tuple
-
-def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
-    return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
-
-def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
-    return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
-
-def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
-    return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
-
-class FakeQuantizeBase(ABC, Module):
-    r""" Base fake quantize module
-    Any fake quantize implementation should derive from this class.
-
-    Concrete fake quantize module should follow the same API. In forward, they will update
-    the statistics of the observed Tensor and fake quantize the input. They should also provide a
-    `calculate_qparams` function that computes the quantization parameters given
-    the collected statistics.
-
-    """
-
-    fake_quant_enabled: torch.Tensor
-    observer_enabled: torch.Tensor
-
-    def __init__(self):
-        super().__init__()
-        # fake_quant_enabled and observer_enabled are buffers to support their
-        # replication in DDP. Data type is uint8 because NCCL does not support
-        # bool tensors.
-        self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
-        self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
-
-    @abstractmethod
-    def forward(self, x):
-        pass
-
-    @abstractmethod
-    def calculate_qparams(self, **kwargs):
-        pass
-
-    @torch.jit.export
-    def enable_fake_quant(self, enabled: bool = True) -> None:
-        self.fake_quant_enabled[0] = 1 if enabled else 0
-
-    @torch.jit.export
-    def disable_fake_quant(self):
-        self.enable_fake_quant(False)
-
-    @torch.jit.export
-    def enable_observer(self, enabled: bool = True) -> None:
-        self.observer_enabled[0] = 1 if enabled else 0
-
-    @torch.jit.export
-    def disable_observer(self):
-        self.enable_observer(False)
-
-    with_args = classmethod(_with_args)
-
-class FakeQuantize(FakeQuantizeBase):
-    r""" Simulate the quantize and dequantize operations in training time.
-    The output of this module is given by
-
-    x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
-
-
-
-    * :attr:`scale` defines the scale factor used for quantization.
-
-    * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
-
-    * :attr:`quant_min` specifies the minimum allowable quantized value.
-
-    * :attr:`quant_max` specifies the maximum allowable quantized value.
-
-    * :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
-      statistics can still be updated.
-
-    * :attr:`observer_enable` controls statistics collection on tensors
-
-    * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
-                    allowable values are torch.qint8 and torch.quint8. The values of quant_min and
-                    quant_max should be chosen to be consistent with the dtype
-
-
-    Args:
-        observer (module): Module for observing statistics on input tensors and calculating scale
-                           and zero-point.
-        quant_min (int): The minimum allowable quantized value.
-        quant_max (int): The maximum allowable quantized value.
-        observer_kwargs (optional): Arguments for the observer module
-
-    Attributes:
-        observer (Module): User provided module that collects statistics on the input tensor and
-                           provides a method to calculate scale and zero-point.
-
-    """
-
-    scale: torch.Tensor
-    zero_point: torch.Tensor
-
-    def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
-        super().__init__()
-        assert quant_min <= quant_max, \
-            'quant_min must be less than or equal to quant_max'
-        self.quant_min = quant_min
-        self.quant_max = quant_max
-        self.activation_post_process = observer(**observer_kwargs)
-        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
-        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
-        self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
-        self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
-        self.dtype = self.activation_post_process.dtype
-        self.qscheme = self.activation_post_process.qscheme
-        self.ch_axis = self.activation_post_process.ch_axis \
-            if hasattr(self.activation_post_process, 'ch_axis') else -1
-        assert _is_per_channel(self.qscheme) or \
-            _is_per_tensor(self.qscheme), \
-            'Only per channel and per tensor quantization are supported in fake quantize' + \
-            ' got qscheme: ' + str(self.qscheme)
-        self.is_per_channel = _is_per_channel(self.qscheme)
-
-    @torch.jit.export
-    def calculate_qparams(self):
-        return self.activation_post_process.calculate_qparams()
-
-    def forward(self, X):
-        if self.observer_enabled[0] == 1:
-            self.activation_post_process(X.detach())
-            _scale, _zero_point = self.calculate_qparams()
-            _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
-            if self.scale.shape != _scale.shape:
-                self.scale.resize_(_scale.shape)
-                self.zero_point.resize_(_zero_point.shape)
-            self.scale.copy_(_scale)
-            self.zero_point.copy_(_zero_point)
-
-        if self.fake_quant_enabled[0] == 1:
-            if self.is_per_channel:
-                X = torch.fake_quantize_per_channel_affine(
-                    X, self.scale, self.zero_point,
-                    self.ch_axis, self.quant_min, self.quant_max)
-            else:
-                X = torch.fake_quantize_per_tensor_affine(
-                    X, self.scale, self.zero_point,
-                    self.quant_min, self.quant_max)
-        return X
-
-    @torch.jit.export
-    def extra_repr(self):
-        return 'fake_quant_enabled={}, observer_enabled={}, ' \
-               'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
-               'scale={}, zero_point={}'.format(
-                   self.fake_quant_enabled, self.observer_enabled,
-                   self.quant_min, self.quant_max,
-                   self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
-
-    def _save_to_state_dict(self, destination, prefix, keep_vars):
-        # We cannot currently register scalar values as buffers, so need to manually
-        # specify serialization here.
-        super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
-        destination[prefix + 'scale'] = self.scale
-        destination[prefix + 'zero_point'] = self.zero_point
-
-    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
-                              missing_keys, unexpected_keys, error_msgs):
-        # Removing this function throws an error that the the size of the loaded tensor does not match the original size
-        # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
-        local_state = ['scale', 'zero_point']
-        for name in local_state:
-            key = prefix + name
-            if key in state_dict:
-                val = state_dict[key]
-                # Custom handling to allow loading scale and zero_point
-                # of size N into uninitialized buffers of size 0. The
-                # buffers are resized here, and the values are copied in
-                # the default state_dict loading code of the parent.
-                if name == 'scale':
-                    self.scale.resize_(val.shape)
-                else:
-                    assert name == 'zero_point'
-                    self.zero_point.resize_(val.shape)
-                # For torchscript module we need to update the attributes here since we do not
-                # call the `_load_from_state_dict` function defined module.py
-                if torch.jit.is_scripting():
-                    if name == 'scale':
-                        self.scale.copy_(val)
-                    else:
-                        assert name == 'zero_point'
-                        self.zero_point.copy_(val)
-            elif strict:
-                missing_keys.append(key)
-        super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
-                                                        missing_keys, unexpected_keys, error_msgs)
-
-class FixedQParamsFakeQuantize(FakeQuantizeBase):
-    """ Simulate quantize and dequantize with fixed quantization
-    parameters in training time. Only per tensor quantization
-    is supported.
-    Args:
-        `scale` (float): fixed scale for the fake quantize module
-        `zero_point` (int): fixed zero point for the fake quantize module
-        `dtype`, `qscheme`, `quant_min`, `quant_max`
-    """
-
-    scale: torch.Tensor
-    zero_point: torch.Tensor
-
-    def __init__(self,
-                 scale,
-                 zero_point,
-                 dtype=torch.quint8,
-                 qscheme=torch.per_tensor_affine,
-                 quant_min=0,
-                 quant_max=255):
-        super().__init__()
-        assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max'
-        self.quant_min = quant_min
-        self.quant_max = quant_max
-        self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
-        self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
-        self.dtype = dtype
-        self.qscheme = qscheme
-        assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
-            ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
-
-    def forward(self, X):
-        if self.fake_quant_enabled[0] == 1:
-            X = torch.fake_quantize_per_tensor_affine(X, self.scale,
-                                                      self.zero_point, self.quant_min,
-                                                      self.quant_max)
-        return X
-
-    @torch.jit.export
-    def calculate_qparams(self):
-        return self.scale, self.zero_point
-
-    @torch.jit.export
-    def extra_repr(self):
-        return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
-               'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
-                   self.fake_quant_enabled, self.observer_enabled,
-                   self.scale, self.zero_point, self.dtype,
-                   self.quant_min, self.quant_max, self.qscheme)
-
-class FusedMovingAvgObsFakeQuantize(FakeQuantize):
-    r"""Fused module that is used to observe the input tensor (compute min/max), compute
-    scale/zero_point and fake_quantize the tensor.
-    This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
-    to compute the min/max values in order to compute the scale/zero_point.
-    The qscheme input in the observer is used to differentiate between symmetric/affine
-    quantization scheme.
-
-    The output of this module is given by
-    x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
-
-    Similar to :class:`~torch.quantization.FakeQuantize`, and accepts the same attributes as the
-    base class.
-
-    """
-
-    def __init__(
-        self,
-        observer: Any = MovingAverageMinMaxObserver,
-        quant_min: int = 0,
-        quant_max: int = 255,
-        **observer_kwargs: Any
-    ) -> None:
-        super().__init__(observer, quant_min, quant_max, **observer_kwargs)
-        assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\
-            "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
-        self.quant_min: int = quant_min
-        self.quant_max: int = quant_max
-        self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
-        self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
-        self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
-
-        self.quant_min, self.quant_max = self.activation_post_process.quant_min, self.activation_post_process.quant_max
-
-    @torch.jit.export
-    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
-        return self.activation_post_process.calculate_qparams()
-
-    @torch.jit.export
-    def extra_repr(self) -> str:
-        return (
-            "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
-            "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
-                self.fake_quant_enabled,
-                self.observer_enabled,
-                self.scale,
-                self.zero_point,
-                self.dtype,
-                self.quant_min,
-                self.quant_max,
-                self.qscheme,
-                self.activation_post_process.reduce_range,
-            )
-        )
-
-    def forward(self, X: torch.Tensor) -> torch.Tensor:
-        return torch.fused_moving_avg_obs_fake_quant(
-            X,
-            self.observer_enabled,
-            self.fake_quant_enabled,
-            self.activation_post_process.min_val,
-            self.activation_post_process.max_val,
-            self.scale,
-            self.zero_point,
-            self.activation_post_process.averaging_constant,
-            self.quant_min,
-            self.quant_max,
-            self.ch_axis,
-            self.is_per_channel,
-            self.is_symmetric_quant,
-        )
-
-default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
-                                            dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
-default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
-                                                   dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
-
-# TODO(future PR): remove these defaults and enforce activation functions
-# to explicitly specify their output range
-default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
-    scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
-default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
-    scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
-
-default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
-                                                               quant_min=-128,
-                                                               quant_max=127,
-                                                               dtype=torch.qint8,
-                                                               qscheme=torch.per_channel_symmetric,
-                                                               reduce_range=False,
-                                                               ch_axis=0)
-default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
-                                                      quant_min=0,
-                                                      quant_max=255,
-                                                      dtype=torch.quint8,
-                                                      qscheme=torch.per_tensor_affine,
-                                                      reduce_range=True)
-
-default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
-                                                                       quant_min=0,
-                                                                       quant_max=255,
-                                                                       dtype=torch.quint8,)
-
-
-default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
-                                                                      quant_min=-128,
-                                                                      quant_max=127,
-                                                                      dtype=torch.qint8,
-                                                                      qscheme=torch.per_tensor_symmetric)
-
-default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
-                                                                                  quant_min=-128,
-                                                                                  quant_max=127,
-                                                                                  dtype=torch.qint8,
-                                                                                  qscheme=torch.per_channel_symmetric)
-
-def _is_fake_quant_script_module(mod):
-    ''' Returns true if given mod is an instance of FakeQuantize script module.
-    '''
-    if isinstance(mod, torch.jit.RecursiveScriptModule):
-        # qualified name looks like '__torch__.torch.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
-        suffix = mod._c.qualified_name.split('.', 1)[1]
-        name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
-        return name == 'torch.quantization.fake_quantize.FakeQuantize' or \
-            name == 'torch.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
-    return False
-
-def disable_fake_quant(mod):
-    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
-        mod.disable_fake_quant()
-
-def enable_fake_quant(mod):
-    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
-        mod.enable_fake_quant()
-
-def disable_observer(mod):
-    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
-        mod.disable_observer()
-
-def enable_observer(mod):
-    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
-        mod.enable_observer()
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/quantization/fake_quantize.py`, while adding an import statement
+here.
+"""
+
+from torch.ao.quantization.fake_quantize import (
+    _is_per_channel,
+    _is_per_tensor,
+    _is_symmetric_quant,
+    FakeQuantizeBase,
+    FakeQuantize,
+    FixedQParamsFakeQuantize,
+    FusedMovingAvgObsFakeQuantize,
+    default_fake_quant,
+    default_weight_fake_quant,
+    default_symmetric_fixed_qparams_fake_quant,
+    default_affine_fixed_qparams_fake_quant,
+    default_per_channel_weight_fake_quant,
+    default_histogram_fake_quant,
+    default_fused_act_fake_quant,
+    default_fused_wt_fake_quant,
+    default_fused_per_channel_wt_fake_quant,
+    _is_fake_quant_script_module,
+    disable_fake_quant,
+    enable_fake_quant,
+    disable_observer,
+    enable_observer,
+)
index ae89b4a..2bd1d99 100644 (file)
@@ -5,10 +5,16 @@ from .observer import (HistogramObserver, MovingAverageMinMaxObserver,
                        default_float_qparams_observer, default_observer,
                        default_per_channel_weight_observer,
                        default_placeholder_observer, default_weight_observer)
-from .fake_quantize import (FakeQuantize, default_fake_quant,
-                            default_per_channel_weight_fake_quant,
-                            default_weight_fake_quant, default_fused_act_fake_quant, default_fused_wt_fake_quant,
-                            FusedMovingAvgObsFakeQuantize, default_fused_per_channel_wt_fake_quant)
+from torch.ao.quantization.fake_quantize import (
+    FakeQuantize,
+    default_fake_quant,
+    default_per_channel_weight_fake_quant,
+    default_weight_fake_quant,
+    default_fused_act_fake_quant,
+    default_fused_wt_fake_quant,
+    FusedMovingAvgObsFakeQuantize,
+    default_fused_per_channel_wt_fake_quant,
+)
 import torch
 import torch.nn as nn
 
index 2c577a6..a2c7519 100644 (file)
@@ -16,7 +16,7 @@ import torch.nn.qat as nnqat
 from typing import Optional, Union, Dict, Set, Callable, Any
 
 from torch.ao.quantization.stubs import QuantStub, DeQuantStub
-from .fake_quantize import (
+from torch.ao.quantization.fake_quantize import (
     default_affine_fixed_qparams_fake_quant,
     default_symmetric_fixed_qparams_fake_quant,
 )