Fix zero-dim handling in torch.matmul (#63359)
authorRichard Zou <zou3519@gmail.com>
Tue, 17 Aug 2021 20:39:52 +0000 (13:39 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 20:44:47 +0000 (13:44 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63359

Fixes #63352. The problem was that in e.g. `torch.matmul(A, B)` with A,
B having shapes [3, 2, 0] and [0, 2], the code attempts to call
`A.view(-1, 0)` which fails due to "-1 being ambiguous". The solution is
to manually compute what we want the shape of the view to be.

Test Plan: - new tests

Reviewed By: ngimel

Differential Revision: D30351583

Pulled By: zou3519

fbshipit-source-id: 7625691fe8b85d96a4073409596a932c303e3e8c

aten/src/ATen/native/LinearAlgebra.cpp
torch/testing/_internal/common_methods_invocations.py

index 1be3788..bbb6fce 100644 (file)
@@ -1455,7 +1455,12 @@ Tensor matmul(
     }
 
     // fold the batch into the first dimension
-    Tensor t1 = tensor1.expect_contiguous()->view({-1, size1[size1.size() - 1]});
+    // Why not tensor1.view(-1, size1[size1.size() -1])?
+    // If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
+    // This can happen in e.g. [3, 5, 0] @ [0, 0].
+    // So we manually compute the folding as a result.
+    const auto dim1_size = c10::multiply_integers(size1.begin(), size1.end() - 1);
+    auto t1 = tensor1.expect_contiguous()->view({dim1_size, size1[size1.size() - 1]});
     Tensor output = has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size)
                             : at::_unsafe_view(t1.mm(t2), output_size);
     return has_out ? out.set_(output) : output;
index ee43a02..b281c5e 100644 (file)
@@ -4129,10 +4129,13 @@ def sample_inputs_matmul(op_info, device, dtype, requires_grad):
                   ((S, M), (M,)),
                   ((M,), (M, S)),
                   ((S, M), (M, S)),
+                  ((S, 0), (0, M)),
                   ((S, S, M), (M,)),
                   ((S, S, M), (M, S)),
+                  ((S, S, 0), (0, S)),
                   ((M,), (S, M, S)),
                   ((S, M), (S, M, S)),
+                  ((0, 0), (S, 0, 0)),
                   ((S, S, M, M), (S, S, M, S)),
                   ((S, S, M, M), (M,)),
                   ((M,), (S, S, M, S)))