[mlir] Bubble up tensor.extract_slice above linalg operation
authorOkwan Kwon <okkwon@gmail.com>
Fri, 18 Mar 2022 16:48:34 +0000 (16:48 +0000)
committerOkwan Kwon <okkwon@gmail.com>
Thu, 31 Mar 2022 16:48:38 +0000 (16:48 +0000)
Bubble up extract_slice above Linalg operation.

A sequence of operations

    %0 = linalg.<op> ... arg0, arg1, ...
    %1 = tensor.extract_slice %0 ...

can be replaced with

    %0 = tensor.extract_slice %arg0
    %1 = tensor.extract_slice %arg1
    %2 = linalg.<op> ... %0, %1, ...

This results in the reduce computation of the linalg operation.

The implementation uses the tiling utility functions. One difference
from the tiling process is that we don't need to insert the checking
code for the out-of-bound accesses. The use of the slice itself
represents that the code writer is sure about the boundary condition.
To avoid adding the boundary condtion check code, `omitPartialTileCheck`
is introduced for the tiling utility functions.

Differential Revision: https://reviews.llvm.org/D122437

13 files changed:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp [new file with mode: 0644]
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

index ab892fd1fcb858c6f3fd45d3ab17fed02a014486..5b6f99f74a38ef043d7f4c992c1d3cf764fade3f 100644 (file)
@@ -119,6 +119,9 @@ void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
 /// Patterns that are used to inline constant operands into linalg generic ops.
 void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
 
+/// Patterns that are used to bubble up extract slice op above linalg op.
+void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
+
 /// Options that control fusion of elementwise operations.
 struct LinalgElementwiseFusionOptions {
   /// Enable fusion of reshapes into the shape with elementwise operations. By
index 4f991588d1d43ea4541e9ab035ac2192ddc5a046..4f707dd85dc186e5274ac8bc0149dcbdb89b54d1 100644 (file)
@@ -166,16 +166,21 @@ SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
 
 /// Creates an extract_slice/subview op for a single `valueToTile` with
 /// `builder`. This new operation extracts a tile of `valueToTile`, starting
-/// at offsets `lbs` and with sizes `subShapeSizes`.
+/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
+/// controls whether to omit the partial/boundary tile condition check in cases
+/// where we statically know that it is unnecessary.
 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
                      ValueRange tileSizes, AffineMap map, ValueRange lbs,
-                     ValueRange ubs, ValueRange subShapeSizes);
+                     ValueRange ubs, ValueRange subShapeSizes,
+                     bool omitPartialTileCheck);
 
 /// Creates extract_slice/subview ops for all `valuesToTile` of the given
 /// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
 /// nest for tiling with the given induction variables `ivs` and tile sizes
 /// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the
-/// implicit loops in `linalgOp`.
+/// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to
+/// omit the partial/boundary tile condition check in cases where we statically
+/// know that it is unnecessary.
 ///
 /// Note that a constant zero in `tileSizes` means no tiling at that implicit
 /// loop. The number of non-zero values in `tileSizes` should be equal to the
@@ -184,7 +189,8 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
                                       LinalgOp linalgOp,
                                       ArrayRef<Value> valuesToTile,
                                       ValueRange ivs, ValueRange tileSizes,
-                                      ArrayRef<Value> sizeBounds);
+                                      ArrayRef<Value> sizeBounds,
+                                      bool omitPartialTileCheck);
 
 /// Add the tile loop induction variables `ivs` to the IndexOp results found in
 /// the body of the `tiledOp` to account for the tile offset.
index 82a33b0654402521c63c330a8aa3d9ca1933d427..87ac693492113efb657def5f49e83f116a555b38 100644 (file)
@@ -480,7 +480,7 @@ AffineMap inversePermutation(AffineMap map);
 /// ```mlir
 ///    affine_map<(d0, d1) -> (d0, 0, 0)>
 /// ```
-AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map);
+AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map);
 
 /// Concatenates a list of `maps` into a single AffineMap, stepping over
 /// potentially empty maps. Assumes each of the underlying map has 0 symbols.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
