[mlir][Linalg] Add a structured.pack_transpose transform op
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 18 Jan 2023 20:26:13 +0000 (12:26 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 20 Jan 2023 08:30:16 +0000 (00:30 -0800)
This transform is complementary to the `structured.pack` op which
allows packing a whole op but does not allow transposes on the individual
operands.

`structured.pack_transpose` allows transposing single operands connected to
pack or unpack ops after the fact.

This makes the system overall more composable than e.g. a giant transform
op with all permutation specified at once.

Differential Revision: https://reviews.llvm.org/D142053

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/transform-op-pack.mlir

index a04c48f..9b51228 100644 (file)
@@ -773,6 +773,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     LogicalResult reifyResultShapes(OpBuilder &b,
         ReifiedRankedShapedTypeDims &reifiedReturnShapes);
 
+    /// Return the index in the indexingMaps vector that corresponds to this `opOperand`
+    int64_t getIndexingMapIndex(OpOperand *opOperand);
+
     //========================================================================//
     // Forwarding functions to access interface methods from the
     // DestinationStyleOpInterface.
index f6c601f..991aa04 100644 (file)
@@ -363,8 +363,11 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
 def PackOp : Op<Transform_Dialect, "structured.pack", [
-                TransformOpInterface,
+                DeclareOpInterfaceMethods<TransformOpInterface>,
                 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
   let description = [{
     Pack a LinalgOp by applying a data tiling transformation on the op and
@@ -439,15 +442,74 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
   }];
 
   let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure apply(
-      transform::TransformResults &transformResults,
-      transform::TransformState &state);
-
     ::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes();
   }];
 }
 
 //===----------------------------------------------------------------------===//
