Add fast path for addmm when the inputs are conjugate (#59380)
authoranjali411 <chourdiaanjali123@gmail.com>
Wed, 1 Sep 2021 23:11:38 +0000 (16:11 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 23:34:02 +0000 (16:34 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59380

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D28898374

Pulled By: anjali411

fbshipit-source-id: eab0e64d37bb57c18b54cabb8e5c00666338ba04

aten/src/ATen/ConjugateFallback.cpp
aten/src/ATen/cuda/CUDABlas.cpp
aten/src/ATen/native/CPUBlas.cpp
aten/src/ATen/native/CPUBlas.h
aten/src/ATen/native/LinearAlgebra.cpp
aten/src/ATen/native/NegateFallback.cpp
aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/cuda/Blas.cpp
test/test_linalg.py
test/test_torch.py
torch/testing/_internal/common_methods_invocations.py

index a64ef49..2cf9538 100644 (file)
@@ -60,6 +60,17 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
   m.impl("vdot", torch::CppFunction::makeFallthrough());
   m.impl("dot.out", torch::CppFunction::makeFallthrough());
   m.impl("vdot.out", torch::CppFunction::makeFallthrough());
+  m.impl("alias", torch::CppFunction::makeFallthrough());
+  m.impl("mm", torch::CppFunction::makeFallthrough());
+  m.impl("mm.out", torch::CppFunction::makeFallthrough());
+  m.impl("addmm", torch::CppFunction::makeFallthrough());
+  m.impl("addmm_", torch::CppFunction::makeFallthrough());
+  m.impl("addmm.out", torch::CppFunction::makeFallthrough());
+  m.impl("bmm", torch::CppFunction::makeFallthrough());
+  m.impl("bmm.out", torch::CppFunction::makeFallthrough());
+  m.impl("baddbmm", torch::CppFunction::makeFallthrough());
+  m.impl("baddbmm_", torch::CppFunction::makeFallthrough());
+  m.impl("baddbmm.out", torch::CppFunction::makeFallthrough());
 }
 
 } // namespace at
index 75e59d0..70c3dda 100644 (file)
@@ -64,8 +64,8 @@ static void _cublasAdjustLdLevel3(
     int64_t* lda,
     int64_t* ldb,
     int64_t* ldc) {
-  bool transa_ = ((transa == 't') || (transa == 'T'));
-  bool transb_ = ((transb == 't') || (transb == 'T'));
+  bool transa_ = ((transa != 'n') && (transa != 'N'));
+  bool transb_ = ((transb != 'n') && (transb != 'N'));
 
   // Note: leading dimensions generally are checked that they are > 0
   // and at least as big the result requires (even if the value won't
index 1a1f673..f14e4dc 100644 (file)
@@ -78,7 +78,7 @@ char to_blas(TransposeType trans) {
   switch (trans) {
   case Transpose: return 't';
   case NoTranspose: return 'n';
-  // case ConjTranspose: return 'c';
+  case ConjTranspose: return 'c';
   }
   TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
 }
@@ -89,7 +89,7 @@ fbgemm::matrix_op_t to_fbgemm(TransposeType trans) {
   switch (trans) {
   case Transpose: return fbgemm::matrix_op_t::Transpose;
   case NoTranspose: return fbgemm::matrix_op_t::NoTranspose;
-  // case ConjTranspose: return fbgemm::matrix_op_t::Transpose;
+  case ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm");
   }
   TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
 }
index e61207f..3a483e4 100644 (file)
@@ -12,7 +12,7 @@ namespace cpublas {
 enum TransposeType {
   Transpose,
   NoTranspose,
-  // ConjTranspose, -- Not implemented
+  ConjTranspose,
 };
 
 namespace internal {
index 10576a0..2ae6202 100644 (file)
@@ -959,7 +959,6 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
 static void addmm_impl_cpu_(
     Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
   TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
-
   // Array access is faster than .size(n) and .stride(n)
   const auto self_sizes = self.sizes();
   auto m1_strides = m1.strides();
@@ -992,18 +991,18 @@ static void addmm_impl_cpu_(
   if (result_strides[0] == 1 &&
       (result_sizes[1] == 1 || result_strides[1] >= std::max(int64_t{1}, result_sizes[0]))) {
     transpose_c = false;
-    c = result;
+    c = result.resolve_conj();
   } else if (result_strides[1] == 1 &&
              (result_sizes[0] == 1 || result_strides[0] >= std::max(int64_t{1}, result_sizes[1]))) {
     std::swap(m1, m2);
     std::swap(m1_sizes, m2_sizes);
     std::swap(m1_strides, m2_strides);
     transpose_c = true;
-    c = result;
+    c = result.resolve_conj();
   } else {
     transpose_c = false;
     // make c FORTRAN contiguous
-    c = result.transpose(0, 1).contiguous().transpose_(0, 1);
+    c = result.resolve_conj().transpose(0, 1).contiguous().transpose_(0, 1);
   }
 
   const int64_t m = result_sizes[transpose_c ? 1 : 0];
@@ -1017,7 +1016,7 @@ static void addmm_impl_cpu_(
   if (m1_strides[transpose_c ? 1 : 0] == 1 &&
       m1_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, m)) {
     transpose_a = false;
-    a = m1;
+    a = m1.resolve_conj();
   } else if (m1_strides[transpose_c ? 0 : 1] == 1 &&
              m1_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, k)) {
     transpose_a = true;
@@ -1034,7 +1033,7 @@ static void addmm_impl_cpu_(
   if (m2_strides[transpose_c ? 1 : 0] == 1 &&
       m2_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, k)) {
     transpose_b = false;
-    b = m2;
+    b = m2.resolve_conj();
   } else if (m2_strides[transpose_c ? 0 : 1] == 1 &&
              m2_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, n)) {
     transpose_b = true;
@@ -1048,13 +1047,16 @@ static void addmm_impl_cpu_(
   const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0];
   const int64_t ldc = c.strides()[transpose_c ? 0 : 1];
 
+  // Always ensure the conjugation for c is resolved since there's no way to specify c's conjugation in the gemm call
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());
+
   // Apply BLAS routine
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
       result.scalar_type(), "addmm_impl_cpu_",
       [&]{
         at::native::cpublas::gemm(
-            transpose_a ? cpublas::Transpose : cpublas::NoTranspose,
-            transpose_b ? cpublas::Transpose : cpublas::NoTranspose,
+            transpose_a ? a.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose,
+            transpose_b ? b.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose,
             m, n, k,
             alpha.to<scalar_t>(),
             a.data_ptr<scalar_t>(), lda,
@@ -1349,8 +1351,18 @@ Tensor& baddbmm_out_cpu(const Tensor& self_, const Tensor& batch1, const Tensor&
   return at::native::baddbmm__cpu(result, batch1, batch2, beta, alpha);
 }
 
+Tensor& conjugate_mutable_input_if_needed(Tensor& self, bool conjugate) {
+  if (conjugate) {
+    self.conj_physical_();
+  }
+  return self;
+}
+
 Tensor& baddbmm__cpu(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
-  return bmm_out_or_baddbmm_(self, batch1, batch2, beta, alpha, false);
+  bool self_is_conj = self.is_conj();
+  conjugate_mutable_input_if_needed(self, self_is_conj);
+  bmm_out_or_baddbmm_(self, batch1.resolve_conj(), batch2.resolve_conj(), beta, alpha, false);
+  return conjugate_mutable_input_if_needed(self, self_is_conj);
 }
 
 Tensor bmm_cpu(const Tensor& self, const Tensor& mat2) {
@@ -1363,7 +1375,10 @@ Tensor& bmm_out_cpu(const Tensor& batch1, const Tensor& batch2, Tensor &result)
   Scalar alpha(1.0);
   {
   NoNamesGuard guard;
-  bmm_out_or_baddbmm_(result, batch1, batch2, beta, alpha, true);
+  bool result_is_conj = result.is_conj();
+  conjugate_mutable_input_if_needed(result, result_is_conj);
+  bmm_out_or_baddbmm_(result, batch1.resolve_conj(), batch2.resolve_conj(), beta, alpha, true);
+  conjugate_mutable_input_if_needed(result, result_is_conj);
   }
   namedinference::propagate_names_if_nonempty(
       result,
index 86dbe05..d8381f5 100644 (file)
@@ -55,6 +55,7 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) {
   m.impl("view", torch::CppFunction::makeFallthrough());
   m.impl("_unsafe_view", torch::CppFunction::makeFallthrough());
   m.impl("reshape", torch::CppFunction::makeFallthrough());
+  m.impl("alias", torch::CppFunction::makeFallthrough());
 }
 
 } // namespace at
index 3ee909b..4712c3d 100644 (file)
@@ -1411,17 +1411,18 @@ Tensor from_file(c10::string_view filename, c10::optional<bool> shared, c10::opt
 Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
   auto memory_format =
       optional_memory_format.value_or(MemoryFormat::Preserve);
+  Tensor self;
   if (memory_format == MemoryFormat::Preserve) {
     if (src.is_non_overlapping_and_dense()) {
-      // Copy all strides
-      auto self = at::empty_strided(src.sizes(), src.strides(), src.options());
-      self.copy_(src);
-      return self;
+      // Copy all strides, this is marginally faster than calling empty_like
+      self = at::empty_strided(src.sizes(), src.strides(), src.options());
     } else {
-      memory_format = src.suggest_memory_format();
+      self = at::empty_like(src);
     }
+  } else {
+    self = at::empty_like(src, src.options(), memory_format);
   }
-  auto self = at::empty_like(src, src.options(), memory_format);
+
   self.copy_(src);
   return self;
 }
index b447910..269307d 100644 (file)
@@ -4,24 +4,51 @@
 #include <ATen/native/Resize.h>
 #include <c10/util/MaybeOwned.h>
 
-
 namespace at { namespace native {
 
 namespace {
 
+// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492
+c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) {
+  if (resolve_conj && tensor.is_conj()) {
+    return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
+  } else {
+    return c10::MaybeOwned<Tensor>::borrowed(tensor);
+  }
+}
+
+c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) {
+  if (tensor.is_non_overlapping_and_dense()) { // common case
+      transpose_tensor = tensor.is_contiguous();
+      return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor);
+  }
+  IntArrayRef tensor_strides = tensor.strides();
+  IntArrayRef tensor_sizes = tensor.sizes();
+  if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
+    transpose_tensor = false;
+    return resolve_conj_if_indicated(tensor, !transpose_result);
+  } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
+    transpose_tensor = true;
+    return resolve_conj_if_indicated(tensor, transpose_result);
+  } else {
+    transpose_tensor = true;
+    return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
+  }
+}
+
 c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) {
   if (tensor.is_non_overlapping_and_dense()) { // common case
       transpose_tensor = tensor.is_contiguous();
-      return c10::MaybeOwned<Tensor>::borrowed(tensor);
+      return resolve_conj_if_indicated(tensor, true);
   }
   IntArrayRef tensor_strides = tensor.strides();
   IntArrayRef tensor_sizes = tensor.sizes();
   if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
     transpose_tensor = false;
-    return c10::MaybeOwned<Tensor>::borrowed(tensor);
+    return resolve_conj_if_indicated(tensor, true);
   } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
     transpose_tensor = true;
-    return c10::MaybeOwned<Tensor>::borrowed(tensor);
+    return resolve_conj_if_indicated(tensor, true);
   } else {
     transpose_tensor = true;
     return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
@@ -39,19 +66,19 @@ c10::MaybeOwned<Tensor> prepare_batch_matrix_for_cublas(const Tensor& tensor, bo
   if (tensor_strides[fast_dim] == 1 &&
     (tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
     transpose_tensor = false;
-    tensor_ = c10::MaybeOwned<Tensor>::borrowed(tensor);
-    ld_tensor = tensor_strides[leading_dim];
+    tensor_ = resolve_conj_if_indicated(tensor, true);
+    ld_tensor = tensor_->strides()[leading_dim];
   } else if ((tensor_strides[leading_dim] == 1) &&
     (tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
     transpose_tensor = true;
-    tensor_ = c10::MaybeOwned<Tensor>::borrowed(tensor);
-    ld_tensor = tensor_strides[fast_dim];
+    tensor_ = resolve_conj_if_indicated(tensor, false);
+    ld_tensor = tensor_->strides()[fast_dim];
   } else {
     transpose_tensor = !transpose_result;
     // gemm call requires leading dimension and stride parameters to be non-zero
     bool is_stride_non_zero = tensor.strides()[1] != 0 && tensor.strides()[2] != 0;
     if (tensor.is_contiguous() && is_stride_non_zero) {
-      tensor_ = c10::MaybeOwned<Tensor>::borrowed(tensor);
+      tensor_ = resolve_conj_if_indicated(tensor, transpose_result);
     } else {
       tensor_ = c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
     }
@@ -104,8 +131,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
   c10::MaybeOwned<Tensor> result_ = prepare_matrix_for_cublas(result, transpose_result);
   bool transpose_mat1;
   bool transpose_mat2;
-  c10::MaybeOwned<Tensor> mat1_ = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1);
-  c10::MaybeOwned<Tensor> mat2_ = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2);
+  auto mat1_ = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
+  auto mat2_ = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
 
   if (transpose_result) {
     transpose_mat1 = !transpose_mat1;
@@ -141,6 +168,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
             c10::nullopt /* pin_memory */));
   }
 
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
+
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] {
     scalar_t alpha_val = alpha.to<scalar_t>();
     scalar_t beta_val = beta.to<scalar_t>();
@@ -148,8 +177,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
     scalar_t* mat2_ptr = mat2_->data_ptr<scalar_t>();
     scalar_t* result_ptr = result_->data_ptr<scalar_t>();
     at::cuda::blas::gemm<scalar_t>(
-      transpose_mat1 ? 't' : 'n',
-      transpose_mat2 ? 't' : 'n',
+      transpose_mat1 ? mat1_->is_conj() ? 'c' : 't' : 'n',
+      transpose_mat2 ? mat2_->is_conj() ? 'c' : 't' : 'n',
       m, n, k,
       alpha_val,
       mat1_ptr, mat1_ld,
@@ -207,11 +236,11 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor&
 
   if ((result_strides[1] == 1) &&
       ((result_sizes[2] == 1) || (result_strides[2] >= std::max<int64_t>(1, result_sizes[1])))) {
-    result_ = c10::MaybeOwned<Tensor>::borrowed(result);
+    result_ = resolve_conj_if_indicated(result, true);
   } else if ((result_strides[2] == 1) &&
     (result_sizes[1] == 1 || (result_strides[1] >= std::max<int64_t>(1, result_sizes[2])))) {
     transpose_result = true;
-    result_ = c10::MaybeOwned<Tensor>::borrowed(result);
+    result_ = resolve_conj_if_indicated(result, true);
   } else {
     result_ = c10::MaybeOwned<Tensor>::owned(result.transpose(1, 2).clone(at::MemoryFormat::Contiguous).transpose(1, 2));
   }
@@ -230,6 +259,8 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor&
   ldc = result_->strides()[leading_dim];
   int64_t num_batches = result_->sizes()[0];
 
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
+
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] {
     scalar_t alpha_val = alpha.to<scalar_t>();
     scalar_t beta_val = beta.to<scalar_t>();
@@ -237,8 +268,8 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor&
     scalar_t* batch2_ptr = batch2_->data_ptr<scalar_t>();
     scalar_t* result_ptr = result_->data_ptr<scalar_t>();
     at::cuda::blas::bgemm<scalar_t>(
-      transpose_batch1 ? 't' : 'n',
-      transpose_batch2 ? 't' : 'n',
+      transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n',
+      transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n',
       m, n, k,
       alpha_val,
       batch1_ptr, lda, batch1_->strides()[0],
index f7ce392..fbd219b 100644 (file)
@@ -6166,6 +6166,38 @@ scipy_lobpcg  | {:10.2e}  | {:10.2e}  | {:6} | N/A
             _test_mm(n, m, p, dtype, genf)
 
     @onlyOnCPUAndCUDA
+    def test_mm_bmm_non_memory_dense(self, device):
+        def _slice(tensor, fn):
+            return fn(tensor)[..., ::2]
+        A = torch.randn(3, 6, dtype=torch.cfloat, device=device)
+        B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
+        out = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
+        out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
+        A_conj = _slice(A, torch.conj)
+        A_conj_physical = _slice(A, torch.conj_physical)
+
+        self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out))
+        self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out))
+
+        Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device)
+        Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device)
+        Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3)
+        out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).transpose(-1, -2)
+
+        Ab_conj = _slice(Ab, torch.conj)
+        Ab_conj_physical = _slice(Ab, torch.conj_physical)
+
+        def t_b(tensor):
+            return tensor.transpose(-1, -2)
+
+        self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b))
+        self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b))
+
+        # test broadcasting
+        self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b))
+        self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b))
+
+    @onlyOnCPUAndCUDA
     @dtypes(torch.float32, torch.float64)
     def test_strided_mm_bmm(self, device, dtype):
         # Tests strided view case with stride smaller than corresponding dimension size
