From 752a8202303089386a3973dee753dc78d2a657e2 Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Fri, 17 Sep 2021 09:52:47 -0700 Subject: [PATCH] Bf16 matmul (#64619) Summary: Re-create PR to fix https://github.com/pytorch/pytorch/pull/61891. Drop the support for addbmm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64619 Reviewed By: jbschlosser Differential Revision: D30902995 Pulled By: VitalyFedyunin fbshipit-source-id: dc318d73adff8f6974c9752d0d097e69276f8206 --- aten/src/ATen/native/Blas.cpp | 19 ++- aten/src/ATen/native/LinearAlgebra.cpp | 12 +- aten/src/ATen/native/mkldnn/Matmul.cpp | 146 +++++++++++++++++++++ aten/src/ATen/native/mkldnn/Matmul.h | 23 ++++ aten/src/ATen/native/mkldnn/Utils.h | 1 + test/test_linalg.py | 72 +++++----- tools/build_variables.bzl | 1 + .../_internal/common_methods_invocations.py | 8 +- 8 files changed, 244 insertions(+), 38 deletions(-) create mode 100644 aten/src/ATen/native/mkldnn/Matmul.cpp create mode 100644 aten/src/ATen/native/mkldnn/Matmul.h diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index 114de63..c18c657 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -3,6 +3,9 @@ #include #include #include +#include + +#include namespace at { namespace meta { @@ -62,6 +65,13 @@ TORCH_IMPL_FUNC(addmv_out_cpu)(const Tensor &self, const Tensor &mat, const Tens at::native::copy_(const_cast(result), *self_); } if (result.numel() != 0) { + + NoNamesGuard guard; + if (use_mkldnn_bf16_matmul(mat, vec, /*result=*/Tensor())){ + mkldnn_matmul(mat, vec, result, beta_.to(), alpha_.to()); + return; + } + auto r_stride = result.stride(0); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, mat.scalar_type(), "addmv_impl_cpu", [&] { auto beta = beta_.to(); @@ -148,7 +158,14 @@ Tensor dot(const Tensor &self, const Tensor &other){ at::NoNamesGuard guard; dot_check(self, other); - return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, self.scalar_type(), "dot", [&] { + if (use_mkldnn_bf16_matmul(self, other, /*result=*/Tensor())){ + // mkldnn matmul expect result have sizes info to create ideep tensor + auto r = at::empty({1, 1}, self.options()); + mkldnn_matmul(self, other, r, /*beta=*/0); + return r; + } + + return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "dot", [&] { Tensor result = at::empty({}, self.options()); result.fill_(dot_impl(self.numel(), self.data_ptr(), self.stride(0), other.data_ptr(), other.stride(0))); return result; diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 0014826..c8c34eb 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -23,7 +24,6 @@ #include #include - namespace at { namespace meta { TORCH_META_FUNC(addmm)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) { @@ -984,6 +984,11 @@ static void addmm_impl_cpu_( result.copy_(self); } + if (use_mkldnn_bf16_matmul(m1, m2, result)){ + mkldnn_matmul(m1, m2, result, beta.to(), alpha.to()); + return; + } + bool transpose_c = false; Tensor c; @@ -1254,6 +1259,11 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& || (strides[1] == 1 && strides[2] >= sizes[1]); }; + if (use_mkldnn_bf16_matmul(batch1, batch2, self_or_result)){ + mkldnn_matmul(batch1, batch2, self_or_result, beta.to(), alpha.to()); + return self_or_result; + } + if (contraction_size * res_rows * res_cols < 400) { if (is_bmm_out) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "bmm", [&] { diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp new file mode 100644 index 0000000..69568a0 --- /dev/null +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#if !AT_MKLDNN_ENABLED() + +namespace at { +namespace native { + +void mkldnn_matmul( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result, + float beta, + float alpha) { + TORCH_CHECK(false, "mkldnn_matmul: ATen not compiled with MKLDNN support"); +} + +bool use_mkldnn_bf16_matmul( + const Tensor& mat1, + const Tensor& mat2, + const c10::optional& result_opt){ + return false; +} + +} // namespace native +} // namespace at + +#else // AT_MKLDNN_EBABLED + +#include +#include + +namespace at { +namespace native { + +void mkldnn_matmul( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result, + float beta, + float alpha) { + TORCH_CHECK((mat1.dim() == 2 && mat2.dim() == 2) || // aten::addmm + (mat1.dim() == 3 && mat2.dim() == 3) || // aten::bmm, aten::baddbmm + (mat1.dim() == 2 && mat2.dim() == 1) || // aten::mv + (mat1.dim() == 1 && mat2.dim() == 1), // aten::dot + "mkldnn_matmul: unsupported dims for mat and mat2"); + TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 && + mat2.scalar_type() == at::kBFloat16 && + result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path"); + TORCH_CHECK(mkldnn_bf16_device_check(), + "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); + + auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1; + auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2; + auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result; + + ideep::attr_t op_attr; + // "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor + // but mkldnn matmul primitive only support bias be 1-D tensors + // to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over + if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum(); + // If alpha = 0, dose not need actually do gemm computation + if (alpha == 0) + return; + + auto is_mkldnn_optimized_format = [&](const Tensor& t) { + if (t.is_contiguous()) return true; + const auto sizes = t.sizes(); + const auto strides = t.strides(); + if (t.dim() == 2){ + return strides[0] == 1 && strides[1] == sizes[0]; + } else { + // dim = 3 + return strides[0] == sizes[1] * sizes[2] && strides[1] == 1 && strides[2] == sizes[1]; + } + }; + + // Mkldnn only optimized for contiguous or transposed (transpose last 2 dim if 3-D tensor) format now + // Will remove this "contiguous" after mkldnn have fully supported + Tensor mat1_ = is_mkldnn_optimized_format(mat1_unsqueezed) ? mat1_unsqueezed : mat1_unsqueezed.contiguous(); + Tensor mat2_ = is_mkldnn_optimized_format(mat2_unsqueezed) ? mat2_unsqueezed : mat2_unsqueezed.contiguous(); + + // mkldnn_matmul only proceed CPU tensor + const ideep::tensor x = itensor_view_from_dense(mat1_); + const ideep::tensor w = itensor_view_from_dense(mat2_); + ideep::tensor y = itensor_view_from_dense(result_unsqueezed); + ideep::matmul_forward::compute(x, w, y, alpha, beta, + ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr); + if (y.get_data_handle() != result.data_ptr()){ + // ideep will query onednn expect format of output + // if given output format is not expected, ideep will re-init an output buffer + // under this case, we need copy the re-inited buffer back to given buffer + ideep::tensor public_y = itensor_view_from_dense(result); + y.reorder_to(public_y); + } + + if (mat1.dim() == 1 && mat2.dim() == 1){ + // aten::dot + result.squeeze_(); + } + +} + +inline bool checksize(const Tensor& mat1, const Tensor& mat2){ + // if dim = 2, mat1's size = (m * n), mat2's size = (n * k) + // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k) + // else called from aten::mv, mat1.size = (m * n), mat2.size = (n) + // only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel + static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16; + if (mat1.dim() == 1 && mat2.dim() == 1) { + // aten::dot + return mat1.size(0) > mkldnn_gemm_min_size; + } else if (mat1.dim() == 2 && mat2.dim() == 1) { + // aten::mv + return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size; + } else if (mat2.dim() == 2 && mat2.dim() == 2) { + // aten::addmm + return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size; + } else { + // aten::bmm, aten::baddbmm + return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size; + } +} + +bool use_mkldnn_bf16_matmul( + const Tensor& mat1, + const Tensor& mat2, + const c10::optional& result_opt) { + c10::MaybeOwned result_maybe_owned = at::borrow_from_optional_tensor(result_opt); + const Tensor& result = *result_maybe_owned; + return ( + at::globalContext().userEnabledMkldnn() && + mat1.scalar_type() == kBFloat16 && + mat2.scalar_type() == kBFloat16 && + (!result.defined() || result.scalar_type() == kBFloat16) && + mat1.numel() != 0 && + mat2.numel() != 0 && + mkldnn_bf16_device_check() && + checksize(mat1, mat2)); +} + +} // namespace native +} // namespace at + +#endif // AT_MKLDNN_EBABLED diff --git a/aten/src/ATen/native/mkldnn/Matmul.h b/aten/src/ATen/native/mkldnn/Matmul.h new file mode 100644 index 0000000..f19365c --- /dev/null +++ b/aten/src/ATen/native/mkldnn/Matmul.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +// result = beta * result + alpha * gemm(mat1, mat2) +TORCH_API void mkldnn_matmul( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result, + float beta=1, + float alpha=1); + +bool use_mkldnn_bf16_matmul( + const Tensor& mat1, + const Tensor& mat2, + const c10::optional& result_opt); + +} + +} diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h index abfafd5..a27b842 100644 --- a/aten/src/ATen/native/mkldnn/Utils.h +++ b/aten/src/ATen/native/mkldnn/Utils.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include diff --git a/test/test_linalg.py b/test/test_linalg.py index f64e081..08fc535 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -3991,8 +3991,14 @@ class TestLinalg(TestCase): def check(x, y): # Compare with numpy res = torch_fn(x, y) - ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy()))) - self.assertEqual(res.cpu(), ref) + if x.dtype == torch.bfloat16: + ref = torch.from_numpy(np.array(np_fn(x.cpu().float().numpy(), y.cpu().float().numpy()))) + else: + ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy()))) + if res.dtype == torch.bfloat16: + self.assertEqual(res.cpu(), ref.bfloat16()) + else: + self.assertEqual(res.cpu(), ref) # Test out variant out = torch.empty_like(res) @@ -4005,19 +4011,20 @@ class TestLinalg(TestCase): check(x, y) # Contiguous - x = torch.randn(10, dtype=dtype, device=device) - y = torch.randn(10, dtype=dtype, device=device) + x = 0.1 * torch.randn(5000, dtype=dtype, device=device) + y = 0.1 * torch.randn(5000, dtype=dtype, device=device) check(x, y) # 0 strided - y = torch.randn(1, dtype=dtype, device=device).expand(10) + y = 0.1 * torch.randn(1, dtype=dtype, device=device).expand(5000) check(x, y) # 2 strided check(x[::2], y[::2]) - @dtypes(torch.float, torch.cfloat) - @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5}) + @dtypes(torch.float, torch.cfloat, torch.bfloat16) + @dtypesIfCUDA(torch.float, torch.cfloat) + @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5, torch.bfloat16: 1e-0}) def test_dot_vs_numpy(self, device, dtype): self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot) @@ -5951,31 +5958,32 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A # have to use torch.randn(...).to(bfloat16) instead of # torch.randn(..., dtype=bfloat16). randn does not support # bfloat16 yet. + # "*0.2" to reduce errors for low precision ts = [ - torch.randn(10, device=device).to(dtype), - torch.randn(1, device=device).to(dtype).expand(10), + 0.2 * torch.randn(50, device=device).to(dtype), + 0.2 * torch.randn(1, device=device).to(dtype).expand(50), ] vs = [ - torch.randn(100, device=device).to(dtype), - torch.ones(1, device=device).to(dtype).expand(100), # to reduce errors for low precision + 0.2 * torch.randn(100, device=device).to(dtype), + 0.2 * torch.ones(1, device=device).to(dtype).expand(100), # to reduce errors for low precision ] ms = [ # 0d - torch.ones((), device=device).to(dtype).expand(10, 100), # to reduce errors for low precision + 0.2 * torch.ones((), device=device).to(dtype).expand(50, 100), # to reduce errors for low precision # 1d - torch.randn((1, 100), device=device).to(dtype).expand(10, 100), + 0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100), # this initialization reduces errors for low precision for broadcasted matrices # by making sure that intermediate and result values are exactly representable # in low precision type - torch.randint(3, (10, 1), dtype=torch.float, device=device).to(dtype).expand(10, 100), + 0.2 * torch.randint(3, (50, 1), dtype=torch.float, device=device).to(dtype).expand(50, 100), # 2d - torch.randn((10, 100), device=device).to(dtype), - torch.randn((100, 10), device=device).to(dtype).t(), + 0.2 * torch.randn((50, 100), device=device).to(dtype), + 0.2 * torch.randn((100, 50), device=device).to(dtype).t(), ] for m, v, t in itertools.product(ms, vs, ts): self._test_addmm_addmv(torch.addmv, t, m, v) # Test beta=0, t=nan - t = torch.full((10,), math.nan, device=device).to(dtype) + t = torch.full((50,), math.nan, device=device).to(dtype) for m, v in itertools.product(ms, vs): self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) @@ -6153,12 +6161,12 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A return torch.randint(0, 100, (x, y), dtype=dtype, device=device) def genf_bfloat(x, y): - return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) + return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 def genf_float(x, y): return torch.randn(x, y, dtype=dtype, device=device) - for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]: + for (n, m, p) in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: if (dtype == torch.int32) or (dtype == torch.int64): genf = genf_int elif (dtype == torch.bfloat16): @@ -6229,7 +6237,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A return batch_sizes = [1, 10] - M, N, O = 23, 8, 12 + M, N, O = 23, 15, 12 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 is_supported = True @@ -6251,8 +6259,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A def generate_inputs(num_batches): # transposed tensors for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): - b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) - b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + b1 = make_tensor((num_batches, M, N), device, dtype, low=-0.1, high=0.1) + b2 = make_tensor((num_batches, N, O), device, dtype, low=-0.1, high=0.1) b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) yield b1, b2 @@ -6260,8 +6268,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) - b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) - b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) + b1 = make_tensor(shape1, device, dtype, low=-0.1, high=0.1).expand(num_batches, M, N) + b2 = make_tensor(shape2, device, dtype, low=-0.1, high=0.1).expand(num_batches, N, O) yield b1, b2 # zero-sized tensors for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): @@ -6341,7 +6349,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A return num_batches = 2 - M, N, O = 2, 3, 4 + M, N, O = 16, 17, 18 is_supported = True if dtype == torch.bfloat16: @@ -6367,8 +6375,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A # transposed tensors for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): for perm3 in itertools.permutations((0, 1)): - b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) - b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) * 0.1 + b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) * 0.1 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) ref = torch.from_numpy( @@ -6380,8 +6388,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) - b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) - b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) + b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) * 0.1 + b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) * 0.1 ref = torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ).to(device=device, dtype=dtype).sum(0) @@ -6391,8 +6399,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) - b1 = make_tensor(shape1, device, dtype, low=-1, high=1) - b2 = make_tensor(shape2, device, dtype, low=-1, high=1) + b1 = make_tensor(shape1, device, dtype, low=-1, high=1) * 0.1 + b2 = make_tensor(shape2, device, dtype, low=-1, high=1) * 0.1 ref = torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ).to(device=device, dtype=dtype).sum(0) @@ -6414,7 +6422,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A return num_batches = 10 - M, N, O = 12, 8, 5 + M, N, O = 12, 8, 50 is_supported = True if dtype == torch.bfloat16 and self.device_type == 'cuda': diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index a8b9ea1..6b11812 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -871,6 +871,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/native/mkldnn/TensorShape.cpp", "aten/src/ATen/native/mkldnn/UnaryOps.cpp", "aten/src/ATen/native/mkldnn/Utils.cpp", + "aten/src/ATen/native/mkldnn/Matmul.cpp", "aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp", "aten/src/ATen/record_function.cpp", "aten/src/ATen/SavedTensorHooks.cpp", diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1b9939f..19b3712 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6132,14 +6132,14 @@ op_db: List[OpInfo] = [ ), sample_inputs_func=sample_inputs_baddbmm), OpInfo('dot', - dtypes=all_types_and_complex_and(torch.float16), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), assert_autodiffed=True, sample_inputs_func=sample_inputs_dot_vdot, supports_forward_ad=True, ), OpInfo('vdot', - dtypes=all_types_and_complex_and(torch.float16), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), sample_inputs_func=sample_inputs_dot_vdot, supports_forward_ad=True, @@ -7352,7 +7352,7 @@ op_db: List[OpInfo] = [ OpInfo('matmul', aliases=('linalg.matmul',), dtypes=floating_types(), - dtypesIfCPU=all_types_and_complex(), + dtypesIfCPU=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), dtypesIfROCM=floating_types_and(torch.half, torch.bfloat16), backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, @@ -8213,7 +8213,7 @@ op_db: List[OpInfo] = [ OpInfo('__rmatmul__', op=torch.Tensor.__rmatmul__, dtypes=floating_types(), - dtypesIfCPU=all_types_and_complex(), + dtypesIfCPU=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else [], torch.complex64, torch.complex128), backward_dtypesIfCUDA=floating_types_and(torch.float16, -- 2.7.4