+// PackTransposeOp
+//===----------------------------------------------------------------------===//
+def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
+                         FunctionalStyleTransformOpTrait,
+                         MemoryEffectsOpInterface,
+                         DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and 
+    update the `linalg.generic` op that consumes (resp. produces) the operation.
+
+    This transform allows composing a simple `structured.pack` with additional
+    transpositions to e.g. match the data format required by a specific library
+    call or ISA instruction.
+
+    The transpose spec must specify at least one of `outer_perm` or `inner_perm`
+    attributes, which will act upon the `outer_dims_perm` or `inner_dims_pos` of
+    the specified `tensor.pack` or `tensor.unpack` op.
+
+    If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will
+    be created along with transposed versions of the `tensor.pack` and the 
+    consuming `linalg.generic`, which is expected to be the sole consumer.
+
+    If the `target` of this op is a `tensor.unpack` then the whole pack / compute
+    / unpack chain will be transposed and transposed clones of `tensor.pack`,
+    the consuming `linalg.generic` and the tail `tensor.pack` will be created.
+
+    #### Return modes
+
+    This operation targets a single `tensor.pack` / `tensor.unpack` op and a
+    single matching `linalg.generic` that consumes / produces the op. Otherwise,
+    it produces a silenceableFailure.
+
+    This operation may produce a silenceableFailure if the transpose spec is
+    ill-formed (i.e. `outer_perm` or `inner_perm` are not permutations of the
+    proper rank) or if the tranposition of all involved operations fails for any
+    reason.
+
+    This operation returns 3 handles, one to the transformed LinalgOp, one to
+    the transformed `tensor.pack` and one to the transformed `tensor.unpack`.
+    The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op` 
+    was not itself a `tensor.unpack`.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target_pack_or_un_pack_op,
+                       TransformHandleTypeInterface:$target_linalg_op,
+                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_perm,
+                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$inner_perm);
+  let results = (outs TransformHandleTypeInterface:$packed_op,
+                      TransformHandleTypeInterface:$pack_op,
+                      TransformHandleTypeInterface:$un_pack_op);
+  let assemblyFormat = [{
+    $target_pack_or_un_pack_op
+    `with_compute_op` `(` $target_linalg_op `)`
+    (`outer_perm` `=` $outer_perm^ )?
+    (`inner_perm` `=` $inner_perm^ )?
+    attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
 // PadOp
 //===----------------------------------------------------------------------===//
 
index f0f1238..1b7b177 100644 (file)
@@ -1776,6 +1776,21 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+
+    /// Build and return a new PackOp that is a clone of the current PackOp with
+    /// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by 
+    /// innerPermutation (resp. outerPermutation).
+    /// A new `tensor.empty` of the proper shape is built in the process.
+    /// Asserts that:
+    ///   - At least one of innerPermutation or outerPermutation is non-empty.
+    ///   - If not empty, innerPermutation is a valid permutation of size
+    ///     matching innerDimPos.
+    ///   - If not empty, outerPermutation is a valid permutation of size 
+    ///     matching outerDimsPerm.
+    PackOp createTransposedClone(OpBuilder &b,
+                                 Location loc,
+                                 ArrayRef<int64_t> innerPermutation,
+                                 ArrayRef<int64_t> outerPermutation);
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -1832,7 +1847,23 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
     CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
   ];
 
-  let extraClassDeclaration = commonExtraClassDeclaration;
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    /// Build and return a new UnPackOp that is a clone of the current UnPackOp
+    /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by 
+    /// innerPermutation (resp. outerPermutation).
+    /// Asserts that:
+    ///   - At least one of innerPermutation or outerPermutation is non-empty.
+    ///   - If not empty, innerPermutation is a valid permutation of size
+    ///     matching innerDimPos.
+    ///   - If not empty, outerPermutation is a valid permutation of size 
+    ///     matching outerDimsPerm.
+    UnPackOp createTransposedClone(OpBuilder &b,
+                                   Location loc,
+                                   Value transposedSource,
+                                   ArrayRef<int64_t> innerPermutation,
+                                   ArrayRef<int64_t> outerPermutation);
+  }];
+
   let hasCanonicalizeMethod = 1;
 }
 
index b508871..e5e0bdd 100644 (file)
@@ -621,6 +621,22 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
   return success();
 }
 
+/// Return the index in the indexingMaps vector that corresponds to this
+/// `opOperand`.
+int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
+  auto operandNumber = opOperand->getOperandNumber();
+  auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
+  if (!dpsIface.isDpsInput(opOperand))
+    return operandNumber;
+  auto [start, end] = dpsIface.getDpsInitsPositionRange();
+  assert(!dpsIface.isDpsInit(opOperand));
+  // Account for potential inputs that are not DPS and may not appear in
+  // `indexingMaps`.
+  return cast<DestinationStyleOpInterface>(*this->getOperation())
+             .getNumDpsInputs() +
+         operandNumber - start;
+}
+
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
 
index 6efb813..4247a6c 100644 (file)
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/Debug.h"
@@ -1161,16 +1165,12 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
   // Fail on multi-op handles.
   auto linalgOp = dyn_cast<linalg::LinalgOp>(targetOps.front());
   if (targetOps.size() != 1 || !linalgOp) {
-    // TODO: remove this unnecessary set to empty once crashes are fixed.
-    transformResults.set(getPackedOp().cast<OpResult>(), {});
     return emitSilenceableError()
            << "requires target to map to exactly 1 LinalgOp (got "
            << targetOps.size() << ")";
   }
   // Fail on mismatched number of pack sizes.
   if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
-    // TODO: remove this unnecessary set to empty once crashes are fixed.
-    transformResults.set(getPackedOp().cast<OpResult>(), {});
     return emitSilenceableError()
            << "requires number of packed sizes match the number of loops ("
            << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
@@ -1194,6 +1194,263 @@ void transform::PackOp::getEffects(
   transform::consumesHandle(getTarget(), effects);
   transform::onlyReadsHandle(getPackedSizes(), effects);
   transform::producesHandle(getPackedOp(), effects);
+  transform::modifiesPayload(effects);
+}
+
+//===---------------------------------------------------------------------===//
+// PackTransposeOp
+//===---------------------------------------------------------------------===//
+
+namespace {
+enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
+} // namespace
+
+/// Return true if `permutation` is a valid permutation of the `outer_dims_perm`
+/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner)
+/// of the `tensor.pack` or `tensor.unpack` `op.
+/// This is the case when the `permutation` rank matches the rank expected by
+/// `op` and `permutation` is itself a permutation vector.
+/// Return true if either `op` or `permutation` are empty to allow a simpler
+/// polymorphic implementation.
+template <typename RelayoutOpTy>
+bool isValidPackingPermutation(
+    RelayoutOpTy op, ArrayRef<int64_t> permutation,
+    OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
+  static_assert(
+      llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
+      "applies to only pack or unpack operations");
+  if (!op || permutation.empty())
+    return true;
+  int64_t innerRank = op.getInnerDimsPos().size();
+  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
+    return permutation.size() == innerRank && isPermutationVector(permutation);
+  // op.getOuterDimsPerm() may be empty, in which case it is identity.
+  // Don't rely on it.
+  if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
+    return permutation.size() == op.getSourceRank() &&
+           isPermutationVector(permutation);
+  }
+  return permutation.size() == op.getDestRank() &&
+         isPermutationVector(permutation);
+}
+
+/// Return a copy of `tensorType` after permutation by `permutationVector`.
+// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
+// but this would introduce a dependence on Dialect in IR.
+// TODO: Restructure.
+static RankedTensorType permuteShape(RankedTensorType tensorType,
+                                     ArrayRef<int64_t> permutationVector) {
+  SmallVector<int64_t> shape(tensorType.getShape());
+  applyPermutationToVector(shape, permutationVector);
+  return RankedTensorType::Builder(tensorType).setShape(shape);
+}
+
+/// Return a new GenericOp obtained by transposing opOperand by the permutation
+/// vector:
+///   - the corresponding indexing map is transposed by `permutation`
+///   - the corresponding operand value is replaced by `transposedValue`
+/// `linalgOp` is replaced by the return op in the process.
+/// Asserts that `transposedValue` is of the proper transposed ShapedType.
+static LinalgOp transposeOneLinalgOperandAndReplace(
+    RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
+    ArrayRef<int64_t> permutation, Value transposedValue) {
+  // Sanity check the operand.
+  assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
+
+  // Sanity check of the expected transposed tensor type.
+  auto tensorType = permuteShape(
+      opOperand.get().getType().cast<RankedTensorType>(), permutation);
+  assert(tensorType == transposedValue.getType() &&
+         "expected tensor type mismatch");
+
+  // Compute the transposed indexing map.
+  // Sigh unsigned pollution.
+  SmallVector<unsigned> tmpTransposition = llvm::to_vector(
+      llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
+  AffineMap permutationMap =
+      AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
+  AffineMap transposedMap =
+      permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
+
+  // Set the transposed indexing map in the proper position.
+  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+  indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
+  // Set the transposedValue in the proper operand position.
+  SmallVector<Value> operands = linalgOp->getOperands();
+  operands[opOperand.getOperandNumber()] = transposedValue;
+
+  ValueRange operandsRef(operands);
+  auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
+      /*location=*/linalgOp->getLoc(),
+      /*resultTensorTypes=*/
+      operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
+      /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
+      /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
+      /*indexingMaps=*/indexingMaps,
+      /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
+  transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
+  rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
+
+  return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
+}
+
+LogicalResult transform::PackTransposeOp::verify() {
+  if (!isPermutationVector(getInnerPerm())) {
+    return emitOpError() << getInnerPermAttrName()
+                         << " is not a valid permutation";
+  }
+  if (!isPermutationVector(getOuterPerm())) {
+    return emitOpError() << getOuterPermAttrName()
+                         << " is not a valid permutation";
+  }
+  if (getInnerPerm().empty() && getOuterPerm().empty()) {
+    return emitOpError() << " at least one of " << getInnerPermAttrName()
+                         << " or " << getOuterPermAttrName()
+                         << " must be specified";
+  }
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
+                                  transform::TransformState &state) {
+  ArrayRef<Operation *> packOrUnpackOps =
+      state.getPayloadOps(getTargetPackOrUnPackOp());
+  ArrayRef<Operation *> linalgOps = state.getPayloadOps(getTargetLinalgOp());
+  // Step 1. If nothing to pack, propagate success.
+  if (packOrUnpackOps.empty()) {
+    transformResults.set(getPackedOp().cast<OpResult>(), {});
+    transformResults.set(getPackOp().cast<OpResult>(), {});
+    transformResults.set(getUnPackOp().cast<OpResult>(), {});
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  // Step 2. Bunch of runtime sanity check and error messages.
+  // Step 2.1. Fail on multi-op handles.
+  if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) {
+    return emitSilenceableError()
+           << "requires target to map to exactly 1 packing op and 1 packed op ("
+           << "got " << packOrUnpackOps.size() << " and " << linalgOps.size()
+           << ")";
+  }
+
+  // Step 2.2. Fail on wrong type.
+  auto packOp = dyn_cast<tensor::PackOp>(packOrUnpackOps.front());
+  auto unPackOp = dyn_cast<tensor::UnPackOp>(packOrUnpackOps.front());
+  if ((!packOp && !unPackOp)) {
+    return emitSilenceableError() << "requires target to map to a "
+                                     "tensor.pack or tensor.unpack";
+  }
+  LinalgOp linalgOpTarget = dyn_cast<linalg::LinalgOp>(linalgOps.front());
+  if (!linalgOpTarget)
+    return emitSilenceableError() << "requires a LinalgOp target";
+
+  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
+  LinalgOp linalgOp;
+  if (packOp && packOp.getResult().hasOneUse())
+    linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
+  else if (unPackOp)
+    linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
+  if (linalgOp != linalgOpTarget) {
+    auto errorMsg =
+        packOp ? StringLiteral{"not a single use by the LinalgOp target"}
+               : StringLiteral{"not produced by the LinalgOp target"};
+    return emitSilenceableError() << errorMsg;
+  }
+
+  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical PackOp.
+  if (unPackOp) {
+    assert(!packOp && "packOp must be null on entry when unPackOp is not null");
+    OpOperand *packUse = linalgOp.getDpsInitOperand(
+        unPackOp.getSource().cast<OpResult>().getResultNumber());
+    packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
+    if (!packOp || !packOp.getResult().hasOneUse())
+      return emitSilenceableError() << "could not find matching pack op";
+  }
+
+  // Step 2.5. Fail if any permutation does not validate.
+  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
+    ArrayRef<int64_t> perm =
+        (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
+    auto errorMsg = (permType == OuterOrInnerPerm::Outer)
+                        ? StringLiteral{"invalid outer_perm"}
+                        : StringLiteral{"invalid inner_perm"};
+    if (!isValidPackingPermutation(packOp, perm, permType) ||
+        !isValidPackingPermutation(unPackOp, perm, permType)) {
+      Operation *packOrUnpackOp =
+          unPackOp ? unPackOp.getOperation() : packOp.getOperation();
+      return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
+    }
+  }
+
+  // From here on, packOp and linalgOp are always present, unPackOp may or may
+  // not be present.
+  assert(packOp && linalgOp && "unexpected null op");
+
+  // Step 3. Actually transpose the ops.
+  Location loc = linalgOp.getLoc();
+  IRRewriter rewriter(getContext());
+
+  // Step 3.a. Transpose packOp.
+  rewriter.setInsertionPoint(packOp);
+  tensor::PackOp transposedPackOp = packOp.createTransposedClone(
+      rewriter, loc, getInnerPerm(), getOuterPerm());
+
+  // Step 3.b. Transpose linalgOp.
+  assert(packOp.getResult().hasOneUse() && "expect single use");
+  // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
+  // identity. Don't rely on it.
+  int64_t numLeadingDims = packOp.getSourceRank();
+  int64_t numTrailingDims = packOp.getInnerDimsPos().size();
+  // Step 3.b.i. Compute the permutation on the whole operand.
+  // Leading part just reuse the outerPerm.
+  SmallVector<int64_t> permutation(getOuterPerm());
+  if (permutation.empty())
+    llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
+  // Trailing part needs to reindex positions by `numLeadingDims`.
+  if (getInnerPerm().empty()) {
+    llvm::append_range(
+        permutation,
+        llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
+  } else {
+    llvm::append_range(permutation,
+                       llvm::map_range(getInnerPerm(), [&](int64_t pos) {
+                         return numLeadingDims + pos;
+                       }));
+  }
+  assert(isPermutationVector(permutation) && "invalid permutation");
+  // Step 3.b.ii. Save the transposedPackUse operand number in case we need to
+  // get the tied OpResult after `linalgOp` has been replaced.
+  OpOperand &packUse = *(packOp.getResult().getUses().begin());
+  int64_t packUseOperandNumber = packUse.getOperandNumber();
+  // Step 3.b.iii. Actually perform the transposition.
+  rewriter.setInsertionPoint(linalgOp);
+  linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
+      rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
+
+  // Step 3.c. Maybe transpose unPackOp.
+  tensor::UnPackOp transposedUnPackOp;
+  if (unPackOp) {
+    OpOperand &opOperand =
+        transposedLinalgOp->getOpOperand(packUseOperandNumber);
+    OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
+    rewriter.setInsertionPoint(unPackOp);
+    transposedUnPackOp = unPackOp.createTransposedClone(
+        rewriter, loc, transposedResult, getInnerPerm(), getOuterPerm());
+  }
+
+  // Step 4. Replace and return results.
+  rewriter.replaceOp(packOp, transposedPackOp->getResults());
+  transformResults.set(getPackOp().cast<OpResult>(), {transposedPackOp});
+  // transposedLinalgOp was replaced in `transposeOneLinalgOperandAndReplace`.
+  transformResults.set(getPackedOp().cast<OpResult>(), {transposedLinalgOp});
+  if (unPackOp) {
+    rewriter.replaceOp(unPackOp, transposedUnPackOp->getResults());
+    transformResults.set(getUnPackOp().cast<OpResult>(), {transposedUnPackOp});
+  } else {
+    transformResults.set(getUnPackOp().cast<OpResult>(), {});
+  }
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===---------------------------------------------------------------------===//
@@ -1359,7 +1616,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
     if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
         target->getNumRegions() > 0)
       return emitDefiniteFailure()
-             << "expected target that is isloated from above";
+             << "expected target that is isolated from above";
   }
 
   // Clone and replace.
@@ -1907,32 +2164,31 @@ transform::TileOp::apply(TransformResults &transformResults,
     scf::SCFTilingOptions tilingOptions;
     unsigned index = en.index();
     if (!tileSizes.empty()) {
-      tilingOptions.setTileSizeComputationFunction(
-          [&, index](OpBuilder &b, Operation *) {
-            SmallVector<Value, 4> sizes;
-            sizes.reserve(tileSizes.size());
-            unsigned dynamicIdx = 0;
-            for (OpFoldResult ofr : getMixedSizes()) {
-              if (auto attr = ofr.dyn_cast<Attribute>()) {
-                sizes.push_back(b.create<arith::ConstantIndexOp>(
-                    getLoc(), attr.cast<IntegerAttr>().getInt()));
-                continue;
-              }
-              ArrayRef<Operation *> dynamicSizes =
-                  dynamicSizeProducers[dynamicIdx];
-              ArrayRef<int64_t> params = paramSizes[dynamicIdx];
-              ++dynamicIdx;
-              assert((dynamicSizes.empty() ^ params.empty()) &&
-                     "expected either dynamic sizes or parameters");
-              if (!params.empty()) {
-                sizes.push_back(
-                    b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
-              } else {
-                sizes.push_back(dynamicSizes[index]->getResult(0));
-              }
-            }
-            return sizes;
-          });
+      tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b,
+                                                              Operation *) {
+        SmallVector<Value, 4> sizes;
+        sizes.reserve(tileSizes.size());
+        unsigned dynamicIdx = 0;
+        for (OpFoldResult ofr : getMixedSizes()) {
+          if (auto attr = ofr.dyn_cast<Attribute>()) {
+            sizes.push_back(b.create<arith::ConstantIndexOp>(
+                getLoc(), attr.cast<IntegerAttr>().getInt()));
+            continue;
+          }
+          ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
+          ArrayRef<int64_t> params = paramSizes[dynamicIdx];
+          ++dynamicIdx;
+          assert((dynamicSizes.empty() ^ params.empty()) &&
+                 "expected either dynamic sizes or parameters");
+          if (!params.empty()) {
+            sizes.push_back(
+                b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
+          } else {
+            sizes.push_back(dynamicSizes[index]->getResult(0));
+          }
+        }
+        return sizes;
+      });
     }
 
     tilingOptions.setInterchange(getInterchange());
@@ -2149,27 +2405,27 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
 
   // Transform all targets one by one.
   for (Operation *target : targets) {
-    auto tilableOp = dyn_cast<TilingInterface>(target);
-    if (!tilableOp) {
+    auto tileableOp = dyn_cast<TilingInterface>(target);
+    if (!tileableOp) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
           << "only TilingInterface ops are supported";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }
-    rewriter.setInsertionPoint(tilableOp);
+    rewriter.setInsertionPoint(tileableOp);
     FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
     if (!mixedNumThreads.empty()) {
-      tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
+      tilingResult = linalg::tileToForeachThreadOp(rewriter, tileableOp,
                                                    mixedNumThreads, mapping);
     } else {
       tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
-          rewriter, tilableOp, mixedTileSizes, mapping);
+          rewriter, tileableOp, mixedTileSizes, mapping);
     }
 
     if (failed(tilingResult))
-      return transformOp.emitDefaultSilenceableFailure(tilableOp);
-    rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
+      return transformOp.emitDefaultSilenceableFailure(tileableOp);
+    rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults());
 
     tileOps.push_back(tilingResult->tileOp);
     tiledOps.push_back(tilingResult->tiledOp);