index b267b9c..a790839 100644 (file)
@@ -5328,6 +5328,13 @@ else:
         y = x.as_strided([2, 1, 5], [1, 0, 2])
         self.assertEqual(y, y.clone())
 
+    def test_clone_not_memory_dense(self):
+        # github issue: https://github.com/pytorch/pytorch/issues/64176
+        x = torch.randn(10, 8).t()[::2, ::2]
+        y = x.clone()
+        # should retain permutation after densification
+        self.assertTrue(y.stride() == (1, 4))
+
     @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')))
     @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')))
     def test_addcmul(self, device, dtype):
@@ -6013,9 +6020,9 @@ else:
             out_dc = torch.empty(size * size, device=device)[::2]
             for v, m in product(vals_list, mask_list):
                 if m.is_contiguous():
-                    expected = v[:, ::2].clone().view(-1)
+                    expected = v[:, ::2].clone().reshape((-1, ))
                 else:
-                    expected = v[::2].clone().view(-1)
+                    expected = v[::2].clone().reshape((-1, ))
                 out = torch.masked_select(v, m)
                 self.assertEqual(out, expected, atol=0, rtol=0)
                 torch.masked_select(v, m, out=out_dc)
index fe8e36f..10aae41 100644 (file)
@@ -1606,15 +1606,29 @@ def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
 
 
 def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs):
