[mlir][linalg] Add control to pad-slice swap pattern
authorLei Zhang <antiagainst@google.com>
Wed, 16 Feb 2022 15:28:51 +0000 (10:28 -0500)
committerLei Zhang <antiagainst@google.com>
Wed, 16 Feb 2022 16:19:35 +0000 (11:19 -0500)
The pad-slice swap pattern generates `scf.if` and `tensor.generate`
to guard against zero-sized slices if it cannot prove the slice is
always non-zero. This is safe but quite conservative. It can be
unnecessary for cases where we know by problem definition such cases
does not exist, even if with dynamic shaped ops or unknown tile/slice
sizes, e.g., convolution padding size = 1 with kernel dim size = 3.

So this commit introduces a control to the pattern to specify
whether to generate the if constructs to handle such cases better,
given that once the if constructs is materialized, it's very hard
to analyze and simplify.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D117017

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 24230a3..80ec20a 100644 (file)
@@ -1399,10 +1399,27 @@ LogicalResult applyStagedPatterns(
 /// Rewrite extract_slice(pad_tensor(x)) into pad_tensor(extract_slice(x)).
 struct ExtractSliceOfPadTensorSwapPattern
     : public OpRewritePattern<tensor::ExtractSliceOp> {
-  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+  /// A function to control pattern application and rewrite logic.
+  ///
+  /// The function will be given the slice op and should return:
+  /// -  None: to fail the match and not apply the pattern;
+  /// -  true: to apply the pattern with zero slice guard;
+  /// - false: to apply the pattern without zero slice guard.
+  ///
+  /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice
+  /// guard.
+  using ControlFn = std::function<llvm::Optional<bool>(tensor::ExtractSliceOp)>;
+
+  ExtractSliceOfPadTensorSwapPattern(MLIRContext *context,
+                                     ControlFn controlFn = nullptr,
+                                     PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
                                 PatternRewriter &rewriter) const override;
+
+private:
+  ControlFn controlFn;
 };
 
 //===----------------------------------------------------------------------===//
index 6cd8197..4bdc774 100644 (file)
 namespace mlir {
 namespace tensor {
 
+class PadOp;
+
+/// Bubbles up a slice of this pad by taking the slice first and then performing
+/// the padding. `offsets` and `strides` specifies each dimension's start offset
+/// and size for the slice. The slice has unit strides along all dimensions.
+///
+/// Specifically, this function converts:
+/// ```
+/// %0 = tensor.pad %source low[...] high[...] { linalg.yield %cst }
+/// %1 = <extract-slice> %0 offsets=[...], sizes[...]
+/// ```
+/// into
+/// ```
+/// %0 = tensor.extract_slice %source ...
+/// %0 = tensor.pad %0 low[...] high[...] { linalg.yield %cst }
+/// ```
+///
+/// If `generateZeroSliceGuard` is true, the generated IR will contain logic
+/// to guard against the case that we might take a zero-sized slice from the
+/// original source. For such cases, we `tensor.generate` to generate the
+/// full tensor.
+Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
+                            ArrayRef<OpFoldResult> offsets,
+                            ArrayRef<OpFoldResult> sizes,
+                            bool generateZeroSliceGuard = true);
+
 /// Registers external models for Tiling interface for tensor ops.
 /// Currently, it registers:
 ///
index 4493cd4..6897cb9 100644 (file)
@@ -54,6 +54,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRStandardOpsTransforms
   MLIRStandardToLLVM
   MLIRTensor
+  MLIRTensorTilingInterfaceImpl
   MLIRTensorTransforms
   MLIRTransforms
   MLIRTransformUtils
index 65a0ed8..486e8b6 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -911,23 +912,26 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
 
 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
+  if (!sliceOp.hasUnitStride())
+    return failure();
+
   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
   if (!padOp)
     return failure();
-  // Only unit stride supported.
-  if (!sliceOp.hasUnitStride())
-    return failure();
 
-  TilingInterface tilingInterface =
-      dyn_cast<TilingInterface>(padOp.getOperation());
+  bool zeroSliceGuard = true;
+  if (controlFn) {
+    if (Optional<bool> control = controlFn(sliceOp))
+      zeroSliceGuard = control.getValue();
+    else
+      return failure();
+  }
+
   Operation *tiledPadOp =
-      tilingInterface
-          .getTiledImplementation(
-              rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
-              sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
-          .front();
+      tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
+                               sliceOp.getMixedSizes(), zeroSliceGuard);
   // All shapes are static and the data source is actually used. Rewrite into
-  // pad_tensor(subtensor(x)).
+  // pad(extract_slice(x)).
   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
   return success();
 }
index 8a0efe9..5ecdea1 100644 (file)
@@ -63,215 +63,223 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
                          ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes,
                          bool /*tileDestOperands*/) const {
-    auto padOp = cast<PadOp>(op);
-    // Only constant padding value supported.
-    Value padValue = padOp.getConstantPaddingValue();
-    if (!padValue)
+    Operation *result =
+        tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
+    if (!result)
       return {};
+    return {result};
+  }
+};
 