index d8c337d..4515d71 100644 (file)
@@ -3231,7 +3231,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
               return true;
             }
             return shape == constTileSize.value();
-
           })) {
     return op->emitError("mismatch in inner tile sizes specified and shaped of "
                          "tiled dimension in the packed type");
@@ -3239,6 +3238,57 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   return success();
 }
 
+namespace {
+/// Subset of PackOp/UnPackOp fields used to compute the result of applying
+/// various permutations to the op.
+// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
+// these. These may or may not become true foldings / canonicalizations
+// depending on how aggressive we want to be in automatically folding
+// transposes.
+struct PackOrUnPackTransposeResult {
+  SmallVector<int64_t> innerDimsPos;
+  SmallVector<OpFoldResult> innerTiles;
+  SmallVector<int64_t> outerDimsPerm;
+};
+} // namespace
+
+template <typename OpTy>
+static PackOrUnPackTransposeResult
+commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
+                                   ArrayRef<int64_t> innerPermutation,
+                                   ArrayRef<int64_t> outerPermutation) {
+  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+                "applies to only pack or unpack operations");
+  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
+         "some permutation must be non-empty");
+  PackOrUnPackTransposeResult metadata;
+  metadata.innerDimsPos =
+      SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
+  metadata.innerTiles =
+      SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
+  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
+                             ? packOrUnPackOp.getSourceRank()
+                             : packOrUnPackOp.getDestRank();
+  metadata.outerDimsPerm =
+      packOrUnPackOp.getOuterDimsPerm().empty()
+          ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
+          : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
+  if (!innerPermutation.empty()) {
+    assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
+           isPermutationVector(innerPermutation) &&
+           "invalid inner permutation");
+    applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
+    applyPermutationToVector(metadata.innerTiles, innerPermutation);
+  }
+  if (!outerPermutation.empty()) {
+    assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
+           isPermutationVector(outerPermutation) &&
+           "invalid outer permutation");
+    applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
+  }
+  return metadata;
+}
+
 //===----------------------------------------------------------------------===//
 // PackOp
 //===----------------------------------------------------------------------===//
