Introduce `tensor.pack` and `tensor.unpack` operations
authorLorenzo Chelini <l.chelini@icloud.com>
Tue, 15 Nov 2022 09:30:50 +0000 (10:30 +0100)
committerLorenzo Chelini <l.chelini@icloud.com>
Tue, 22 Nov 2022 08:11:59 +0000 (09:11 +0100)
Pack and Unpack return new tensors within which the individual elements
are reshuffled according to the packing specification. This has the
consequence of modifying the canonical order in which a given operator
(i.e., Matmul) accesses the individual elements. After bufferization,
this typically translates to increased access locality and cache
behavior improvement, e.g., eliminating cache line splitting.

Co-authored-by: Mahesh Ravishankar <ravishankarm@google.com>
Co-authored-by: Han-Chung Wang <hanchung@google.com>
RFC: https://discourse.llvm.org/t/rfc-tensor-pack-and-tensor-unpack/66408/1

Reviewed By: nicolasvasilache, rengolin, hanchung

Differential Revision: https://reviews.llvm.org/D138119

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
mlir/test/Transforms/loop-invariant-code-motion.mlir

index 352002b..661a8f8 100644 (file)
@@ -1664,6 +1664,170 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
 }
 
 //===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
+class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
+      Tensor_Op<mnemonic, !listconcat(traits, [
+        DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+        DestinationStyleOpInterface,
+        ConditionallySpeculatable, NoMemoryEffect,
+        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+        TypesMatchWith<"result type matches type of dest",
+                   "dest", "result",
+                   "$_self">])> {
+
+  code commonExtraClassDeclaration = [{
+    int64_t getSourceRank() { return getSource().getType().getRank(); };
+    int64_t getDestRank() { return getDest().getType().getRank(); };
+    RankedTensorType getSourceType() { 
+      return getSource().getType().cast<RankedTensorType>(); };
+    RankedTensorType getDestType() {
+      return getDest().getType().cast<RankedTensorType>(); };
+
+    /// Return position for init operand. Init operand is `dest`. 
+    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+      return {1, 2}; // `dest` operand
+    }
+
+    /// Interface method for ConditionallySpeculatable.
+    Speculation::Speculatability getSpeculatability();   
+    /// Return a mapping from positions `inner_dims_pos` to their 
+    /// tile factors.
+    DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
+    
+    /// Return the tile sizes as OpFoldResult.
+    SmallVector<OpFoldResult> getMixedTiles();
+    
+    /// Return the tile sizes as `int64_t`. If a tile size is dynamic 
+    /// a sentinel `kDynamic` is introduced at that position in 
+    /// the returned vector.
+    SmallVector<int64_t> getStaticTiles();
+  }];
+  
+  let hasVerifier = 1;
+}
+
+def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
+    AttrSizedOperandSegments]> {
+  let summary = "tensor pack operation";
+  let description = [{ 
+    The pack operation converts an input tensor to a higher-dimensional tensor
+    with a tiled and packed layout. The mandatory `inner_dims_pos` attribute
+    specifies a permutation for the original dimensions, while `inner_tiles` is the
+    tiling factor for each dimension. The optional attribute `outer_dims_perm`
+    specifies the order for the tiled data dimension, while the attribute
+    `padding_value` specifies a padding value at the boundary on non-perfectly
+    divisible dimensions. Padding is optional: 
+    - If absent, it is UB if the tile does not perfectly divide the dimension.  
+    - If present, it will pad along high dimensions (high-padding) to make the 
+      tile complete. 
+
+    Example NC_to_NCnc:
+
+    ```mlir
+    tensor.pack %source inner_dims_pos = [0, 1]
+      inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+    ```
+    Example CK to KCck
+
+    ```mlir
+    tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+      inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<8x16x8x32xf32>
+    ```
+
+    In all cases, dimension at position 0 in the input tensor (128) is tiled
+    with a factor of 8, while dimension at position 1 (256) is tiled with a factor
+    of 32. In the second example, the outer data dimensions are interchanged
+    according to `outer_dims_perm`.
+
+    Example NC_to_NCnc with padding:
+
+    ```mlir
+    tensor.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1]
+      inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+    ```
+
+  }];
+  let arguments = (ins AnyRankedTensor:$source,
+                       AnyRankedTensor:$dest,
+                       Optional<AnyType>:$padding_value,
+                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+                       DenseI64ArrayAttr:$inner_dims_pos,
+                       Variadic<Index>:$inner_tiles,
+                       I64ArrayAttr:$static_inner_tiles);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    $source 
+    (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
+    (`outer_dims_perm` `=` $outer_dims_perm^)?  
+    `inner_dims_pos` `=` $inner_dims_pos
+    `inner_tiles` `=`
+    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles,
+                             "ShapedType::kDynamic")
+    `into` $dest attr-dict `:` type($source) `->` type($dest)
+  }];
+
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    // Method to get the `ShapedType` of the result based on the inner tiles,
+    // position of the inner tiles (innerDimsPos)  and interchange vector of  
+    // outer loops (outerDimsPerm).
+    static ShapedType inferPackedType(ShapedType sourceType,
+        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+  }]; 
+}
+
+//===----------------------------------------------------------------------===//
+// UnPackOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
+  let summary = "tensor unpack operation";
+  let description = [{
+    The unpack operation converts a tensor with a tiled and packed layout to a
+    lower-dimensional tensor. Similar to `pack`,  the mandatory attributes
+    `inner_dims_pos` specifies a permutation for the inner data dimensions, while
+    `inner_tiles` is the tiling factor. The attribute `outer_dims_perm` has the
+    exact behavior as the one described in `pack`. In `unpack`, it is UB if the
+    tile does not perfectly divide the dimension.
+
+    Example NCnc_to_NC:
+
+    ```mlir
+    tensor.unpack %source inner_dims_pos = [0, 1]
+      inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+    ```
+
+    Example CK to KCck:
+
+    ```mlir
+    tensor.unapck %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] 
+      inner_tiles = [8, 32] into %dest : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
+    ```
+  }];
+  let arguments = (ins AnyRankedTensor:$source,
+                       AnyRankedTensor:$dest,
+                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+                       DenseI64ArrayAttr:$inner_dims_pos,
+                       Variadic<Index>:$inner_tiles,
+                       I64ArrayAttr:$static_inner_tiles);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    $source
+    (`outer_dims_perm` `=` $outer_dims_perm^)?
+    `inner_dims_pos` `=` $inner_dims_pos
+    `inner_tiles` `=`
+    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles,
+                             "ShapedType::kDynamic")
+    `into` $dest attr-dict `:` type($source) `->` type($dest)
+  }];
+
+  let extraClassDeclaration = commonExtraClassDeclaration;
+}
+
+//===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
 
