Updates internal `assert_allclose` callsites in favor of `assert_close` (#61841)
authorPhilip Meier <github.pmeier@posteo.de>
Thu, 19 Aug 2021 19:45:32 +0000 (12:45 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 19 Aug 2021 19:50:41 +0000 (12:50 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61841

Redo of #60863.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30408145

Pulled By: mruberry

fbshipit-source-id: 0b34ebc7f23ba38ecd89640b61d8aca59b7eab58

26 files changed:
benchmarks/cpp/tensorexpr/bench_ops.py
docs/source/jit.rst
test/mobile/test_bytecode.py
test/mobile/test_lite_script_module.py
test/quantization/core/test_quantized_op.py
test/quantization/jit/test_deprecated_jit_quant.py
test/test_fx.py
test/test_fx_experimental.py
test/test_jit.py
test/test_jit_fuser_te.py
test/test_mobile_optimizer.py
test/test_nn.py
test/test_pruning_op.py
test/test_reductions.py
test/test_static_runtime.py
test/test_tensorexpr.py
test/test_tensorexpr_pybind.py
test/test_throughput_benchmark.py
test/test_torch.py
test/test_xnnpack_integration.py
torch/fx/experimental/fx2trt/example/fx2trt_example.py
torch/jit/_trace.py
torch/testing/_core.py
torch/testing/_deprecated.py
torch/testing/_internal/common_quantization.py
torch/testing/_internal/distributed/distributed_test.py

index ca40e5d..12d766a 100644 (file)
@@ -59,7 +59,7 @@ for op in unary_ops:
         traced(x)
 
     # Validate result.
-    torch.testing.assert_allclose(op(x), traced(x))
+    torch.testing.assert_close(op(x), traced(x))
 
     # Benchmark.
     bench_iters = 100
@@ -94,7 +94,7 @@ def test_batch_norm():
             traced(x, y, z)
 
         # Validate result.
-        torch.testing.assert_allclose(op(x, y, z), traced(x, y, z))
+        torch.testing.assert_close(op(x, y, z), traced(x, y, z))
 
         # Benchmark.
         bench_iters = 100
index eeb0d2a..f791c1c 100644 (file)
@@ -475,7 +475,7 @@ In this case, data-dependent control flow like this can be captured using
     #print(str(scripted_fn.graph).strip())
 
     for input_tuple in [inputs] + check_inputs:
-        torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))
+        torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))
 
 .. testoutput::
     :hide:
index 5511e6a..95baa86 100644 (file)
@@ -228,7 +228,7 @@ class testVariousModelVersions(TestCase):
     #             # Load model and run forward method
     #             mobile_module = _load_for_lite_interpreter(str(tmp_input_model_path))
     #             mobile_module_result = mobile_module(module_input)
-    #             torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
+    #             torch.testing.assert_close(mobile_module_result, expected_mobile_module_result)
     #             current_to_version -= 1
 
     #         # Check backport failure case
@@ -270,7 +270,7 @@ class testVariousModelVersions(TestCase):
                 module_input = 1
                 mobile_module_result = mobile_module(module_input)
                 expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
-                torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
+                torch.testing.assert_close(mobile_module_result, expected_mobile_module_result)
                 shutil.rmtree(tmpdirname)
 
     # Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations
@@ -296,7 +296,7 @@ class testVariousModelVersions(TestCase):
             module_input = 1
             mobile_module_result = mobile_module(module_input)
             expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
-            torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
+            torch.testing.assert_close(mobile_module_result, expected_mobile_module_result)
 
 
     def test_get_model_ops_and_info(self):
index 369371f..a86669e 100644 (file)
@@ -48,13 +48,13 @@ class TestLiteScriptModule(TestCase):
         mobile_module = _load_for_lite_interpreter(buffer)
 
         mobile_module_result = mobile_module(input)
-        torch.testing.assert_allclose(script_module_result, mobile_module_result)
+        torch.testing.assert_close(script_module_result, mobile_module_result)
 
         mobile_module_forward_result = mobile_module.forward(input)
-        torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
+        torch.testing.assert_close(script_module_result, mobile_module_forward_result)
 
         mobile_module_run_method_result = mobile_module.run_method("forward", input)
-        torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
+        torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
 
     def test_save_mobile_module_with_debug_info_with_trace(self):
         class A(torch.nn.Module):
@@ -117,13 +117,13 @@ class TestLiteScriptModule(TestCase):
         mobile_module = _load_for_lite_interpreter(buffer)
 
         mobile_module_result = mobile_module(input)
-        torch.testing.assert_allclose(script_module_result, mobile_module_result)
+        torch.testing.assert_close(script_module_result, mobile_module_result)
 
         mobile_module_forward_result = mobile_module.forward(input)
-        torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
+        torch.testing.assert_close(script_module_result, mobile_module_forward_result)
 
         mobile_module_run_method_result = mobile_module.run_method("forward", input)
-        torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
+        torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
 
     def test_find_and_run_method(self):
         class MyTestModule(torch.nn.Module):
@@ -154,7 +154,7 @@ class TestLiteScriptModule(TestCase):
 
         bundled_inputs = mobile_module.run_method("get_all_bundled_inputs")
         mobile_module_result = mobile_module.forward(*bundled_inputs[0])
-        torch.testing.assert_allclose(script_module_result, mobile_module_result)
+        torch.testing.assert_close(script_module_result, mobile_module_result)
 
     def test_method_calls_with_optional_arg(self):
         class A(torch.nn.Module):
@@ -183,7 +183,7 @@ class TestLiteScriptModule(TestCase):
         input = torch.tensor([5])
         script_module_forward_result = script_module.forward(input)
         mobile_module_forward_result = mobile_module.forward(input)
-        torch.testing.assert_allclose(
+        torch.testing.assert_close(
             script_module_forward_result,
             mobile_module_forward_result
         )
@@ -198,7 +198,7 @@ class TestLiteScriptModule(TestCase):
 
         # now both match again
         mobile_module_forward_result = mobile_module.forward(input, 2)
-        torch.testing.assert_allclose(
+        torch.testing.assert_close(
             script_module_forward_result,
             mobile_module_forward_result
         )
index d0a2dea..6c94586 100644 (file)
@@ -1617,8 +1617,8 @@ class TestQuantizedOps(TestCase):
         quantized_out = torch.topk(qX, k, dim=dim, largest=largest, sorted=sorted)
 
         assert(len(unquantized_out) == len(quantized_out))
-        torch.testing.assert_allclose(quantized_out[0].dequantize(), unquantized_out[0])
-        torch.testing.assert_allclose(quantized_out[1], unquantized_out[1])
+        torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0])
+        torch.testing.assert_close(quantized_out[1], unquantized_out[1])
 
     @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
                                               min_side=1, max_side=10),
@@ -1643,8 +1643,8 @@ class TestQuantizedOps(TestCase):
         quantized_out = torch.topk(qX, k, dim=dim, largest=largest, sorted=sorted)
 
         assert(len(unquantized_out) == len(quantized_out))
-        torch.testing.assert_allclose(quantized_out[0].dequantize(), unquantized_out[0])
-        torch.testing.assert_allclose(quantized_out[1], unquantized_out[1])
+        torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0])
+        torch.testing.assert_close(quantized_out[1], unquantized_out[1])
 
 
     """Tests quantize concatenation (both fused and not)."""
@@ -1846,7 +1846,7 @@ class TestQuantizedOps(TestCase):
         else:
             out = torch.ops.quantized.cat([qX, qY], dim=1, scale=scale, zero_point=zero_point)
 
-        torch.testing.assert_allclose(out.dequantize(), ref.dequantize())
+        torch.testing.assert_close(out.dequantize(), ref.dequantize())
         self.assertNotEqual(out.stride(), sorted(out.stride()))
 
     @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=1, max_dims=5,
@@ -3400,8 +3400,7 @@ class TestQuantizedEmbeddingOps(TestCase):
             num_embeddings, embedding_dim, include_last_offset, weights,
             per_sample_weights, indices, offsets)
 
-        torch.testing.assert_allclose(reference_result, result, atol=atol,
-                                      rtol=rtol)
+        torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol)
 
 
         if bit_rate == 8 or bit_rate == 4:
@@ -3424,7 +3423,7 @@ class TestQuantizedEmbeddingOps(TestCase):
                         per_sample_weights=per_sample_weights,
                         compressed_indices_mapping=torch.tensor(mapping_table),
                         include_last_offset=include_last_offset)
-            torch.testing.assert_allclose(reference_result, result, atol=atol, rtol=rtol)
+            torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol)
 
 
 
@@ -3510,7 +3509,7 @@ class TestQuantizedEmbeddingOps(TestCase):
         qresult = quant_op(packed_weight, indices, pruned_weights=False)
 
         ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False)
-        torch.testing.assert_allclose(ref, qresult, atol=0.005, rtol=1e-3)
+        torch.testing.assert_close(ref, qresult, atol=0.005, rtol=1e-3)
 
 
     def test_embedding_2d_indices(self):
@@ -3533,7 +3532,7 @@ class TestQuantizedEmbeddingOps(TestCase):
         qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
         packed_weight = prepack_op(qweight)
         qresult = quant_op(packed_weight, indices, pruned_weights=False)
-        torch.testing.assert_allclose(ref, qresult, atol=0.05, rtol=1e-3)
+        torch.testing.assert_close(ref, qresult, atol=0.05, rtol=1e-3)
 
     def test_embedding_bag_2d_indices(self):
         """
@@ -3555,7 +3554,7 @@ class TestQuantizedEmbeddingOps(TestCase):
         pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack
         q_weights = pt_prepack_op(weights)
         qresult = pt_op(q_weights, indices, mode=0, pruned_weights=False)
-        torch.testing.assert_allclose(result, qresult, atol=0.05, rtol=1e-3)
+        torch.testing.assert_close(result, qresult, atol=0.05, rtol=1e-3)
 
         # Test TorchBind based embedding_bag operator
         obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
@@ -3569,7 +3568,7 @@ class TestQuantizedEmbeddingOps(TestCase):
         packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight)
         qresult = torch.ops.quantized.embedding_bag_byte(packed_weight, indices, mode=0)
 
-        torch.testing.assert_allclose(result, qresult, atol=0.05, rtol=1e-3)
+        torch.testing.assert_close(result, qresult, atol=0.05, rtol=1e-3)
 
 
 class TestQuantizedConv(TestCase):
index 662ead3..68ddb5c 100644 (file)
@@ -99,7 +99,7 @@ class TestDeprecatedJitQuantized(JitTestCase):
 
             self.assertEqual(len(outs), len(ref_outs))
             for out, ref_out in zip(outs, ref_outs):
-                torch.testing.assert_allclose(out, ref_out)
+                torch.testing.assert_close(out, ref_out)
 
     @skipIfNoFBGEMM
     def test_rnn_quantized(self):
@@ -165,32 +165,32 @@ class TestDeprecatedJitQuantized(JitTestCase):
             # Compare int8 quantized to unquantized
             output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
 
-            torch.testing.assert_allclose(output_int8, ref_out)
+            torch.testing.assert_close(output_int8, ref_out)
             for out, ref in zip(final_hiddens_int8, ref_hid):
-                torch.testing.assert_allclose(out, ref)
+                torch.testing.assert_close(out, ref)
 
             # Compare fp16 quantized to unquantized
             output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
 
-            torch.testing.assert_allclose(output_fp16, ref_out)
+            torch.testing.assert_close(output_fp16, ref_out)
             for out, ref in zip(final_hiddens_fp16, ref_hid):
-                torch.testing.assert_allclose(out, ref)
+                torch.testing.assert_close(out, ref)
 
             def compare_quantized_unquantized(ScriptWrapper, cell):
                 wrapper = ScriptWrapper(cell)
 
                 # Compare quantize scripted module to unquantized
                 script_out, script_hid = wrapper(x, hiddens)
-                torch.testing.assert_allclose(script_out, ref_out)
+                torch.testing.assert_close(script_out, ref_out)
                 for out, ref in zip(script_hid, ref_hid):
-                    torch.testing.assert_allclose(out, ref)
+                    torch.testing.assert_close(out, ref)
 
                 # Compare export/import to unquantized
                 export_import_wrapper = self.getExportImportCopyWithPacking(wrapper)
                 ei_out, ei_hid = export_import_wrapper(x, hiddens)
-                torch.testing.assert_allclose(ei_out, ref_out)
+                torch.testing.assert_close(ei_out, ref_out)
                 for out, ref in zip(ei_hid, ref_hid):
-                    torch.testing.assert_allclose(out, ref)
+                    torch.testing.assert_close(out, ref)
 
             if isinstance(cell, torch.jit.quantized.QuantizedGRU):
                 class ScriptWrapper(torch.jit.ScriptModule):
@@ -252,8 +252,8 @@ class TestDeprecatedJitQuantized(JitTestCase):
             fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16)
             y_fp16 = fb_fp16(value)
 
-            torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3)
-            torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
+            torch.testing.assert_close(y_int8, y_ref, rtol=0.0001, atol=1e-3)
+            torch.testing.assert_close(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
 
     @skipIfNoFBGEMM
     def test_erase_class_tensor_shapes(self):
index e39469d..c55e97d 100644 (file)
@@ -593,17 +593,17 @@ class TestFX(JitTestCase):
         x = torch.rand(3, 4)
         ref_out = msm(x)
         test_out = lowered(x)
-        torch.testing.assert_allclose(test_out, ref_out)
+        torch.testing.assert_close(test_out, ref_out)
 
         # Test TorchScript compilation
         scripted_lowered = torch.jit.script(lowered)
         script_out = scripted_lowered(x)
-        torch.testing.assert_allclose(script_out, ref_out)
+        torch.testing.assert_close(script_out, ref_out)
 
         # Test TorchScript ser/de
         import_copy = self.getExportImportCopy(scripted_lowered)
         imported_out = import_copy(x)
-        torch.testing.assert_allclose(imported_out, ref_out)
+        torch.testing.assert_close(imported_out, ref_out)
 
     def test_reserved_getattr(self):
         """Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
index 00f3201..f000b0a 100644 (file)
@@ -876,7 +876,7 @@ terrible spacing
             traced = symbolic_trace(WrapperMod())
             normalized = NormalizeOperators(traced).transform()
             x, y = torch.randn(3, 4), torch.randn(3, 4)
-            torch.testing.assert_allclose(traced(x, y), normalized(x, y))
+            torch.testing.assert_close(traced(x, y), normalized(x, y))
             self.assertFalse(
                 any(n.target in ops_to_test for n in normalized.graph.nodes)
             )
@@ -891,7 +891,7 @@ terrible spacing
             traced = symbolic_trace(WrapperMod())
             normalized = NormalizeOperators(traced).transform()
             x = torch.randn(3, 4)
-            torch.testing.assert_allclose(traced(x), normalized(x))
+            torch.testing.assert_close(traced(x), normalized(x))
             self.assertFalse(
                 any(n.target in ops_to_test for n in normalized.graph.nodes)
             )
@@ -1413,12 +1413,12 @@ class {test_classname}(torch.nn.Module):
         with torch.no_grad():
             model = Foo().eval()
             optimized_model = optimization.optimize_for_inference(model)
-            torch.testing.assert_allclose(model(inp), optimized_model(inp))
+            torch.testing.assert_close(model(inp), optimized_model(inp))
 
             optimized_model2 = optimization.optimize_for_inference(
                 model, pass_config={"remove_dropout": False}
             )
-            torch.testing.assert_allclose(model(inp), optimized_model2(inp))
+            torch.testing.assert_close(model(inp), optimized_model2(inp))
 
     @skipIfNoTorchVision
     @skipIfNoMkldnn
@@ -1450,7 +1450,7 @@ class {test_classname}(torch.nn.Module):
 
                 orig_out = model(inp)
                 new_out = optimized_model(inp)
-                torch.testing.assert_allclose(orig_out, new_out)
+                torch.testing.assert_close(orig_out, new_out)
 
 
 class TestNormalizeOperators(JitTestCase):
index 99df960..2dd0d47 100644 (file)
@@ -497,7 +497,7 @@ class TestJit(JitTestCase):
         FileCheck().check_not("aten::relu(") \
             .check("aten::_add_relu(") \
             .run(m.graph)
-        torch.testing.assert_allclose(orig_res, new_res)
+        torch.testing.assert_close(orig_res, new_res)
 
         # add, relu_
         a = torch.rand((7, 11))
@@ -516,7 +516,7 @@ class TestJit(JitTestCase):
         FileCheck().check_not("aten::relu_(") \
             .check("aten::_add_relu(") \
             .run(m.graph)
-        torch.testing.assert_allclose(orig_res, new_res)
+        torch.testing.assert_close(orig_res, new_res)
 
         class Madd_(torch.nn.Module):
             def __init__(self, relu_op):
@@ -547,10 +547,10 @@ class TestJit(JitTestCase):
             .check_not("aten::relu_(") \
             .check("aten::_add_relu_(") \
             .run(m.graph)
-        torch.testing.assert_allclose(orig_res, new_res)
+        torch.testing.assert_close(orig_res, new_res)
         # Since _add_relu_ does inplace mutation ensure
         # a_copy is modified
-        torch.testing.assert_allclose(orig_res, a_copy)
+        torch.testing.assert_close(orig_res, a_copy)
 
         class Madd_out(torch.nn.Module):
             def __init__(self, relu_op):
@@ -585,10 +585,10 @@ class TestJit(JitTestCase):
             .check_not("aten::relu_(") \
             .check("aten::_add_relu(") \
             .run(m.graph)
-        torch.testing.assert_allclose(orig_res, new_res)
+        torch.testing.assert_close(orig_res, new_res)
         # Since _add_relu_ with out=a does inplace mutation ensure
         # a_copy is modified
-        torch.testing.assert_allclose(orig_res, a_copy)
+        torch.testing.assert_close(orig_res, a_copy)
 
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple executor doesn't have shape information")
     def test_peephole_optimize_shape_ops(self):
@@ -8888,7 +8888,7 @@ dedent """
     def test_pack_unpack_state(self):
         sm = TestScript.DerivedStateModule()
         x = torch.rand(3, 4, dtype=torch.float)
-        torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+        torch.testing.assert_close(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
 
         # Test save path
         self.assertFalse(sm.pack_called.item())
@@ -8899,11 +8899,11 @@ dedent """
         # ensure unpack was called after serialization so as to leave the module in an initialized state
         self.assertTrue(sm.unpack_called.item())
 
-        torch.testing.assert_allclose(sm.derived, torch.neg(sm.param))
+        torch.testing.assert_close(sm.derived, torch.neg(sm.param))
 
         # Test load paths
         self.assertTrue(imported.unpack_called.item())
-        torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+        torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
 
     @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
     def test_torch_functional(self):
@@ -9101,11 +9101,11 @@ dedent """
                 return self.submod(x + self.buf)
 
         m = Mod()
-        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
+        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
         m.apply(lambda s: s._pack())
-        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4))
+        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.zeros(3, 4))
         m.apply(lambda s: s._unpack())
-        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
+        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
 
     def test_torch_any(self):
         def fn(x):
@@ -10958,7 +10958,7 @@ dedent """
         torch._C._jit_pass_remove_dropout(m._c)
         res = m(data)
         FileCheck().check_not("aten::dropout").run(str(m.graph))
-        torch.testing.assert_allclose(ref_res, res, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_res, res, rtol=1e-2, atol=1e-3)
 
     def test_unfold_zero_dim(self):
         def fn(x):
index ba47547..64c26b7 100644 (file)
@@ -1186,7 +1186,7 @@ class TestTEFuser(JitTestCase):
             ref = fn(input_v, mask)
             try:
                 t = torch.jit.trace(fn, (input_v, mask))
-                torch.testing.assert_allclose(ref, t(input_v, mask))
+                torch.testing.assert_close(ref, t(input_v, mask))
                 print(torch.jit.last_executed_optimized_graph())
                 self.assertLastGraphAllFused()
             except Exception as e:
@@ -1287,7 +1287,7 @@ class TestTEFuser(JitTestCase):
                 continue
             try:
                 t = torch.jit.trace(fn, (x,))
-                torch.testing.assert_allclose(ref, t(x))
+                torch.testing.assert_close(ref, t(x))
                 self.assertAllFused(t.graph_for(x))
             except Exception as e:
                 raise RuntimeError(
@@ -1683,7 +1683,7 @@ class TestTEFuser(JitTestCase):
             for _ in range(4):
                 for pair in zip(script(*inputs), eager(*inputs)):
                     test, ref = pair
-                    torch.testing.assert_allclose(test, ref)
+                    torch.testing.assert_close(test, ref)
                     self.assertAllFused(script.graph_for(*inputs))
 
     def test_sub_gt_and(self):
@@ -1776,10 +1776,10 @@ class TestTEFuser(JitTestCase):
                 one = torch.tensor([[1]]).to(dtype2)
                 script = torch.jit.trace(eager, (x, zero))
                 for _ in range(3):
-                    torch.testing.assert_allclose(
+                    torch.testing.assert_close(
                         script(x, zero),
                         eager(x, zero))
-                    torch.testing.assert_allclose(
+                    torch.testing.assert_close(
                         script(x, one),
                         eager(x, one))
                 self.assertAllFused(script.graph_for(x, one))
@@ -1824,7 +1824,7 @@ class TestTEFuser(JitTestCase):
                 xs -= 0.1 * xs.grad
                 x.grad = None
                 xs.grad = None
-        torch.testing.assert_allclose(y, ys)
+        torch.testing.assert_close(y, ys)
 
     def test_relu_fwd_bwd(self):
         def eager(x):
@@ -1907,12 +1907,12 @@ class TestTEFuser(JitTestCase):
             for _ in range(3):
                 script(x)
 
-            torch.testing.assert_allclose(eager(x), script(x))
+            torch.testing.assert_close(eager(x), script(x))
 
             # Now when an input hits the unrolled path, it will produce an
             # incorrectly-sized tensor, since size=1 has been burned in.
             x = torch.ones((8, 1))
-            torch.testing.assert_allclose(eager(x), script(x))
+            torch.testing.assert_close(eager(x), script(x))
 
 works_list = [
     '__radd__',
index 78ebb55..19f07e2 100644 (file)
@@ -119,7 +119,7 @@ class TestOptimizer(TestCase):
                    .check_not("aten::relu(") \
                    .check_count("aten::_add_relu(", 1, exactly=True) \
                    .run(optimized_scripted_model.graph)
-        torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
 
         FileCheck().check_not("Tensor = aten::conv2d") \
                    .check_not("Tensor = prim::CallFunction") \
@@ -131,7 +131,7 @@ class TestOptimizer(TestCase):
                    .check_not("aten::relu(") \
                    .check_count("aten::_add_relu(", 1, exactly=True) \
                    .run(optimized_scripted_model.foo.graph)
-        torch.testing.assert_allclose(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
 
 
         optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
@@ -142,7 +142,7 @@ class TestOptimizer(TestCase):
                    .check_not("prepacked::linear_clamp_run") \
                    .check_not("prepacked::conv2d_clamp_run") \
                    .run(optimized_scripted_model_no_prepack.graph)
-        torch.testing.assert_allclose(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
 
 
         bn_test_module = BNTestModule()
@@ -157,14 +157,14 @@ class TestOptimizer(TestCase):
         bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
         self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
         bn_input = torch.rand(1, 1, 6, 6)
-        torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
 
         optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
         no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
         FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
                    .run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
         bn_input = torch.rand(1, 1, 6, 6)
-        torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
 
         class MyMobileOptimizedTagTest(torch.nn.Module):
             def __init__(self):
@@ -231,7 +231,7 @@ class TestOptimizer(TestCase):
         FileCheck().check_not("dropout.__") \
             .check_count("aten::_add_relu(", 1, exactly=True) \
             .run(optimized_scripted_model.foo.graph)
-        torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
 
         class BNTestNoForwardModule(torch.nn.Module):
             def __init__(self):
@@ -257,7 +257,7 @@ class TestOptimizer(TestCase):
         bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
         self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1)
         bn_input = torch.rand(1, 1, 6, 6)
-        torch.testing.assert_allclose(
+        torch.testing.assert_close(
             bn_no_forward_scripted_module.foo(bn_input),
             bn_fold_no_forward_scripted_module.foo(bn_input),
             rtol=1e-2,
@@ -493,7 +493,7 @@ class TestOptimizer(TestCase):
             data = torch.randn(4, 1, 4, 4)
             m_res = m(data)
             m_optim_res = m_optim(data)
-            torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
+            torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
 
             # generic case
 
@@ -507,7 +507,7 @@ class TestOptimizer(TestCase):
             data = torch.randn(4, 1, 4, 4)
             m_res = m(data)
             m_optim_res = m_optim(data)
-            torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
+            torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
 
     @unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
     def test_mobilenet_optimize_for_mobile(self):
index ccf6f6e..d21e047 100644 (file)
@@ -4717,7 +4717,7 @@ class TestNN(NNTestCase):
         packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor)
         actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor)
         expected_output = fc_op(X, W, b)
-        torch.testing.assert_allclose(expected_output, actual_output.cpu(), atol=1e-3, rtol=1e-3)
+        torch.testing.assert_close(torch.from_numpy(expected_output), actual_output.cpu(), atol=1e-3, rtol=1e-3)
 
     def test_embeddingbag_from_pretrained(self):
         a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
@@ -6797,8 +6797,7 @@ class TestNN(NNTestCase):
             encoder_input = torch.tensor([[[20., 30., 40., 50.]]])
             result = model(encoder_input)
             ref_output = torch.tensor([[[2.249815, 0.131006, -0.702199, 0.177868]]])
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output)
+            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
 
             # deterministic input
             encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
@@ -6806,8 +6805,7 @@ class TestNN(NNTestCase):
             result = model(encoder_input)
             ref_output = perm_fn(torch.tensor([[[2.264103, 0.121417, -0.696012, 0.159724]],
                                                [[2.264103, 0.121417, -0.696012, 0.159724]]]))
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output)
+            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
 
             # deterministic input
             encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
@@ -6831,8 +6829,7 @@ class TestNN(NNTestCase):
                                                 [2.4237977, 0.03290575, -0.60561789, -0.05940082]],
                                                [[2.41383916, 0.02686345, -0.61256377, -0.06380707],
                                                 [2.42000277, 0.03800944, -0.60824798, -0.04754947]]]))
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output)
+            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
 
     def test_transformerdecoderlayer(self):
         # this is a deterministic test for TransformerDecoderLayer
@@ -7013,8 +7010,7 @@ class TestNN(NNTestCase):
             memory_input = torch.tensor([[[60., 70., 80., 90.]]])
             result = model(decoder_input, memory_input)
             ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]])
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output)
+            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
@@ -7023,8 +7019,7 @@ class TestNN(NNTestCase):
             result = model(decoder_input, memory_input)
             ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]],
                                                [[2.415448, 0.054389, -0.610932, -0.0156613]]]))
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output)
+            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
@@ -7034,8 +7029,7 @@ class TestNN(NNTestCase):
             result = model(decoder_input, memory_input)
             ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]],
                                                [[2.338531, 0.087709, -0.65776, 0.080646]]]))
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output)
+            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
@@ -7061,8 +7055,7 @@ class TestNN(NNTestCase):
                                                 [2.42216881, 0.03586554, -0.6067524, -0.05289126]],
                                                [[2.42205716, 0.03488046, -0.60683681, -0.05460596],
                                                 [2.42240309, 0.0354595, -0.60659063, -0.05378816]]]))
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output)
+            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
 
     def test_transformerencoder(self):
         def get_a_test_layer(use_cuda, activation, batch_first=False):
@@ -7130,13 +7123,13 @@ class TestNN(NNTestCase):
                                                 [2.422901, 0.024187, -0.606178, -0.074929]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # all 0
             mask = torch.zeros([2, 5]).to(device) == 1
             result = model(encoder_input, src_key_padding_mask=mask)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
             mask[0, 1] = 1
             mask[1, 3] = 1
             mask[1, 4] = 1
@@ -7153,7 +7146,7 @@ class TestNN(NNTestCase):
                                                 [2.4242, 0.024653, -0.605266, -0.074959]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # test case 2, multiple layers no norm
             model = nn.TransformerEncoder(encoder_layer, 2).to(device)
@@ -7170,7 +7163,7 @@ class TestNN(NNTestCase):
                                                 [2.419075, 0.017449, -0.608722, -0.085014]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             model = nn.TransformerEncoder(encoder_layer, 6).to(device)
             result = model(encoder_input, src_key_padding_mask=mask)
@@ -7186,7 +7179,7 @@ class TestNN(NNTestCase):
                                                 [2.419101, 0.017453, -0.608704, -0.085025]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # test case 3, multiple layers with norm
             # d_model = 4
@@ -7205,7 +7198,7 @@ class TestNN(NNTestCase):
                                                 [1.695952, -0.357637, -0.893065, -0.445251]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             model = nn.TransformerEncoder(encoder_layer, 6, norm=norm).to(device)
             result = model(encoder_input, src_key_padding_mask=mask)
@@ -7221,7 +7214,7 @@ class TestNN(NNTestCase):
                                                 [1.695955, -0.357639, -0.893051, -0.445265]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
 
     def test_transformerdecoder(self):
@@ -7271,7 +7264,7 @@ class TestNN(NNTestCase):
             ref_output = torch.tensor(
                 [[[2.314351, 0.094805, -0.671322, 0.101977]]]).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-3)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
@@ -7282,7 +7275,7 @@ class TestNN(NNTestCase):
                                                [[2.422245, 0.051716, -0.606338, -0.024756]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-4)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
@@ -7294,7 +7287,7 @@ class TestNN(NNTestCase):
                                                [[2.343536, 0.085561, -0.654954, 0.074991]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-4)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
@@ -7324,7 +7317,7 @@ class TestNN(NNTestCase):
                                                 [2.432306, 0.028858, -0.599542, -0.072846]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # key_padding_mask
             key_padding_mask = torch.zeros(2, 3).to(device) == 1
@@ -7338,7 +7331,7 @@ class TestNN(NNTestCase):
                                                 [2.432306, 0.028858, -0.599542, -0.072846]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # key_padding_mask
             key_padding_mask[0, 2] = 1
@@ -7354,7 +7347,7 @@ class TestNN(NNTestCase):
                                                 [2.432659, 0.029244, -0.599294, -0.072382]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # memory_key_padding_mask
             key_padding_mask = torch.zeros(2, 5).to(device) == 1
@@ -7368,7 +7361,7 @@ class TestNN(NNTestCase):
                                                 [2.432306, 0.028858, -0.599542, -0.072846]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # memory_key_padding_mask
             key_padding_mask[0, 4] = 1
@@ -7385,7 +7378,7 @@ class TestNN(NNTestCase):
                                                 [2.433075, 0.028543, -0.598987, -0.073985]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # multiple layers no norm
             model = nn.TransformerDecoder(decoder_layer, 2).to(device)
@@ -7397,7 +7390,7 @@ class TestNN(NNTestCase):
             ref_output = torch.tensor(
                 [[[2.31316, 0.0950293, -0.671995, 0.102802]]]).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-3)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
 
             # multiple layers no norm
             model = nn.TransformerDecoder(decoder_layer, 6).to(device)
@@ -7430,7 +7423,7 @@ class TestNN(NNTestCase):
                                                 [2.43113, 0.0279516, -0.600376, -0.0736896]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # multiple layers with norm
             # d_model = 4
@@ -7444,7 +7437,7 @@ class TestNN(NNTestCase):
             ref_output = torch.tensor(
                 [[[1.66166, -0.326986, -1.01466, -0.320017]]]).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-3)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
 
             # multiple layers with norm
             model = nn.TransformerDecoder(decoder_layer, 6, norm=norm).to(device)
@@ -7477,7 +7470,7 @@ class TestNN(NNTestCase):
                                                 [1.69571, -0.357363, -0.894154, -0.444196]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
             # gelu activation test cases
             activation = "gelu"
@@ -7495,7 +7488,7 @@ class TestNN(NNTestCase):
             result = model(decoder_input, memory_input)
             ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]]).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-3)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
@@ -7505,7 +7498,7 @@ class TestNN(NNTestCase):
             ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]],
                                                [[2.415448, 0.054389, -0.610932, -0.0156613]]])).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-4)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
@@ -7516,7 +7509,7 @@ class TestNN(NNTestCase):
             ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]],
                                                [[2.338531, 0.087709, -0.65776, 0.080646]]])).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-4)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
 
             # deterministic input
             decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
@@ -7546,7 +7539,7 @@ class TestNN(NNTestCase):
                                                 [2.42240309, 0.0354595, -0.60659063, -0.05378816]]]
                                               )).to(device)
             self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5)
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
 
     @unittest.skipIf(not (TEST_CUDNN and TEST_MULTIGPU), 'CUDNN or multi-gpu not available')
     def test_cudnn_rnn_dropout_states_device(self):
index 28f31ae..97a499b 100644 (file)
@@ -50,7 +50,7 @@ class PruningOpTest(TestCase):
         ref_pruned_weights, ref_compressed_indices_map = get_reference_result(
             embedding_weights, mask, indices_type)
 
-        torch.testing.assert_allclose(pt_pruned_weights, ref_pruned_weights)
+        torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights)
         self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map)
         self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
 
index 42edfb3..f3f0d4c 100644 (file)
@@ -2664,36 +2664,38 @@ class TestReductions(TestCase):
             self.assertEqual(np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg,
                              exact_dtype=False)
 
-            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True), msg=error_msg)
+            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
+                             msg=error_msg)
             self.assertEqual(np_function(np_input, axis=2, keepdims=True), fn(master_input, dim=2, keepdim=True),
                              msg=error_msg, exact_dtype=False)
 
-            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True), msg=error_msg)
+            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
+                             msg=error_msg)
             self.assertEqual(np_function(np_input, axis=-1, keepdims=True), fn(master_input, dim=-1, keepdim=True),
                              msg=error_msg, exact_dtype=False)
 
-            # Check if returned data is correct.
-            check_func = (torch.testing.assert_allclose if math.isnan(return_value) or math.isinf(return_value) else
-                          self.assertEqual)
-
-            check_func(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg)
-            check_func(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg)
-            check_func(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True), msg=error_msg)
-            check_func(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True), msg=error_msg)
+            self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg)
+            self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg)
+            self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True),
+                             msg=error_msg)
+            self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True),
+                             msg=error_msg)
 
             if name != 'logsumexp':
                 # The scipy function does not work for reduction the zero dimension
-                check_func(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(), msg=error_msg)
-                check_func(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(), msg=error_msg)
-                check_func(np.float32(np_function(np_input, axis=1, keepdims=True)),
-                           fn(master_input, dim=1, keepdim=True).cpu().numpy(),
-                           msg=error_msg)
-                check_func(np.float32(np_function(np_input, axis=-2, keepdims=True)),
-                           fn(master_input, dim=-2, keepdim=True).cpu().numpy(),
-                           msg=error_msg)
+                self.assertEqual(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(),
+                                 msg=error_msg)
+                self.assertEqual(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(),
+                                 msg=error_msg)
+                self.assertEqual(np.float32(np_function(np_input, axis=1, keepdims=True)),
+                                 fn(master_input, dim=1, keepdim=True).cpu().numpy(),
+                                 msg=error_msg)
+                self.assertEqual(np.float32(np_function(np_input, axis=-2, keepdims=True)),
+                                 fn(master_input, dim=-2, keepdim=True).cpu().numpy(),
+                                 msg=error_msg)
 
                 # logsumexp throws a type error when not specifying dim so test separately.
-                check_func(torch.full((), return_value, device=device), fn(master_input), msg=error_msg)
+                self.assertEqual(torch.full((), return_value, device=device), fn(master_input), msg=error_msg)
             else:
                 self.assertRaises(TypeError, lambda: fn(master_input))
 
index 9b38a5a..94043e2 100644 (file)
@@ -186,10 +186,10 @@ class TestStaticModule(TestCase):
         o_test_kw = attention_a(src, src, value=src, mask=src_mask)
 
         for a, b in zip(o_ref, o_test):
-            torch.testing.assert_allclose(a, b)
+            torch.testing.assert_close(a, b)
 
         for a, b in zip(o_ref, o_test_kw):
-            torch.testing.assert_allclose(a, b)
+            torch.testing.assert_close(a, b)
 
     def test_multihead_attention_layer_benchmark(self):
         HID_DIM = 256
@@ -228,20 +228,20 @@ class TestStaticModule(TestCase):
             top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
         ref_bot = bot_l(bot_inp)
         acc_bot = bot_l_acc(bot_inp)[0]
-        torch.testing.assert_allclose(acc_bot, ref_bot)
+        torch.testing.assert_close(acc_bot, ref_bot)
         ref_top = top_l(top_inp)
         acc_top = top_l_acc(top_inp)[0]
-        torch.testing.assert_allclose(acc_top, ref_top)
+        torch.testing.assert_close(acc_top, ref_top)
         for _ in range(5):
             with torch.no_grad():
                 bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
                 top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
             ref_bot = bot_l(bot_inp)
             acc_bot = bot_l_acc(bot_inp)[0]
-            torch.testing.assert_allclose(acc_bot, ref_bot)
+            torch.testing.assert_close(acc_bot, ref_bot)
             ref_top = top_l(top_inp)
             acc_top = top_l_acc(top_inp)[0]
-            torch.testing.assert_allclose(acc_top, ref_top)
+            torch.testing.assert_close(acc_top, ref_top)
 
     def test_trivial_graph(self):
         s = torch.full((2, 2), 2)
@@ -249,7 +249,7 @@ class TestStaticModule(TestCase):
         o_ref = tg(s, s, s)
         tg_a = StaticModule(tg)
         o_test = tg_a(s, s, s)[0]
-        torch.testing.assert_allclose(o_ref, o_test)
+        torch.testing.assert_close(o_ref, o_test)
 
     def test_leaky_relu(self):
         s = torch.randn(5, 5)
@@ -257,7 +257,7 @@ class TestStaticModule(TestCase):
         o_ref = tg(s)
         tg_a = StaticModule(tg)
         o_test = tg_a(s)[0]
-        torch.testing.assert_allclose(o_ref, o_test)
+        torch.testing.assert_close(o_ref, o_test)
 
     def test_attr(self):
         """
@@ -293,7 +293,7 @@ class TestStaticModule(TestCase):
         ms = torch.jit.script(m)
         sm = StaticModule(ms)
         output_sm = sm(input)[0]
-        torch.testing.assert_allclose(output_s, output_sm)
+        torch.testing.assert_close(output_s, output_sm)
         sm.benchmark([input], {}, 2, 2)
         sm.benchmark_individual_ops([input], {}, 2, 2)
         sm.benchmark([], {"x": input}, 2, 2)
@@ -307,7 +307,7 @@ class TestStaticModule(TestCase):
         torch._C._fuse_to_static_module(tg.graph)
         assert "StaticSubgraph" in str(tg.graph)
         o_test = tg(s, s, s)
-        torch.testing.assert_allclose(o_ref, o_test)
+        torch.testing.assert_close(o_ref, o_test)
 
     @unittest.skip("Temporarily disabled")
     def test_fusion_multihead_attention_layer(self):
@@ -332,7 +332,7 @@ class TestStaticModule(TestCase):
         o_test = attention(src, src, src, src_mask)
 
         for a, b in zip(o_ref, o_test):
-            torch.testing.assert_allclose(a, b)
+            torch.testing.assert_close(a, b)
 
     @unittest.skip("Temporarily disabled")
     def test_fusion_loop(self):
@@ -344,7 +344,7 @@ class TestStaticModule(TestCase):
         torch._C._fuse_to_static_module(lg.graph)
         assert "StaticSubgraph" in str(lg.graph)
         o_test = lg(a, b, c)
-        torch.testing.assert_allclose(o_ref, o_test)
+        torch.testing.assert_close(o_ref, o_test)
 
     @unittest.skip("Temporarily disabled")
     def test_fusion_outputs(self):
@@ -357,7 +357,7 @@ class TestStaticModule(TestCase):
         assert "StaticSubgraph" in str(og.graph)
         o_test = og(a, b, b, c)
         for i in o_ref.keys():
-            torch.testing.assert_allclose(o_ref[i], o_test[i])
+            torch.testing.assert_close(o_ref[i], o_test[i])
 
 
 if __name__ == "__main__":
index 5014510..6353113 100644 (file)
@@ -1468,7 +1468,7 @@ class TestTensorExprFuser(BaseTestClass):
         am_s = getModule(True)
         ref = am(x, x, x)
         test = am_s(x, x, x)
-        torch.testing.assert_allclose(ref, test)
+        torch.testing.assert_close(ref, test)
 
         # Now do the aliasing
         am.a = am.b
@@ -1477,7 +1477,7 @@ class TestTensorExprFuser(BaseTestClass):
         am_s.a = am_s.b
         test = am_s(x, x, x)
 
-        torch.testing.assert_allclose(ref, test)
+        torch.testing.assert_close(ref, test)
 
     def test_alias_analysis_inputs(self):
         class AliasModule(nn.Module):
@@ -1510,7 +1510,7 @@ class TestTensorExprFuser(BaseTestClass):
         x = torch.randn(128, 128)
         test = am_s(x, x, x)
 
-        torch.testing.assert_allclose(ref, test)
+        torch.testing.assert_close(ref, test)
 
     def test_alias_analysis_input_and_module(self):
         class AliasModule(nn.Module):
@@ -1545,7 +1545,7 @@ class TestTensorExprFuser(BaseTestClass):
         am_s.b = x
         test = am_s(x, x, x)
 
-        torch.testing.assert_allclose(ref, test)
+        torch.testing.assert_close(ref, test)
 
     def test_multiple_outputs(self):
         for device in self.devices:
index 4138b2f..d838892 100644 (file)
@@ -44,7 +44,7 @@ class TestTensorExprPyBind(JitTestCase):
             tB = torch.randn(n)
             tC = torch.empty(n)
             cg.call([tA, tB, tC])
-            torch.testing.assert_allclose(tA + tB, tC)
+            torch.testing.assert_close(tA + tB, tC)
 
     def test_call_raw(self):
         with kernel_arena_scope():
@@ -55,7 +55,7 @@ class TestTensorExprPyBind(JitTestCase):
             tB = torch.randn(n, dtype=torch.float64)
             tC = torch.empty(n, dtype=torch.float64)
             cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()])
-            torch.testing.assert_allclose(tA + tB, tC)
+            torch.testing.assert_close(tA + tB, tC)
 
     def test_external_calls(self):
         with kernel_arena_scope():
@@ -77,7 +77,7 @@ class TestTensorExprPyBind(JitTestCase):
             tB = torch.ones(4, 1)
             tC = torch.empty(1, 1)
             codegen.call([tA, tB, tC])
-            torch.testing.assert_allclose(torch.matmul(tA, tB), tC)
+            torch.testing.assert_close(torch.matmul(tA, tB), tC)
 
     def test_dynamic_shape(self):
         with kernel_arena_scope():
@@ -103,7 +103,7 @@ class TestTensorExprPyBind(JitTestCase):
                 tB = torch.randn(n, dtype=torch.double)
                 tC = torch.empty(n, dtype=torch.double)
                 cg.call([tA, tB, tC, n])
-                torch.testing.assert_allclose(tA - tB, tC)
+                torch.testing.assert_close(tA - tB, tC)
 
             test_with_shape(8)
             test_with_shape(31)
index 9d60344..139ca0c 100644 (file)
@@ -1,7 +1,6 @@
 
 import torch
 from torch.utils import ThroughputBenchmark
-from torch.testing import assert_allclose
 
 from torch.testing._internal.common_utils import run_tests, TestCase, TemporaryFileName
 
@@ -56,7 +55,7 @@ class TestThroughputBenchmark(TestCase):
             # or just unpack the list of inputs
             module_result = module(*inputs[i])
             bench_result = bench.run_once(*inputs[i])
-            assert_allclose(bench_result, module_result)
+            torch.testing.assert_close(bench_result, module_result)
 
         stats = bench.benchmark(
             num_calling_threads=4,
index 6766d50..515052a 100644 (file)
@@ -1566,7 +1566,7 @@ class AbstractTestCases:
             n_half = len(ref_sample) // 2
             _ = engine.draw(n=n_half)
             sample = engine.draw(n=n_half)
-            torch.testing.assert_allclose(sample, ref_sample[n_half:])
+            torch.testing.assert_close(sample, ref_sample[n_half:])
 
         def test_sobolengine_continuing_scrambled(self):
             self.test_sobolengine_continuing(scramble=True)
@@ -1578,7 +1578,7 @@ class AbstractTestCases:
             engine.reset()
             self.assertEqual(engine.num_generated, 0)
             sample = engine.draw(n=len(ref_sample))
-            torch.testing.assert_allclose(sample, ref_sample)
+            torch.testing.assert_close(sample, ref_sample)
 
         def test_sobolengine_reset_scrambled(self):
             self.test_sobolengine_reset(scramble=True)
@@ -1588,7 +1588,7 @@ class AbstractTestCases:
             engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
             engine.fast_forward(4)
             sample = engine.draw(n=4)
-            torch.testing.assert_allclose(sample, ref_sample[4:])
+            torch.testing.assert_close(sample, ref_sample[4:])
             # alternate fast forwarding with sampling
             engine.reset()
             even_draws = []
@@ -1597,9 +1597,9 @@ class AbstractTestCases:
                     even_draws.append(engine.draw())
                 else:
                     engine.fast_forward(1)
-            torch.testing.assert_allclose(
+            torch.testing.assert_close(
                 ref_sample[[i for i in range(8) if i % 2 == 0]],
-                np.concatenate(even_draws),
+                torch.from_numpy(np.concatenate(even_draws)),
             )
 
         def test_sobolengine_fast_forward_scrambled(self):
@@ -1609,13 +1609,13 @@ class AbstractTestCases:
             d = 50
             engine = torch.quasirandom.SobolEngine(d, scramble=scramble, seed=123456)
             sample = engine.draw(1024)
-            torch.testing.assert_allclose(
+            torch.testing.assert_close(
                 torch.mean(sample, dim=0), torch.full((d,), 0.5), atol=2, rtol=2
             )
-            torch.testing.assert_allclose(
+            torch.testing.assert_close(
                 np.percentile(sample, 25, axis=0), np.repeat(0.25, d), atol=2, rtol=2
             )
-            torch.testing.assert_allclose(
+            torch.testing.assert_close(
                 np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2
             )
 
@@ -2440,7 +2440,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j,  ..., 1.+1.j, 1.+1.j, 1.+1.j],
             actual_norm, actual_mean, actual_stdev = \
                 torch.ops._caffe2.LayerNorm(torch.tensor(X), torch.tensor(
                     weight), torch.tensor(bias), 1, epsilon, True)
-            torch.testing.assert_allclose(expected_norm, actual_norm)
+            torch.testing.assert_close(expected_norm, actual_norm)
 
         def test_memory_format(self):
             def test_helper(x, memory_format):
index 4fa64e7..a0f8328 100644 (file)
@@ -34,7 +34,7 @@ class TestXNNPACKOps(TestCase):
         ref_result = F.linear(input_data, weight, bias)
         packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
         output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
-        torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
 
     @given(input_size=st.integers(2, 32),
            weight_output_dim=st.integers(2, 64),
@@ -49,7 +49,7 @@ class TestXNNPACKOps(TestCase):
         ref_result = F.linear(input_data, weight, bias)
         packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
         output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
-        torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
 
 
     @given(batch_size=st.integers(0, 3),
@@ -107,7 +107,7 @@ class TestXNNPACKOps(TestCase):
         packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(weight, bias,
                                                                       strides, paddings, dilations, groups)
         xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(input_data, packed_weight_bias)
-        torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
     @given(batch_size=st.integers(1, 3),
            input_channels_per_group=st.integers(1, 32),
@@ -174,7 +174,7 @@ class TestXNNPACKOps(TestCase):
                                                                                 output_paddings, dilations,
                                                                                 groups)
         xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(input_data, packed_weight_bias)
-        torch.testing.assert_allclose(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3)
 
 @unittest.skipUnless(torch.backends.xnnpack.enabled,
                      " XNNPACK must be enabled for these tests."
@@ -214,7 +214,7 @@ class TestXNNPACKSerDes(TestCase):
         input_data = torch.rand(data_shape)
         ref_result = scripted_linear(input_data)
         output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
-        torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
 
         # Serialize the modules and then deserialize
         input_data = torch.rand(data_shape)
@@ -228,7 +228,7 @@ class TestXNNPACKSerDes(TestCase):
         deserialized_linear_clamp_prepacked = torch.jit.load(buffer)
         ref_result = deserialized_linear(input_data)
         output_linearprepacked = deserialized_linear_clamp_prepacked(input_data)
-        torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
 
     @given(batch_size=st.integers(0, 3),
            input_channels_per_group=st.integers(1, 32),
@@ -309,7 +309,7 @@ class TestXNNPACKSerDes(TestCase):
             weight, bias, strides, paddings, dilations, groups))
         ref_result = scripted_conv2d(input_data)
         xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
-        torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
         # Serialize the modules and then deserialize
         input_data = torch.rand((batch_size, input_channels, height, width))
@@ -325,7 +325,7 @@ class TestXNNPACKSerDes(TestCase):
         deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
         ref_result = deserialized_conv2d(input_data)
         xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
-        torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
     @given(batch_size=st.integers(0, 3),
            input_channels_per_group=st.integers(1, 32),
@@ -417,7 +417,7 @@ class TestXNNPACKSerDes(TestCase):
             weight, bias, strides, paddings, output_paddings, dilations, groups))
         ref_result = scripted_conv2d(input_data)
         xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
-        torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
         # Serialize the modules and then deserialize
         input_data = torch.rand((batch_size, input_channels, height, width))
@@ -433,7 +433,7 @@ class TestXNNPACKSerDes(TestCase):
         deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
         ref_result = deserialized_conv2d(input_data)
         xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
-        torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
     @given(batch_size=st.integers(0, 3),
            input_channels_per_group=st.integers(1, 32),
@@ -549,7 +549,7 @@ class TestXNNPACKSerDes(TestCase):
                 groups))
         ref_result = scripted_m(input_data)
         xnnpack_result = scripted_m_prepacked(input_data)
-        torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
         # Serialize the modules and then deserialize
         input_data = torch.rand((batch_size, input_channels, height, width))
@@ -564,7 +564,7 @@ class TestXNNPACKSerDes(TestCase):
         deserialized_m_prepacked = torch.jit.load(buffer)
         ref_result = deserialized_m(input_data)
         xnnpack_result = deserialized_m_prepacked(input_data)
-        torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
 
 @unittest.skipUnless(torch.backends.xnnpack.enabled,
@@ -610,7 +610,7 @@ class TestXNNPACKRewritePass(TestCase):
                 else:
                     FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
             xnnpack_result = deserialized_scripted_model(input_data)
-            torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
     def test_linear(self):
         data_shape = [2, 3, 32]
@@ -965,7 +965,7 @@ class TestXNNPACKConv1dTransformPass(TestCase):
                 else:
                     FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
             transformed_result = deserialized_scripted_model(input_data)
-            torch.testing.assert_allclose(ref_result, transformed_result, rtol=1e-2, atol=1e-3)
+            torch.testing.assert_close(ref_result, transformed_result, rtol=1e-2, atol=1e-3)
 
             optimized_buffer = io.BytesIO()
             torch.jit.save(optimized_scripted_model, optimized_buffer)
@@ -980,7 +980,7 @@ class TestXNNPACKConv1dTransformPass(TestCase):
                 else:
                     FileCheck().check_count(pattern, v, exactly=True).run(deserialized_optimized_scripted_model.graph)
             xnnpack_result = deserialized_optimized_scripted_model(input_data)
-            torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
+            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
 
     def test_conv1d_basic(self):
index fff539d..76bf69a 100644 (file)
@@ -236,7 +236,7 @@ if __name__ == "__main__":
 
     # Assert results are equal with the original model.
     rn18 = rn18.cuda()
-    torch.testing.assert_allclose(split_mod(x), rn18(x))
+    torch.testing.assert_close(split_mod(x), rn18(x))
 
     import time
     NITER = 100
index 069b73e..5a2f6e5 100644 (file)
@@ -24,7 +24,7 @@ from torch._jit_internal import _qualified_name, is_scripting, get_callable_argu
 from torch.autograd import function
 from torch.nn import Module
 
-from torch.testing._core import _get_default_tolerance
+from torch.testing._asserts import _get_default_rtol_and_atol
 
 _flatten = torch._C._jit_flatten
 _unflatten = torch._C._jit_unflatten
@@ -417,7 +417,7 @@ def _check_trace(
                     check_tensor_val = n_check.t("value")
 
                     try:
-                        torch.testing.assert_allclose(mod_tensor_val, check_tensor_val)
+                        torch.testing.assert_close(mod_tensor_val, check_tensor_val, equal_nan=True)
                     except (RuntimeError, AssertionError) as e:
                         if tensor_compare_errors is None:
                             tensor_compare_errors = ""
@@ -489,11 +489,12 @@ def _check_trace(
                         orig = orig.to_dense()
                     if ref.is_mkldnn:
                         ref = ref.to_dense()
-                    torch.testing.assert_allclose(
+                    torch.testing.assert_close(
                         orig.double(),
                         ref.double(),
                         rtol=check_tolerance,
-                        atol=_get_default_tolerance(orig, ref)[1],
+                        atol=_get_default_rtol_and_atol(orig, ref)[1],
+                        equal_nan=True,
                     )
                 except AssertionError as e:
                     maybe_warn_nondeterministic()
index 9a5fb0c..d980615 100644 (file)
@@ -18,7 +18,6 @@ __all__ = [
     "all_types_and_complex",
     "all_types_and_complex_and",
     "all_types_and_half",
-    "assert_allclose",
     "complex_types",
     "empty_types",
     "floating_and_complex_types",
@@ -246,30 +245,6 @@ def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: Unio
 
     return _helper(a, b, " ")
 
-def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') -> None:
-    if not isinstance(actual, torch.Tensor):
-        actual = torch.tensor(actual)
-    if not isinstance(expected, torch.Tensor):
-        expected = torch.tensor(expected, dtype=actual.dtype)
-    if expected.shape != actual.shape:
-        raise AssertionError("expected tensor shape {0} doesn't match with actual tensor "
-                             "shape {1}!".format(expected.shape, actual.shape))
-    if rtol is None or atol is None:
-        if rtol is not None or atol is not None:
-            raise ValueError("rtol and atol must both be specified or both be unspecified")
-        rtol, atol = _get_default_tolerance(actual, expected)
-
-    result, debug_msg = _compare_tensors_internal(actual, expected,
-                                                  rtol=rtol, atol=atol,
-                                                  equal_nan=equal_nan)
-
-    if result:
-        return
-
-    if msg is None or msg == '':
-        msg = debug_msg
-
-    raise AssertionError(msg)
 
 def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
     if tensor.numel() <= 1:  # can't make non-contiguous
@@ -406,19 +381,3 @@ def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dt
 
 def get_all_device_types() -> List[str]:
     return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-
-# 'dtype': (rtol, atol)
-_default_tolerances = {
-    'float64': (1e-5, 1e-8),  # NumPy default
-    'float32': (1e-4, 1e-5),  # This may need to be changed
-    'float16': (1e-3, 1e-3),  # This may need to be changed
-}
-
-
-def _get_default_tolerance(a, b=None) -> Tuple[float, float]:
-    if b is None:
-        dtype = str(a.dtype).split('.')[-1]  # e.g. "float32"
-        return _default_tolerances.get(dtype, (0, 0))
-    a_tol = _get_default_tolerance(a)
-    b_tol = _get_default_tolerance(b)
-    return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1]))
index 7355aee..3cf7338 100644 (file)
@@ -5,17 +5,24 @@ we don't internalize without warning, but still go through a deprecation cycle.
 
 import functools
 import warnings
-from typing import Any, Callable
+from typing import Any, Callable, Optional, Tuple
 
 import torch
 
 
-__all__ = ["rand", "randn"]
+__all__ = [
+    "rand",
+    "randn",
+    "assert_allclose",
+]
 
 
 def warn_deprecated(instructions: str) -> Callable:
     def outer_wrapper(fn: Callable) -> Callable:
-        msg = f"torch.testing.{fn.__name__} is deprecated and will be removed in the future. {instructions.strip()}"
+        msg = (
+            f"torch.testing.{fn.__name__} is deprecated and will be removed in a future release. "
+            f"{instructions.strip()}"
+        )
 
         @functools.wraps(fn)
         def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
@@ -29,3 +36,51 @@ def warn_deprecated(instructions: str) -> Callable:
 
 rand = warn_deprecated("Use torch.rand instead.")(torch.rand)
 randn = warn_deprecated("Use torch.randn instead.")(torch.randn)
+
+
+_DTYPE_PRECISIONS = {
+    torch.float16: (1e-3, 1e-3),
+    torch.float32: (1e-4, 1e-5),
+    torch.float64: (1e-5, 1e-8),
+}
+
+
+def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
+    actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
+    expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
+    return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
+
+
+# TODO: include the deprecation as soon as torch.testing.assert_close is stable
+# @warn_deprecated(
+#     "Use torch.testing.assert_close instead. "
+#     "For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844."
+# )
+def assert_allclose(
+    actual: Any,
+    expected: Any,
+    rtol: Optional[float] = None,
+    atol: Optional[float] = None,
+    equal_nan: bool = True,
+    msg: str = "",
+) -> None:
+    if not isinstance(actual, torch.Tensor):
+        actual = torch.tensor(actual)
+    if not isinstance(expected, torch.Tensor):
+        expected = torch.tensor(expected, dtype=actual.dtype)
+
+    if rtol is None and atol is None:
+        rtol, atol = _get_default_rtol_and_atol(actual, expected)
+
+    torch.testing.assert_close(
+        actual,
+        expected,
+        rtol=rtol,
+        atol=atol,
+        equal_nan=equal_nan,
+        check_device=True,
+        check_dtype=False,
+        check_stride=False,
+        check_is_coalesced=False,
+        msg=msg or None,
+    )
index 2470b53..6b2d1dd 100644 (file)
@@ -975,12 +975,12 @@ class QuantizationLiteTestCase(QuantizationTestCase):
 
                     mobile_module_result = mobile_module(input)
 
-                    torch.testing.assert_allclose(script_module_result, mobile_module_result)
+                    torch.testing.assert_close(script_module_result, mobile_module_result)
                     mobile_module_forward_result = mobile_module.forward(input)
-                    torch.testing.assert_allclose(script_module_result, mobile_module_forward_result)
+                    torch.testing.assert_close(script_module_result, mobile_module_forward_result)
 
                     mobile_module_run_method_result = mobile_module.run_method("forward", input)
-                    torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result)
+                    torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
                 except AssertionError as e:
                     if retry == max_retry:
                         raise e
index 2a126ab..096b718 100644 (file)
@@ -4119,20 +4119,13 @@ class DistributedTest:
                 grad_hook = net_with_hook.module.weight.grad
                 avg_hook = grad_hook.clone()
                 # Verify hook grad with expected.
-                # Cannot use exact match here due to a very small accuracy loss,
-                # e.g. 1e-05, for powerSGD hook case.
-                assert_func = (
-                    self.assertEqual
-                    if hook == default.allreduce_hook
-                    else torch.testing.assert_allclose
-                )
-                assert_func(
-                    avg_hook[0, 0],
+                self.assertEqual(
+                    avg_hook[0, 0].item(),
                     expected_grad,
                     msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}",
                 )
                 # Verify hook grad with vanilla allreduce
-                assert_func(
+                self.assertEqual(
                     avg_hook[0, 0],
                     avg[0, 0],
                     msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}",
@@ -4937,8 +4930,8 @@ class DistributedTest:
                 model.module.running_mean,
                 model.module.running_var,
             )
-            torch.testing.assert_allclose(running_mean, all_input_var.mean(1))
-            torch.testing.assert_allclose(running_var, all_input_var.var(1))
+            torch.testing.assert_close(running_mean, all_input_var.mean(1))
+            torch.testing.assert_close(running_var, all_input_var.var(1))
 
         @sandcastle_skip_if(
             BACKEND != "nccl" and BACKEND != "gloo",