new file mode 100644 (file)
index 0000000..53200d8
--- /dev/null
@@ -0,0 +1,139 @@
+//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns that transforms linalg.<op> +
+// tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
+// the computation for the linalg op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Bubble up extract_slice above Linalg operation.
+///
+/// A sequence of operations
+///
+/// ```mlir
+/// %0 = linalg.<op> ... arg0, arg1, ...
+/// %1 = tensor.extract_slice %0 ...
+/// ```
+///
+/// can be replaced with
+///
+/// ```mlir
+/// %0 = tensor.extract_slice %arg0
+/// %1 = tensor.extract_slice %arg1
+/// %2 = linalg.<op> ... %0, %1, ...
+/// ```
+///
+/// This results in the reduce computation of the linalg operation.
+///
+struct BubbleUpExtractSliceOpPattern
+    : OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const final {
+    Value source = sliceOp.source();
+    auto linalgOp = source.getDefiningOp<LinalgOp>();
+    if (!linalgOp) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "expected source to be linalg op");
+    }
+
+    // TODO: we might relax this if we want heuristics to detect that all uses
+    // are small portion of the output.
+    if (!linalgOp->hasOneUse()) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "expected single use of linalg op");
+    }
+
+    if (linalgOp.getNumOutputs() != 1) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "expected single output of linalg op");
+    }
+
+    if (!linalgOp.hasTensorSemantics()) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "expected tensor of linalg op");
+    }
+
+    if (!sliceOp.hasUnitStride())
+      return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
+
+    OpOperand *outOperand = linalgOp.getOutputOperand(0);
+    AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
+    if (!indexingMap.isProjectedPermutation()) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "expected a projected permutation for output");
+    }
+
+    auto linalgLoc = linalgOp.getLoc();
+    auto allShapeSizes =
+        linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
+    AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
+    if (!shapeSizesToLoopsMap) {
+      return rewriter.notifyMatchFailure(
+          linalgOp, "failed to get loops map from shape sizes");
+    }
+    auto sizeBounds = applyMapToValues(rewriter, linalgLoc,
+                                       shapeSizesToLoopsMap, allShapeSizes);
+
+    auto sliceLoc = sliceOp.getLoc();
+    auto offsetVals = getValueOrCreateConstantIndexOp(
+        rewriter, sliceLoc, sliceOp.getMixedOffsets());
+    auto sizeVals = getValueOrCreateConstantIndexOp(rewriter, sliceLoc,
+                                                    sliceOp.getMixedSizes());
+
+    // The offsets and sizes from the slice operation only give you the tile
+    // size of the output. Use that compute the tile sizes and offsets of the
+    // loops. For loops not used to access the output, set the tile sizes to
+    // loop bounds and set the offset to 0.
+    Value zero = rewriter.create<arith::ConstantIndexOp>(linalgLoc, 0);
+    SmallVector<Value, 4> tileOffsets(sizeBounds.size(), zero);
+    SmallVector<Value, 4> tileSizes = sizeBounds;
+    for (auto const &result : enumerate(indexingMap.getResults())) {
+      unsigned position = result.value().cast<AffineDimExpr>().getPosition();
+      tileOffsets[position] = offsetVals[result.index()];
+      tileSizes[position] = sizeVals[result.index()];
+    }
+
+    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+
+    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+        rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes,
+        sizeBounds, /*omitPartialTileCheck=*/true);
+
+    SmallVector<Type, 4> resultTensorTypes;
+    for (OpOperand *opOperand : linalgOp.getOutputTensorOperands())
+      resultTensorTypes.push_back(
+          tiledOperands[opOperand->getOperandNumber()].getType());
+
+    Operation *newOp =
+        linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands);
+    rewriter.replaceOp(sliceOp, newOp->getResults());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
+    RewritePatternSet &patterns) {
+  auto *context = patterns.getContext();
+  patterns.add<BubbleUpExtractSliceOpPattern>(context);
+}
index 77457dac3113e0c23ffdd6d467e573973c12ed72..19955d450f248271058b5748ed04b0d12a831283 100644 (file)
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRLinalgTransforms
+  BubbleUpExtractSlice.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   CodegenStrategy.cpp
index 94edb8b630876b2a3b5ed0b277c06b07f177e9a4..066008c114525565fd94b0089f987f069c333fa2 100644 (file)
@@ -142,9 +142,9 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   clonedShapes.reserve(producer.getNumInputsAndOutputs());
 
   // Compute subranges for all tensor input/output operands.
