--- /dev/null
+//===- Utils.h - Transform dialect utilities --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H
+#define MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class OpAsmPrinter;
+
+namespace transform {
+class TransformState;
+
+/// Printer hook for custom directive in assemblyFormat.
+///
+/// custom<PackedOrDynamicIndexList>($packed, $values, $integers)
+///
+/// where `values` are variadic Index values, `integers` is an `I64ArrayAttr`
+/// and `packed` is a single transform dialect handle who's mapped payload ops
+/// have a single Index result and represent the index list. Either `packed`
+/// or the other two parameters may be specified.
+///
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list or with a single handle. E.g., `[%arg0, 7, 42, %arg42]` or just `%h`.
+void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+ Value packed, OperandRange values,
+ ArrayRef<int64_t> integers);
+
+/// Pasrer hook for custom directive in assemblyFormat.
+///
+/// custom<PackedOrDynamicIndexList>($packed, $values, $integers)
+///
+/// See `printPackedOrDynamicIndexList` for details.
+ParseResult parsePackedOrDynamicIndexList(
+ OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
+#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
/*target=*/target,
/*num_threads=*/ValueRange{},
/*tile_sizes=*/dynamicTileSizes,
+ /*packed_num_threads=*/Value(),
+ /*packed_tile_sizes=*/Value(),
/*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
/*static_tile_sizes=*/staticTileSizesAttr,
/*mapping=*/mapping);
/*target=*/target,
/*num_threads=*/dynamicNumThreads,
/*tile_sizes=*/ValueRange{},
+ /*packed_num_threads=*/Value(),
+ /*packed_tile_sizes=*/Value(),
/*static_num_threads=*/staticNumThreadsAttr,
/*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
/*mapping=*/mapping);
}
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
+/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
+/// to exactly one op with one index result, return that value.
static DiagnosedSilenceableFailure unpackPDLOperations(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
for (OpFoldResult ofr : ofrs) {
- // Don't try to unpack non-PDL operation.
- if (ofr.is<Attribute>() ||
- !ofr.get<Value>().getType().isa<pdl::OperationType>()) {
+ if (ofr.is<Attribute>()) {
+ if (!ofr.get<Attribute>().isa<IntegerAttr>())
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
result.push_back(ofr);
continue;
}
ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
- for (Operation *op : payloadOps) {
- if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
- DiagnosedSilenceableFailure diag =
- transformOp.emitSilenceableError()
- << "payload op must have exactly 1 index result";
- diag.attachNote(op->getLoc())
- << "has " << op->getNumResults() << " results";
- return diag;
- }
- result.push_back(op->getResult(0));
+ if (payloadOps.size() != 1) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "handle must be mapped to exactly one payload op";
+ diag.attachNote(ofr.get<Value>().getLoc())
+ << "mapped to " << payloadOps.size() << " payload ops";
+ return diag;
}
+
+ Operation *op = payloadOps[0];
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+ result.push_back(op->getResult(0));
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+// Given a list of OpFoldResults that are either index attrs or op
+// handles, return a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op. (There
+// must be exactly one mapped payload op and it must have exactly one
+// index result.)
+static DiagnosedSilenceableFailure
+unpackPDLOperations(transform::TransformState &state,
+ TransformOpInterface transformOp,
+ SmallVector<OpFoldResult> &result, Value packedHandle) {
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
+ for (Operation *op : payloadOps) {
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+ result.push_back(op->getResult(0));
}
return DiagnosedSilenceableFailure::success();
if (targets.empty())
return DiagnosedSilenceableFailure::success();
- // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
- // Convert to OpFoldResults[index attributes or payload op].
- SmallVector<OpFoldResult> numThreads;
- DiagnosedSilenceableFailure status =
- unpackPDLOperations(state, transformOp, numThreads, mixedNumThreads);
- if (!status.succeeded())
- return status;
-
- // getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
- // Convert to OpFoldResults[index attributes or payload op].
- SmallVector<OpFoldResult> tileSizes;
- status = unpackPDLOperations(state, transformOp, tileSizes, mixedTileSizes);
- if (!status.succeeded())
- return status;
-
// Transform all targets one by one.
for (Operation *target : targets) {
auto tilableOp = dyn_cast<TilingInterface>(target);
FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
if (!mixedNumThreads.empty()) {
tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
- numThreads, mapping);
+ mixedNumThreads, mapping);
} else {
tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
- rewriter, tilableOp, tileSizes, mapping);
+ rewriter, tilableOp, mixedTileSizes, mapping);
}
if (failed(tilingResult))
transform::TransformResults &transformResults,
transform::TransformState &state) {
IRRewriter rewriter(getContext());
+ auto transformOp = cast<TransformOpInterface>(getOperation());
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
// Result payload ops.
SmallVector<Operation *> tileOps;
SmallVector<Operation *> tiledOps;
+ // Unpack handles.
+ SmallVector<OpFoldResult> mixedNumThreads;
+ DiagnosedSilenceableFailure status =
+ getPackedNumThreads()
+ ? unpackPDLOperations(state, transformOp, mixedNumThreads,
+ getPackedNumThreads())
+ : unpackPDLOperations(state, transformOp, mixedNumThreads,
+ getMixedNumThreads());
+ if (!status.succeeded())
+ return status;
+ SmallVector<OpFoldResult> mixedTileSizes;
+ status = getPackedTileSizes()
+ ? unpackPDLOperations(state, transformOp, mixedTileSizes,
+ getPackedTileSizes())
+ : unpackPDLOperations(state, transformOp, mixedTileSizes,
+ getMixedTileSizes());
+ if (!status.succeeded())
+ return status;
+
DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl(
- rewriter, state, cast<TransformOpInterface>(getOperation()), targets,
- getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps,
- tiledOps);
+ rewriter, state, transformOp, targets, mixedNumThreads, mixedTileSizes,
+ getMapping(), tileOps, tiledOps);
if (!diag.succeeded()) {
transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
}
LogicalResult TileToForeachThreadOp::verify() {
- if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
- return emitOpError("either num_threads or tile_sizes must be specified");
+ int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
+ static_cast<int>(getPackedNumThreads() != Value());
+ if (numThreadsSpec > 1)
+ return emitOpError(
+ "num_threads and packed_num_threads are mutually exclusive");
+ int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
+ static_cast<int>(getPackedTileSizes() != Value());
+ if (tileSizesSpec > 1)
+ return emitOpError(
+ "tile_sizes and packed_tile_sizes are mutually exclusive");
+ if (numThreadsSpec == 0 && tileSizesSpec == 0)
+ return emitOpError(
+ "either (packed_)num_threads or (packed_)tile_sizes must be specified");
return success();
}
--- /dev/null
+//===- Utils.cpp - Transform dialect utilities ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Utils/Utils.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+void transform::printPackedOrDynamicIndexList(OpAsmPrinter &printer,
+ Operation *op, Value packed,
+ OperandRange values,
+ ArrayRef<int64_t> integers) {
+ if (packed) {
+ assert(values.empty() && integers.empty() && "expected no values/integers");
+ printer << packed;
+ return;
+ }
+ printDynamicIndexList(printer, op, values, integers);
+}
+
+ParseResult transform::parsePackedOrDynamicIndexList(
+ OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers) {
+ OpAsmParser::UnresolvedOperand packedOperand;
+ if (parser.parseOptionalOperand(packedOperand).has_value()) {
+ packed.emplace(std::move(packedOperand));
+ integers = parser.getBuilder().getDenseI64ArrayAttr({});
+ return success();
+ }
+ return parseDynamicIndexList(parser, values, integers);
+}