From 5d80a48cef373e22393af1b1f4f4e3f2ad948a76 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Wed, 1 Sep 2021 16:11:38 -0700 Subject: [PATCH] Add fast path for addmm when the inputs are conjugate (#59380) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59380 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28898374 Pulled By: anjali411 fbshipit-source-id: eab0e64d37bb57c18b54cabb8e5c00666338ba04 --- aten/src/ATen/ConjugateFallback.cpp | 11 +++ aten/src/ATen/cuda/CUDABlas.cpp | 4 +- aten/src/ATen/native/CPUBlas.cpp | 4 +- aten/src/ATen/native/CPUBlas.h | 2 +- aten/src/ATen/native/LinearAlgebra.cpp | 35 +++++-- aten/src/ATen/native/NegateFallback.cpp | 1 + aten/src/ATen/native/TensorFactories.cpp | 13 +-- aten/src/ATen/native/cuda/Blas.cpp | 65 +++++++++---- test/test_linalg.py | 32 +++++++ test/test_torch.py | 11 ++- .../_internal/common_methods_invocations.py | 104 +++++++++++++++++---- 11 files changed, 223 insertions(+), 59 deletions(-) diff --git a/aten/src/ATen/ConjugateFallback.cpp b/aten/src/ATen/ConjugateFallback.cpp index a64ef49..2cf9538 100644 --- a/aten/src/ATen/ConjugateFallback.cpp +++ b/aten/src/ATen/ConjugateFallback.cpp @@ -60,6 +60,17 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) { m.impl("vdot", torch::CppFunction::makeFallthrough()); m.impl("dot.out", torch::CppFunction::makeFallthrough()); m.impl("vdot.out", torch::CppFunction::makeFallthrough()); + m.impl("alias", torch::CppFunction::makeFallthrough()); + m.impl("mm", torch::CppFunction::makeFallthrough()); + m.impl("mm.out", torch::CppFunction::makeFallthrough()); + m.impl("addmm", torch::CppFunction::makeFallthrough()); + m.impl("addmm_", torch::CppFunction::makeFallthrough()); + m.impl("addmm.out", torch::CppFunction::makeFallthrough()); + m.impl("bmm", torch::CppFunction::makeFallthrough()); + m.impl("bmm.out", torch::CppFunction::makeFallthrough()); + m.impl("baddbmm", torch::CppFunction::makeFallthrough()); + m.impl("baddbmm_", torch::CppFunction::makeFallthrough()); + m.impl("baddbmm.out", torch::CppFunction::makeFallthrough()); } } // namespace at diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 75e59d0..70c3dda 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -64,8 +64,8 @@ static void _cublasAdjustLdLevel3( int64_t* lda, int64_t* ldb, int64_t* ldc) { - bool transa_ = ((transa == 't') || (transa == 'T')); - bool transb_ = ((transb == 't') || (transb == 'T')); + bool transa_ = ((transa != 'n') && (transa != 'N')); + bool transb_ = ((transb != 'n') && (transb != 'N')); // Note: leading dimensions generally are checked that they are > 0 // and at least as big the result requires (even if the value won't diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 1a1f673..f14e4dc 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -78,7 +78,7 @@ char to_blas(TransposeType trans) { switch (trans) { case Transpose: return 't'; case NoTranspose: return 'n'; - // case ConjTranspose: return 'c'; + case ConjTranspose: return 'c'; } TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } @@ -89,7 +89,7 @@ fbgemm::matrix_op_t to_fbgemm(TransposeType trans) { switch (trans) { case Transpose: return fbgemm::matrix_op_t::Transpose; case NoTranspose: return fbgemm::matrix_op_t::NoTranspose; - // case ConjTranspose: return fbgemm::matrix_op_t::Transpose; + case ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm"); } TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index e61207f..3a483e4 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -12,7 +12,7 @@ namespace cpublas { enum TransposeType { Transpose, NoTranspose, - // ConjTranspose, -- Not implemented + ConjTranspose, }; namespace internal { diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 10576a0..2ae6202 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -959,7 +959,6 @@ Tensor outer(const Tensor& self, const Tensor& vec2) { static void addmm_impl_cpu_( Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) { TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2); - // Array access is faster than .size(n) and .stride(n) const auto self_sizes = self.sizes(); auto m1_strides = m1.strides(); @@ -992,18 +991,18 @@ static void addmm_impl_cpu_( if (result_strides[0] == 1 && (result_sizes[1] == 1 || result_strides[1] >= std::max(int64_t{1}, result_sizes[0]))) { transpose_c = false; - c = result; + c = result.resolve_conj(); } else if (result_strides[1] == 1 && (result_sizes[0] == 1 || result_strides[0] >= std::max(int64_t{1}, result_sizes[1]))) { std::swap(m1, m2); std::swap(m1_sizes, m2_sizes); std::swap(m1_strides, m2_strides); transpose_c = true; - c = result; + c = result.resolve_conj(); } else { transpose_c = false; // make c FORTRAN contiguous - c = result.transpose(0, 1).contiguous().transpose_(0, 1); + c = result.resolve_conj().transpose(0, 1).contiguous().transpose_(0, 1); } const int64_t m = result_sizes[transpose_c ? 1 : 0]; @@ -1017,7 +1016,7 @@ static void addmm_impl_cpu_( if (m1_strides[transpose_c ? 1 : 0] == 1 && m1_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, m)) { transpose_a = false; - a = m1; + a = m1.resolve_conj(); } else if (m1_strides[transpose_c ? 0 : 1] == 1 && m1_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, k)) { transpose_a = true; @@ -1034,7 +1033,7 @@ static void addmm_impl_cpu_( if (m2_strides[transpose_c ? 1 : 0] == 1 && m2_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, k)) { transpose_b = false; - b = m2; + b = m2.resolve_conj(); } else if (m2_strides[transpose_c ? 0 : 1] == 1 && m2_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, n)) { transpose_b = true; @@ -1048,13 +1047,16 @@ static void addmm_impl_cpu_( const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0]; const int64_t ldc = c.strides()[transpose_c ? 0 : 1]; + // Always ensure the conjugation for c is resolved since there's no way to specify c's conjugation in the gemm call + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj()); + // Apply BLAS routine AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, result.scalar_type(), "addmm_impl_cpu_", [&]{ at::native::cpublas::gemm( - transpose_a ? cpublas::Transpose : cpublas::NoTranspose, - transpose_b ? cpublas::Transpose : cpublas::NoTranspose, + transpose_a ? a.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose, + transpose_b ? b.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose, m, n, k, alpha.to(), a.data_ptr(), lda, @@ -1349,8 +1351,18 @@ Tensor& baddbmm_out_cpu(const Tensor& self_, const Tensor& batch1, const Tensor& return at::native::baddbmm__cpu(result, batch1, batch2, beta, alpha); } +Tensor& conjugate_mutable_input_if_needed(Tensor& self, bool conjugate) { + if (conjugate) { + self.conj_physical_(); + } + return self; +} + Tensor& baddbmm__cpu(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { - return bmm_out_or_baddbmm_(self, batch1, batch2, beta, alpha, false); + bool self_is_conj = self.is_conj(); + conjugate_mutable_input_if_needed(self, self_is_conj); + bmm_out_or_baddbmm_(self, batch1.resolve_conj(), batch2.resolve_conj(), beta, alpha, false); + return conjugate_mutable_input_if_needed(self, self_is_conj); } Tensor bmm_cpu(const Tensor& self, const Tensor& mat2) { @@ -1363,7 +1375,10 @@ Tensor& bmm_out_cpu(const Tensor& batch1, const Tensor& batch2, Tensor &result) Scalar alpha(1.0); { NoNamesGuard guard; - bmm_out_or_baddbmm_(result, batch1, batch2, beta, alpha, true); + bool result_is_conj = result.is_conj(); + conjugate_mutable_input_if_needed(result, result_is_conj); + bmm_out_or_baddbmm_(result, batch1.resolve_conj(), batch2.resolve_conj(), beta, alpha, true); + conjugate_mutable_input_if_needed(result, result_is_conj); } namedinference::propagate_names_if_nonempty( result, diff --git a/aten/src/ATen/native/NegateFallback.cpp b/aten/src/ATen/native/NegateFallback.cpp index 86dbe05..d8381f5 100644 --- a/aten/src/ATen/native/NegateFallback.cpp +++ b/aten/src/ATen/native/NegateFallback.cpp @@ -55,6 +55,7 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) { m.impl("view", torch::CppFunction::makeFallthrough()); m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); m.impl("reshape", torch::CppFunction::makeFallthrough()); + m.impl("alias", torch::CppFunction::makeFallthrough()); } } // namespace at diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 3ee909b..4712c3d 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -1411,17 +1411,18 @@ Tensor from_file(c10::string_view filename, c10::optional shared, c10::opt Tensor clone(const Tensor& src, c10::optional optional_memory_format) { auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve); + Tensor self; if (memory_format == MemoryFormat::Preserve) { if (src.is_non_overlapping_and_dense()) { - // Copy all strides - auto self = at::empty_strided(src.sizes(), src.strides(), src.options()); - self.copy_(src); - return self; + // Copy all strides, this is marginally faster than calling empty_like + self = at::empty_strided(src.sizes(), src.strides(), src.options()); } else { - memory_format = src.suggest_memory_format(); + self = at::empty_like(src); } + } else { + self = at::empty_like(src, src.options(), memory_format); } - auto self = at::empty_like(src, src.options(), memory_format); + self.copy_(src); return self; } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index b447910..269307d 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -4,24 +4,51 @@ #include #include - namespace at { namespace native { namespace { +// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 +c10::MaybeOwned inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { + if (resolve_conj && tensor.is_conj()) { + return c10::MaybeOwned::owned(tensor.resolve_conj()); + } else { + return c10::MaybeOwned::borrowed(tensor); + } +} + +c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); + } + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, !transpose_result); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, transpose_result); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { if (tensor.is_non_overlapping_and_dense()) { // common case transpose_tensor = tensor.is_contiguous(); - return c10::MaybeOwned::borrowed(tensor); + return resolve_conj_if_indicated(tensor, true); } IntArrayRef tensor_strides = tensor.strides(); IntArrayRef tensor_sizes = tensor.sizes(); if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { transpose_tensor = false; - return c10::MaybeOwned::borrowed(tensor); + return resolve_conj_if_indicated(tensor, true); } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { transpose_tensor = true; - return c10::MaybeOwned::borrowed(tensor); + return resolve_conj_if_indicated(tensor, true); } else { transpose_tensor = true; return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); @@ -39,19 +66,19 @@ c10::MaybeOwned prepare_batch_matrix_for_cublas(const Tensor& tensor, bo if (tensor_strides[fast_dim] == 1 && (tensor_strides[leading_dim] >= std::max(1, m))) { transpose_tensor = false; - tensor_ = c10::MaybeOwned::borrowed(tensor); - ld_tensor = tensor_strides[leading_dim]; + tensor_ = resolve_conj_if_indicated(tensor, true); + ld_tensor = tensor_->strides()[leading_dim]; } else if ((tensor_strides[leading_dim] == 1) && (tensor_strides[fast_dim] >= std::max(1, n))) { transpose_tensor = true; - tensor_ = c10::MaybeOwned::borrowed(tensor); - ld_tensor = tensor_strides[fast_dim]; + tensor_ = resolve_conj_if_indicated(tensor, false); + ld_tensor = tensor_->strides()[fast_dim]; } else { transpose_tensor = !transpose_result; // gemm call requires leading dimension and stride parameters to be non-zero bool is_stride_non_zero = tensor.strides()[1] != 0 && tensor.strides()[2] != 0; if (tensor.is_contiguous() && is_stride_non_zero) { - tensor_ = c10::MaybeOwned::borrowed(tensor); + tensor_ = resolve_conj_if_indicated(tensor, transpose_result); } else { tensor_ = c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); } @@ -104,8 +131,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma c10::MaybeOwned result_ = prepare_matrix_for_cublas(result, transpose_result); bool transpose_mat1; bool transpose_mat2; - c10::MaybeOwned mat1_ = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1); - c10::MaybeOwned mat2_ = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2); + auto mat1_ = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result); + auto mat2_ = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result); if (transpose_result) { transpose_mat1 = !transpose_mat1; @@ -141,6 +168,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma c10::nullopt /* pin_memory */)); } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj()); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] { scalar_t alpha_val = alpha.to(); scalar_t beta_val = beta.to(); @@ -148,8 +177,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_t* mat2_ptr = mat2_->data_ptr(); scalar_t* result_ptr = result_->data_ptr(); at::cuda::blas::gemm( - transpose_mat1 ? 't' : 'n', - transpose_mat2 ? 't' : 'n', + transpose_mat1 ? mat1_->is_conj() ? 'c' : 't' : 'n', + transpose_mat2 ? mat2_->is_conj() ? 'c' : 't' : 'n', m, n, k, alpha_val, mat1_ptr, mat1_ld, @@ -207,11 +236,11 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& if ((result_strides[1] == 1) && ((result_sizes[2] == 1) || (result_strides[2] >= std::max(1, result_sizes[1])))) { - result_ = c10::MaybeOwned::borrowed(result); + result_ = resolve_conj_if_indicated(result, true); } else if ((result_strides[2] == 1) && (result_sizes[1] == 1 || (result_strides[1] >= std::max(1, result_sizes[2])))) { transpose_result = true; - result_ = c10::MaybeOwned::borrowed(result); + result_ = resolve_conj_if_indicated(result, true); } else { result_ = c10::MaybeOwned::owned(result.transpose(1, 2).clone(at::MemoryFormat::Contiguous).transpose(1, 2)); } @@ -230,6 +259,8 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ldc = result_->strides()[leading_dim]; int64_t num_batches = result_->sizes()[0]; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj()); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] { scalar_t alpha_val = alpha.to(); scalar_t beta_val = beta.to(); @@ -237,8 +268,8 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& scalar_t* batch2_ptr = batch2_->data_ptr(); scalar_t* result_ptr = result_->data_ptr(); at::cuda::blas::bgemm( - transpose_batch1 ? 't' : 'n', - transpose_batch2 ? 't' : 'n', + transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n', + transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n', m, n, k, alpha_val, batch1_ptr, lda, batch1_->strides()[0], diff --git a/test/test_linalg.py b/test/test_linalg.py index f7ce392..fbd219b 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -6166,6 +6166,38 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A _test_mm(n, m, p, dtype, genf) @onlyOnCPUAndCUDA + def test_mm_bmm_non_memory_dense(self, device): + def _slice(tensor, fn): + return fn(tensor)[..., ::2] + A = torch.randn(3, 6, dtype=torch.cfloat, device=device) + B = torch.randn(3, 3, dtype=torch.cfloat, device=device) + out = torch.empty(3, 3, device=device, dtype=torch.complex64).t() + out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t() + A_conj = _slice(A, torch.conj) + A_conj_physical = _slice(A, torch.conj_physical) + + self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out)) + self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out)) + + Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device) + Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device) + Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3) + out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).transpose(-1, -2) + + Ab_conj = _slice(Ab, torch.conj) + Ab_conj_physical = _slice(Ab, torch.conj_physical) + + def t_b(tensor): + return tensor.transpose(-1, -2) + + self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b)) + self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b)) + + # test broadcasting + self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b)) + self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b)) + + @onlyOnCPUAndCUDA @dtypes(torch.float32, torch.float64) def test_strided_mm_bmm(self, device, dtype): # Tests strided view case with stride smaller than corresponding dimension size diff --git a/test/test_torch.py b/test/test_torch.py index b267b9c..a790839 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5328,6 +5328,13 @@ else: y = x.as_strided([2, 1, 5], [1, 0, 2]) self.assertEqual(y, y.clone()) + def test_clone_not_memory_dense(self): + # github issue: https://github.com/pytorch/pytorch/issues/64176 + x = torch.randn(10, 8).t()[::2, ::2] + y = x.clone() + # should retain permutation after densification + self.assertTrue(y.stride() == (1, 4)) + @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda'))) @dtypes(*set(torch.testing.get_all_math_dtypes('cpu'))) def test_addcmul(self, device, dtype): @@ -6013,9 +6020,9 @@ else: out_dc = torch.empty(size * size, device=device)[::2] for v, m in product(vals_list, mask_list): if m.is_contiguous(): - expected = v[:, ::2].clone().view(-1) + expected = v[:, ::2].clone().reshape((-1, )) else: - expected = v[::2].clone().view(-1) + expected = v[::2].clone().reshape((-1, )) out = torch.masked_select(v, m) self.assertEqual(out, expected, atol=0, rtol=0) torch.masked_select(v, m, out=out_dc) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index fe8e36f..10aae41 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1606,15 +1606,29 @@ def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs): - args_list = ( - ((S, M), (M, S)), - ) - inputs = tuple(SampleInput(make_tensor(first_shape, device, dtype, - requires_grad=requires_grad), - args=(make_tensor(second_shape, device, dtype, - requires_grad=requires_grad),)) - for first_shape, second_shape in args_list) - return inputs + first_shape, second_shape = (S, M), (M, S) + sample_inputs = [] + sample_inputs.append( + SampleInput(make_tensor(first_shape, device, dtype, + requires_grad=requires_grad), + args=(make_tensor(second_shape, device, dtype, + requires_grad=requires_grad),))) + + if dtype.is_complex: + sample_inputs.append( + SampleInput(make_tensor(first_shape, device, dtype, + requires_grad=requires_grad), + args=( + make_tensor(second_shape, device, dtype, + requires_grad=requires_grad).conj(),))) + + sample_inputs.append( + SampleInput(make_tensor(first_shape, device, dtype, + requires_grad=requires_grad).transpose(0, 1), + args=( + make_tensor(second_shape, device, dtype, + requires_grad=requires_grad).transpose(0, 1).conj(),))) + return sample_inputs def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) @@ -1627,15 +1641,40 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): ((), (2, 2), (2, 3), True) ] test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] - inputs = tuple(SampleInput(make_tensor(shape_a, device, dtype, requires_grad=requires_grad), - args=(make_tensor(shape_b, device, dtype, - requires_grad=requires_grad), - make_tensor(shape_c, device, dtype, - requires_grad=requires_grad)), - kwargs={'alpha': alpha_val, 'beta': beta_val}, - broadcasts_input=broadcasts_input) - for shape_a, shape_b, shape_c, broadcasts_input in test_cases) - return inputs + + sample_inputs = [] + + for shape_a, shape_b, shape_c, broadcasts_input in test_cases: + sample_inputs.append( + SampleInput( + make_tensor(shape_a, device, dtype, requires_grad=requires_grad), + args=( + make_tensor(shape_b, device, dtype, + requires_grad=requires_grad), + make_tensor(shape_c, device, dtype, + requires_grad=requires_grad)), + kwargs={'alpha': alpha_val, 'beta': beta_val}, + broadcasts_input=broadcasts_input)) + + if dtype.is_complex: + shape = (3, 3) + sample_inputs.append( + SampleInput(make_tensor(shape, device, dtype, requires_grad=requires_grad), + args=( + make_tensor(shape, device, dtype, + requires_grad=requires_grad).t().conj(), + make_tensor(shape, device, dtype, + requires_grad=requires_grad)), + kwargs={'alpha': alpha_val, 'beta': beta_val},)) + sample_inputs.append( + SampleInput(make_tensor(shape, device, dtype, requires_grad=requires_grad), + args=( + make_tensor(shape, device, dtype, + requires_grad=requires_grad), + make_tensor(shape, device, dtype, + requires_grad=requires_grad).t().conj()), + kwargs={'alpha': alpha_val, 'beta': beta_val},)) + return sample_inputs def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): return ( @@ -1767,6 +1806,23 @@ def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs): sample_inputs.append(SampleInput(args[0], args=(args[1], args[2]), kwargs=dict(beta=beta * (1 + 2j), alpha=alpha * (2 + 3j)), broadcasts_input=broadcasts_input)) + + if dtype.is_complex: + shapes = [(S, S, S), (S, M, S), (S, S, M)] + args = (make_tensor(shapes[0], device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor(shapes[1], device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor(shapes[2], device, dtype, + low=None, high=None, + requires_grad=requires_grad)) + sample_inputs.append( + SampleInput( + args[0].transpose(-1, 1), args=(args[1].transpose(-1, 1).conj(), args[2].transpose(-1, 1).conj()), + kwargs=dict(beta=beta * (1 + 2j), alpha=alpha * (2 + 3j)),)) + return tuple(sample_inputs) def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs): @@ -5847,6 +5903,13 @@ op_db: List[OpInfo] = [ *[torch.bfloat16] if SM53OrLater else [], torch.complex64, torch.complex128), supports_forward_ad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view', device_type='cuda')], skips=( # FIXME: bfloat16 backward support likely depends on CUDA11+ # and SM53+ @@ -7045,7 +7108,6 @@ op_db: List[OpInfo] = [ skips=( # matmul does not correctly warn when resizing out= inputs SkipInfo('TestCommon', 'test_out'), - SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'), )), OpInfo('max', op=torch.max, @@ -7835,6 +7897,10 @@ op_db: List[OpInfo] = [ assert_autodiffed=True, sample_inputs_func=sample_inputs_matmul, supports_out=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view')], skips=( SkipInfo('TestJit', 'test_variant_consistency_jit',), )), -- 2.7.4