-    args_list = (
-        ((S, M), (M, S)),
-    )
-    inputs = tuple(SampleInput(make_tensor(first_shape, device, dtype,
-                                           requires_grad=requires_grad),
-                               args=(make_tensor(second_shape, device, dtype,
-                                     requires_grad=requires_grad),))
-                   for first_shape, second_shape in args_list)
-    return inputs
+    first_shape, second_shape = (S, M), (M, S)
+    sample_inputs = []
+    sample_inputs.append(
+        SampleInput(make_tensor(first_shape, device, dtype,
+                                requires_grad=requires_grad),
+                    args=(make_tensor(second_shape, device, dtype,
+                                      requires_grad=requires_grad),)))
+
+    if dtype.is_complex:
+        sample_inputs.append(
+            SampleInput(make_tensor(first_shape, device, dtype,
+                                    requires_grad=requires_grad),
+                        args=(
+                            make_tensor(second_shape, device, dtype,
+                                        requires_grad=requires_grad).conj(),)))
+
+        sample_inputs.append(
+            SampleInput(make_tensor(first_shape, device, dtype,
+                                    requires_grad=requires_grad).transpose(0, 1),
+                        args=(
+                            make_tensor(second_shape, device, dtype,
+                                        requires_grad=requires_grad).transpose(0, 1).conj(),)))
+    return sample_inputs
 
 def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
     alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6)