@@ -3386,6 +3436,19 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
   return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
 }
 
+PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
+                                     ArrayRef<int64_t> innerPermutation,
+                                     ArrayRef<int64_t> outerPermutation) {
+  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
+      *this, innerPermutation, outerPermutation);
+  Value transposedDest =
+      createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
+                              metadata.innerDimsPos, metadata.outerDimsPerm);
+  return b.create<PackOp>(loc, getSource(), transposedDest,
+                          metadata.innerDimsPos, metadata.innerTiles,
+                          getPaddingValue(), metadata.outerDimsPerm);
+}
+
 /// Returns true if the tiles and the tiled dims are constant.
 template <typename OpTy>
 bool areTilesAndTiledDimsAllConstant(OpTy op) {
@@ -3508,6 +3571,17 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
         builder.getDenseI64ArrayAttr(staticTileSizes));
 }
 
+UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
+                                         Value transposedSource,
+                                         ArrayRef<int64_t> innerPermutation,
+                                         ArrayRef<int64_t> outerPermutation) {
+  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
+      *this, innerPermutation, outerPermutation);
+  return b.create<UnPackOp>(loc, transposedSource, getDest(),
+                            metadata.innerDimsPos, metadata.innerTiles,
+                            metadata.outerDimsPerm);
+}
+
 /// pack(unpack(x)) -> x
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
index d1304bb..b8c569f 100644 (file)
@@ -49,21 +49,21 @@ transform.sequence failures(propagate) {
   iterator_types = ["reduction", "parallel"]
 }
 
