[mlir] [VectorOps,LinAlg] Remove direct LLVM lowering for vector.broadcast
authoraartbik <ajcbik@google.com>
Fri, 13 Mar 2020 16:35:29 +0000 (09:35 -0700)
committeraartbik <ajcbik@google.com>
Fri, 13 Mar 2020 18:42:51 +0000 (11:42 -0700)
Summary:
The direct lowering of vector.broadcast into LLVM has been replaced by
progressive lowering into elementary vector ops. This also required a
small refactoring of a llvm.mlir test that used a direct vector.broadcast
operator (just to define a matmul).

Reviewers: nicolasvasilache, andydavis1, rriddle

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76143

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Dialect/Linalg/llvm.mlir

index a41a9d2..828a964 100644 (file)
@@ -817,59 +817,6 @@ public:
   }
 };
 
-// TODO(ajcbik): remove this rule once LinAlg tests are cleaned up
-class VectorOuterProductOpConversion : public ConvertToLLVMPattern {
-public:
-  explicit VectorOuterProductOpConversion(MLIRContext *context,
-                                          LLVMTypeConverter &typeConverter)
-      : ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(),
-                             context, typeConverter) {}
-
-  PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto loc = op->getLoc();
-    auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
-    auto *ctx = op->getContext();
-    auto vLHS = adaptor.lhs().getType().cast<LLVM::LLVMType>();
-    auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
-    auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
-    auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
-    auto llvmArrayOfVectType = typeConverter.convertType(
-        cast<vector::OuterProductOp>(op).getResult().getType());
-    Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
-    Value a = adaptor.lhs(), b = adaptor.rhs();
-    Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
-    SmallVector<Value, 8> lhs, accs;
-    lhs.reserve(rankLHS);
-    accs.reserve(rankLHS);
-    for (unsigned d = 0, e = rankLHS; d < e; ++d) {
-      // shufflevector explicitly requires i32.
-      auto attr = rewriter.getI32IntegerAttr(d);
-      SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
-      auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
-      Value aD = nullptr, accD = nullptr;
-      // 1. Broadcast the element a[d] into vector aD.
-      aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
-      // 2. If acc is present, extract 1-d vector acc[d] into accD.
-      if (acc)
-        accD = rewriter.create<LLVM::ExtractValueOp>(
-            loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
-      // 3. Compute aD outer b (plus accD, if relevant).
-      Value aOuterbD =
-          accD
-              ? rewriter.create<LLVM::FMAOp>(loc, vRHS, aD, b, accD).getResult()
-              : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
-      // 4. Insert as value `d` in the descriptor.
-      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
-                                                  desc, aOuterbD,
-                                                  rewriter.getI64ArrayAttr(d));
-    }
-    rewriter.replaceOp(op, desc);
-    return matchSuccess();
-  }
-};
-
 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorTypeCastOpConversion(MLIRContext *context,
@@ -1160,8 +1107,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                   VectorShuffleOpConversion, VectorExtractElementOpConversion,
                   VectorExtractOpConversion, VectorFMAOp1DConversion,
                   VectorInsertElementOpConversion, VectorInsertOpConversion,
-                  VectorOuterProductOpConversion, VectorTypeCastOpConversion,
-                  VectorPrintOpConversion>(ctx, converter);
+                  VectorTypeCastOpConversion, VectorPrintOpConversion>(
+      ctx, converter);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
index 82ec950..290e1a2 100644 (file)
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefix=LLVM-LOOPS
 
 func @range(%arg0: index) {
   %c0 = constant 0 : index
@@ -172,14 +172,22 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
 // CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
 
 // LLVM-LOOPS-LABEL: func @matmul_vec_impl(
-//   LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-//   LLVM-LOOPS: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32, 1 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-//   LLVM-LOOPS: llvm.shufflevector {{.*}} [2 : i32, 2 : i32, 2 : i32, 2 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-//   LLVM-LOOPS: llvm.shufflevector {{.*}} [3 : i32, 3 : i32, 3 : i32, 3 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-//   LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
-//   LLVM-LOOPS-NEXT: "llvm.intr.fma"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
-//   LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
-
+//  LLVM-LOOPS-SAME: %[[A:.*0]]: memref<?x?xvector<4xf32>>,
+//  LLVM-LOOPS-SAME: %[[B:.*1]]: memref<?x?xvector<4xf32>>,
+//  LLVM-LOOPS-SAME: %[[C:.*2]]: memref<?x?xvector<4x4xf32>>)
+// LLVM-LOOPS: %[[C0:.*]] = constant 0 : index
+// LLVM-LOOPS: %[[C1:.*]] = constant 1 : index
+// LLVM-LOOPS: %[[T0:.*]] = dim %[[A]], 0 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T1:.*]] = dim %[[A]], 1 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: %[[T2:.*]] = dim %[[B]], 1 : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS: loop.for %[[I:.*]] = %[[C0]] to %[[T0]] step %[[C1]] {
+// LLVM-LOOPS: loop.for %[[J:.*]] = %[[C0]] to %[[T2]] step %[[C1]] {
+// LLVM-LOOPS: loop.for %[[K:.*]] = %[[C0]] to %[[T1]] step %[[C1]] {
+// LLVM-LOOPS:   %[[T3:.*]] = load %[[A]][%[[I]], %[[K]]] : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS:   %[[T4:.*]] = load %[[B]][%[[K]], %[[J]]] : memref<?x?xvector<4xf32>>
+// LLVM-LOOPS:   %[[T5:.*]] = load %[[C]][%[[I]], %[[J]]] : memref<?x?xvector<4x4xf32>>
+// LLVM-LOOPS:   %[[T6:.*]] = vector.outerproduct %3, %4, %5 : vector<4xf32>, vector<4xf32>
+// LLVM-LOOPS:   store %[[T6]], %[[C]][%[[I]], %[[J]]] : memref<?x?xvector<4x4xf32>>
 
 #indexed_matmul_trait = {
   args_in = 2,