-    // Helper variables and functions for various arithmetic operations. These
-    // are used extensively for computing new offset/length and padding values.
-    Location loc = op->getLoc();
-    AffineExpr dim0, dim1;
-    bindDims(b.getContext(), dim0, dim1);
-    // Add two integers.
-    auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
-    auto add = [&](Value v1, Value v2) {
-      return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2});
-    };
-    // Subtract two integers.
-    auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
-    auto sub = [&](Value v1, Value v2) {
-      return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2});
-    };
-    // Take the minimum of two integers.
-    auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
-    auto min = [&](Value v1, Value v2) {
-      return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
-    };
-    // Take the maximum of two integers.
-    auto max = [&](Value v1, Value v2) {
-      return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
-    };
-    // Zero index-typed integer.
-    auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
+} // namespace
 
-    // Helper function for filling static/dynamic low/high padding indices
-    // vectors of PadOp.
-    auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
-                           SmallVector<int64_t> &staticIndices) {
-      if (auto constInt = getConstantIntValue(val)) {
-        staticIndices.push_back(*constInt);
-      } else {
-        staticIndices.push_back(ShapedType::kDynamicSize);
-        dynIndices.push_back(val);
-      }
-    };
+Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
+                                    ArrayRef<OpFoldResult> offsets,
+                                    ArrayRef<OpFoldResult> sizes,
+                                    bool generateZeroSliceGuard) {
+  // Only constant padding value supported.
+  Value padValue = padOp.getConstantPaddingValue();
+  if (!padValue)
+    return nullptr;
 
-    // Compute new offsets, lengths, low padding, high padding.
-    SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
-    SmallVector<Value> newLows, newHighs;
-    SmallVector<int64_t> staticNewLows, staticNewHighs;
-    // Set to true if the original data source is not read at all.
-    bool hasZeroLen = false;
-    // Same as hasZeroLen, but for dynamic dimension sizes. This condition
-    // is true if the original data source turns out to be unused at runtime.
-    Value dynHasZeroLenCond;
+  // Helper variables and functions for various arithmetic operations. These
+  // are used extensively for computing new offset/length and padding values.
+  Location loc = padOp->getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(b.getContext(), dim0, dim1);
+  // Add two integers.
+  auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
+  auto add = [&](Value v1, Value v2) {
+    return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2});
+  };
+  // Subtract two integers.
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](Value v1, Value v2) {
+    return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2});
+  };
+  // Take the minimum of two integers.
+  auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
+  auto min = [&](Value v1, Value v2) {
+    return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
+  };
+  // Take the maximum of two integers.
+  auto max = [&](Value v1, Value v2) {
+    return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
+  };
+  // Zero index-typed integer.
+  auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
 
-    int64_t rank = padOp.getSourceType().getRank();
-    for (unsigned dim = 0; dim < rank; ++dim) {
-      auto low =
-          getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]);
-      bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
-      auto high =
-          getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]);
-      bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
-      auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]);
-      auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]);
-      auto srcSize = b.createOrFold<tensor::DimOp>(loc, padOp.source(), dim);
+  // Helper function for filling static/dynamic low/high padding indices
+  // vectors of PadOp.
+  auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
+                         SmallVector<int64_t> &staticIndices) {
+    if (auto constInt = getConstantIntValue(val)) {
+      staticIndices.push_back(*constInt);
+    } else {
+      staticIndices.push_back(ShapedType::kDynamicSize);
+      dynIndices.push_back(val);
+    }
+  };
 
