if (!vLHS)
return parser.emitError(parser.getNameLoc(),
"expected vector type for operand #1");
- VectorType resType =
- vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
- vLHS.getElementType())
- : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
+
+ unsigned numScalableDims = vLHS.getNumScalableDims();
+ VectorType resType;
+ if (vRHS) {
+ numScalableDims += vRHS.getNumScalableDims();
+ resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+ vLHS.getElementType(), numScalableDims);
+ } else {
+ // Scalar RHS operand
+ resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
+ numScalableDims);
+ }
if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
result.attributes.append(
return emitOpError("expected #1 operand dim to match result dim #1");
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
return emitOpError("expected #2 operand dim to match result dim #2");
+ if (vRHS.isScalable() != vLHS.isScalable())
+ return emitOpError("expected either all or none of vector operands #1 "
+ "and #2 to be scalable");
} else {
// An AXPY operation.
if (vRES.getRank() != 1)
--- /dev/null
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt
+
+func.func @scalable_outerproduct(%src : memref<?xf32>) {
+ %idx = arith.constant 0 : index
+ %cst = arith.constant 1.0 : f32
+ %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+ %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+
+ %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<[4]xf32>
+ vector.store %op, %src[%idx] : memref<?xf32>, vector<[4x4]xf32>
+
+ %op2 = vector.outerproduct %0, %cst : vector<[4]xf32>, f32
+ vector.store %op2, %src[%idx] : memref<?xf32>, vector<[4]xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_outerproduct(%src : memref<?xf32>) {
+ %idx = arith.constant 0 : index
+ %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+ %1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
+
+ // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
+ %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
+}
+// -----
+
+func.func @invalid_outerproduct1(%src : memref<?xf32>) {
+ %idx = arith.constant 0 : index
+ %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4x4]xf32>
+ %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+
+ // expected-error @+1 {{expected 1-d vector for operand #1}}
+ %op = vector.outerproduct %0, %1 : vector<[4x4]xf32>, vector<[4]xf32>
+}