index 019cffe..95dbd77 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -2945,6 +2946,369 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// PackOp/UnPackOp Common
+//===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+static LogicalResult
+reifyResultShapesImpl(OpTy op, OpBuilder &builder,
+                      ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+                "applies to only pack or unpack operations");
+  int64_t destRank = op.getDestRank();
+  reifiedReturnShapes.resize(1, SmallVector<Value>(destRank));
+  for (auto dim : llvm::seq<int64_t>(0, destRank)) {
+    reifiedReturnShapes[0][dim] =
+        builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
+  }
+  return success();
+}
+
+template <typename OpTy>
+static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) {
+  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+                "applies to only pack or unpack operations");
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
+  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
+  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
+  assert(tiles.size() == dimsToTile.size() &&
+         "tiles must match indices of dimension to block");
+  // bind the dimension `i` with the tile factor.
+  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
+    dimAndTileMapping[dimsToTile[i]] = tiles[i];
+  return dimAndTileMapping;
+}
+
+template <typename OpTy>
+static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
+  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+                "applies to only pack or unpack operations");
+  SmallVector<OpFoldResult> mixedInnerTiles;
+  unsigned dynamicValIndex = 0;
+  for (Attribute attr : op.getStaticInnerTiles()) {
+    auto tileAttr = attr.cast<IntegerAttr>();
+    if (!ShapedType::isDynamic(tileAttr.getInt()))
+      mixedInnerTiles.push_back(tileAttr);
+    else
+      mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
+  }
+  return mixedInnerTiles;
+}
+
+template <typename OpTy>
+static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
+  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+                "applies to only pack or unpack operations");
+  SmallVector<Value> dynamicTiles;
+  SmallVector<int64_t> staticTiles;
+  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles,
+                             ShapedType::kDynamic);
+  return staticTiles;
+}
+
+/// Returns true if `dimsPos` is invalid. It is invalid when:
+/// a) It contains duplicate.
+/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
+/// c) The number of elements in `dimsPos` is > than `rank`.
+static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
+                                             size_t rank) {
+  size_t dimsPosSize = dimsPos.size();
+  if (dimsPosSize > rank)
+    return true;
+  DenseSet<int64_t> uniqued;
+  for (int64_t dim : dimsPos)
+    uniqued.insert(dim);
+  if (dimsPosSize != uniqued.size())
+    return true;
+  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
+    return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
+  });
+}
+
+/// Returns true if the dimension of `sourceShape` is smaller than the dimension
+/// of the `limitShape`.
+static bool areAllInBound(ArrayRef<int64_t> sourceShape,
+                          ArrayRef<int64_t> limitShape) {
+  assert(
+      sourceShape.size() == limitShape.size() &&
+      "expected source shape rank, and limit of the shape to have same rank");
+  return llvm::all_of(
+      llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
+        int64_t sourceExtent = std::get<0>(it);
+        int64_t limit = std::get<1>(it);
+        return ShapedType::isDynamic(sourceExtent) ||
+               ShapedType::isDynamic(limit) || sourceExtent <= limit;
+      });
+}
+
+template <typename OpTy>
+static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
+  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+                "applies to only pack or unpack operations");
+  Operation *op = packOrUnPack.getOperation();
+
+  // Return true if we have a zero-value tile.
+  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
+    return llvm::any_of(
+        tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
+  };
+
+  // Verify tiles. Do not allow zero tiles.
+  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
+  if (hasZeros(mixedTiles))
+    return op->emitError("invalid zero tile factor");
+
+  // Verify inner_dims_pos and outer_dims_perm.
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
+  size_t unpackedRank = unpackedType.getRank();
+  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
+  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
+  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
+    return op->emitError("invalid inner_dims_pos vector");
+  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
+    return op->emitError("invalid outer_dims_perm vector");
+
+  // Tiling factors must be less than or equal to the input rank for pack (or
+  // output rank for unpack), and must match the number of `inner_dims_pos`.
+  if (mixedTiles.size() > unpackedRank) {
+    return op->emitError("tiling factors must be less than or equal to the "
+                         "input rank for pack or output rank for unpack");
+  }
+  if (mixedTiles.size() != innerDimsPos.size()) {
+    return op->emitError(
+        "tiling factors must equal the number of dimensions to tile");
+  }
+
+  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
+                              ? packOrUnPack.getDestType()
+                              : packOrUnPack.getSourceType();
+  size_t packedRank = packedType.getRank();
+  // Require output rank to match input rank + number of blocking factors.
+  if (unpackedRank + mixedTiles.size() != packedRank) {
+    return op->emitError(
+        "packed rank must equal unpacked rank + tiling factors");
+  }
+
+  // Verify result shape is greater than the minimum expected
+  // by the pack operation, and that the output shape
+  // represents full tiles.
+  ShapedType expectedPackedType = PackOp::inferPackedType(
+      unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
+  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
+    return op->emitError("the shape of output is not large enough to hold the "
+                         "packed data. Expected at least ")
+           << expectedPackedType << ", got " << packedType;
+  }
+  if (!llvm::all_of(
+          llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
+                    mixedTiles),
+          [](std::tuple<int64_t, OpFoldResult> it) {
+            Optional<int64_t> constTileSize =
+                getConstantIntValue(std::get<1>(it));
+            int64_t shape = std::get<0>(it);
+            if (!constTileSize) {
+              // If specified tile size is dynamic, output shape should
+              // be dynamic too.
+              return ShapedType::isDynamic(shape);
+            } else {
+              if (ShapedType::isDynamic(shape)) {
+                // For the shape being dynamic when tile size is
+                // specified, return true. In canonical form a constant
+                // tile size should lead to constant shape of the tiled
+                // dimension, but not needed for verification.
+                return true;
+              }
+              return shape == constTileSize.value();
+            }
+          })) {
+    return op->emitError("mismatch in inner tile sizes specified and shaped of "
+                         "tiled dimension in the packed type");
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
+void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "pack");
+}
+
+LogicalResult
+PackOp::reifyResultShapes(OpBuilder &builder,
+                          ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
+}
+
+DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
+  return getDimAndTileMappingImpl(*this);
+}
+
+SmallVector<OpFoldResult> PackOp::getMixedTiles() {
+  return getMixedTilesImpl(*this);
+}
+
+SmallVector<int64_t> PackOp::getStaticTiles() {
+  return getStaticTilesImpl(*this);
+}
+
+/// Check if we have enough static information to catch undefined behavior when
+/// the tile size does not divide perfectly the dimension of the input tensor.
+static bool
+areNotFullTiles(ArrayRef<int64_t> inputShape,
+                DenseMap<int64_t, OpFoldResult> const &dimAndTileMapping) {
+  int64_t rank = inputShape.size();
+  for (int64_t dim = 0; dim < rank; dim++) {
+    if (ShapedType::isDynamic(inputShape[dim]))
+      continue;
+    auto it = dimAndTileMapping.find(dim);
+    if (it == dimAndTileMapping.end())
+      continue;
+    Optional<int64_t> constantTile = getConstantIntValue(it->second);
+    if (!constantTile)
+      continue;
+    if (inputShape[dim] % (*constantTile) != 0)
+      return true;
+  }
+  return false;
+}
+
+LogicalResult PackOp::verify() {
+  if (failed(commonVerifierPackAndUnPackOp(*this)))
+    return failure();
+
+  // Verify padding value, and bail out if the tile does not divide the
+  // dimension fully. In the case of dynamic tile factors or dimensions, having
+  // a partial tile is undefined behavior.
+  auto paddingValue = getPaddingValue();
+  if (paddingValue &&
+      paddingValue.getType() != getSourceType().getElementType()) {
+    return emitOpError("expected padding_value has ")
+           << getSourceType().getElementType()
+           << " but got: " << paddingValue.getType();
+  }
+
+  auto dimAndTileMapping = getDimAndTileMapping();
+  if (!paddingValue &&
+      areNotFullTiles(getSourceType().getShape(), dimAndTileMapping)) {
+    return emitOpError("invalid tile factor provided. Only full tiles are "
+                       "supported when padding_value is not set");
+  }
+  return success();
+}
+
+/// Returns a vector that interchanges `elements` starting at offset `offset`
+/// based on the indexes in `interchangeVector`.
+template <typename T>
+SmallVector<T> interchange(ArrayRef<T> elements,
+                           ArrayRef<int64_t> interchangeVector,
+                           int offset = 0) {
+  SmallVector<T> vec = llvm::to_vector(elements);
+  for (auto en : llvm::enumerate(interchangeVector))
+    vec[en.index() + offset] = elements[en.value() + offset];
+
+  return vec;
+}
+
+/// Get the expected packed type based on source type, tile factors, position of
+/// the inner tiles and permutation of the outer tiled loop.
+ShapedType PackOp::inferPackedType(ShapedType sourceType,
+                                   ArrayRef<int64_t> innerTileSizes,
+                                   ArrayRef<int64_t> innerDimsPos,
+                                   ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = llvm::to_vector(sourceType.getShape());
+  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
+    if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
+      continue;
+    if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
+      resultShape[tiledDim.value()] = ShapedType::kDynamic;
+      continue;
+    }
+    resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()],
+                                            innerTileSizes[tiledDim.index()]);
+  }
+
+  resultShape = interchange<int64_t>(resultShape, outerDimsPerm);
+
+  // Append the inner tile dimensions.
+  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
+  return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+/// Returns true if the tiles and the tiled dims are constant.
+template <typename OpTy>
+bool areTilesAndTiledDimsAllConstant(OpTy op) {
+  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+                "applies to only pack or unpack operations");
+  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
+                              ? op.getDestType()
+                              : op.getSourceType();
+  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
+  for (auto [dimDest, tile] : llvm::zip(
+           packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
+    Optional<int64_t> constTileSize = getConstantIntValue(tile);
+    if (!constTileSize || ShapedType::isDynamic(dimDest))
+      return false;
+  }
+  return true;
+}
+
+Speculation::Speculatability PackOp::getSpeculatability() {
+  if (auto paddingValue = getPaddingValue())
+    return Speculation::Speculatable;
+
+  // The verifier rejects already operations if we can statically prove that the
+  // sizes of the tiles do not divide perfectly the dimension; thus, check only
+  // to have constant tiles and tiled inner dimensions.
+  if (!areTilesAndTiledDimsAllConstant(*this))
+    return Speculation::NotSpeculatable;
+
+  return Speculation::Speculatable;
+}
+
+//===----------------------------------------------------------------------===//
+// UnPackOp
+//===----------------------------------------------------------------------===//
+
+void UnPackOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "unpack");
+}
+
+LogicalResult
+UnPackOp::reifyResultShapes(OpBuilder &builder,
+                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
+}
+
+DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
+  return getDimAndTileMappingImpl(*this);
+}
+
+SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
+  return getMixedTilesImpl(*this);
+}
+
+SmallVector<int64_t> UnPackOp::getStaticTiles() {
+  return getStaticTilesImpl(*this);
+}
+
+LogicalResult UnPackOp::verify() {
+  return commonVerifierPackAndUnPackOp(*this);
+}
+
+Speculation::Speculatability UnPackOp::getSpeculatability() {
+  // See PackOp::getSpeculatability.
+  if (!areTilesAndTiledDimsAllConstant(*this))
+    return Speculation::NotSpeculatable;
+
+  return Speculation::Speculatable;
+}
+
+//===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
index d1d1c35..b085053 100644 (file)
@@ -522,3 +522,92 @@ func.func @empty_wrong_number_of_operands(%sz : index) {
   %out = tensor.empty(%sz) : tensor<2x?x?x5xf32>
   return
 }