-      // The new amount of low padding is `low - offset`. Except for the case
-      // where none of the low padding is read. In that case, the new amount of
-      // low padding is zero.
-      //
-      // Optimization: If low = 0, then newLow = 0.
-      Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
-      appendIndex(newLow, newLows, staticNewLows);
+  // Compute new offsets, lengths, low padding, high padding.
+  SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
+  SmallVector<Value> newLows, newHighs;
+  SmallVector<int64_t> staticNewLows, staticNewHighs;
+  // Set to true if the original data source is not read at all.
+  bool hasZeroLen = false;
+  // Same as hasZeroLen, but for dynamic dimension sizes. This condition
+  // is true if the original data source turns out to be unused at runtime.
+  Value dynHasZeroLenCond;
 
-      // Start reading the data from position `offset - low`. Since the original
-      // read may have started in the low padding zone, this value could be
-      // negative. Therefore, start reading from:
-      //
-      // max(offset - low, 0)
-      //
-      // The original read could also have started in the high padding zone.
-      // In that case, set the offset to the end of source tensor. The new
-      // ExtractSliceOp length will be zero in that case. (Effectively reading
-      // no data from the source.)
-      //
-      // Optimization: If low = 0, then the formula can be simplified.
-      Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
-                                  : min(offset, srcSize);
-      newOffsets.push_back(getAsOpFoldResult(newOffset));
+  int64_t rank = padOp.getSourceType().getRank();
+  for (unsigned dim = 0; dim < rank; ++dim) {
+    auto low =
+        getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]);
+    bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
+    auto high =
+        getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]);
+    bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
+    auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]);
+    auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]);
+    auto srcSize = b.createOrFold<tensor::DimOp>(loc, padOp.source(), dim);
 
-      // The original ExtractSliceOp was reading until position `offset +
-      // length`. Therefore, the corresponding position within the source tensor
-      // is:
-      //
-      // offset + length - low
-      //
-      // In case the original ExtractSliceOp stopped reading within the low
-      // padding zone, this value can be negative. In that case, the end
-      // position of the read should be zero. (Similar to newOffset.)
-      //
-      // The original read could also have stopped in the high padding zone.
-      // In that case, set the end positition of the read should be the end of
-      // the source tensor. (Similar to newOffset.)
-      //
-      // endLoc = min(max(offset - low + length, 0), srcSize)
-      //
-      // The new ExtractSliceOp length is `endLoc - newOffset`.
-      //
-      // Optimization: If low = 0, then the formula can be simplified.
-      Value endLoc =
-          hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize)
-                    : min(add(offset, length), srcSize);
-      Value newLength = sub(endLoc, newOffset);
-      newLengths.push_back(getAsOpFoldResult(newLength));
+    // The new amount of low padding is `low - offset`. Except for the case
+    // where none of the low padding is read. In that case, the new amount of
+    // low padding is zero.
+    //
+    // Optimization: If low = 0, then newLow = 0.
+    Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
+    appendIndex(newLow, newLows, staticNewLows);
 
-      // Check if newLength is zero. In that case, no SubTensorOp should be
-      // executed.
-      if (auto newLengthInt = getConstantIntValue(newLength)) {
-        hasZeroLen |= *newLengthInt == 0;
-      } else {
-        Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
-                                              newLength, zero);
-        dynHasZeroLenCond =
-            dynHasZeroLenCond
-                ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
-                : check;
-      }
+    // Start reading the data from position `offset - low`. Since the original
+    // read may have started in the low padding zone, this value could be
+    // negative. Therefore, start reading from:
+    //
+    // max(offset - low, 0)
+    //
+    // The original read could also have started in the high padding zone.
+    // In that case, set the offset to the end of source tensor. The new
+    // ExtractSliceOp length will be zero in that case. (Effectively reading
+    // no data from the source.)
+    //
+    // Optimization: If low = 0, then the formula can be simplified.
+    Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
+                                : min(offset, srcSize);
+    newOffsets.push_back(getAsOpFoldResult(newOffset));
 
