[mlir][linalg] lower index operations during linalg to loop lowering.
authorTobias Gysi <gysit@google.com>
Tue, 13 Apr 2021 08:37:40 +0000 (08:37 +0000)
committerTobias Gysi <gysit@google.com>
Tue, 13 Apr 2021 09:04:09 +0000 (09:04 +0000)
The patch extends the linalg to loop lowering pass to replace all linalg index operations by the induction variables of the generated loop nests.

Differential Revision: https://reviews.llvm.org/D100364

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/test/Dialect/Linalg/loop-order.mlir
mlir/test/Dialect/Linalg/loops.mlir

index 8b2b0cd..bb1000a 100644 (file)
@@ -516,6 +516,47 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
   return loops;
 }
 
+/// Replace the index operations in the body of the loop nest by the matching
+/// induction variables. If available use the interchange vector to map the
+/// interchanged induction variables to the dimension of the index operation.
+static void replaceIndexOpsByInductionVariables(
+    LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef<Operation *> loopOps,
+    ArrayRef<unsigned> interchangeVector) {
+  // Extract the induction variables of the loop nest from outer to inner.
+  SmallVector<Value> allIvs;
+  for (Operation *loopOp : loopOps) {
+    llvm::TypeSwitch<Operation *>(loopOp)
+        .Case([&](scf::ParallelOp parallelOp) {
+          allIvs.append(parallelOp.getInductionVars().begin(),
+                        parallelOp.getInductionVars().end());
+        })
+        .Case([&](scf::ForOp forOp) {
+          allIvs.push_back(forOp.getInductionVar());
+        })
+        .Case([&](AffineForOp affineForOp) {
+          allIvs.push_back(affineForOp.getInductionVar());
+        })
+        .Default([&](Operation *op) { assert(false && "unexpected op"); });
+  }
+  assert(linalgOp.getNumLoops() == allIvs.size() &&
+         "expected the number of loops and induction variables to match");
+  // Replace the index operations in the body of the innermost loop op.
+  if (!loopOps.empty()) {
+    LoopLikeOpInterface loopOp = loopOps.back();
+    for (IndexOp indexOp :
+         llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) {
+      // Search the indexing dimension in the interchange vector if available.
+      assert(interchangeVector.empty() ||
+             interchangeVector.size() == linalgOp.getNumLoops());
+      const auto *it = llvm::find(interchangeVector, indexOp.dim());
+      uint64_t dim = it != interchangeVector.end()
+                         ? std::distance(interchangeVector.begin(), it)
+                         : indexOp.dim();
+      rewriter.replaceOp(indexOp, allIvs[dim]);
+    }
+  }
+}
+
 namespace {
 template <typename LoopType>
 class LinalgRewritePattern : public RewritePattern {
@@ -528,11 +569,14 @@ public:
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     auto linalgOp = dyn_cast<LinalgOp>(op);
-    // TODO: remove hasIndexSemantics check once index ops are supported.
-    if (!linalgOp || linalgOp.hasIndexSemantics())
+    if (!isa<LinalgOp>(op))
       return failure();
-    if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector))
+    Optional<LinalgLoops> loopOps =
+        linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector);
+    if (!loopOps.hasValue())
       return failure();
+    replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(),
+                                        interchangeVector);
     rewriter.eraseOp(op);
     return success();
   }
index 968ffdc..c572967 100644 (file)
@@ -24,22 +24,49 @@ func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
 
 // -----
 
