}
};
-// TODO(ajcbik): remove this rule once LinAlg tests are cleaned up
-class VectorOuterProductOpConversion : public ConvertToLLVMPattern {
-public:
- explicit VectorOuterProductOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(),
- context, typeConverter) {}
-
- PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
- auto *ctx = op->getContext();
- auto vLHS = adaptor.lhs().getType().cast<LLVM::LLVMType>();
- auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
- auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
- auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
- auto llvmArrayOfVectType = typeConverter.convertType(
- cast<vector::OuterProductOp>(op).getResult().getType());
- Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
- Value a = adaptor.lhs(), b = adaptor.rhs();
- Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
- SmallVector<Value, 8> lhs, accs;
- lhs.reserve(rankLHS);
- accs.reserve(rankLHS);
- for (unsigned d = 0, e = rankLHS; d < e; ++d) {
- // shufflevector explicitly requires i32.
- auto attr = rewriter.getI32IntegerAttr(d);
- SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
- auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
- Value aD = nullptr, accD = nullptr;
- // 1. Broadcast the element a[d] into vector aD.
- aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
- // 2. If acc is present, extract 1-d vector acc[d] into accD.
- if (acc)
- accD = rewriter.create<LLVM::ExtractValueOp>(
- loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
- // 3. Compute aD outer b (plus accD, if relevant).
- Value aOuterbD =
- accD
- ? rewriter.create<LLVM::FMAOp>(loc, vRHS, aD, b, accD).getResult()
- : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
- // 4. Insert as value `d` in the descriptor.
- desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
- desc, aOuterbD,
- rewriter.getI64ArrayAttr(d));
- }
- rewriter.replaceOp(op, desc);
- return matchSuccess();
- }
-};
-
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorTypeCastOpConversion(MLIRContext *context,
VectorShuffleOpConversion, VectorExtractElementOpConversion,
VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorOuterProductOpConversion, VectorTypeCastOpConversion,
- VectorPrintOpConversion>(ctx, converter);
+ VectorTypeCastOpConversion, VectorPrintOpConversion>(
+ ctx, converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefix=LLVM-LOOPS
func @range(%arg0: index) {
%c0 = constant 0 : index
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
-// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// LLVM-LOOPS: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32, 1 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// LLVM-LOOPS: llvm.shufflevector {{.*}} [2 : i32, 2 : i32, 2 : i32, 2 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// LLVM-LOOPS: llvm.shufflevector {{.*}} [3 : i32, 3 : i32, 3 : i32, 3 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-// LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
-// LLVM-LOOPS-NEXT: "llvm.intr.fma"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
-// LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
-
+// LLVM-LOOPS-SAME: %[[A:.*0]]: memref<?x?xvector<4xf32>>,
+// LLVM-LOOPS-SAME: %[[B:.*1]]: memref<?x?xvector<4xf32>>,
+// LLVM-LOOPS-SAME: %[[C:.*2]]: memref<?x?xvector<4x4xf32>>)
+// LLVM-LOOPS: %[[C0:.*]] = constant 0 : index
+// LLVM-LOOPS: %[[C1:.*]] = constant 1 : index
+// LLVM-LOOPS: %[[T0:.*]] = dim %[[A]], 0 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T1:.*]] = dim %[[A]], 1 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T2:.*]] = dim %[[B]], 1 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: loop.for %[[I:.*]] = %[[C0]] to %[[T0]] step %[[C1]] {
+// LLVM-LOOPS: loop.for %[[J:.*]] = %[[C0]] to %[[T2]] step %[[C1]] {
+// LLVM-LOOPS: loop.for %[[K:.*]] = %[[C0]] to %[[T1]] step %[[C1]] {
+// LLVM-LOOPS: %[[T3:.*]] = load %[[A]][%[[I]], %[[K]]] : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T4:.*]] = load %[[B]][%[[K]], %[[J]]] : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T5:.*]] = load %[[C]][%[[I]], %[[J]]] : memref<?x?xvector<4x4xf32>>
+// LLVM-LOOPS: %[[T6:.*]] = vector.outerproduct %3, %4, %5 : vector<4xf32>, vector<4xf32>
+// LLVM-LOOPS: store %[[T6]], %[[C]][%[[I]], %[[J]]] : memref<?x?xvector<4x4xf32>>
#indexed_matmul_trait = {
args_in = 2,