[mlir][vector] Add scalable vectors support to OuterProductOp
authorFrank (Fang) Gao <fang.gao1@huawei.com>
Fri, 13 Jan 2023 15:47:44 +0000 (10:47 -0500)
committerPrabhdeep Singh Soni <prabhdeep.singh.soni3@huawei.com>
Fri, 13 Jan 2023 15:54:18 +0000 (10:54 -0500)
This will probably be the first in a series of patches that tries to
enable code generation for ARM SME (extension of SVE).

Since SME's core operation is the outer product instruction, I figured
that it would probably be a good idea to enable the outer product
operation to properly accept and generate scalable vectors.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D138718

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

index cc56ce6..e2a3e61 100644 (file)
@@ -2657,10 +2657,18 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
   if (!vLHS)
     return parser.emitError(parser.getNameLoc(),
                             "expected vector type for operand #1");
-  VectorType resType =
-      vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
-                             vLHS.getElementType())
-           : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
+
+  unsigned numScalableDims = vLHS.getNumScalableDims();
+  VectorType resType;
+  if (vRHS) {
+    numScalableDims += vRHS.getNumScalableDims();
+    resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+                              vLHS.getElementType(), numScalableDims);
+  } else {
+    // Scalar RHS operand
+    resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
+                              numScalableDims);
+  }
 
   if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
     result.attributes.append(
@@ -2696,6 +2704,9 @@ LogicalResult OuterProductOp::verify() {
       return emitOpError("expected #1 operand dim to match result dim #1");
     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
       return emitOpError("expected #2 operand dim to match result dim #2");
+    if (vRHS.isScalable() != vLHS.isScalable())
+      return emitOpError("expected either all or none of vector operands #1 "
+                         "and #2 to be scalable");
   } else {
     // An AXPY operation.
     if (vRES.getRank() != 1)