[mlir][tensor] Expose padding requirement of pack ops to a static method
authorHanhan Wang <hanchung@google.com>
Thu, 9 Mar 2023 18:03:31 +0000 (10:03 -0800)
committerHanhan Wang <hanchung@google.com>
Thu, 9 Mar 2023 18:03:46 +0000 (10:03 -0800)
It also simplifies the implementation of the method. The map is not needed in the check.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D145522

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

index 80c0ba5..09b7775 100644 (file)
@@ -1788,6 +1788,13 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
+    // Returns true if we have enough static information to catch undefined
+    // behavior when the tile size does not divide perfectly the dimension of
+    // the input tensor.
+    static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
+                                    ArrayRef<int64_t> innerDimsPos,
+                                    ArrayRef<OpFoldResult> innerTiles);
+
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
index 2253133..baf2130 100644 (file)
@@ -3422,22 +3422,16 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
-/// Check if we have enough static information to catch undefined behavior when
-/// the tile size does not divide perfectly the dimension of the input tensor.
-static bool
-areNotFullTiles(ArrayRef<int64_t> inputShape,
-                DenseMap<int64_t, OpFoldResult> const &dimAndTileMapping) {
-  int64_t rank = inputShape.size();
-  for (int64_t dim = 0; dim < rank; dim++) {
-    if (ShapedType::isDynamic(inputShape[dim]))
+bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
+                                 ArrayRef<int64_t> innerDimsPos,
+                                 ArrayRef<OpFoldResult> innerTiles) {
+  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
+    if (ShapedType::isDynamic(inputShape[pos]))
       continue;
-    auto it = dimAndTileMapping.find(dim);
-    if (it == dimAndTileMapping.end())
-      continue;
-    std::optional<int64_t> constantTile = getConstantIntValue(it->second);
+    std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
     if (!constantTile)
       continue;
-    if (inputShape[dim] % (*constantTile) != 0)
+    if (inputShape[pos] % (*constantTile) != 0)
       return true;
   }
   return false;
@@ -3458,9 +3452,9 @@ LogicalResult PackOp::verify() {
            << " but got: " << paddingValue.getType();
   }
 
-  auto dimAndTileMapping = getDimAndTileMapping();
   if (!paddingValue &&
-      areNotFullTiles(getSourceType().getShape(), dimAndTileMapping)) {
+      requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
+                          getMixedTiles())) {
     return emitOpError("invalid tile factor provided. Only full tiles are "
                        "supported when padding_value is not set");
   }