result_ref = m_ref(data)
self.assertTrue(torch.equal(result, result_ref))
+ def test_sub_scalar(self):
+ class M(torch.nn.Module):
+ def forward(self, x):
+ x = x + 1
+ x = x - 1
+ return x
+
+ m = M().eval()
+ m = prepare_fx(m, {"": default_qconfig})
+ m = convert_fx(m)
+ occurrence = {
+ ns.call_function(torch.quantize_per_tensor): 1,
+ ns.call_method("dequantize"): 1
+ }
+ self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence)
+
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
module, functional, qconfig, is_reference, node_list)
def test_bmm_int_reference(self):
+ """ int8 is not supported for bmm so we won't produce reference
+ pattern for it
+ """
class M(torch.nn.Module):
def __init__(self):
super().__init__()
ns.call_function(torch.quantize_per_tensor),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
- ns.call_method('dequantize'),
ns.call_function(torch.bmm),
- ns.call_function(torch.quantize_per_tensor),
- ns.call_method('dequantize'),
]
m = M().eval()
the pattern matched to this QuantizeHandler instance during the
prepare step.
"""
+ dtypes = get_qconfig_dtypes(qconfig)
+ if not (self.binary_op in binary_op_supported_dtypes and dtypes in binary_op_supported_dtypes[self.binary_op]):
+ return False
if self.num_tensor_args == 1:
return True
elif self.all_node_args_are_tensors and self.input_output_observed():
if is_reference:
act_dtype = activation_dtype(qconfig)
- if act_dtype == torch.float:
+ dtypes = get_qconfig_dtypes(qconfig)
+ if act_dtype == torch.float or \
+ not (self.binary_op in binary_op_supported_dtypes and dtypes in binary_op_supported_dtypes[self.binary_op]):
return quantized_graph.node_copy(node, load_arg(quantized=torch.float))
else:
if self.num_tensor_args == 2: