From 3855c246395303c2cde38c832a21cca577851b39 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 13 Sep 2021 17:58:20 -0700 Subject: [PATCH] Add BFloat16 support for cross, tril, triu, tril_indices, triu_indices and cumsum operators on CPU (#62454) Summary: Add BFloat16 support for cross, tril, triu, tril_indices, triu_indices and cumsum operators on CPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/62454 Reviewed By: albanD Differential Revision: D30845805 Pulled By: heitorschueroff fbshipit-source-id: f83836862e38109ec929e83567133e9e88096b8b --- aten/src/ATen/native/TensorFactories.cpp | 4 ++-- aten/src/ATen/native/TriangularOps.cpp | 8 ++++---- aten/src/ATen/native/cpu/CrossKernel.cpp | 2 +- test/test_tensor_creation_ops.py | 24 ++++++++++++++++++++++ .../_internal/common_methods_invocations.py | 3 +++ 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 4712c3d..67ef8b6 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -979,7 +979,7 @@ Tensor tril_indices_cpu( // // 3. sequential RAM + transpose: create an n X 2 Tensor, fill the Tensor // sequentially, and then transpose it. - AT_DISPATCH_ALL_TYPES(result.scalar_type(), "tril_indices", [&]() -> void { + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "tril_indices", [&]() -> void { // fill the Tensor with correct values scalar_t* result_data = result.data_ptr(); int64_t i = 0; @@ -1017,7 +1017,7 @@ Tensor triu_indices_cpu( // create an empty Tensor with correct size auto result = at::native::empty_cpu({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); - AT_DISPATCH_ALL_TYPES(result.scalar_type(), "triu_indices", [&]() -> void { + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "triu_indices", [&]() -> void { // fill the Tensor with correct values scalar_t* result_data = result.data_ptr(); int64_t i = 0; diff --git a/aten/src/ATen/native/TriangularOps.cpp b/aten/src/ATen/native/TriangularOps.cpp index ec1741d..765069b 100644 --- a/aten/src/ATen/native/TriangularOps.cpp +++ b/aten/src/ATen/native/TriangularOps.cpp @@ -99,7 +99,7 @@ Tensor& tril_cpu_(Tensor &self, int64_t k) { Tensor self_c; std::tie(inplace, self_c) = checkTrilTriuBatchContiguous(self, true); Tensor result = inplace ? self : at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{ apply_triu_tril(result, self_c, inplace, k); }); if (!inplace) self.copy_(result); @@ -113,7 +113,7 @@ Tensor& tril_cpu_out(const Tensor& self, int64_t k, Tensor &result) { } Tensor self_c; std::tie(std::ignore, self_c) = checkTrilTriuBatchContiguous(self, false); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{ apply_triu_tril(result, self_c, false, k); }); return result; @@ -134,7 +134,7 @@ Tensor& triu_cpu_(Tensor &self, int64_t k) { Tensor self_c; std::tie(inplace, self_c) = checkTrilTriuBatchContiguous(self, true); Tensor result = inplace ? self : at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{ apply_triu_tril(result, self_c, inplace, k); }); if (!inplace) self.copy_(result); @@ -148,7 +148,7 @@ Tensor& triu_cpu_out(const Tensor& self, int64_t k, Tensor &result) { } Tensor self_c; std::tie(std::ignore, self_c) = checkTrilTriuBatchContiguous(self, false); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{ apply_triu_tril(result, self_c, false, k); }); return result; diff --git a/aten/src/ATen/native/cpu/CrossKernel.cpp b/aten/src/ATen/native/cpu/CrossKernel.cpp index 55e0229..d5bbc81 100644 --- a/aten/src/ATen/native/cpu/CrossKernel.cpp +++ b/aten/src/ATen/native/cpu/CrossKernel.cpp @@ -65,7 +65,7 @@ static void apply_cross(Tensor& result, const Tensor& a, const Tensor& b, const } static void cross_kernel_impl(Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(result.scalar_type(), "cross", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, result.scalar_type(), "cross", [&]() { apply_cross(result, a, b, dim); }); } diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 2404f02..e698768 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -311,6 +311,21 @@ class TestTensorCreation(TestCase): for s, d, dtype in product(shapes, diagonals, dtypes): run_test(s, device, d, dtype) + @onlyCPU + def test_triu_tril_bfloat16(self, device): + op_funcs = [torch.tril, torch.triu] + for op_fun in op_funcs: + input = torch.randn(3, 3, dtype=torch.float32, device=device).bfloat16().requires_grad_(True) + input2 = input.detach().clone().float().requires_grad_(True) + out = op_fun(input) + out.sum().backward() + out2 = op_fun(input2) + out2.sum().backward() + self.assertEqual(out.dtype, torch.bfloat16) + self.assertEqual(input.grad.dtype, torch.bfloat16) + self.assertEqual(out, out2.bfloat16()) + self.assertEqual(input.grad, input2.grad.bfloat16(), atol=0.01, rtol=0) + def test_diagflat(self, device): dtype = torch.float32 # Basic sanity test @@ -1213,6 +1228,15 @@ class TestTensorCreation(TestCase): self.assertEqual(b.triu(2), output) self.assertRaises(RuntimeError, lambda: b.triu_(2)) + @onlyCPU + def test_triu_tril_indices_bfloat16(self, device): + op_funcs = [torch.tril_indices, torch.triu_indices] + for op_fun in op_funcs: + out = op_fun(4, 3, 1, dtype=torch.bfloat16) + out2 = op_fun(4, 3, 1, dtype=torch.float) + self.assertEqual(out.dtype, torch.bfloat16) + self.assertEqual(out, out2.bfloat16()) + # TODO: update to work on CUDA, too @onlyCPU def test_stack(self, device): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a2b9fea..a4281bb 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6437,6 +6437,7 @@ op_db: List[OpInfo] = [ skips=(DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),)), OpInfo('cross', dtypes=all_types_and_complex(), + dtypesIfCPU=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.half), sample_inputs_func=sample_inputs_cross, supports_forward_ad=True, @@ -9012,10 +9013,12 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, sample_inputs_func=sample_inputs_transpose_swapdims), OpInfo('tril', + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypes=all_types_and_complex_and(torch.bool, torch.half), supports_forward_ad=True, sample_inputs_func=sample_inputs_tril_triu), OpInfo('triu', + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypes=all_types_and_complex_and(torch.bool, torch.half), supports_forward_ad=True, sample_inputs_func=sample_inputs_tril_triu), -- 2.7.4