[mlir][linalg] Prepare pad to static bounding box for scalar operands.
authorTobias Gysi <gysit@google.com>
Fri, 11 Jun 2021 13:21:32 +0000 (13:21 +0000)
committerTobias Gysi <gysit@google.com>
Fri, 11 Jun 2021 13:51:29 +0000 (13:51 +0000)
Adapt pad to static bounding box to support structured ops taking scalar operands.

Differential Revision: https://reviews.llvm.org/D103891

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

index fb2735c..44acac1 100644 (file)
@@ -123,18 +123,17 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
 ///      created PadTensorOp.
 /// Return failure if the operand cannot be padded to a static shape.
 static LogicalResult padOperandToSmallestStaticBoundingBox(
-    PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand,
+    PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
     const LinalgTilingOptions &options, Value &result) {
-  auto tensorType = operand.get().getType().cast<RankedTensorType>();
   // Already static shape, no need to pad.
-  if (tensorType.hasStaticShape())
+  if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
     return success();
-  auto subtensor = operand.get().getDefiningOp<SubTensorOp>();
+  auto subtensor = opOperand->get().getDefiningOp<SubTensorOp>();
   // Not a subtensor, cannot construct a static bounding box.
   if (!subtensor)
     return failure();
   SmallVector<int64_t> staticSizes;
-  staticSizes.reserve(tensorType.getRank());
+  staticSizes.reserve(opToPad.getRank(opOperand));
   auto shapedOp =
       cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation());
   for (auto size : shapedOp.getMixedSizes()) {
@@ -148,11 +147,11 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
           opToPad, "No constant bounding box can be found for padding");
     staticSizes.push_back(indexAttr.getInt());
   }
-  Value pad = options.paddingValueComputationFunction(rewriter, operand);
-  auto staticTensorType =
-      RankedTensorType::get(staticSizes, tensorType.getElementType());
+  Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
+  auto staticTensorType = RankedTensorType::get(
+      staticSizes, getElementTypeOrSelf(opOperand->get().getType()));
   result = linalg::PadTensorOp::createPadHighOp(
-      staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter);
+      staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
   return success();
 }
 
@@ -183,9 +182,8 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
     // If padding was requested but the shape cannot be bounded statically then
     // the pattern fails to apply.
     if (failed(padOperandToSmallestStaticBoundingBox(
-            rewriter, opToPad, *opOperand, options, paddedOperand))) {
+            rewriter, opToPad, opOperand, options, paddedOperand)))
       return failure();
-    }
     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
   }