EmbeddingBagModule,
EmbeddingModule,
EmbeddingWithLinear,
+ LinearReluLinearModel,
)
# annotated models
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
checkQuantized(model)
+ def test_linear_relu_fusion(self):
+ dtype = torch.qint8
+ model = LinearReluLinearModel().eval()
+ qconfig = default_dynamic_qconfig
+ qconfig_dict = {'' : qconfig}
+ torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)
+ prepare_dynamic(model, qconfig_dict)
+ convert_dynamic(model)
+
+ def checkQuantized(model):
+ self.checkDynamicQuantizedLinearRelu(model.fc1, dtype)
+ self.checkDynamicQuantizedLinear(model.fc2, dtype)
+ self.checkScriptable(model, self.calib_data, check_save_load=True)
+ self.checkNoQconfig(model)
+
+ checkQuantized(model)
+
@given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]),
dtype=st.sampled_from([torch.qint8, torch.float16]))
def test_quantized_rnn(self, qconfig, dtype):
import torch.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
+import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.multiprocessing as mp
# graph mode quantization based on fx
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
+ @skipIfNoFBGEMM
+ def test_dynamic_with_fusion(self):
+ """
+ Tests that dynamic quantization APIs work with Linear + Relu fusion
+ """
+ class LinearRelu(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(5, 5)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ x = self.linear(x)
+ return self.relu(x)
+
+ class Linear(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.w = torch.ones(5, 5)
+ self.b = torch.zeros(5)
+
+ def forward(self, x):
+ return torch.nn.functional.linear(x, self.w, self.b)
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu())
+ self.mods2 = Linear()
+ self.relu = F.relu
+
+ def forward(self, x):
+ x = self.mods1(x)
+ x = self.mods2(x)
+ x = self.relu(x)
+ return x
+
+ model = M().eval()
+ qconfig = {
+ "": default_dynamic_qconfig,
+ }
+ m = prepare_fx(model, qconfig)
+ m = convert_fx(m)
+ m(torch.rand(5, 5))
+ node_list = [
+ ns.call_module(nniqd.LinearReLU),
+ ns.call_module(nniqd.LinearReLU),
+ ns.call_function(torch.ops.quantized.linear_relu_dynamic),
+ ]
+ self.checkGraphModuleNodes(m, expected_node_list=node_list)
+
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
}
quant_type_to_qlinear_relu_fun = {
# we don't have linear_relu_dynamic
- QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
+ QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic),
QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu),
QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu),
}
--- /dev/null
+from .modules import * # noqa: F403
--- /dev/null
+import torch
+from .linear_relu import LinearReLU
+
+__all__ = [
+ 'LinearReLU',
+]
--- /dev/null
+import torch
+import torch.nn.quantized.dynamic as nnqd
+import torch.nn.intrinsic as nni
+
+class LinearReLU(nnqd.Linear):
+ r"""
+ A LinearReLU module fused from Linear and ReLU modules that can be used
+ for dynamic quantization.
+ Supports both, FP16 and INT8 quantization.
+
+ We adopt the same interface as :class:`torch.nn.quantized.dynamic.Linear`.
+
+ Attributes:
+ Same as torch.nn.quantized.dynamic.Linear
+
+ Examples::
+
+ >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
+ >>> input = torch.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.size())
+ torch.Size([128, 30])
+ """
+ _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
+
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
+ super().__init__(in_features, out_features, bias, dtype)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self._packed_params.dtype == torch.qint8:
+ # TODO check if we should set reduce_rage = True by default here
+ Y = torch.ops.quantized.linear_relu_dynamic(
+ x, self._packed_params._packed_params, reduce_range=True)
+ # TODO Support this in a later PR
+ # elif self._packed_params.dtype == torch.float16:
+ # Y = torch.ops.quantized.linear_relu_dynamic_fp16(
+ # x, self._packed_params._packed_params)
+ else:
+ raise RuntimeError('Unsupported dtype on dynamic quantized linear relu!')
+ return Y.to(x.dtype)
+
+ def _get_name(self):
+ return 'DynamicQuantizedLinearReLU'
+
+ @classmethod
+ def from_float(cls, mod):
+ return super(LinearReLU, cls).from_float(mod)
import torch
import torch.nn.quantized as nnq
+import torch.nn.intrinsic as nni
from torch.nn.quantized.modules.utils import _quantize_weight
class Linear(nnq.Linear):
mod (Module): a float module, either produced by torch.quantization
utilities or provided by the user
"""
- float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
+ float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
+ torch.nn.intrinsic.modules.fused.LinearReLU]
+
assert type(mod) in float_modules, \
'nn.quantized.dynamic.Linear.from_float only works for one of' + \
str([float_mod.__name__ for float_mod in float_modules])
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+ if type(mod) == nni.LinearReLU:
+ mod = mod[0]
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
qweight = mod.weight.float()
else:
raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!')
- qlinear = Linear(mod.in_features, mod.out_features, dtype=dtype)
+ qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
return qlinear
elif dtypes in [(torch.float32, torch.qint8, torch.quint8),
(torch.float32, torch.float16, None)]:
# choose linear dynamic or linear dynamic fp16 op based on weight dtype
- qlinear_op = torch.ops.quantized.linear_dynamic \
- if weight_dtype == torch.qint8 \
- else torch.ops.quantized.linear_dynamic_fp16
+ if weight_dtype == torch.qint8:
+ if self.relu_node:
+ qlinear_op = torch.ops.quantized.linear_relu_dynamic
+ else:
+ qlinear_op = torch.ops.quantized.linear_dynamic
+ else: # TODO add support for fp16 + relu fusion in a later PR
+ qlinear_op = torch.ops.quantized.linear_dynamic_fp16
+
linear_input = load_arg(quantized=torch.float)(self.linear_node.args[0])
qlinear_args = (linear_input, packed_weight) # type: ignore[assignment]
op_out = quantized_graph.create_node(
# TODO: may need to change the key to Node regenerate the map in each transformation,
# since we might not be able to rely on the name
node_name_to_scope[op_out.name] = node_name_to_scope[self.linear_node.name]
- if self.relu_node:
+ if self.relu_node and weight_dtype is not torch.qint8:
op_out = quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {})
return op_out
else:
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic.quantized as nniq
+import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.nn.intrinsic.qat as nniqat
import torch.nn.intrinsic as nni
import torch.nn.qat as nnqat
nnq.Linear,
nni.LinearReLU,
nniq.LinearReLU,
+ nniqd.LinearReLU,
nnqat.Linear,
nnqd.Linear,
nniqat.LinearReLU,
nniqat.ConvReLU2d,
nniqat.ConvReLU3d,
nniqat.LinearReLU,
+ nniqd.LinearReLU,
])
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
+import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.nn.intrinsic.quantized._reference as nniqr
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
nn.GRU: nnqd.GRU,
nn.LSTMCell: nnqd.LSTMCell,
nn.RNNCell: nnqd.RNNCell,
+ nni.LinearReLU: nniqd.LinearReLU,
}
# Allowlist for propagating the qconfig
import torch
import torch.nn as nn
import torch.nn.functional as F
+import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.nn.intrinsic import _FusedModule
self.assertEqual(type(mod), nnqd.Linear)
self.assertEqual(mod._packed_params.dtype, dtype)
+ def checkDynamicQuantizedLinearRelu(self, mod, dtype):
+ r"""Checks that mod has been swapped for an nnqd.Linear
+ module, the bias is float.
+ """
+ self.assertEqual(type(mod), nniqd.LinearReLU)
+ self.assertEqual(mod._packed_params.dtype, dtype)
+
def check_eager_serialization(self, ref_model, loaded_model, x):
# Check state dict serialization and torch.save APIs
model_dict = ref_model.state_dict()