self.checkGraphModuleNodes(m, expected_node_list=node_list)
def test_qconfig_qat_module_type(self):
- class Linear(torch.nn.Module):
+ class LinearRelu(nn.Sequential):
def __init__(self):
- super().__init__()
- self.w = torch.ones(5, 5)
- self.b = torch.zeros(5)
-
- def forward(self, x):
- return torch.nn.functional.linear(x, self.w, self.b)
-
+ super().__init__(
+ nn.Linear(5, 5),
+ nn.ReLU(),
+ )
class M(torch.nn.Module):
def __init__(self):
super().__init__()
- self.mods1 = torch.nn.Sequential(
- torch.nn.Linear(5, 5),
- )
+ self.lin_relu = LinearRelu()
+ self.linear = nn.Linear(5, 5)
def forward(self, x):
- x = self.mods1(x)
+ x = self.lin_relu(x)
+ x = self.linear(x)
return x
model = M().train()
"": None,
"object_type": [
(torch.nn.Linear, default_qat_qconfig),
+ (torch.nn.ReLU, default_qat_qconfig),
],
}
m = prepare_qat_fx(model, qconfig_dict)
m(torch.rand(5, 5))
node_list = [
ns.call_function(torch.quantize_per_tensor),
+ ns.call_module(nniq.LinearReLU),
ns.call_module(nnq.Linear),
ns.call_method("dequantize"),
]
all_qat_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
object_type_dict = qconfig_dict.get("object_type", None)
- for k, v in object_type_dict.items():
+ new_object_type_dict = object_type_dict.copy()
+ for k, v in new_object_type_dict.items():
if k in all_qat_mappings:
object_type_dict[all_qat_mappings[k]] = v
return qconfig_dict