-//    CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//    CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 //    CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)>
 
 //  CHECK-LABEL: @col_reduction_2d_static
 //   CHECK-SAME:   %[[T0:.+]]: tensor<7x3xf16>,
 //   CHECK-SAME:   %[[T1:.+]]: tensor<3xf16>
 func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> {
-  //      CHECK:  %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16>
+  //      CHECK:  %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16>
   //      CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) 
-  // CHECK-SAME:   inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16>
+  // CHECK-SAME:   outer_dims_perm = [1, 0] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<3x2x4xf16>
   //  CHECK-NOT: tensor.pack
   //      CHECK: linalg.generic 
   // CHECK-SAME:   indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
   // CHECK-SAME:   iterator_types = ["reduction", "parallel", "reduction"]
-  // CHECK-SAME:   ins(%{{.*}} : tensor<2x3x4xf16>)
+  // CHECK-SAME:   ins(%{{.*}} : tensor<3x2x4xf16>)
   // CHECK-SAME:  outs(%{{.*}} : tensor<3xf16>)
   %2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) {
   ^bb0(%in: f16, %out: f16):
@@ -78,8 +78,15 @@ func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) ->
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  transform.structured.pack %0 packed_sizes = [4, 0]
+  %1 = transform.structured.pack %0 packed_sizes = [4, 0]
       : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+  %pack = transform.get_producer_of_operand %1[0] 
