#define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
namespace mlir {
+class LLVMTypeConverter;
class ModulePassBase;
+class OwningRewritePatternList;
+/// Collect a set of patterns to convert from the Vector dialect to LLVM.
+void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns);
+
+/// Create a pass to convert vector operations to the LLVMIR dialect.
ModulePassBase *createLowerVectorToLLVMPass();
} // namespace mlir
}
def OuterProductOp :
Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
- Arguments<(ins AnyVector:$lhs, AnyVector:$rhs)>,
+ Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
Results<(outs AnyVector)> {
- let summary = "outerproduct operation";
+ let summary = "vector outerproduct with optional fused add";
let description = [{
Takes 2 1-D vectors and returns the 2-D vector containing the outer product.
- Example:
- ```
+ An optional extra 2-D vector argument may be specified in which case the
+ operation returns the sum of the outer product and the extra vector. When
+ lowered to the LLVMIR dialect, this form emits `llvm.fmuladd`, which can
+ lower to actual `fma` instructions in LLVM.
+
+ Examples
+
%2 = vector.extractelement %0, %1: vector<4xf32>, vector<8xf32>
return %2: vector<4x8xf32>
- ```
+
+ %3 = vector.extractelement %0, %1, %2:
+ vector<4xf32>, vector<8xf32>, vector<4x8xf32>
+ return %3: vector<4x8xf32>
}];
let extraClassDeclaration = [{
VectorType getOperandVectorTypeLHS() {
VectorType getOperandVectorTypeRHS() {
return rhs()->getType().cast<VectorType>();
}
+ VectorType getOperandVectorTypeACC() {
+ return (llvm::size(acc()) == 0) ? VectorType() :
+ (*acc().begin())->getType().cast<VectorType>();
+ }
VectorType getVectorType() {
return getResult()->getType().cast<VectorType>();
}
auto positionArrayAttr = extractOp.position();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
- Value *extracted =
- rewriter
- .create<LLVM::ExtractValueOp>(loc, llvmResultType,
- adaptor.vector(), positionArrayAttr)
- .getResult();
+ Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
}
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
- auto indexType = rewriter.getIndexType();
+ auto i32Type = rewriter.getIntegerType(32);
if (positionAttrs.size() > 1) {
auto nDVectorType = vectorType;
auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
nDVectorType.getElementType());
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
- extracted = rewriter
- .create<LLVM::ExtractValueOp>(
- loc, lowering.convertType(oneDVectorType), extracted,
- nMinusOnePositionAttrs)
- .getResult();
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs);
}
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto constant = rewriter
- .create<LLVM::ConstantOp>(
- loc, lowering.convertType(indexType), position)
- .getResult();
+ auto constant = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(i32Type), position);
extracted =
- rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant)
- .getResult();
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
return matchSuccess();
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
auto *ctx = op->getContext();
- auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
- auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
- auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements();
- auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements();
+ auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
+ auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
+ auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
+ auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
- Value *desc =
- rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult();
- for (unsigned i = 0, e = rankV1; i < e; ++i) {
- // Emit the following pattern:
- // vec(a[i]) * b -> llvmStructOfVectType[i]
- Value *a = adaptor.lhs(), *b = adaptor.rhs();
- // shufflevector explicitly requires i32 /
- auto attr = rewriter.getI32IntegerAttr(i);
- SmallVector<Attribute, 4> broadcastAttr(rankV2, attr);
- auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx);
- auto *broadcasted =
- rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr)
- .getResult();
- auto *multiplied =
- rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult();
- desc = rewriter
- .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc,
- multiplied,
- positionAttr(rewriter, i))
- .getResult();
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+ Value *a = adaptor.lhs(), *b = adaptor.rhs();
+ Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+ SmallVector<Value *, 8> lhs, accs;
+ lhs.reserve(rankLHS);
+ accs.reserve(rankLHS);
+ for (unsigned d = 0, e = rankLHS; d < e; ++d) {
+ // shufflevector explicitly requires i32.
+ auto attr = rewriter.getI32IntegerAttr(d);
+ SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
+ auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
+ Value *aD = nullptr, *accD = nullptr;
+ // 1. Broadcast the element a[d] into vector aD.
+ aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
+ // 2. If acc is present, extract 1-d vector acc[d] into accD.
+ if (acc)
+ accD = rewriter.create<LLVM::ExtractValueOp>(loc, vRHS, acc,
+ positionAttr(rewriter, d));
+ // 3. Compute aD outer b (plus accD, if relevant).
+ Value *aOuterbD =
+ accD ? rewriter.create<LLVM::fmuladd>(loc, vRHS, aD, b, accD)
+ .getResult()
+ : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
+ // 4. Insert as value `d` in the descriptor.
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmArrayOfVectType, desc, aOuterbD, positionAttr(rewriter, d));
}
rewriter.replaceOp(op, desc);
return matchSuccess();
};
/// Populate the given list with patterns that convert from Vector to LLVM.
-static void
-populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
+void mlir::populateVectorToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
- ctx, converter);
+ converter.getDialect()->getContext(), converter);
}
namespace {
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
- populateVectorToLLVMConversionPatterns(converter, patterns, &getContext());
+ populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
static void print(OpAsmPrinter *p, OuterProductOp op) {
*p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
+ if (llvm::size(op.acc()) > 0)
+ *p << ", " << **op.acc().begin();
*p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
}
static ParseResult parseOuterProductOp(OpAsmParser *parser,
OperationState *result) {
- SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
- Type t0, t1;
- if (parser->parseOperandList(operandsInfo) || parser->parseColonType(t0) ||
- parser->parseComma() || parser->parseType(t1))
+ SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
+ Type tLHS, tRHS;
+ if (parser->parseOperandList(operandsInfo) || parser->parseColonType(tLHS) ||
+ parser->parseComma() || parser->parseType(tRHS))
return failure();
- VectorType v0 = t0.dyn_cast<VectorType>();
- VectorType v1 = t1.dyn_cast<VectorType>();
- if (!v0 || !v1)
+ if (operandsInfo.size() < 2)
+ return parser->emitError(parser->getNameLoc(),
+ "expected at least 2 operands");
+ VectorType vLHS = tLHS.dyn_cast<VectorType>();
+ VectorType vRHS = tRHS.dyn_cast<VectorType>();
+ if (!vLHS || !vRHS)
return parser->emitError(parser->getNameLoc(), "expected 2 vector types");
- VectorType resType = VectorType::get({v0.getDimSize(0), v1.getDimSize(0)},
- v0.getElementType());
- return failure(parser->resolveOperands(operandsInfo, {t0, t1},
- parser->getCurrentLocation(),
- result->operands) ||
- parser->addTypeToList(resType, result->types));
+ VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+ vLHS.getElementType());
+ return failure(
+ parser->resolveOperand(operandsInfo[0], tLHS, result->operands) ||
+ parser->resolveOperand(operandsInfo[1], tRHS, result->operands) ||
+ (operandsInfo.size() > 2 &&
+ parser->resolveOperand(operandsInfo[2], resType, result->operands)) ||
+ parser->addTypeToList(resType, result->types));
}
static LogicalResult verify(OuterProductOp op) {
- VectorType v1 = op.getOperandVectorTypeLHS(),
- v2 = op.getOperandVectorTypeRHS(), res = op.getVectorType();
- if (v1.getRank() != 1)
+ VectorType vLHS = op.getOperandVectorTypeLHS(),
+ vRHS = op.getOperandVectorTypeRHS(),
+ vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
+ if (vLHS.getRank() != 1)
return op.emitOpError("expected 1-d vector for operand #1");
- if (v2.getRank() != 1)
+ if (vRHS.getRank() != 1)
return op.emitOpError("expected 1-d vector for operand #2");
- if (res.getRank() != 2)
+ if (vRES.getRank() != 2)
return op.emitOpError("expected 2-d vector result");
- if (v1.getDimSize(0) != res.getDimSize(0))
- return op.emitOpError(
- "expected first operand dim to match first result dim");
- if (v2.getDimSize(0) != res.getDimSize(1))
- return op.emitOpError(
- "expected second operand dim to match second result dim");
+ if (vLHS.getDimSize(0) != vRES.getDimSize(0))
+ return op.emitOpError("expected #1 operand dim to match result dim #1");
+ if (vRHS.getDimSize(0) != vRES.getDimSize(1))
+ return op.emitOpError("expected #2 operand dim to match result dim #2");
+ if (vACC && vACC != vRES)
+ return op.emitOpError("expected operand #3 of same type as result type");
return success();
}
+
//===----------------------------------------------------------------------===//
// VectorTransferReadOp
//===----------------------------------------------------------------------===//
// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s
-func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
- %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
- %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32>
- return %3 : vector<8xf32>
+func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
+ %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
+ return %2 : vector<2x3xf32>
}
-// CHECK-LABEL: vec_1d
-// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
-// CHECK-5: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
-// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]">
-// CHECK: llvm.return {{.*}} : !llvm<"<8 x float>">
+// CHECK-LABEL: outerproduct
+// CHECK: llvm.undef : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
-func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
- %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
- return %2 : vector<4x8xf32>
+func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+ %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
+ return %2 : vector<2x3xf32>
}
-// CHECK-LABEL: vec_2d
-// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
-// CHECK-4: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
-// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
-// CHECK: llvm.return {{.*}} : !llvm<"[4 x <8 x float>]">
+// CHECK-LABEL: outerproduct_add
+// CHECK: llvm.undef : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+// CHECK: "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+// CHECK: "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
-func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> {
- %0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32>
- return %0 : vector<8x16xf32>
+func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
+ %0 = vector.extractelement %arg0[0 : i32]: vector<4x3x16xf32>
+ return %0 : vector<3x16xf32>
}
-// CHECK-LABEL: vec_3d
-// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
-// CHECK: llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]">
\ No newline at end of file
+// CHECK-LABEL: extract_vec_2d_from_vec_3d
+// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
+// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]">
+
+func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
+ %0 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
+ return %0 : f32
+}
+// CHECK-LABEL: extract_element_from_vec_3d
+// CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
+// CHECK: llvm.constant(0 : i32) : !llvm.i32
+// CHECK: llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>">
+// CHECK: llvm.return %{{.*}} : !llvm.float
\ No newline at end of file
// -----
-// CHECK-LABEL: position_empty
-func @position_empty(%arg0: vector<4x8x16xf32>) {
+func @extract_element_vector_type(%arg0: index) {
+ // expected-error@+1 {{expected vector type}}
+ %1 = vector.extractelement %arg0[] : index
+}
+
+// -----
+
+func @extractelement_position_empty(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected non-empty position attribute}}
%1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_rank_overflow
-func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank smaller than vector}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_overflow
-func @position_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
+ // expected-error@+1 {{expected position attribute of rank smaller than vector}}
+ %1 = "vector.extractelement" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
+}
+
+// -----
+
+func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: position_underflow
-func @position_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
}
// -----
-// CHECK-LABEL: outerproduct_non_vector_operand
+func @outerproduct_num_operands(%arg0: f32) {
+ // expected-error@+1 {{expected at least 2 operands}}
+ %1 = vector.outerproduct %arg0 : f32, f32
+}
+// -----
+
func @outerproduct_non_vector_operand(%arg0: f32) {
// expected-error@+1 {{expected 2 vector types}}
%1 = vector.outerproduct %arg0, %arg0 : f32, f32
// -----
-// CHECK-LABEL: outerproduct_operand_1
func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// expected-error@+1 {{expected 1-d vector for operand #1}}
%1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
// -----
-// CHECK-LABEL: outerproduct_operand_2
func @outerproduct_operand_2(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// expected-error@+1 {{expected 1-d vector for operand #2}}
%1 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<4x8xf32>
}
+
+// -----
+
+func @outerproduct_result_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected 2-d vector result}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_1_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected #1 operand dim to match result dim #1}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8x16xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_2_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+ // expected-error@+1 {{expected #2 operand dim to match result dim #2}}
+ %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<4x16xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) {
+ // expected-error@+1 {{expected operand #3 of same type as result type}}
+ %1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>)
+}
}
// CHECK-LABEL: outerproduct
-func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
+func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
%0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
- return %0 : vector<4x8xf32>
+ // CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
+ %1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
+ return %1 : vector<4x8xf32>
}