From 495e7e4815d3f9a4000a6671022fd2608440db75 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 17 Aug 2021 13:39:52 -0700 Subject: [PATCH] Fix zero-dim handling in torch.matmul (#63359) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63359 Fixes #63352. The problem was that in e.g. `torch.matmul(A, B)` with A, B having shapes [3, 2, 0] and [0, 2], the code attempts to call `A.view(-1, 0)` which fails due to "-1 being ambiguous". The solution is to manually compute what we want the shape of the view to be. Test Plan: - new tests Reviewed By: ngimel Differential Revision: D30351583 Pulled By: zou3519 fbshipit-source-id: 7625691fe8b85d96a4073409596a932c303e3e8c --- aten/src/ATen/native/LinearAlgebra.cpp | 7 ++++++- torch/testing/_internal/common_methods_invocations.py | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 1be3788..bbb6fce 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1455,7 +1455,12 @@ Tensor matmul( } // fold the batch into the first dimension - Tensor t1 = tensor1.expect_contiguous()->view({-1, size1[size1.size() - 1]}); + // Why not tensor1.view(-1, size1[size1.size() -1])? + // If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous. + // This can happen in e.g. [3, 5, 0] @ [0, 0]. + // So we manually compute the folding as a result. + const auto dim1_size = c10::multiply_integers(size1.begin(), size1.end() - 1); + auto t1 = tensor1.expect_contiguous()->view({dim1_size, size1[size1.size() - 1]}); Tensor output = has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size) : at::_unsafe_view(t1.mm(t2), output_size); return has_out ? out.set_(output) : output; diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ee43a02..b281c5e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -4129,10 +4129,13 @@ def sample_inputs_matmul(op_info, device, dtype, requires_grad): ((S, M), (M,)), ((M,), (M, S)), ((S, M), (M, S)), + ((S, 0), (0, M)), ((S, S, M), (M,)), ((S, S, M), (M, S)), + ((S, S, 0), (0, S)), ((M,), (S, M, S)), ((S, M), (S, M, S)), + ((0, 0), (S, 0, 0)), ((S, S, M, M), (S, S, M, S)), ((S, S, M, M), (M,)), ((M,), (S, S, M, S))) -- 2.7.4