+    : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
+  %2, %pack_2, %empty_unpack_2 = 
+    transform.structured.pack_transpose %pack with_compute_op(%1) 
+    outer_perm = [1, 0]
+     : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">) 
+    -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !pdl.operation)
 }
 
 // -----
@@ -183,7 +190,7 @@ transform.sequence failures(propagate) {
 //                                                                            K   N   n   k
 // CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
 //                                                                            M   N   m   n
-// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d4, d3)>
 
 // CHECK-LABEL: @matmul
 //  CHECK-SAME:   %[[A:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
@@ -196,19 +203,19 @@ func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
   // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x2x4xf32>
   //      CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4]
   // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x3x4xf32>
-  //      CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
-  // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x2x3xf32>
+  //      CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2]
+  // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x3x2xf32>
 
   //      CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
   // CHECK-SAME:     iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} 
   // CHECK-SAME:  ins(%{{.*}} : tensor<?x?x2x4xf32>, tensor<?x?x3x4xf32>)
-  // CHECK-SAME: outs(%{{.*}} : tensor<?x?x2x3xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<?x?x3x2xf32>)
   %0 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
                      outs(%C: tensor<?x?xf32>)
     -> tensor<?x?xf32>
 
-  //      CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
-  // CHECK-SAME:   : tensor<?x?x2x3xf32> -> tensor<?x?xf32>
+  //      CHECK: tensor.unpack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2]
+  // CHECK-SAME:   : tensor<?x?x3x2xf32> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
@@ -218,6 +225,14 @@ transform.sequence failures(propagate) {
     //                                            M  N  K
     %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
       : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+
+    %unpack = transform.get_consumers_of_result %1[0] 
+      : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
+    %2, %pack_2, %unpack_2 = 
+      transform.structured.pack_transpose %unpack with_compute_op(%1) 
+      outer_perm = [1, 0] inner_perm = [1, 0]
+      : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) 
+      -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
 }
 
 // -----
