[mlir][tensor] Add shape inference methods to tensor::PackOp.
authorHanhan Wang <hanchung@google.com>
Fri, 10 Feb 2023 01:24:26 +0000 (17:24 -0800)
committerHanhan Wang <hanchung@google.com>
Fri, 10 Feb 2023 01:24:42 +0000 (17:24 -0800)
Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D143686

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

index e702189..9e1c8bc 100644 (file)
@@ -1772,6 +1772,14 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
+    // This is a static method to allow getting the shape of the destination
+    // expected while creating a `pack` op.
+    static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
+        Location loc, ArrayRef<OpFoldResult> sourceDims,
+        ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
     // Method to get the `ShapedType` of the result based on the inner tiles,
     // position of the inner tiles (innerDimsPos)  and interchange vector of
     // outer loops (outerDimsPerm).
index d35895a..74b3f93 100644 (file)
@@ -3479,14 +3479,29 @@ LogicalResult PackOp::verify() {
   return success();
 }
 
-/// Get the expected packed type based on source type, tile factors, position of
-/// the inner tiles and permutation of the outer tiled loop.
-ShapedType PackOp::inferPackedType(ShapedType sourceType,
-                                   ArrayRef<int64_t> innerTileSizes,
-                                   ArrayRef<int64_t> innerDimsPos,
-                                   ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = llvm::to_vector(sourceType.getShape());
-  for (const auto &tiledDim : llvm::enumerate(innerDimsPos)) {
+/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
+/// Value's to kDynamic, even if they are arith.constant values.
+static SmallVector<int64_t>
+asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
+  SmallVector<int64_t> result;
+  for (auto o : ofrs) {
+    // Have to do this first, as getConstantIntValue special-cases constants.
+    if (o.dyn_cast<Value>())
+      result.push_back(ShapedType::kDynamic);
+    else
+      result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
+  }
+  return result;
+}
+
+/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
+/// the packed type. Having a shared helper helps implement these two methods in
+/// a way that ensures that they agree on which dimensions are dynamic.
+static SmallVector<int64_t> getPackOpResultTypeShape(
+    ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
+    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
+  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
     if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
       continue;
     if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
@@ -3497,11 +3512,60 @@ ShapedType PackOp::inferPackedType(ShapedType sourceType,
                                             innerTileSizes[tiledDim.index()]);
   }
 
+  // Swap tile loops if outer_dims_perm is available.
   if (!outerDimsPerm.empty())
     applyPermutationToVector(resultShape, outerDimsPerm);
 
   // Append the inner tile dimensions.
   resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
+  return resultShape;
+}
+
+SmallVector<OpFoldResult> PackOp::getResultShape(
+    OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
+    ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+    ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
+
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
+  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
+    resultDims[tiledDim.value()] = makeComposedFoldedAffineApply(
+        builder, loc, ceilDivExpr,
+        {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
+  }
+  if (!outerDimsPerm.empty())
+    applyPermutationToVector(resultDims, outerDimsPerm);
+  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
+
+  SmallVector<int64_t> resultTypeShape =
+      getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
+                               asShapeWithAnyValueAsDynamic(innerTileSizes),
+                               innerDimsPos, outerDimsPerm);
+
+  // Fix-up `resultDims` to ensure that they are Value's if and only if the
+  // result type shape says it's a dynamic dim. This is needed as callers may
+  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
+  // dynamic dims returned by that.
+  for (unsigned i = 0; i < resultDims.size(); ++i) {
+    if (!ShapedType::isDynamic(resultTypeShape[i]))
+      continue;
+    resultDims[i] =
+        getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
+  }
+
+  return resultDims;
+}
+
+/// Get the expected packed type based on source type, tile factors, position of
+/// the inner tiles and permutation of the outer tiled loop.
+ShapedType PackOp::inferPackedType(ShapedType sourceType,
+                                   ArrayRef<int64_t> innerTileSizes,
+                                   ArrayRef<int64_t> innerDimsPos,
+                                   ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+      sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
   return RankedTensorType::get(resultShape, sourceType.getElementType());
 }