-      // The amount of high padding is simply the number of elements remaining,
-      // so that the result has the same length as the original ExtractSliceOp.
-      // As an optimization, if the original high padding is zero, then the new
-      // high padding must also be zero.
-      Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
-      appendIndex(newHigh, newHighs, staticNewHighs);
+    // The original ExtractSliceOp was reading until position `offset +
+    // length`. Therefore, the corresponding position within the source tensor
+    // is:
+    //
+    // offset + length - low
+    //
+    // In case the original ExtractSliceOp stopped reading within the low
+    // padding zone, this value can be negative. In that case, the end
+    // position of the read should be zero. (Similar to newOffset.)
+    //
+    // The original read could also have stopped in the high padding zone.
+    // In that case, set the end positition of the read should be the end of
+    // the source tensor. (Similar to newOffset.)
+    //
+    // endLoc = min(max(offset - low + length, 0), srcSize)
+    //
+    // The new ExtractSliceOp length is `endLoc - newOffset`.
+    //
+    // Optimization: If low = 0, then the formula can be simplified.
+    Value endLoc = hasLowPad
+                       ? min(max(add(sub(offset, low), length), zero), srcSize)
+                       : min(add(offset, length), srcSize);
+    Value newLength = sub(endLoc, newOffset);
+    newLengths.push_back(getAsOpFoldResult(newLength));
 
-      // Only unit stride supported.
-      newStrides.push_back(b.getIndexAttr(1));
+    // Check if newLength is zero. In that case, no SubTensorOp should be
+    // executed.
+    if (auto newLengthInt = getConstantIntValue(newLength)) {
+      hasZeroLen |= *newLengthInt == 0;
+    } else {
+      Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                            newLength, zero);
+      dynHasZeroLenCond =
+          dynHasZeroLenCond
+              ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
+              : check;
     }
 
-    // The shape of the result can be obtained from the sizes passed in.
-    SmallVector<Value> dynDims;
-    SmallVector<int64_t> shape;
-    dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize);
-    RankedTensorType resultType =
-        RankedTensorType::get(shape, padOp.getResultType().getElementType());
+    // The amount of high padding is simply the number of elements remaining,
+    // so that the result has the same length as the original ExtractSliceOp.
+    // As an optimization, if the original high padding is zero, then the new
+    // high padding must also be zero.
+    Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
+    appendIndex(newHigh, newHighs, staticNewHighs);
+
+    // Only unit stride supported.
+    newStrides.push_back(b.getIndexAttr(1));
+  }
 
-    // Insert cast to ensure that types match. (May be folded away.)
-    auto castResult = [&](Value val) -> Operation * {
-      auto castOp = b.create<tensor::CastOp>(loc, resultType, val);
-      return castOp;
-    };
+  // The shape of the result can be obtained from the sizes passed in.
+  SmallVector<Value> dynDims;
+  SmallVector<int64_t> shape;
+  dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize);
+  RankedTensorType resultType =
+      RankedTensorType::get(shape, padOp.getResultType().getElementType());
 
-    // In cases where the original data source is unused: Emit a GenerateOp and
-    // do not generate a SliceOp. (The result shape of the SliceOp would
-    // have a dimension of size 0, the semantics of which is unclear.)
-    auto createGenerateOp = [&]() {
-      // Create GenerateOp.
-      auto generateOp = b.create<tensor::GenerateOp>(
-          loc, resultType, dynDims,
-          [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
-            builder.create<tensor::YieldOp>(gLoc, padValue);
-          });
-      return castResult(generateOp);
-    };
+  // Insert cast to ensure that types match. (May be folded away.)
+  auto castResult = [&](Value val) -> Operation * {
+    return b.create<tensor::CastOp>(loc, resultType, val);
+  };
 