@@ -404,3 +419,177 @@ transform.sequence failures(propagate) {
     %1 = transform.structured.pack %0 packed_sizes = [2, 3] 
       : (!pdl.operation) -> (!transform.op<"linalg.generic">)
 }
+
+// -----
+
+func.func @no_single_packing_op(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+  %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+  %1 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+  %2 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+  return
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+      // expected-error @below {{requires target to map to exactly 1 packing op and 1 packed op (got 2 and 1)}}
+    transform.structured.pack_transpose %0 with_compute_op(%1) 
+    inner_perm = [0]
+      : (!pdl.operation, !pdl.operation) 
+      -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @no_single_pack_unpack(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = tensor.empty() : tensor<f32>
+  return
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["arith.constant"]} in %arg1
+    %1 = transform.structured.match ops{["tensor.empty"]} in %arg1
+      // expected-error @below {{requires target to map to a tensor.pack or tensor.unpack}}
+    transform.structured.pack_transpose %0 with_compute_op(%1) 
+    inner_perm = [0]
+      : (!pdl.operation, !pdl.operation) 
+      -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @no_linalg_target(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+  %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+  %1 = arith.constant 0 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1 = transform.structured.match ops{["arith.constant"]} in %arg1
+      // expected-error @below {{requires a LinalgOp target}}
+    transform.structured.pack_transpose %0 with_compute_op(%1) 
+    inner_perm = [0]
+      : (!pdl.operation, !pdl.operation) 
+      -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @no_single_use_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+  %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+  %f0 = arith.constant 0.0 : f32
+  %1 = tensor.empty() : tensor<f32>
+  %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<f32>) -> tensor<f32>
+  return
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1 = transform.structured.match ops{["linalg.fill"]} in %arg1
+      // expected-error @below {{not a single use by the LinalgOp target}}
+    transform.structured.pack_transpose %0 with_compute_op(%1) 
+    inner_perm = [0]
+      : (!pdl.operation, !pdl.operation) 
+      -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @not_produced_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+  %a = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+  %b = tensor.unpack %a inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+  %f0 = arith.constant 0.0 : f32
+  %1 = tensor.empty() : tensor<f32>
+  %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<f32>) -> tensor<f32>
+  return
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+    %1 = transform.structured.match ops{["linalg.fill"]} in %arg1
+      // expected-error @below {{not produced by the LinalgOp target}}
+    transform.structured.pack_transpose %0 with_compute_op(%1) 
+    inner_perm = [0]
+      : (!pdl.operation, !pdl.operation) 
+      -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @no_matching_pack(%source: tensor<16xf32>) {
+  %f0 = arith.constant 0.0 : f32
+  %1 = tensor.empty() : tensor<4x4xf32>
+  %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
+  %b = tensor.unpack %2 inner_dims_pos = [0] inner_tiles = [4] into %source : tensor<4x4xf32> -> tensor<16xf32>
+  return
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+    %1 = transform.structured.match ops{["linalg.fill"]} in %arg1
+      // expected-error @below {{could not find matching pack op}}
+    transform.structured.pack_transpose %0 with_compute_op(%1) 
+    inner_perm = [0]
+      : (!pdl.operation, !pdl.operation) 
+      -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @invalid_outer_perm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %0 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%C: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+
+    %unpack = transform.get_consumers_of_result %1[0] 
+      : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
+    %2, %pack_2, %unpack_2 = 
+      // expected-error @below {{invalid outer_perm}}
+      transform.structured.pack_transpose %unpack with_compute_op(%1) 
+      outer_perm = [1]
+      : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) 
+      -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
+}
+
+// -----
+
+func.func @invalid_inner_perm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %0 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%C: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+
+    %unpack = transform.get_consumers_of_result %1[0] 
+      : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
+    %2, %pack_2, %unpack_2 = 
+      // expected-error @below {{invalid inner_perm}}
+      transform.structured.pack_transpose %unpack with_compute_op(%1) 
+      inner_perm = [1]
+      : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) 
+      -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
+}