}];
}
+def PackOp : Op<Transform_Dialect, "structured.pack", [
+ TransformOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
+ let description = [{
+ Pack a LinalgOp by applying a data tiling transformation on the op and
+ packing the operands according to the `packed_sizes` specification.
+
+ Iterator dimensions are tiled in their canonical order in the op spec.
+ Operands are packed according to the same canonical order of the op iterator
+ dimensions.
+
+ Specifying a packed size of 0 for an iterator removes it from consideration
+ for packing.
+
+ `tensor.pack` (resp. `tensor.unpack`) operations are inserted for the operands
+ (resp. results) that need to be packed (resp. unpacked) according to the
+ `packed_sizes` specification.
+
+ #### Example
+
+ Consider a `linalg.matmul` with indexing maps:
+ ```
+ // M N K M K
+ // affine_map<(d0, d1, d2) -> (d0, d2)>
+ // K N
+ // affine_map<(d0, d1, d2) -> (d2, d1)>
+ // M N
+ // affine_map<(d0, d1, d2) -> (d0, d1)>
+ %0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs( %C: tensor<?x?xf32>)
+ ```
+
+ Specifying packed_sizes [2, 3, 4] results in tiling the iterator dimensions
+ M, N and K, in this order, in both the op and its operands.
+ ```
+ // M N K m n k M K m k
+ // affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+ // K N n k
+ // affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+ // M N m n
+ // affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+ %0 = linalg.generic_representing_some_higher_d_matmul
+ ins(%A, %B: tensor<?x?x2x4xf32>, tensor<?x?x4x3xf32>)
+ outs( %C: tensor<?x?x2x4xf32>)
+ ```
+ In particular, note that the second operand `B` has shape `KxNxnxk` (and not
+ `KxNxkxn` as one could expect by looking **only** at the operand).
+
+ Other layouts can be obtained unsurprisingly from this canonical
+ transformation by composing the resulting operation with a (future)
+ `transform.structured.pack_transpose` op.
+ This composition allows separating concerns and composes better compared
+ to adding additional permutation attributes to this transform op.
+
+ #### Return modes
+
+ This operation applies to a single Linalg op, otherwise it fails.
+ This operation may produce a definiteFailure if the packing fails for any
+ reason.
+
+ The returned handle point to the packed LinalgOp.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ Variadic<PDL_Operation>:$packed_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_packed_sizes);
+ let results = (outs TransformHandleTypeInterface:$packed_op);
+ let assemblyFormat = [{
+ $target
+ `packed_sizes` `=` custom<DynamicIndexList>($packed_sizes,
+ $static_packed_sizes)
+ attr-dict
+ `:` functional-type($target, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ transform::TransformResults &transformResults,
+ transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
using namespace mlir::transform;
#define DEBUG_TYPE "linalg-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
return cast<LinalgOp>(result->getOperation());
}
+/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
+/// to exactly one op with one index result, return that value.
+static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
+ transform::TransformState &state, TransformOpInterface transformOp,
+ SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
+ for (OpFoldResult ofr : ofrs) {
+ if (ofr.is<Attribute>()) {
+ if (!ofr.get<Attribute>().isa<IntegerAttr>())
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ result.push_back(ofr);
+ continue;
+ }
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
+ if (payloadOps.size() != 1) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "handle must be mapped to exactly one payload op";
+ diag.attachNote(ofr.get<Value>().getLoc())
+ << "mapped to " << payloadOps.size() << " payload ops";
+ return diag;
+ }
+
+ Operation *op = payloadOps[0];
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+ result.push_back(op->getResult(0));
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+// Given a list of OpFoldResults that are either index attrs or op
+// handles, return a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op. (There
+// must be exactly one mapped payload op and it must have exactly one
+// index result.)
+static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
+ transform::TransformState &state, TransformOpInterface transformOp,
+ SmallVector<OpFoldResult> &result, Value packedHandle) {
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
+ for (Operation *op : payloadOps) {
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+ result.push_back(op->getResult(0));
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
}
//===---------------------------------------------------------------------===//
+// PackOp
+//===---------------------------------------------------------------------===//
+
+SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
+ Builder b(getContext());
+ return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
+}
+
+/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
+static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
+ bool found = false;
+ for (AffineExpr e : map.getResults()) {
+ if (!e.isFunctionOfDim(dim))
+ continue;
+ if (found)
+ return false;
+ found = true;
+ }
+ return true;
+}
+
+/// Return the index of the first result of `map` that is a function of
+/// AffineDimExpr(dim), std::nullopt otherwise.
+static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
+ int64_t dim) {
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ AffineExpr expr = map.getResult(i);
+ if (!expr.isFunctionOfDim(dim))
+ continue;
+ return i;
+ }
+ return std::nullopt;
+}
+
+/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
+/// `newDim` at `iteratorTypes.size()` by:
+/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
+/// 2. Appending a `newDim` to the domain of every indexing map.
+/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
+/// by potentially adding a `newDim` result to `map`.
+/// The preserved invariant is that `iteratorTypes.size()` is always equal to
+/// `map.getNumDims()` for every map in `indexingMaps`.
+///
+/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
+/// Return a vector that records the optional packing for each operand.
+/// Return failure if the packed indexing cannot be represented with a LinalgOp.
+///
+/// Further details:
+/// ================
+/// The current implementation of packing (i.e. data tiling) consists of
+/// rewriting a linearized strip-mined form into a higher-dimensional access.
+/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
+/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
+/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
+///
+/// This rewrite into higher dimensional access is not possible for general
+/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
+/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
+/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
+/// The rewrite of the access would be a form not representable in Linalg:
+/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
+/// Note however that as `J` and `ii` iterate, the accesses do not have a
+/// particular alignment, so packing does not achieve alignment in this case
+///
+/// In the future, we may want to consider a mixed-form that allows some
+/// alignment in the presence of multiple accesses:
+/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
+/// And would rewrite accesses as:
+/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
+static FailureOr<SmallVector<std::optional<int64_t>>>
+packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
+ SmallVectorImpl<utils::IteratorType> &iteratorTypes,
+ int64_t dim) {
+ int64_t newDim = iteratorTypes.size();
+ iteratorTypes.push_back(iteratorTypes[dim]);
+
+ SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
+ indexingMaps.size(), std::nullopt);
+ SmallVector<AffineMap> newMaps;
+ for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
+ ++operandIdx) {
+ AffineMap map = indexingMaps[operandIdx];
+
+ // Add the `newDim` to map whatever the case.
+ assert(map.getNumDims() == newDim && "num dims invariant violation");
+ map = map.shiftDims(1, newDim);
+
+ // Get the at-most-1 index of the result that is a function of `dim`.
+ // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
+ // logically chunks dimension `dim` into `K * dim + newDim`, where the
+ // packing factor `K` is specified separately.
+ assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
+ "num results invariant violation");
+ auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
+ if (!maybeOperandDimensionToPack.has_value()) {
+ newMaps.push_back(map);
+ continue;
+ }
+
+ // We can only pack AffineDimExpr atm.
+ if (!map.getResult(maybeOperandDimensionToPack.value())
+ .isa<AffineDimExpr>())
+ return failure();
+
+ // Add `newDim` to the results of the map.
+ map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
+ map.getNumResults());
+ newMaps.push_back(map);
+
+ // Record the that `operandIdx` is packed.
+ packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
+ }
+ indexingMaps = newMaps;
+
+ return packedDimPerIndexingMap;
+}
+
+namespace {
+
+/// Helper struct to encode packing along one dimension of a LinalgOp.
+struct PackedOperandsDim {
+ OpFoldResult packedSize;
+ SmallVector<std::optional<int64_t>> packedDimForEachOperand;
+};
+
+/// Helper struct to encode packing along all dimensions of a LinalgOp.
+struct PackedOperandsDimList {
+ void push_back(PackedOperandsDim &&packedOperandsDims) {
+ spec.emplace_back(packedOperandsDims);
+ }
+ /// Return all the dims that have been packed for operand @ `operandPos`.
+ SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
+ /// Return all the pack sizes by which an operand @ `operandPos` is packed.
+ SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
+
+private:
+ SmallVector<PackedOperandsDim> spec;
+};
+
+} // namespace
+
+SmallVector<int64_t>
+PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
+ SmallVector<int64_t> res;
+ for (int64_t i = 0, e = spec.size(); i < e; ++i) {
+ if (!spec[i].packedDimForEachOperand[operandPos].has_value())
+ continue;
+ res.push_back(spec[i].packedDimForEachOperand[operandPos].value());
+ }
+ return res;
+}
+
+SmallVector<OpFoldResult>
+PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
+ SmallVector<OpFoldResult> res;
+ for (int64_t i = 0, e = spec.size(); i < e; ++i) {
+ if (!spec[i].packedDimForEachOperand[operandPos].has_value())
+ continue;
+ res.push_back(spec[i].packedSize);
+ }
+ return res;
+}
+
+/// Implement packing of a single LinalgOp by performing packing by
+/// `packedSizeHandles`. There must be one packedSizeHandles entry per
+/// `linalgOp` iterator. Return the packed Linalg op on success, failure
+/// otherwise.
+static FailureOr<linalg::LinalgOp>
+packOneLinalgOp(RewriterBase &rewriter, transform::TransformState &state,
+ TransformOpInterface transformOp, linalg::LinalgOp linalgOp,
+ ArrayRef<OpFoldResult> packedSizeHandles) {
+ assert(packedSizeHandles.size() == linalgOp.getNumLoops() &&
+ "incorrect number of pack sizes");
+
+ Location loc = linalgOp->getLoc();
+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+ SmallVector<utils::IteratorType> iteratorTypes =
+ linalgOp.getIteratorTypesArray();
+ LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
+ llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
+ llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
+ DBGSNL(););
+
+ // Unpack handles to constants or actual SSA index values.
+ SmallVector<OpFoldResult> packedSizes;
+ DiagnosedSilenceableFailure status = unpackSingleIndexResultPDLOperations(
+ state, transformOp, packedSizes, packedSizeHandles);
+
+ // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
+ PackedOperandsDimList listOfPackedOperandsDim;
+ for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
+ std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
+ // Skip tile sizes explicitly set to 0.
+ if (maybeConstant.has_value() && maybeConstant.value() == 0)
+ continue;
+
+ PackedOperandsDim packedOperandsDims;
+ packedOperandsDims.packedSize = packedSizes[i];
+ FailureOr<SmallVector<std::optional<int64_t>>>
+ maybePackedDimForEachOperand =
+ packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
+ if (failed(maybePackedDimForEachOperand))
+ return failure();
+ packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
+ listOfPackedOperandsDim.push_back(std::move(packedOperandsDims));
+
+ LLVM_DEBUG(
+ DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
+ << "\n";
+ llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
+ llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
+ llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
+ DBGS() << "packedDimForEachOperand: ");
+ DBGSNL(););
+ }
+
+ // Step 2. Propagate packing to all LinalgOp operands.
+ SmallVector<Value> inputsAndInits, results;
+ for (auto operandsList :
+ {linalgOp.getDpsInputOperands(), linalgOp.getDpsInitOperands()}) {
+ for (OpOperand *opOperandPtr : operandsList) {
+ int64_t pos = opOperandPtr->getOperandNumber();
+ Value operand = opOperandPtr->get();
+ SmallVector<int64_t> innerPos =
+ listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
+ SmallVector<OpFoldResult> innerPackSizes =
+ listOfPackedOperandsDim.extractPackSizesForOperand(pos);
+ LLVM_DEBUG(
+ DBGS() << "operand: " << operand << "\n";
+ llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
+ llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
+ DBGSNL(););
+ if (innerPackSizes.empty()) {
+ inputsAndInits.push_back(operand);
+ continue;
+ }
+ Value dest = tensor::PackOp::createDestinationTensor(
+ rewriter, loc, operand, innerPackSizes, innerPos,
+ /*outerDimsPerm=*/{});
+ // TODO: value of the padding attribute should be determined by consumers.
+ Attribute zeroAttr =
+ rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
+ Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+ inputsAndInits.push_back(rewriter.create<tensor::PackOp>(
+ loc, operand, dest, innerPos, innerPackSizes, zero));
+ }
+ }
+
+ // Step 3. Build the packed op, use the type of `inits` as result types.
+ ValueRange inputs =
+ ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
+ ValueRange inits =
+ ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
+ auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
+ linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
+ iteratorTypes);
+ packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
+
+ // Step 4. Propagate packing to all the op results.
+ for (OpResult result : packedLinalgOp->getResults()) {
+ int64_t resultNum = result.getResultNumber();
+ tensor::PackOp maybePackedInit =
+ inits[resultNum].getDefiningOp<tensor::PackOp>();
+ if (!maybePackedInit) {
+ results.push_back(result);
+ continue;
+ }
+ // Build the symmetrical UnPackOp to the existing PackOp.
+ results.push_back(rewriter.create<tensor::UnPackOp>(
+ packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
+ maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+ }
+
+ // Step 5. Replace `linalgOp`.
+ rewriter.replaceOp(linalgOp, results);
+
+ // Return packedLinalgOp.
+ return cast<linalg::LinalgOp>(packedLinalgOp.getOperation());
+}
+
+DiagnosedSilenceableFailure
+transform::PackOp::apply(transform::TransformResults &transformResults,
+ transform::TransformState &state) {
+ ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+ // If nothing to pack, propagate success.
+ if (targetOps.empty()) {
+ transformResults.set(getPackedOp().cast<OpResult>(), {});
+ return DiagnosedSilenceableFailure::success();
+ }
+ // Fail on multi-op handles.
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(targetOps.front());
+ if (targetOps.size() != 1 || !linalgOp) {
+ // TODO: remove this unnecessary set to empty once crashes are fixed.
+ transformResults.set(getPackedOp().cast<OpResult>(), {});
+ return emitSilenceableError()
+ << "requires target to map to exactly 1 LinalgOp (got "
+ << targetOps.size() << ")";
+ }
+ // Fail on mismatched number of pack sizes.
+ if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
+ // TODO: remove this unnecessary set to empty once crashes are fixed.
+ transformResults.set(getPackedOp().cast<OpResult>(), {});
+ return emitSilenceableError()
+ << "requires number of packed sizes match the number of loops ("
+ << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
+ << ")";
+ }
+
+ IRRewriter rewriter(linalgOp->getContext());
+ rewriter.setInsertionPoint(linalgOp);
+ FailureOr<LinalgOp> maybeResult =
+ packOneLinalgOp(rewriter, state, *this, linalgOp, getMixedPackedSizes());
+ if (failed(maybeResult))
+ return emitDefiniteFailure("data tiling failed");
+
+ transformResults.set(getPackedOp().cast<OpResult>(),
+ maybeResult->getOperation());
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PackOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getTarget(), effects);
+ transform::onlyReadsHandle(getPackedSizes(), effects);
+ transform::producesHandle(getPackedOp(), effects);
+}
+
+//===---------------------------------------------------------------------===//
// PadOp
//===---------------------------------------------------------------------===//
/*mapping=*/mapping);
}
-/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
-/// to exactly one op with one index result, return that value.
-static DiagnosedSilenceableFailure unpackPDLOperations(
- transform::TransformState &state, TransformOpInterface transformOp,
- SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
- for (OpFoldResult ofr : ofrs) {
- if (ofr.is<Attribute>()) {
- if (!ofr.get<Attribute>().isa<IntegerAttr>())
- return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
- result.push_back(ofr);
- continue;
- }
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
- if (payloadOps.size() != 1) {
- DiagnosedSilenceableFailure diag =
- transformOp.emitSilenceableError()
- << "handle must be mapped to exactly one payload op";
- diag.attachNote(ofr.get<Value>().getLoc())
- << "mapped to " << payloadOps.size() << " payload ops";
- return diag;
- }
-
- Operation *op = payloadOps[0];
- if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
- DiagnosedSilenceableFailure diag =
- transformOp.emitSilenceableError()
- << "payload op must have exactly 1 index result";
- diag.attachNote(op->getLoc())
- << "has " << op->getNumResults() << " results";
- return diag;
- }
- result.push_back(op->getResult(0));
- }
-
- return DiagnosedSilenceableFailure::success();
-}
-
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
-static DiagnosedSilenceableFailure
-unpackPDLOperations(transform::TransformState &state,
- TransformOpInterface transformOp,
- SmallVector<OpFoldResult> &result, Value packedHandle) {
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
- for (Operation *op : payloadOps) {
- if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
- DiagnosedSilenceableFailure diag =
- transformOp.emitSilenceableError()
- << "payload op must have exactly 1 index result";
- diag.attachNote(op->getLoc())
- << "has " << op->getNumResults() << " results";
- return diag;
- }
- result.push_back(op->getResult(0));
- }
-
- return DiagnosedSilenceableFailure::success();
-}
-
DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
TransformOpInterface transformOp, ArrayRef<Operation *> targets,
SmallVector<OpFoldResult> mixedNumThreads;
DiagnosedSilenceableFailure status =
getPackedNumThreads()
- ? unpackPDLOperations(state, transformOp, mixedNumThreads,
- getPackedNumThreads())
- : unpackPDLOperations(state, transformOp, mixedNumThreads,
- getMixedNumThreads());
+ ? unpackSingleIndexResultPDLOperations(
+ state, transformOp, mixedNumThreads, getPackedNumThreads())
+ : unpackSingleIndexResultPDLOperations(
+ state, transformOp, mixedNumThreads, getMixedNumThreads());
if (!status.succeeded())
return status;
SmallVector<OpFoldResult> mixedTileSizes;
status = getPackedTileSizes()
- ? unpackPDLOperations(state, transformOp, mixedTileSizes,
- getPackedTileSizes())
- : unpackPDLOperations(state, transformOp, mixedTileSizes,
- getMixedTileSizes());
+ ? unpackSingleIndexResultPDLOperations(
+ state, transformOp, mixedTileSizes, getPackedTileSizes())
+ : unpackSingleIndexResultPDLOperations(
+ state, transformOp, mixedTileSizes, getMixedTileSizes());
if (!status.succeeded())
return status;
--- /dev/null
+// RUN: mlir-opt -test-transform-dialect-interpreter -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#reduction_2d_trait = {
+ indexing_maps = [#map, #map1],
+ iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
+
+// CHECK-LABEL: @reduction_2d_static
+// CHECK-SAME: %[[T0:.+]]: tensor<3x7xf16>,
+// CHECK-SAME: %[[T1:.+]]: tensor<3xf16>
+func.func @reduction_2d_static(%t0: tensor<3x7xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> {
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16>
+ // CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16)
+ // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<3x7xf16> -> tensor<3x2x4xf16>
+ // CHECK-NOT: tensor.pack
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
+ // CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
+ // CHECK-SAME: ins(%{{.*}} : tensor<3x2x4xf16>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<3xf16>)
+ %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor<3x7xf16>) outs(%t1 : tensor<3xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %3 = arith.addf %in, %out : f16
+ linalg.yield %3 : f16
+ } -> tensor<3xf16>
+
+ // CHECK-NOT: tensor.unpack
+ return %2 : tensor<3xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.pack %0 packed_sizes = [0, 4]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+#col_reduction_2d_trait = {
+ indexing_maps = [#map, #map1],
+ iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @col_reduction_2d_static
+// CHECK-SAME: %[[T0:.+]]: tensor<7x3xf16>,
+// CHECK-SAME: %[[T1:.+]]: tensor<3xf16>
+func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> {
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16>
+ // CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16)
+ // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16>
+ // CHECK-NOT: tensor.pack
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
+ // CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
+ // CHECK-SAME: ins(%{{.*}} : tensor<2x3x4xf16>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<3xf16>)
+ %2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %3 = arith.addf %in, %out : f16
+ linalg.yield %3 : f16
+ } -> tensor<3xf16>
+
+ // CHECK-NOT: tensor.unpack
+ return %2 : tensor<3xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.pack %0 packed_sizes = [4, 0]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#reduction_2d_trait = {
+ indexing_maps = [#map, #map1],
+ iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-DAG: #[[$DIV4:.*]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
+
+// CHECK-LABEL: @reduction_2d_dynamic
+// CHECK-SAME: %[[T0:.+]]: tensor<?x?xf16>,
+// CHECK-SAME: %[[T1:.+]]: tensor<?xf16>
+func.func @reduction_2d_dynamic(%t0: tensor<?x?xf16>, %t1: tensor<?xf16>) -> tensor<?xf16> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?xf16>
+ // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf16>
+ // CHECK: %[[D1B4:.*]] = affine.apply #[[$DIV4]]()[%[[D1]]]
+ // CHECK: %[[EMPTY:.*]] = tensor.empty(%[[D0]], %[[D1B4]]) : tensor<?x?x4xf16>
+ // CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16)
+ // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<?x?xf16> -> tensor<?x?x4xf16>
+ // CHECK-NOT: tensor.pack
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
+ // CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
+ // CHECK-SAME: ins(%{{.*}} : tensor<?x?x4xf16>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<?xf16>)
+ %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor<?x?xf16>) outs(%t1 : tensor<?xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %3 = arith.addf %in, %out : f16
+ linalg.yield %3 : f16
+ } -> tensor<?xf16>
+
+ // CHECK-NOT: tensor.unpack
+ return %2 : tensor<?xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.pack %0 packed_sizes = [0, 4]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#reduction_2d_trait = {
+ indexing_maps = [#map, #map1],
+ iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-DAG: #[[$DIV3:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
+// CHECK-DAG: #[[$DIV4:.*]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+
+// CHECK-LABEL: @reduction_2d_dynamic
+// CHECK-SAME: %[[T0:.+]]: tensor<?x?xf16>,
+// CHECK-SAME: %[[T1:.+]]: tensor<?xf16>
+func.func @reduction_2d_dynamic(%t0: tensor<?x?xf16>, %t1: tensor<?xf16>) -> tensor<?xf16> {
+ // CHECK: %[[PACKED_0:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16)
+ // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 4] into %{{.*}} : tensor<?x?xf16> -> tensor<?x?x3x4xf16>
+ // CHECK: %[[PACKED_1:.*]] = tensor.pack %[[T1]] padding_value(%{{.*}} : f16)
+ // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [3] into %{{.*}} : tensor<?xf16> -> tensor<?x3xf16>
+ // CHECK-NOT: tensor.pack
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
+ // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel", "reduction"]
+ // CHECK-SAME: ins(%{{.*}} : tensor<?x?x3x4xf16>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<?x3xf16>)
+ %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor<?x?xf16>) outs(%t1 : tensor<?xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %3 = arith.addf %in, %out : f16
+ linalg.yield %3 : f16
+ } -> tensor<?xf16>
+
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0] inner_tiles = [3] into %{{.*}} : tensor<?x3xf16> -> tensor<?xf16>
+ return %2 : tensor<?xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.pack %0 packed_sizes = [3, 4]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+// M N K m n k M K m k
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// K N n k
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+// M N m n
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: @matmul
+// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[B:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:[0-9a-zA-Z]+]]: tensor<?x?xf32>
+func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+
+ // CHECK: %[[PACK_A:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 4]
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x2x4xf32>
+ // CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4]
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x3x4xf32>
+ // CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x2x3xf32>
+
+ // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<?x?x2x4xf32>, tensor<?x?x3x4xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<?x?x2x3xf32>)
+ %0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
+ // CHECK-SAME: : tensor<?x?x2x3xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ // M N K
+ %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+// N F H W C KH KW f c
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d4, d2 + d5, d3 + d6, d8)>
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d4, d5, d6, d7, d8)>
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d7)>
+
+// CHECK-LABEL: @conv_2d_nchw_fchw
+// CHECK-SAME: %[[INPUT:.+]]: tensor<14x512x28x28xf32>,
+// CHECK-SAME: %[[FILTER:.+]]: tensor<1024x512x1x1xf32>
+// CHECK-SAME: %[[INIT:.+]]: tensor<14x1024x28x28xf32>
+func.func @conv_2d_nchw_fchw(%i: tensor<14x512x28x28xf32>, %f: tensor<1024x512x1x1xf32>,
+ %o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32> {
+
+ // CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1] inner_tiles = [8]
+ // CHECK-SAME: : tensor<14x512x28x28xf32> -> tensor<14x64x28x28x8xf32>
+ // CHECK: %[[PACK_FILTER:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [4, 8]
+ // CHECK-SAME: : tensor<1024x512x1x1xf32> -> tensor<256x64x1x1x4x8xf32>
+ // CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1] inner_tiles = [4]
+ // CHECK-SAME: : tensor<14x1024x28x28xf32> -> tensor<14x256x28x28x4xf32>
+ // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<14x64x28x28x8xf32>, tensor<256x64x1x1x4x8xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<14x256x28x28x4xf32>)
+ %0 = linalg.conv_2d_nchw_fchw ins(%i, %f: tensor<14x512x28x28xf32>, tensor<1024x512x1x1xf32>)
+ outs(%o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32>
+
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [1] inner_tiles = [4]
+ // CHECK-SAME: : tensor<14x256x28x28x4xf32> -> tensor<14x1024x28x28xf32>
+ return %0: tensor<14x1024x28x28xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match interface{LinalgOp} in %arg1
+ // N F H W C KH KW
+ %1 = transform.structured.pack %0 packed_sizes = [0, 4, 0, 0, 8, 0, 0]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+// N H W F KH KW C f c
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d4, d2 + d5, d6, d8)>
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d4, d5, d6, d3, d7, d8)>
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d7)>
+
+// CHECK-LABEL: @conv_2d_nhwc_hwcf
+// CHECK-SAME: %[[INPUT:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[FILTER:.+]]: tensor<1x?x?x?xf32>
+// CHECK-SAME: %[[INIT:.+]]: tensor<?x1x?x?xf32>
+func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+
+ // CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3] inner_tiles = [6]
+ // CHECK-SAME: : tensor<?x1x?x?xf32> -> tensor<?x1x?x?x6xf32>
+ // CHECK: %[[PACK_FILTER:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3, 2] inner_tiles = [4, 6]
+ // CHECK-SAME: : tensor<1x?x?x?xf32> -> tensor<1x?x?x?x4x6xf32>
+ // CHECK: %[[PACK_OUTPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3] inner_tiles = [4]
+ // CHECK-SAME: : tensor<?x1x?x?xf32> -> tensor<?x1x?x?x4xf32>
+
+ // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<?x1x?x?x6xf32>, tensor<1x?x?x?x4x6xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<?x1x?x?x4xf32>)
+ %0 = linalg.conv_2d_nhwc_hwcf
+ ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
+ outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [3] inner_tiles = [4]
+ // CHECK-SAME: : tensor<?x1x?x?x4xf32> -> tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match interface{LinalgOp} in %arg1
+ // N H W F KH KW C
+ %1 = transform.structured.pack %0 packed_sizes = [0, 0, 0, 4, 0, 0, 6]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+// CHECK-DAG: affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+// M N K n k M K k
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
+// K N n k
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1, d3, d4)>
+// M N n
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+
+// CHECK-LABEL: @matmul_dynamic_pack_size
+// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[B:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:[0-9a-zA-Z]+]]: tensor<?x?xf32>
+func.func @matmul_dynamic_pack_size(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+ // CHECK: %[[TS:.*]] = "some_tile_size"() : () -> index
+ %sz = "some_tile_size"() : () -> (index)
+
+ // CHECK: %[[PACK_A:.*]] = tensor.pack %[[A]] {{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]]
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x?xf32>
+ // CHECK: %[[PACK_B:.*]] = tensor.pack %[[B]] {{.*}} inner_dims_pos = [1, 0] inner_tiles = [%[[TS]], %[[TS]]]
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+ // CHECK: %[[PACK_C:.*]] = tensor.pack %[[C]] {{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]]
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x?xf32>
+ // CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<?x?x?xf32>, tensor<?x?x?x?xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<?x?x?xf32>)
+ %0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]] into %[[C]]
+ // CHECK-SAME: : tensor<?x?x?xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %sz = transform.structured.match ops{["some_tile_size"]} in %arg1
+ %1 = transform.structured.pack %0 packed_sizes = [0, %sz, %sz]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+func.func @conv_cant_pack(%i: tensor<14x512x28x28xf32>, %f: tensor<1024x512x1x1xf32>,
+ %o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32> {
+ %0 = linalg.conv_2d_nchw_fchw ins(%i, %f: tensor<14x512x28x28xf32>, tensor<1024x512x1x1xf32>)
+ outs(%o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32>
+ return %0: tensor<14x1024x28x28xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match interface{LinalgOp} in %arg1
+ // N F H W C KH KW
+ // expected-error @below {{data tiling failed}}
+ %1 = transform.structured.pack %0 packed_sizes = [0, 0, 4, 0, 0, 0, 0]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+ %1 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ // expected-error @below {{requires target to map to exactly 1 LinalgOp (got 2)}}
+ %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+
+// -----
+
+func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ // expected-error @below {{requires number of packed sizes match the number of loops (2 vs 3)}}
+ %1 = transform.structured.pack %0 packed_sizes = [2, 3]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}