-  clonedShapes.append(makeTiledShapes(b, loc, producer,
-                                      getTiledOperands(producer), ivs,
-                                      tileSizes, sizeBounds));
+  clonedShapes.append(makeTiledShapes(
+      b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
+      /**omitPartialTileCheck=*/false));
 
   // Iterate over the results in order.
   // Extract the subtensor type from the linearized range.
index 3ce6570f4bf5309dacac45c1415c926fb13000ab..9c962eb4b02c0bc55c7de0d97e2706eee796e2b8 100644 (file)
@@ -163,7 +163,8 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
   erase_value(tileIvs, nullptr);
   SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
   tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
-                                  tileSizes, producerLoopBounds);
+                                  tileSizes, producerLoopBounds,
+                                  /**omitPartialTileCheck=*/false);
 
   // Output fusion has to update the iteration arguments of the tile loop nest.
   // In particular, the iteration argument of the outermost tile loop needs to
index 4f863298ba422a824503a8f5c64b852bc95bac2e..b9a069c1d1e501fdf01aa5f91bbd1f29288c89e5 100644 (file)
@@ -178,8 +178,9 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
     SmallVector<Value> valuesToTile = operandValuesToUse;
     auto sizeBounds =
         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
-    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
-        b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds);
+    SmallVector<Value, 4> tiledOperands =
+        makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes,
+                        sizeBounds, /*omitPartialTileCheck=*/false);
 
     // TODO: use an interface/adaptor to avoid leaking position in
     // `tiledOperands`.
@@ -325,9 +326,9 @@ static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
         // Note: The tensor::PadOp is located outside of the loop nest. It is
         // later moved inside by ExtractSliceOfPadTensorSwapPattern.
         auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext());
-        Value tiledOutput =
-            makeTiledShape(b, loc, newPadOp->getResult(0), tileSizes, map,
-                           offsets, allDims, sizes);
+        Value tiledOutput = makeTiledShape(
+            b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims,
+            sizes, /*omitPartialTileCheck=*/false);
         auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
         assert(sliceOp && "expected ExtractSliceOp");
         // Insert the tile into the output tensor.
index 4660baadf7e25da708f1d0509e107ae2a7ad00b5..845c6434994e73a635c639c3310ea2b438d6aaf8 100644 (file)
@@ -504,7 +504,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
     //   readType = VectorType::get({}, bbarg.getType());
     // } else {
     if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
-      map = inverseAndBroadcastProjectedPermuation(
+      map = inverseAndBroadcastProjectedPermutation(
           linalgOp.getTiedIndexingMap(opOperand));
       readType = VectorType::get(commonVectorShape,
                                  getElementTypeOrSelf(opOperand->get()));
index 0d2a61713a5ea79106f122b70d9e8ca042efc8b5..c4b73a61417a110e54053711178df1828a2dd081 100644 (file)
@@ -745,7 +745,8 @@ static Value fullyComposeAndAffineApply(OpBuilder &b, Location loc,
 
 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
                      ValueRange tileSizes, AffineMap map, ValueRange lbs,
-                     ValueRange ubs, ValueRange subShapeSizes) {
+                     ValueRange ubs, ValueRange subShapeSizes,
+                     bool omitPartialTileCheck) {
   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
   assert(shapedType && "only shaped types can be tiled");
   ArrayRef<int64_t> shape = shapedType.getShape();
@@ -773,7 +774,7 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
     auto m = map.getSubMap({r});
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n");
     auto offset = applyMapToValues(builder, loc, m, lbs).front();
-    offsets.push_back(offset);
+    offsets.push_back(getAsOpFoldResult(offset));
     auto closedIntSize =
         applyMapToValues(builder, loc, m, subShapeSizes).front();
     // Resulting size needs to be made half open interval again.
@@ -781,6 +782,17 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
     Value size =
         fullyComposeAndAffineApply(builder, loc, s0 + 1, closedIntSize);
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "makeTiledShape: new offset: " << offset << "\n");
+    strides.push_back(builder.getIndexAttr(1));
+
+    if (omitPartialTileCheck) {
+      // We statically know that the partial/boundary tile condition is
+      // unnecessary.
+      LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
+      sizes.push_back(getAsOpFoldResult(size));
+      continue;
+    }
 
     // The size of the subview / extract_slice should be trimmed to avoid
     // out-of-bounds accesses, unless:
@@ -829,12 +841,8 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
       size = builder.create<AffineMinOp>(loc, builder.getIndexType(), minMap,
                                          operands);
     }