-    // Emit a SliceOp and a PadOp. Should not be used in cases where
-    // the result shape of the new SliceOp has a zero dimension.
-    auto createPadTensorOfSubTensor = [&]() {
-      // Create pad_tensor(subtensor(x)).
-      auto newSliceOp = b.create<tensor::ExtractSliceOp>(
-          loc, padOp.source(), newOffsets, newLengths, newStrides);
-      auto newPadOp = b.create<PadOp>(loc, newSliceOp, staticNewLows,
-                                      staticNewHighs, newLows, newHighs);
+  // In cases where the original data source is unused: Emit a GenerateOp and
+  // do not generate a SliceOp. (The result shape of the SliceOp would
+  // have a dimension of size 0, the semantics of which is unclear.)
+  auto createGenerateOp = [&]() {
+    // Create GenerateOp.
+    auto generateOp = b.create<tensor::GenerateOp>(
+        loc, resultType, dynDims,
+        [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
+          builder.create<tensor::YieldOp>(gLoc, padValue);
+        });
+    return castResult(generateOp);
+  };
 
-      // Copy region to new PadOp.
-      BlockAndValueMapping bvm;
-      padOp.region().cloneInto(&newPadOp.getRegion(), bvm);
+  // Emit a SliceOp and a PadOp. Should not be used in cases where
+  // the result shape of the new SliceOp has a zero dimension.
+  auto createPadOfExtractSlice = [&]() {
+    // Create pad(extract_slice(x)).
+    auto newSliceOp = b.create<tensor::ExtractSliceOp>(
+        loc, padOp.source(), newOffsets, newLengths, newStrides);
+    auto newPadOp = b.create<PadOp>(loc, newSliceOp, staticNewLows,
+                                    staticNewHighs, newLows, newHighs);
 
-      // Cast result and return.
-      return castResult(newPadOp);
-    };
+    // Copy region to new PadOp.
+    BlockAndValueMapping bvm;
+    padOp.region().cloneInto(&newPadOp.getRegion(), bvm);
 
-    // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known
-    // that the original data source x is not used.
-    if (hasZeroLen)
-      return {createGenerateOp()};
+    // Cast result and return.
+    return castResult(newPadOp);
+  };
 
-    // If there are dynamic dimensions: Generate an scf.if check to avoid
-    // creating SliceOps with result dimensions of size 0 at runtime.
-    if (dynHasZeroLenCond) {
-      auto result = b.create<scf::IfOp>(
-          loc, resultType, dynHasZeroLenCond,
-          /*thenBuilder=*/
-          [&](OpBuilder &b, Location loc) {
-            b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
-          },
-          /*elseBuilder=*/
-          [&](OpBuilder &b, Location loc) {
-            b.create<scf::YieldOp>(loc,
-                                   createPadTensorOfSubTensor()->getResult(0));
-          });
-      return {result};
-    }
-    return {createPadTensorOfSubTensor()};
-  }
-};
+  // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
+  // the original data source x is not used.
+  if (hasZeroLen)
+    return createGenerateOp();
 
-} // namespace
+  // If there are dynamic dimensions: Generate an scf.if check to avoid
+  // creating SliceOps with result dimensions of size 0 at runtime.
+  if (generateZeroSliceGuard && dynHasZeroLenCond) {
+    auto result = b.create<scf::IfOp>(
+        loc, resultType, dynHasZeroLenCond,
+        /*thenBuilder=*/
+        [&](OpBuilder &b, Location loc) {
+          b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
+        },
+        /*elseBuilder=*/
+        [&](OpBuilder &b, Location loc) {
+          b.create<scf::YieldOp>(loc, createPadOfExtractSlice()->getResult(0));
+        });
+    return result;
+  }
+  return createPadOfExtractSlice();
+}
 
 void mlir::tensor::registerTilingOpInterfaceExternalModels(
     DialectRegistry &registry) {
index 8ef56db..da029a7 100644 (file)
@@ -7084,6 +7084,7 @@ cc_library(
         ":StandardOpsTransforms",
         ":Support",
         ":TensorDialect",
+        ":TensorTilingInterfaceImpl",
         ":TensorTransforms",
         ":TensorUtils",
         ":TransformUtils",