-func @index_op(%arg0: memref<4x8xindex>) {
-  linalg.generic {
-    indexing_maps = [affine_map<(i, j) -> (i, j)>],
-    iterator_types = ["parallel", "parallel"]}
-  outs(%arg0 : memref<4x8xindex>) {
-  ^bb0(%arg1: index):   // no predecessors
-    %0 = linalg.index 1 : index
-    linalg.yield %0 : index
+#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)>
+func @generic(%output: memref<1x2x3x4x5xindex>) {
+  linalg.generic {indexing_maps = [#map],
+                  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+    outs(%output : memref<1x2x3x4x5xindex>) {
+    ^bb0(%arg0 : index):
+    %i = linalg.index 0 : index
+    %j = linalg.index 1 : index
+    %k = linalg.index 2 : index
+    %l = linalg.index 3 : index
+    %m = linalg.index 4 : index
+    %0 = addi %i, %j : index
+    %1 = addi %0, %k : index
+    %2 = addi %1, %l : index
+    %3 = addi %2, %m : index
+    linalg.yield %3: index
   }
   return
 }
-// LOOP-LABEL: @index_op
-//      LOOP:   linalg.generic
 
-// PARALLEL-LABEL: @index_op
-//      PARALLEL:   linalg.generic
+// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1
+// LOOP:   scf.for %[[i:.*]] = %c0 to %c1 step %c1
+// LOOP:     scf.for %[[l:.*]] = %c0 to %c4 step %c1
+// LOOP:       scf.for %[[j:.*]] = %c0 to %c2 step %c1
+// LOOP:         scf.for %[[k:.*]] = %c0 to %c3 step %c1
+// LOOP:           %{{.*}} = addi %[[i]], %[[j]] : index
+// LOOP:           %{{.*}} = addi %{{.*}}, %[[k]] : index
+// LOOP:           %{{.*}} = addi %{{.*}}, %[[l]] : index
+// LOOP:           %{{.*}} = addi %{{.*}}, %[[m]] : index
 
-// AFFINE-LABEL: @index_op
-//      AFFINE:   linalg.generic
+// PARALLEL:                   scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) =
+// PARALLEL-SAME:   to (%c5, %c1, %c4, %c2, %c3)
+// PARALLEL:        %{{.*}} = addi %[[i]], %[[j]] : index
+// PARALLEL:        %{{.*}} = addi %{{.*}}, %[[k]] : index
+// PARALLEL:        %{{.*}} = addi %{{.*}}, %[[l]] : index
+// PARALLEL:        %{{.*}} = addi %{{.*}}, %[[m]] : index
+
+// AFFINE: affine.for %[[m:.*]] = 0 to 5
+// AFFINE:   affine.for %[[i:.*]] = 0 to 1
+// AFFINE:     affine.for %[[l:.*]] = 0 to 4
+// AFFINE:       affine.for %[[j:.*]] = 0 to 2
+// AFFINE:         affine.for %[[k:.*]] = 0 to 3
+// AFFINE:           %{{.*}} = addi %[[i]], %[[j]] : index
+// AFFINE:           %{{.*}} = addi %{{.*}}, %[[k]] : index
+// AFFINE:           %{{.*}} = addi %{{.*}}, %[[l]] : index
+// AFFINE:           %{{.*}} = addi %{{.*}}, %[[m]] : index
index 5be3525..b0ffc7f 100644 (file)
@@ -880,6 +880,61 @@ func @generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1:
   library_call = "some_external_function_name_2",
   doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"
 }
+func @generic_index_region(
+        %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+        %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+        %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+  linalg.generic #trait4
+      ins(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
+     outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+                         memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      %i = linalg.index 0 : index
+      %j = linalg.index 1 : index
+      %k = linalg.index 2 : index
+      %result_1 = mulf %a, %b : f32
+
+      %ij = addi %i, %j : index
+      %ijk = addi %ij, %k : index
+      %ijk_int = index_cast %ijk : index to i32
+      %ijk_float = sitofp %ijk_int : i32 to f32
+
+      %result_2 = addf %c, %ijk_float : f32
+      linalg.yield %result_1, %result_2 : f32, f32
+  }
+  return
+}
+
+// CHECKLOOP-LABEL: @generic_index_region
+//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+//       CHECKLOOP:   scf.for %[[j:.*]] = {{.*}}
+//       CHECKLOOP:     scf.for %[[k:.*]] = {{.*}}
+//       CHECKLOOP:       %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
+//       CHECKLOOP:       %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
+//       CHECKLOOP:       %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
+//       CHECKLOOP:       %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
+//       CHECKLOOP:       %[[ij:.*]] = addi %[[i]], %[[j]] : index
+//       CHECKLOOP:       %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
+//       CHECKLOOP:       %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
+//       CHECKLOOP:       %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
+//       CHECKLOOP:       %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
+//       CHECKLOOP:       store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
+//       CHECKLOOP:       store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
+
+// CHECKPARALLEL-LABEL: @generic_index_region
+//       CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
+//       CHECKPARALLEL:   %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
+//       CHECKPARALLEL:   %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
+//       CHECKPARALLEL:   %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
+//       CHECKPARALLEL:   %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
+//       CHECKPARALLEL:   %[[ij:.*]] = addi %[[i]], %[[j]] : index
+//       CHECKPARALLEL:   %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
+//       CHECKPARALLEL:   %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
+//       CHECKPARALLEL:   %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
+//       CHECKPARALLEL:   %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
+//       CHECKPARALLEL:   store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
+//       CHECKPARALLEL:   store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
+
 func @indexed_generic_region(
         %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
         %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
@@ -973,6 +1028,43 @@ func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
 //       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][]
 //       CHECKPARALLEL:   store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
 
+func @generic_index_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
+{
+  linalg.generic #trait_broadcast
+      ins(%arg0 : memref<i32>)
+     outs(%arg1 : memref<3x4xi32>) {
+    ^bb(%a: i32, %b: i32) :
+      %i = linalg.index 0 : index
+      %j = linalg.index 1 : index
+      %ij = addi %i, %j : index
+      %ij_int = index_cast %ij : index to i32
+      %result = addi %a, %ij_int : i32
+      linalg.yield %result : i32
+  }
+  return
+}
+
+// CHECKLOOP-LABEL: @generic_index_op_zero_rank
+//  CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
+//  CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
+//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+//       CHECKLOOP:   scf.for %[[j:.*]] = {{.*}}
+//       CHECKLOOP:     %[[a:.*]] = memref.load %[[ARG0]][
+//       CHECKLOOP:     %[[ij:.*]] = addi %[[i]], %[[j]] : index
+//       CHECKLOOP:     %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
+//       CHECKLOOP:     %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
+//       CHECKLOOP:     store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
+
+// CHECKPARALLEL-LABEL: @generic_index_op_zero_rank
+//  CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
+//  CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
+//       CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]])
+//       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][
+//       CHECKPARALLEL:   %[[ij:.*]] = addi %[[i]], %[[j]] : index
+//       CHECKPARALLEL:   %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
+//       CHECKPARALLEL:   %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
+//       CHECKPARALLEL:   store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
+
 func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
 {
   linalg.indexed_generic #trait_broadcast
@@ -1065,6 +1157,47 @@ func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
   library_call = "some_reduce_external_fn"
 }
 
+func @generic_index_op_1D_reduce(%arg0: memref<?xf32>,
+                                %arg1: memref<f32>,
+                                %arg2: memref<f32>)
+{
+  linalg.generic #trait_reduce_init_1D
+      ins(%arg0, %arg1 : memref<?xf32>, memref<f32>)
+     outs(%arg2 : memref<f32>) {
+    ^bb(%a: f32, %b: f32, %c: f32) :
+      %i = linalg.index 0 : index
+      %0 = constant 0 : index
+      %1 = cmpi eq, %0, %i : index
+      %2 = select %1, %b, %c : f32
+      %3 = addf %a, %2 : f32
+      linalg.yield %3 : f32
+  }
+  return
+}
+// CHECKLOOP-LABEL: @generic_index_op_1D_reduce
+//  CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+//  CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+//  CHECKLOOP-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
+//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+//       CHECKLOOP:   %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
+//       CHECKLOOP:   %[[b:.*]] = memref.load %[[ARG1]][]
+//       CHECKLOOP:   %[[c:.*]] = memref.load %[[ARG2]][]
+//       CHECKLOOP:   %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
+//       CHECKLOOP:   %[[e:.*]] = addf %[[a]], %[[d]]
+//       CHECKLOOP:   store %[[e]], %[[ARG2]][]
+
+// CHECKPARALLEL-LABEL: @generic_index_op_1D_reduce
+//  CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+//  CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+//  CHECKPARALLEL-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
+//       CHECKPARALLEL: scf.for %[[i:.*]] = {{.*}}
+//       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
+//       CHECKPARALLEL:   %[[b:.*]] = memref.load %[[ARG1]][]
+//       CHECKPARALLEL:   %[[c:.*]] = memref.load %[[ARG2]][]
+//       CHECKPARALLEL:   %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
+//       CHECKPARALLEL:   %[[e:.*]] = addf %[[a]], %[[d]]
+//       CHECKPARALLEL:   store %[[e]], %[[ARG2]][]
+
 func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
                                    %arg1: memref<f32>,
                                    %arg2: memref<f32>)