[MLIR] Fold outer dims permutation to pack when propagating
authorLorenzo Chelini <l.chelini@icloud.com>
Sun, 8 Jan 2023 13:04:18 +0000 (14:04 +0100)
committerLorenzo Chelini <l.chelini@icloud.com>
Fri, 13 Jan 2023 15:11:44 +0000 (16:11 +0100)
Instead of folding the transpose into the linalg.generic keep the
transposition in the packing operation, effectively making the
linalg.generic transparent to the propagation. Additionally, if the init
operand of the generic has users pack the init and pass it as the
operand to the generic.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D141483

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir

index 5e54097..5660704 100644 (file)
@@ -87,11 +87,39 @@ static PackInfo getPackingInfoFromConsumer(
   return packInfo;
 }
 
+static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
+                                             ArrayRef<AffineExpr> exprs) {
+  // Compute `outer_dims_perm`. See example:
+  // current exprs      : (d0, d1, d2, d3) -> (d2, d3)
+  // perm               : [0, 3, 1, 2]
+  // First map d2, d3 with their position in the array as:
+  // currentPositionTileLoops: dim | pos
+  //                           d2  | 0
+  //                           d3  | 1
+  // then scan `perm` in order and get the `outer_dims_perm`
+  // to be used, here it would be [1, 0].
+  assert(!perm.empty() && "expect perm not to be empty");
+  assert(!exprs.empty() && "expect exprs not to be empty");
+  if (exprs.size() == 1)
+    return {};
+  SmallVector<int64_t> outerDimsPerm;
+  DenseMap<int64_t, int64_t> currentPositionTileLoops;
+  for (auto [pos, expr] : llvm::enumerate(exprs)) {
+    unsigned posInDomain = expr.cast<AffineDimExpr>().getPosition();
+    currentPositionTileLoops[posInDomain] = pos;
+  }
+  for (int64_t loopIdx : perm) {
+    if (currentPositionTileLoops.count(loopIdx))
+      outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
+  }
+  return outerDimsPerm;
+}
+
 /// Returns a tuple for packed operand and indexing_map with the assumptions:
 ///   1) The generic op is the producer of the pack op.
 ///   2) The generic op has only one result.
 /// If the operand is a scalar or packing dimensions are all irrelevant to the
-/// operand, the opreand and the updated indexing map will be returned.
+/// operand, the operand and the updated indexing map will be returned.
 /// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
 ///
 ///   #map0 = affine_map<(d0, d1) -> (d0, d1)>
@@ -148,16 +176,26 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
     exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
   }
 
-  // Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op.
-  // TODO: should we propagate the permutation of outer dims to the pack op?
+  // Step 2. Handle outer dim permutations.
   SmallVector<int64_t> outerDimsPerm;
   if (!packInfo.outerDimsOnDomainPerm.empty()) {
+    outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
+
+    // Step 2.1: Fold transpose into the linalg.generic.
     SmallVector<int64_t> inversedOuterPerm =
         invertPermutationVector(packInfo.outerDimsOnDomainPerm);
     for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
       int64_t dimPos = exprs[i].cast<AffineDimExpr>().getPosition();
       exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
     }
+    // Step 2.2: Undo the transposition on `exprs` and propagate the
+    // transposition on the pack using outerDimsPerm.
+    if (!outerDimsPerm.empty()) {
+      SmallVector<AffineExpr> auxVec = exprs;
+      for (const auto &en : enumerate(outerDimsPerm))
+        auxVec[en.index()] = exprs[en.value()];
+      exprs = auxVec;
+    }
   }
   auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
 
@@ -254,9 +292,7 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
     indexingMaps.push_back(packedIndexingMap);
   }
 
-  int64_t numLoops = genericOp.getNumLoops();
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
-  int64_t newNumLoops = numLoops + numInnerLoops;
   SmallVector<utils::IteratorType> iterTypes =
       genericOp.getIteratorTypesArray();
   iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
@@ -265,24 +301,18 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
   auto [packedOutOperand, packedOutIndexingMap] =
       getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp,
                                      opOperand);
-  SmallVector<AffineExpr> outExprs(
-      packedOutIndexingMap.getResults().drop_back(numInnerLoops));
-  // Apply transpose to the indexing map, because we'll replace the init operand
-  // with the destination of pack op.
-  auto outerDimsPerm = packOp.getOuterDimsPerm();
-  if (!outerDimsPerm.empty()) {
-    applyPermutationToVector<AffineExpr>(outExprs, outerDimsPerm);
-  }
-  for (int i = 0; i < numInnerLoops; ++i)
-    outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i));
-  AffineMap outMap =
-      AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext());
-  indexingMaps.push_back(outMap);
+  indexingMaps.push_back(packedOutIndexingMap);
 
+  // We'll replace the init operand with the destination of pack op if the init
+  // operand has not users in the body of the linalg.generic (pure elementwise).
+  // If it has users we need to pack the init operand too and replace the init
+  // with the packing result.
+  Value dest = (genericOp.getRegionOutputArgs()[0].use_empty())
+                   ? packOp.getDest()
+                   : packedOutOperand;
   auto newGenericOp = rewriter.create<linalg::GenericOp>(
-      loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps,
-      iterTypes, /*bodyBuild=*/nullptr,
-      linalg::getPrunedAttributeList(genericOp));
+      loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
   rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
                              newGenericOp.getRegion().begin());
   return newGenericOp;
index bb84272..cd9d3ac 100644 (file)
@@ -96,17 +96,16 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten
     into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
   return %pack : tensor<16x4x32x16xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK:      func.func @elem_pack_transpose_outer_dims
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16xi32>
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
 // CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:     into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<4x16x32x16xi32>
+// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:     into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
 // CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
 // CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[PACK_ARG0]]
 // CHECK-SAME:     outs(%[[DEST]]
@@ -131,17 +130,16 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>,
     into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
   return %pack : tensor<16x4x16x32xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK:      func.func @elem_pack_transpose_inner_and_outer_dims
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
 // CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
 // CHECK-SAME:     into %[[ARG0_EMPTY]]
 // CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
 // CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[PACK_ARG0]]
 // CHECK-SAME:     outs(%[[DEST]]
@@ -285,7 +283,7 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1)>
-func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
+func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
 {
   %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
   %transpose = linalg.generic {
@@ -308,3 +306,61 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
     into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
   return %4 : tensor<200x4x16x100x16x32xi32>
 }
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
+// CHECK:     func.func @transpose_pack_with_outer_dims
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<200x4x16x100x16x32xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME:  into %[[ARG0_EMPTY]]
+// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
+// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG2_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP]]]
+// CHECK-SAME:  ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME:  outs(%[[DEST]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
+  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<128x256xi32>)
+      outs(%init : tensor<128x256xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32):
+      %4 = arith.addi %arg3, %arg4 : i32
+      linalg.yield %4 : i32
+  } -> tensor<128x256xi32>
+  %empty = tensor.empty() : tensor<16x4x32x16xi32>
+  %pack = tensor.pack %elem
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [32, 16]
+    into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+  return %pack : tensor<16x4x32x16xi32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG0_EMPTY]]
+// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG1_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME:  ins(%[[PACKED_ARG0]]
+// CHECK-SAME:  outs(%[[PACKED_ARG1]]