+
+// -----
+
+func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
+  // expected-error@+1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}}
+  %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32>  -> tensor<8x8x16x33xf32>
+  return %0 : tensor<8x8x16x33xf32>
+}
+
+// -----
+
+func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
+  // expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}}
+  %0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+// -----
+
+func.func @pack_invalid_inner_dims_pos_vector(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+  // expected-error@+1 {{invalid inner_dims_pos vector}}
+  %0 = tensor.pack %input inner_dims_pos = [2, 0] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+  return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+
+func.func @pack_invalid_duplicate_element_in_inner_dims(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+  // expected-error@+1 {{invalid inner_dims_pos vector}}
+  %0 = tensor.pack %input inner_dims_pos = [1, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+  return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+
+func.func @pack_invalid_duplicate_element_in_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+  // expected-error@+1 {{invalid outer_dims_perm vector}}
+  %0 = tensor.pack %input outer_dims_perm = [1, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+  return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+
+func.func @unpack_invalid_out_of_bound_outer_perm(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+  // expected-error@+1 {{invalid outer_dims_perm vector}}
+  %0 = tensor.unpack %output outer_dims_perm = [2, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %input : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+  return %0 : tensor<256x128xf32>
+}
+
+// -----
+
+func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+  // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
+  %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+  return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+
+func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
+  // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
+  %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+  return %0 : tensor<256x128xf32>
+}
+
+// -----
+
+func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+  // expected-error@+1 {{invalid zero tile factor}}
+  %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [0, 2] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+  return %0 : tensor<8x8x32x16xf32>
+}
+
+// -----
+func.func @pack_mismatch_inner_tile_size_and_output_shape(
+  %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
+  // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+  %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
+  return %0 : tensor<?x?x8x8xf32>
+}
+
+// -----
+
+func.func @unpack_mismatch_inner_tile_size_and_output_shape(
+  %input : tensor<?x?x8x8xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+  %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
index a561726..3bb6235 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt --split-input-file %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: func @cast(
 func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
@@ -13,6 +13,8 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @empty(
 //  CHECK-SAME:             %[[sz:.*]]: index
 func.func @empty(%sz: index) -> tensor<5x?x6xf32> {
@@ -21,6 +23,8 @@ func.func @empty(%sz: index) -> tensor<5x?x6xf32> {
   return %0 : tensor<5x?x6xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @empty_with_encoding(
 //  CHECK-SAME:             %[[sz:.*]]: index
 func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> {
@@ -29,6 +33,8 @@ func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> {
   return %0 : tensor<5x?x6xf32, "foo">
 }
 
+// -----
+
 // CHECK-LABEL:   func @extract(
 // CHECK-SAME:                  %[[TENSOR:.*]]: tensor<?x?x?xf32>,
 // CHECK-SAME:                  %[[INDEX:.*]]: index) {
@@ -38,6 +44,8 @@ func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
   return
 }
 
+// -----
+
 // CHECK-LABEL:   func @insert(
 // CHECK-SAME:                  %[[SCALAR:.*]]: f32
 // CHECK-SAME:                  %[[INDEX:.*]]: index
@@ -48,6 +56,8 @@ func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @tensor.from_elements() {
 func.func @tensor.from_elements() {
   %c0 = "arith.constant"() {value = 0: index} : () -> index
@@ -74,6 +84,8 @@ func.func @tensor.from_elements() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @tensor.generate
 func.func @tensor.generate(%m : index, %n : index)
     -> tensor<?x3x?xf32> {
@@ -85,6 +97,8 @@ func.func @tensor.generate(%m : index, %n : index)
   return %tnsr : tensor<?x3x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @tensor_reshape
 func.func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
          %shape2: tensor<2xi32>, %shape3: tensor<?xi32>) -> tensor<*xf32> {
@@ -97,6 +111,8 @@ func.func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
   return %new_unranked : tensor<*xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @slice({{.*}}) {
 func.func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
   %c0 = arith.constant 0 : index
@@ -120,6 +136,8 @@ func.func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @insert_slice({{.*}}) {
 func.func @insert_slice(
     %t: tensor<8x16x4xf32>,
@@ -154,6 +172,8 @@ func.func @insert_slice(
   return
 }
 
+// -----
+
 func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
     -> (tensor<f32>, tensor<1x1xf32>) {
   %0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
@@ -164,6 +184,8 @@ func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
 //       CHECK:   tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
 //       CHECK:   tensor.expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
 
+// -----
+
 func.func @legal_collapsing_reshape_dynamic_tensor
   (%arg0: tensor<?x?x?x4x?xf32>) -> tensor<?x?x?xf32>
 {
@@ -175,6 +197,8 @@ func.func @legal_collapsing_reshape_dynamic_tensor
 //      CHECK:   tensor.collapse_shape
 // CHECK-SAME:    [0], [1], [2, 3, 4]
 
+// -----
+
 func.func @rank(%t : tensor<4x4x?xf32>) {
   // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32>
   %0 = "tensor.rank"(%t) : (tensor<4x4x?xf32>) -> index
@@ -184,6 +208,8 @@ func.func @rank(%t : tensor<4x4x?xf32>) {
   return
 }
 
+// -----
+
 func.func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
                   %pad_value: f32) -> tensor<6x?x?x?xf32> {
   %0 = tensor.pad %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] {
@@ -201,6 +227,8 @@ func.func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
 //  CHECK-SAME:     high[3, 3, %[[HIGH]], 2]
 //       CHECK:    : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
 
+// -----
+
 func.func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> {
   %0 = tensor.pad %arg0 low[1, 2] high[2, 3] {
     ^bb0(%arg1 : index, %arg2 : index):
@@ -213,6 +241,8 @@ func.func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32
 //       CHECK:   tensor.pad %[[ARG0]] low[1, 2] high[2, 3]
 //       CHECK:    : tensor<3x4xf32> to tensor<6x9xf32>
 
+// -----
+
 func.func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
                        %pad_value: f32) -> tensor<?x?xf32> {
   %0 = tensor.pad %arg0 low[0, 0] high[%ub0, %ub1] {
@@ -230,6 +260,8 @@ func.func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
 //  CHECK-SAME:     high[%[[UB0]], %[[UB1]]]
 //       CHECK:    : tensor<2x3xf32> to tensor<?x?xf32>
 
+// -----
+
 func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
                          %pad_value: f32) -> tensor<2x3xf32> {
   %0 = tensor.pad %arg0 low[0, 0] high[%ub0, %ub1] {
@@ -247,6 +279,8 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
 //  CHECK-SAME:     high[%[[UB0]], %[[UB1]]]
 //       CHECK:    : tensor<?x?xf32> to tensor<2x3xf32>
 
+// -----
+
 // CHECK-LABEL: func @test_splat_op
 // CHECK-SAME: [[S:%arg[0-9]+]]: f32
 func.func @test_splat_op(%s : f32) {
@@ -258,6 +292,8 @@ func.func @test_splat_op(%s : f32) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func.func @gather_scatter(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<4x5x6xf32>,
 // CHECK-SAME:  %[[ARG1:.*]]: tensor<1x3x2xindex>,
@@ -281,3 +317,106 @@ func.func @gather_scatter(
     (tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32>
   return
 }
+
+// -----
+
+func.func @pack_nc_to_ncnc(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) -> tensor<128x256xf32> {
+  %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+  %1 = tensor.empty() : tensor<128x256xf32>
+  %2 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+  return %2 : tensor<128x256xf32>
+}
+
+// CHECK-LABEL: func.func @pack_nc_to_ncnc(
+// CHECK-SAME:  %[[SOURCE:.*]]: tensor<128x256xf32>,
+// CHECK-SAME:  %[[DEST:.*]]: tensor<4x16x32x16xf32>)
+// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[DEST]] : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32>
+// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+
+// -----
+
+func.func @pack_nc_to_ncnc_with_padding(%source: tensor<13x15xf32>, %dest: tensor<2x8x8x2xf32>, %padding: f32) -> tensor<13x15xf32> {
+  %0 = tensor.pack %source padding_value(%padding : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+  %1 = tensor.empty() : tensor<13x15xf32>
+  %2 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
+  return %2 : tensor<13x15xf32>
+}
+
+// CHECK-LABEL: func.func @pack_nc_to_ncnc_with_padding(
+// CHECK-SAME:  %[[SOURCE:.*]]: tensor<13x15xf32>,
+// CHECK-SAME:  %[[DEST:.*]]: tensor<2x8x8x2xf32>,
+// CHECK-SAME:  %[[PADDING:.*]]: f32)
+// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] padding_value(%[[PADDING]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<13x15xf32>
+// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[BUFF]] : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
+
+// -----
+
+func.func @pack_ck_to_kcck(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<128x256xf32> {
+  %0 = tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<16x4x32x16xf32>
+  %1 = tensor.empty() : tensor<128x256xf32>
+  %2 = tensor.unpack %0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %1 : tensor<16x4x32x16xf32> -> tensor<128x256xf32>
+  return %2 : tensor<128x256xf32>
+}
+
+// CHECK-LABEL: func.func @pack_ck_to_kcck(
+// CHECK-SAME:  %[[SOURCE:.*]]: tensor<128x256xf32>,
+// CHECK-SAME:  %[[DEST:.*]]: tensor<16x4x32x16xf32>)
+// CHECK: %[[PACKED:.*]] = tensor.pack %[[SOURCE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[DEST]] : tensor<128x256xf32> -> tensor<16x4x32x16xf32>
+// CHECK: %[[BUFF:.*]] = tensor.empty() : tensor<128x256xf32>
+// CHECK: %{{.*}} = tensor.unpack %[[PACKED]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %[[BUFF]] : tensor<16x4x32x16xf32> -> tensor<128x256xf32>
+
+// -----
+
+func.func @pad_and_pack_fully_dynamic(%source: tensor<?x?xf32>, %dest: tensor<?x?x?x?xf32>, %pad: f32, %tile_n : index, %tile_m : index) -> tensor<?x?x?x?xf32> {
+  %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: func.func @pad_and_pack_fully_dynamic(
+// CHECK-SAME:  %[[SOURCE:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:  %[[DEST:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:  %[[PAD:.*]]: f32,
+// CHECK-SAME:  %[[TILE_N:.*]]: index,
+// CHECK-SAME:  %[[TILE_M:.*]]: index)
+// CHECK: %{{.*}} = tensor.pack %[[SOURCE]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_N]], %[[TILE_M]]] into %[[DEST]] : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+
+// -----
+
+func.func @pad_and_pack_partially_dynamic(%source: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>, %pad: f32) -> tensor<?x?x8x2xf32> {
+  %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+  return %0 : tensor<?x?x8x2xf32>
+}
+
+// CHECK-LABEL: func.func @pad_and_pack_partially_dynamic(
+// CHECK-SAME:  %[[SOURCE:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:  %[[DEST:.*]]: tensor<?x?x8x2xf32>,
+// CHECK-SAME:  %[[PAD:.*]]: f32)
+// CHECK: %{{.*}} = tensor.pack %[[SOURCE]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+
+// -----
+
+func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>, %tile_n : index, %tile_m : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @unpack_fully_dynamic(
+// CHECK-SAME:  %[[SOURCE:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:  %[[DEST:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:  %[[TILE_N:.*]]: index,
+// CHECK-SAME:  %[[TILE_M:.*]]: index)
+// CHECK: %{{.*}} = tensor.unpack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_N]], %[[TILE_M]]] into %[[DEST]] : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+
+// -----
+
+func.func @unpack_partially_dynamic(%source: tensor<?x?x8x2xf32>, %dest: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %dest : tensor<?x?x8x2xf32> -> tensor<?x?xf32>
+  return %0: tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @unpack_partially_dynamic(
+// CHECK-SAME:  %[[SOURCE:.*]]: tensor<?x?x8x2xf32>,
+// CHECK-SAME:  %[[DEST:.*]]: tensor<?x?xf32>)
+// CHECK: %{{.*}} = tensor.unpack %[[SOURCE]] inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %[[DEST]] : tensor<?x?x8x2xf32> -> tensor<?x?xf32>
index 4d2b5b7..090962d 100644 (file)
@@ -874,3 +874,58 @@ func.func @speculate_ceildivsi_const(
 
   return
 }
+
+// -----
+
+func.func @speculate_static_pack_and_unpack(%source: tensor<128x256xf32>, 
+  %dest: tensor<4x16x32x16xf32>, %lb: index, %ub: index, %step: index) {
+
+  // CHECK: tensor.pack
+  // CHECK-NEXT: scf.for  
+  scf.for %i = %lb to %ub step %step {
+    %packed = tensor.pack %source 
+      inner_dims_pos = [0, 1] 
+      inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+  }
+  
+  // CHECK: tensor.unpack
+  // CHECK-NEXT: scf.for 
+  scf.for %i = %lb to %ub step %step {
+    %unpacked = tensor.unpack %dest
+      inner_dims_pos = [0, 1] 
+      inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+  }
+  return 
+}
+
+// -----
+
+func.func @speculate_dynamic_pack_and_unpack(%source: tensor<?x?xf32>,
+  %dest: tensor<?x?x?x?xf32>, %lb: index, %ub: index, %step: index,
+  %tile_m: index, %tile_n: index, %pad: f32) {
+
+  // CHECK: scf.for
+  // CHECK-NEXT: tensor.pack
+  scf.for %i = %lb to %ub step %step {
+    %packed = tensor.pack %source
+      inner_dims_pos = [0, 1]
+      inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  }
+
+  // CHECK: scf.for
+  // CHECK-NEXT: tensor.unpack
+  scf.for %i = %lb to %ub step %step {
+    %unpacked = tensor.unpack %dest
+      inner_dims_pos = [0, 1] 
+      inner_tiles = [%tile_n, %tile_m] into %source : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  }
+
+  // CHECK: tensor.pack
+  // CHECK-NEXT: scf.for
+  scf.for %i = %lb to %ub step %step {
+    %packed = tensor.pack %source padding_value(%pad : f32)
+      inner_dims_pos = [0, 1]
+      inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  }
+  return
+}