@@ -1627,15 +1641,40 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
         ((), (2, 2), (2, 3), True)
     ]
     test_cases = tests_list + tests_with_lhs_broadcasting  # type: ignore[operator]
-    inputs = tuple(SampleInput(make_tensor(shape_a, device, dtype, requires_grad=requires_grad),
-                               args=(make_tensor(shape_b, device, dtype,
-                                                 requires_grad=requires_grad),
-                                     make_tensor(shape_c, device, dtype,
-                                                 requires_grad=requires_grad)),
-                               kwargs={'alpha': alpha_val, 'beta': beta_val},
-                               broadcasts_input=broadcasts_input)
-                   for shape_a, shape_b, shape_c, broadcasts_input in test_cases)
-    return inputs
+
+    sample_inputs = []
+
+    for shape_a, shape_b, shape_c, broadcasts_input in test_cases:
+        sample_inputs.append(
+            SampleInput(
+                make_tensor(shape_a, device, dtype, requires_grad=requires_grad),
+                args=(
+                    make_tensor(shape_b, device, dtype,
+                                requires_grad=requires_grad),
+                    make_tensor(shape_c, device, dtype,
+                                requires_grad=requires_grad)),
+                kwargs={'alpha': alpha_val, 'beta': beta_val},
+                broadcasts_input=broadcasts_input))
+
+    if dtype.is_complex:
+        shape = (3, 3)
+        sample_inputs.append(
+            SampleInput(make_tensor(shape, device, dtype, requires_grad=requires_grad),
+                        args=(
+                            make_tensor(shape, device, dtype,
+                                        requires_grad=requires_grad).t().conj(),
+                            make_tensor(shape, device, dtype,
+                                        requires_grad=requires_grad)),
+                        kwargs={'alpha': alpha_val, 'beta': beta_val},))
+        sample_inputs.append(
+            SampleInput(make_tensor(shape, device, dtype, requires_grad=requires_grad),
+                        args=(
+                            make_tensor(shape, device, dtype,
+                                        requires_grad=requires_grad),
+                            make_tensor(shape, device, dtype,
+                                        requires_grad=requires_grad).t().conj()),
+                        kwargs={'alpha': alpha_val, 'beta': beta_val},))
+    return sample_inputs
 
 def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
     return (
@@ -1767,6 +1806,23 @@ def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs):
             sample_inputs.append(SampleInput(args[0], args=(args[1], args[2]),
                                              kwargs=dict(beta=beta * (1 + 2j), alpha=alpha * (2 + 3j)),
                                              broadcasts_input=broadcasts_input))