-
-    sizes.push_back(size);
-    LLVM_DEBUG(llvm::dbgs()
-               << "makeTiledShape: new offset: " << offset << "\n");
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
-    strides.push_back(builder.getIndexAttr(1));
+    sizes.push_back(getAsOpFoldResult(size));
   }
 
   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
@@ -886,7 +894,8 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
                                       LinalgOp linalgOp,
                                       ArrayRef<Value> valuesToTile,
                                       ValueRange ivs, ValueRange tileSizes,
-                                      ArrayRef<Value> sizeBounds) {
+                                      ArrayRef<Value> sizeBounds,
+                                      bool omitPartialTileCheck) {
   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
                            [](Value v) { return !isZero(v); })) &&
@@ -921,7 +930,8 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
 
     tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
-                                         sizeBounds, subShapeSizes));
+                                         sizeBounds, subShapeSizes,
+                                         omitPartialTileCheck));
   }
 
   return tiledShapes;
index 32eb07cf1a2f48815193cb84bc86fa05dfc10313..c93b4c28d769a18f6925ac70d13de2d26c571ccc 100644 (file)
@@ -679,7 +679,7 @@ AffineMap mlir::inversePermutation(AffineMap map) {
   return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
 }
 
-AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) {
+AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true));
   MLIRContext *context = map.getContext();
   AffineExpr zero = mlir::getAffineConstantExpr(0, context);
diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
new file mode 100644 (file)
index 0000000..234c2b8
--- /dev/null
@@ -0,0 +1,158 @@
+//RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s
+
+func @dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5:index) -> tensor<?x?xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%arg0 : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1]
+    : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//      CHECK: func @dynamic
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor<?xf32> to tensor<?xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE2]] : tensor<?x?xf32>)
+//      CHECK: return %[[GENERIC]] : tensor<?x?xf32>
+
+//-----
+
+func @static(%arg0: tensor<16x8xf32>, %arg1: tensor<8xf32>) -> tensor<4x2xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<16x8xf32>, tensor<8xf32>)
+    outs(%arg0 : tensor<16x8xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<16x8xf32>
+  %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1]
+    : tensor<16x8xf32> to tensor<4x2xf32>
+  return %1 : tensor<4x2xf32>
+}
+
+//      CHECK: func @static
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<8xf32> to tensor<2xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32>
+//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>)
+//      CHECK: return %[[GENERIC]] : tensor<4x2xf32>
+
+//-----
+
+func @mixed(%arg0: tensor<?x8xf32>, %arg1: tensor<8xf32>, %arg2: index, %arg3: index) -> tensor<?x2xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x8xf32>, tensor<8xf32>)
+    outs(%arg0 : tensor<?x8xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x8xf32>
+  %1 = tensor.extract_slice %0 [8, %arg2] [%arg3, 2] [1, 1]
+    : tensor<?x8xf32> to tensor<?x2xf32>
+  return %1 : tensor<?x2xf32>
+}
+
+//      CHECK: func @mixed
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor<?x8xf32> to tensor<?x2xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg2] [2] [1] : tensor<8xf32> to tensor<2xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor<?x8xf32> to tensor<?x2xf32>
+//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<?x2xf32>)
+//      CHECK: return %[[GENERIC]] : tensor<?x2xf32>
+
+//-----
+
+func @dynamic_to_static(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<4x2xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%arg0 : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1]
+    : tensor<?x?xf32> to tensor<4x2xf32>
+  return %1 : tensor<4x2xf32>
+}
+
+//      CHECK: func @dynamic_to_static
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<?x?xf32> to tensor<4x2xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<?xf32> to tensor<2xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<?x?xf32> to tensor<4x2xf32>
+//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>)
+//      CHECK: return %[[GENERIC]] : tensor<4x2xf32>
+
+//-----
+
+func @matmul_slice() -> tensor<2x2xf32> {
+    %lhs = arith.constant dense<1.0> : tensor<4x4xf32>
+    %rhs = arith.constant dense<1.0> : tensor<4x4xf32>
+    %dst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0], [8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]> : tensor<4x4xf32>
+    %0 = linalg.matmul ins(%lhs, %rhs : tensor<4x4xf32>, tensor<4x4xf32>) outs(%dst : tensor<4x4xf32>) -> tensor<4x4xf32>
+    %1 = tensor.extract_slice %0[1,1][2,2][1,1] : tensor<4x4xf32> to tensor<2x2xf32>
+    return %1 : tensor<2x2xf32>
+}
+
+// CHECK: func @matmul_slice
+// CHECK: %[[SLICE0:.+]] = arith.constant dense<1.000000e+00> : tensor<2x4xf32>
+// CHECK: %[[SLICE1:.+]] = arith.constant dense<1.000000e+00> : tensor<4x2xf32>
+// CHECK: %[[SLICE3:.+]] = tensor.extract_slice %[[CST:.+]][1, 1] [2, 2] [1, 1] : tensor<4x4xf32> to tensor<2x2xf32>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[SLICE0]], %[[SLICE1]] : tensor<2x4xf32>, tensor<4x2xf32>) outs(%[[SLICE3]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+// CHECK: return %[[MATMUL]] : tensor<2x2xf32>
+
+//-----
+
+func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>) -> tensor<1x32x32x16xf32> {
+  %c112 = arith.constant 112 : index
+  %c32 = arith.constant 32 : index
+  %c16 = arith.constant 16 : index
+  %c8 = arith.constant 8 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f32
+
+  %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+
+  %conv = linalg.conv_2d_nhwc_hwcf
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+    ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>)
+    outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+
+  %slice = tensor.extract_slice %conv [0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32>
+
+  return %slice : tensor<1x32x32x16xf32>
+}
+
+// CHECK: func @conv_slice
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[0, 128, 128, 0] [1, 65, 65, 3] [1, 1, 1, 1] : tensor<1x225x225x3xf32> to tensor<1x65x65x3xf32>
+// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[0, 0, 0, 16] [3, 3, 3, 16] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x16xf32>
+// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[INIT]][0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
+// CHECK: return %[[CONV]] : tensor<1x32x32x16xf32>
index e6786b24f59399208f64dc792c35cc852d8f8d8c..06d1d81a3859289f942142a238a275aa9ee9ffe8 100644 (file)
@@ -132,6 +132,11 @@ struct TestLinalgTransforms
       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
                      "tiled_loop"),
       llvm::cl::init("for")};
+  Option<bool> testBubbleUpExtractSliceOpPattern{
+      *this, "test-bubble-up-extract-slice-op-pattern",
+      llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
+                     "extract_slice + linalgOp"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -635,6 +640,12 @@ static void applySplitReduction(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyBubbleUpExtractSliceOpPattern(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateBubbleUpExtractSliceOpPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   auto lambda = [&](void *) {
@@ -686,6 +697,8 @@ void TestLinalgTransforms::runOnOperation() {
                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
   if (testSplitReduction)
     return applySplitReduction(getOperation());
+  if (testBubbleUpExtractSliceOpPattern)
+    return applyBubbleUpExtractSliceOpPattern(getOperation());
 }
 
 namespace mlir {