import torch.nn.quantized as nnq
import torch.nn.quantized._reference as nnqr
import torch.nn.quantized.dynamic as nnqd
-import torch.nn.functional as F
import torch.quantization
from torch.quantization import (
[4, 8],
[True, False],
[True, False],
- [True, False],
[True, False])
for (batch_size, in_features, out_features, use_bias,
- use_fused, per_channel, is_reference) in options:
+ use_fused, per_channel) in options:
self._test_linear_api_impl(
batch_size, in_features, out_features, use_bias, use_fused,
- per_channel, is_reference)
+ per_channel)
- def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, is_reference):
+ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel):
if torch.backends.quantized.engine == 'qnnpack':
per_channel = False
- # (use_fused, is_reference) -> quantized class
+ # use_fused -> quantized class
class_map = {
- (True, True) : nniqr.LinearReLU,
- (True, False) : nniq.LinearReLU,
- (False, True) : nnqr.Linear,
- (False, False) : nnq.Linear,
+ True: nniq.LinearReLU,
+ False: nnq.Linear,
}
W = torch.rand(out_features, in_features).float()
B = torch.rand(out_features).float() if use_bias else None
scale = 0.5
zero_point = 3
- qlinear = class_map[(use_fused, is_reference)](in_features, out_features)
+ qlinear = class_map[use_fused](in_features, out_features)
qlinear_copy = qlinear # deepcopy does not work right now
# qlinear_copy = copy.deepcopy(qlinear)
# Check if the module implementation matches calling the
# ops directly
- if is_reference:
- weight = qlinear._qweight
- bias = qlinear._bias
- weight_dequant = weight.dequantize()
- X_q_dq = X_q.dequantize()
- Z_ref = F.linear(X_q_dq, weight_dequant, bias)
- if use_fused:
- Z_ref = F.relu(Z_ref, inplace=True)
- Z_ref = torch.quantize_per_tensor(Z_ref, scale, zero_point, torch.quint8)
+ W_pack = qlinear._packed_params._packed_params
+ if use_fused:
+ Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)
else:
- W_pack = qlinear._packed_params._packed_params
- if use_fused:
- Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)
- else:
- Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)
+ Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)
self.assertEqual(Z_ref, Z_q)
self.assertTrue(
else:
self.assertEqual(model_dict[key], loaded_dict[key])
- loaded_qlinear = class_map[(use_fused, is_reference)](
+ loaded_qlinear = class_map[use_fused](
in_features, out_features)
loaded_qlinear.load_state_dict(loaded_dict)
- if is_reference:
- self.assertEqual(qlinear._qweight, loaded_qlinear._qweight)
- self.assertEqual(qlinear._bias, loaded_qlinear._bias)
- else:
- linear_unpack = torch.ops.quantized.linear_unpack
- self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
- linear_unpack(loaded_qlinear._packed_params._packed_params))
+ linear_unpack = torch.ops.quantized.linear_unpack
+ self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
+ linear_unpack(loaded_qlinear._packed_params._packed_params))
self.assertEqual(qlinear.scale, loaded_qlinear.scale)
self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
# make sure loaded_qlinear has the same dir as qlinear since
self.checkScriptable(loaded_qlinear, [[X_q]], check_save_load=True)
self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
- if not is_reference:
- self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
+ self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
Z_q2 = loaded_qlinear(X_q)
self.assertEqual(Z_q, Z_q2)
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.quantized as nnq
+import torch.nn.quantized._reference as nnqr
import torch.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
LinearModule,
(),
(linear_module_input,),
- ns.call_module(nn.Linear) if is_reference else ns.call_module(nnqd.Linear),
+ ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear),
None,
),
(
LinearModule,
(),
(linear_module_input,),
- ns.call_module(nn.Linear if is_reference else nnq.Linear),
+ ns.call_module(nnqr.Linear if is_reference else nnq.Linear),
None,
),
]
""" Test quantizing functional conv and linear with reference option
"""
tests = self._get_conv_linear_test_cases(is_reference=True)
+
+ def _get_keys(prefix, is_dynamic):
+ all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
+ if not is_dynamic:
+ all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
+ return all_keys
+
for (is_dynamic, ModuleClass, module_constructor_inputs,
inputs, quantized_node, weight_prepack_node) in tests:
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
qr = result_dict["quantized_reference"]
def checkWeightQParams(model):
- for module_name in ("linear", "conv"):
+ for module_name in ("conv",):
if hasattr(model, module_name):
self.assertTrue(hasattr(qr.get_submodule(module_name), "_weight_qparams"))
self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
+ for module_name in ("linear",):
+ if hasattr(model, module_name):
+ self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
+ self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
+ self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
+ self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
- def checkSerDeser(model):
- for module_name in ("linear", "conv"):
+ def checkSerDeser(model, is_dynamic):
+ for module_name in ("conv",):
if hasattr(model, module_name):
# make sure seralization works
state_dict = copy.deepcopy(model.state_dict())
module._weight_qparams["scale"] = None
model.load_state_dict(state_dict)
self.assertTrue(torch.equal(prev_scale, module._weight_qparams["scale"]))
+ for module_name in ("linear",):
+ if hasattr(model, module_name):
+ # make sure seralization works
+ state_dict = copy.deepcopy(model.state_dict())
+ all_keys = _get_keys(module_name, is_dynamic)
+ for key in all_keys:
+ self.assertTrue(key in state_dict)
+ # check load_state_dict restores states
+ module = getattr(model, module_name)
+ prev_scale = module.weight_scale
+ module.weight_scale = None
+ model.load_state_dict(state_dict)
+ module = getattr(model, module_name)
+ self.assertTrue(torch.equal(prev_scale, module.weight_scale))
checkWeightQParams(qr)
# make sure the qparams are preserved after copy
checkWeightQParams(qr)
- checkSerDeser(qr)
+ checkSerDeser(qr, is_dynamic)
@skipIfNoFBGEMM
def test_dynamic_quant_weight_observer(self):
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
+ def test_ref_linear_module(self):
+ """ Make sure the numerics for models with ref linear module
+ matches models with fbgemm/qnnpack module
+ """
+ class M1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(10, 5)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ class M2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(10, 5)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ return self.relu(self.linear(x))
+
+ for M in [M1, M2]:
+ m = M().eval()
+ m = prepare_fx(m, {"": default_qconfig})
+ m_copy = copy.deepcopy(m)
+ m = convert_fx(m, is_reference=False)
+ m_ref = convert_fx(m_copy, is_reference=True)
+ data = torch.randn(5, 10)
+ result = m(data)
+ result_ref = m_ref(data)
+ self.assertTrue(torch.equal(result, result_ref))
+
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
import torch
-from .linear_relu import LinearReLU
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
__all__ = [
- 'LinearReLU',
'ConvReLU1d',
'ConvReLU2d',
'ConvReLU3d',
+++ /dev/null
-import torch
-import torch.nn.intrinsic as nni
-import torch.nn.quantized._reference as nnqr
-import torch.nn.functional as F
-
-class LinearReLU(nnqr.Linear):
- _FLOAT_MODULE = nni.LinearReLU
-
- def __init__(
- self,
- in_features,
- out_features,
- bias=True,
- dtype=torch.qint8):
- super().__init__(in_features, out_features, bias, dtype)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_dequant = x.dequantize()
- weight_dequant = self._qweight.dequantize()
- float_result = F.linear(x_dequant, weight_dequant, self._bias)
- float_result = F.relu(float_result, inplace=True)
- # NEEDFIX: we don't have dtype in the Linear module APIs right now!
- result = torch.quantize_per_tensor(
- float_result, self.scale, self.zero_point, torch.quint8)
- return result
-
- def _get_name(self):
- return "QuantizedLinearReLU(Reference)"
import torch
-import torch.nn.quantized as nnq
+import torch.nn as nn
import torch.nn.functional as F
-from typing import Optional
+from typing import Optional, Dict, Any
+from .utils import _quantize_and_dequantize_weight
+from .utils import _save_weight_qparams
+from .utils import _get_weight_qparam_keys
-class Linear(nnq.Linear):
- """ A backend independent version of nn.quantized.Linear
- we will not pack the parameters in this module, since weight packing is an
- optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
- this is useful when user want to use this module in other backends like Glow.
+class Linear(nn.Linear):
+ """ A reference quantized linear module that fits into the FX
+ Graph Mode Quantization workflow
+ activation will be floating point Tensor, we will store floating
+ point weight as well in the module, but in forward we'll quantize
+ and dequantize the weight before running the floating point functional
+ linear operator.
"""
- def __init__(self, in_features, out_features, bias_=True,
- dtype=torch.qint8):
- super().__init__(in_features, out_features, bias_, dtype)
- self._qweight, self._bias = self._packed_params._weight_bias()
- del self._packed_params
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias_: bool = True,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ weight_qparams: Optional[Dict[str, Any]] = None):
+ super().__init__(in_features, out_features, bias_, device, dtype)
+ if weight_qparams is None:
+ weight_qparams = {
+ "qscheme": torch.per_tensor_affine,
+ "dtype": torch.quint8,
+ "scale": 1.0,
+ "zero_point": 0
+ }
+ self.weight_qscheme = weight_qparams["qscheme"]
+ self.weight_dtype = weight_qparams["dtype"]
+ assert self.weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
+ Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized linear module")
+ if self.weight_qscheme is not None:
+ self.register_buffer(
+ "weight_scale",
+ torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
+ self.register_buffer(
+ "weight_zero_point",
+ torch.tensor(
+ weight_qparams["zero_point"],
+ dtype=torch.int, device=device))
+ if self.weight_qscheme == torch.per_channel_affine:
+ self.register_buffer(
+ "weight_axis",
+ torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
+ else:
+ # added for TorchScriptability, not used
+ self.register_buffer(
+ "weight_axis",
+ torch.tensor(0, dtype=torch.int, device=device))
def _get_name(self):
return "QuantizedLinear(Reference)"
+ def get_weight(self):
+ """
+ Fake quantize (quantize and dequantize) the weight with
+ the quantization parameters for weight, this is used to
+ simulate the numerics for the quantized weight in a quantized
+ model
+ """
+ # supress mypy warning
+ assert isinstance(self.weight, torch.Tensor)
+ assert isinstance(self.weight_scale, torch.Tensor)
+ assert isinstance(self.weight_zero_point, torch.Tensor)
+ assert isinstance(self.weight_axis, torch.Tensor)
+ return _quantize_and_dequantize_weight(
+ self.weight, self.weight_qscheme, self.weight_dtype, self.weight_scale,
+ self.weight_zero_point, self.weight_axis)
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_dequant = x.dequantize()
- weight_dequant = self._qweight.dequantize()
- float_result = F.linear(x_dequant, weight_dequant, self._bias)
- # NEEDFIX: we don't have dtype in the Linear module APIs right now!
- result = torch.quantize_per_tensor(
- float_result, self.scale, self.zero_point, torch.quint8)
+ """
+ we have:
+ w(float) -- quant - dequant \
+ x(float) ------------- F.linear ---
+
+ In the full model, we will see
+ w(float) -- quant - *dequant \
+ x -- quant --- *dequant -- *F.linear --- *quant - dequant
+ and the backend should be able to fuse the ops with `*` into a quantized linear
+ """
+ weight_dequant = self.get_weight()
+ result = F.linear(x, weight_dequant, self.bias)
return result
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
- destination[prefix + '_qweight'] = self._qweight
- destination[prefix + '_bias'] = self._bias
+ _save_weight_qparams(
+ destination, prefix, self.weight_qscheme, self.weight_dtype,
+ self.weight_scale, self.weight_zero_point, self.weight_axis)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
- self._qweight = state_dict[prefix + '_qweight']
- self._bias = state_dict[prefix + '_bias']
- state_dict.pop(prefix + '_qweight')
- state_dict.pop(prefix + '_bias')
+ for key in _get_weight_qparam_keys(state_dict, prefix):
+ setattr(self, key, state_dict[prefix + key])
+ state_dict.pop(prefix + key)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
- def _weight_bias(self):
- return self._qweight, self._bias
-
- def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
- self._qweight = w
- self._bias = b
+ @classmethod
+ def from_float(cls, float_linear, weight_qparams):
+ qref_linear = Linear(
+ float_linear.in_features, float_linear.out_features,
+ float_linear.bias is not None, device=float_linear.weight.device,
+ dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
+ qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
+ if float_linear.bias is not None:
+ qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
+ return qref_linear
--- /dev/null
+import torch
+from typing import Dict, Any
+
+def _quantize_and_dequantize_weight(
+ weight: torch.Tensor,
+ weight_qscheme: torch.qscheme,
+ weight_dtype: torch.dtype,
+ weight_scale: torch.Tensor,
+ weight_zero_point: torch.Tensor,
+ weight_axis: torch.Tensor):
+ """ Quantize and then dequantize the weight based on
+ the quantization parameters
+ """
+ if weight_qscheme == torch.per_tensor_affine:
+ weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
+ weight_dequant = weight.dequantize()
+ elif weight_qscheme == torch.per_channel_affine:
+ weight = torch.quantize_per_channel(
+ weight, weight_scale,
+ weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type]
+ weight_dequant = weight.dequantize()
+ else:
+ weight_dequant = weight
+ return weight_dequant
+
+def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis):
+ destination[prefix + "weight_qscheme"] = weight_qscheme
+ destination[prefix + "weight_dtype"] = weight_dtype
+ if weight_qscheme is not None:
+ destination[prefix + "weight_scale"] = weight_scale
+ destination[prefix + "weight_zero_point"] = weight_zero_point
+ if weight_qscheme == torch.per_channel_affine:
+ destination[prefix + "weight_axis"] = weight_axis
+
+def _get_weight_qparam_keys(
+ state_dict: Dict[str, Any],
+ prefix: str):
+ keys = ["weight_qscheme", "weight_dtype"]
+ weight_qscheme = state_dict[prefix + "weight_qscheme"]
+ if weight_qscheme is not None:
+ keys.append("weight_scale")
+ keys.append("weight_zero_point")
+ if weight_qscheme == torch.quantize_per_channel:
+ keys.append("weight_axis")
+ return keys
# Get the float linear and attach qscheme and qparams
# the the module
float_linear = self.linear
+ fused_linear = None
if isinstance(float_linear, (torch.nn.qat.Linear, torch.nn.intrinsic.qat.LinearReLU)):
float_linear = float_linear.to_float()
# change qat linear to linear
setattr(modules[parent_name], name, float_linear)
# Attach weight fake quant to the linear module
if isinstance(float_linear, torch.nn.intrinsic.LinearReLU):
+ fused_linear = float_linear
float_linear = float_linear[0]
weight_post_process = self.linear.weight_fake_quant
else:
if isinstance(float_linear, torch.nn.intrinsic.LinearReLU):
+ fused_linear = float_linear
float_linear = self.linear[0] # type: ignore[index]
# Attach the weight observer to the module
weight_post_process = qconfig.weight() # type: ignore[union-attr]
weight_post_process(float_linear.weight) # type: ignore[operator]
weight_qparams = get_qparam_dict(weight_post_process)
- _to_reference(float_linear, weight_qparams)
+ # TODO: include the configuration in backend_config_dict
+ # we can have a map from module to reference module
+ # and allow user to register new ones
+ qlinear_cls = get_static_quant_module_class(
+ type(float_linear), is_reference=is_reference)
+ ref_linear = qlinear_cls.from_float(float_linear, weight_qparams)
+
+ # if the parent is a fused linear (Sequential), we can replace the first
+ # item to ref linear, otherwise we can update
+ # the linear instance in the module tree
+ if fused_linear is not None:
+ fused_linear[0] = ref_linear
+ else:
+ parent_name, name = _parent_name(self.linear_node.target)
+ setattr(modules[parent_name], name, ref_linear)
op_out = quantized_graph.create_node(
'call_module',
self.linear_node.target,
# Default map for swapping float module to reference quantized modules
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
+ nn.Linear: nnqr.Linear,
nn.Conv1d: nnqr.Conv1d,
nn.Conv2d: nnqr.Conv2d,
nn.Conv3d: nnqr.Conv3d,
- nn.Linear: nnqr.Linear,
nni.ConvReLU1d: nniqr.ConvReLU1d,
nni.ConvReLU2d: nniqr.ConvReLU2d,
nni.ConvReLU3d: nniqr.ConvReLU3d,
- nni.LinearReLU: nniqr.LinearReLU,
# QAT Modules
- nnqat.Linear: nnqr.Linear,
nnqat.Conv2d: nnqr.Conv2d,
nnqat.Conv3d: nnqr.Conv3d,
nniqat.ConvBn1d: nnqr.Conv1d,
nniqat.ConvBnReLU3d: nniqr.ConvReLU3d,
nniqat.ConvReLU2d: nniqr.ConvReLU2d,
nniqat.ConvReLU3d: nniqr.ConvReLU3d,
- nniqat.LinearReLU: nniqr.LinearReLU,
}
# Default map for swapping float module to quantized ones