+
+    if dtype.is_complex:
+        shapes = [(S, S, S), (S, M, S), (S, S, M)]
+        args = (make_tensor(shapes[0], device, dtype,
+                            low=None, high=None,
+                            requires_grad=requires_grad),
+                make_tensor(shapes[1], device, dtype,
+                            low=None, high=None,
+                            requires_grad=requires_grad),
+                make_tensor(shapes[2], device, dtype,
+                            low=None, high=None,
+                            requires_grad=requires_grad))
+        sample_inputs.append(
+            SampleInput(
+                args[0].transpose(-1, 1), args=(args[1].transpose(-1, 1).conj(), args[2].transpose(-1, 1).conj()),
+                kwargs=dict(beta=beta * (1 + 2j), alpha=alpha * (2 + 3j)),))
+
     return tuple(sample_inputs)
 
 def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
@@ -5847,6 +5903,13 @@ op_db: List[OpInfo] = [
                                                     *[torch.bfloat16] if SM53OrLater else [],
                                                     torch.complex64, torch.complex128),
            supports_forward_ad=True,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
+                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
+                   'TestMathBits', 'test_conj_view', device_type='cuda')],
            skips=(
                # FIXME: bfloat16 backward support likely depends on CUDA11+
                #   and SM53+
@@ -7045,7 +7108,6 @@ op_db: List[OpInfo] = [
            skips=(
                # matmul does not correctly warn when resizing out= inputs
                SkipInfo('TestCommon', 'test_out'),
-               SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'),
            )),
     OpInfo('max',
            op=torch.max,
@@ -7835,6 +7897,10 @@ op_db: List[OpInfo] = [
            assert_autodiffed=True,
            sample_inputs_func=sample_inputs_matmul,
            supports_out=False,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
+                   'TestMathBits', 'test_conj_view')],
            skips=(
                SkipInfo('TestJit', 'test_variant_consistency_jit',),
            )),