[mlir][Linalg] Relax PadTensor tiling constraints and expose it to strategies.
authorNicolas Vasilache <ntv@google.com>
Mon, 17 Jan 2022 17:07:46 +0000 (17:07 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 17 Jan 2022 17:13:55 +0000 (17:13 +0000)
Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D117334

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir

index cbf0304..8e7ea21 100644 (file)
@@ -46,6 +46,9 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
 //===----------------------------------------------------------------------===//
 using LinalgLoops = SmallVector<Operation *, 4>;
 
+void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
+                                     const LinalgTilingOptions &options);
+
 /// Populate patterns for vectorizing low-D convolution ops. This is a step in
 /// progressive lowering for convolution ops, it assume high-D convolution ops
 /// were decomposed previously.
index 859f3f8..2a052c0 100644 (file)
@@ -100,6 +100,8 @@ struct LinalgStrategyTilePass
                                              filter);
     else
       tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
+    if (anchorOpName == linalg::PadTensorOp::getOperationName())
+      populatePadTensorTilingPatterns(tilingPattern, options);
     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
   }
 
index 89ca833..36bd434 100644 (file)
@@ -354,7 +354,9 @@ static LogicalResult tilePadTensorOp(RewriterBase &builder, PadTensorOp op,
   int64_t rank = op.getResultType().getRank();
   SmallVector<Value> tileSizes =
       options.tileSizeComputationFunction(builder, op);
-  assert(static_cast<int64_t>(tileSizes.size()) == rank);
+  // Normalize untiled padding dimensions to 0.
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  tileSizes.append(rank - tileSizes.size(), zero);
   // Compute lower and upper bounds of the loop nest.
   SmallVector<Range> ranges = op.getIterationDomain(builder);
   SmallVector<Value> lbs, dims, allDims, steps;
@@ -490,6 +492,12 @@ static void insertTilingPatterns(RewritePatternSet &patterns,
   patterns.add<PadTensorOpTilingPattern>(ctx, options);
 }
 
+void mlir::linalg::populatePadTensorTilingPatterns(
+    RewritePatternSet &patterns, const LinalgTilingOptions &options) {
+  auto *ctx = patterns.getContext();
+  patterns.add<PadTensorOpTilingPattern>(ctx, options);
+}
+
 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
   MLIRContext *ctx = funcOp.getContext();
   RewritePatternSet patterns(ctx);
index 46fe369..a837935 100644 (file)
@@ -2,6 +2,8 @@
 // RUN: FileCheck %s -check-prefix=TILE2
 // RUN: mlir-opt %s -linalg-tile="tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
 // RUN: FileCheck %s -check-prefix=TILE1
+// This test only checks that tiling does not crash.
+// RUN: mlir-opt %s -linalg-tile="tile-sizes=2" -resolve-shaped-type-result-dims -cse -split-input-file
 
 //  TILE2-DAG:  #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
 //  TILE2-DAG:  #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>