"::mlir::pdl_interp::PDLInterpDialect",
];
+ let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
+ /// Name of the attribute attachable to the symbol table operation
+ /// containing named sequences. This is used to trigger verification.
+ constexpr const static llvm::StringLiteral
+ kWithNamedSequenceAttrName = "transform.with_named_sequence";
+
+ /// Names of the attribute attachable to an operation so it can be
+ /// identified as root by the default interpreter pass.
+ constexpr const static llvm::StringLiteral
+ kTargetTagAttrName = "transform.target_tag";
+
/// Returns the named PDL constraint functions available in the dialect
/// as a map from their name to the function.
const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
// class body to comply with visibility and full-declaration requirements.
inline RegionScope make_region_scope(Region ®ion);
+ /// Creates a new region scope for the given isolated-from-above region.
+ /// Unlike the non-isolated counterpart, there is no nesting expectation.
+ // Implementation note: this method is inline but implemented outside of the
+ // class body to comply with visibility and full-declaration requirements
+ inline RegionScope make_isolated_region_scope(Region ®ion);
+
/// A RAII object maintaining a "stack frame" for a transform IR region. When
/// applying a transform IR operation that contains a region, the caller is
/// expected to create a RegionScope before applying the ops contained in the
class RegionScope {
public:
/// Forgets the mapping from or to values defined in the associated
- /// transform IR region.
+ /// transform IR region, and restores the mapping that existed before
+ /// entering this scope.
~RegionScope() {
state.mappings.erase(region);
+ if (storedMappings.has_value())
+ state.mappings.swap(*storedMappings);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.pop_back();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
private:
+ /// Tag structure for differentiating the constructor for isolated regions.
+ struct Isolated {};
+
/// Creates a new scope for mappings between values defined in the given
- /// transform IR region and payload IR operations.
+ /// transform IR region and payload IR objects.
RegionScope(TransformState &state, Region ®ion)
: state(state), region(®ion) {
auto res = state.mappings.try_emplace(this->region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
+ /// Creates a new scope for mappings between values defined in the given
+ /// isolated-from-above transform IR region and payload IR objects.
+ RegionScope(TransformState &state, Region ®ion, Isolated)
+ : state(state), region(®ion) {
+ // Store the previous mapping stack locally.
+ storedMappings = llvm::SmallDenseMap<Region *, Mappings>();
+ storedMappings->swap(state.mappings);
+ state.mappings.try_emplace(this->region);
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ state.regionStack.push_back(this->region);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ }
+
/// Back-reference to the transform state.
TransformState &state;
/// The region this scope is associated with.
Region *region;
+ /// Local copy of the mappings that existed before entering the current
+ /// region. Used only when the current region is isolated so we don't
+ /// accidentally look up the values defined outside the isolated region.
+ std::optional<llvm::SmallDenseMap<Region *, Mappings>> storedMappings =
+ std::nullopt;
+
friend RegionScope TransformState::make_region_scope(Region &);
+ friend RegionScope TransformState::make_isolated_region_scope(Region &);
};
friend class RegionScope;
/// TransformValueHandleTypeInterface.
void setValues(OpResult handle, ValueRange values);
+ /// Indicates that the result of the transform IR op at the given position
+ /// corresponds to the given range of mapped values. All mapped values are
+ /// expected to be compatible with the type of the result, e.g., if the result
+ /// is an operation handle, all mapped values are expected to be payload
+ /// operations.
+ void setMappedValues(OpResult handle, ArrayRef<MappedValue> values);
+
private:
/// Creates an instance of TransformResults that expects mappings for
/// `numSegments` values, which may be associated with payload operations or
RaggedArray<Value> values;
};
+/// Creates a RAII object the lifetime of which corresponds to the new mapping
+/// for transform IR values defined in the given region. Values defined in
+/// surrounding regions remain accessible.
TransformState::RegionScope TransformState::make_region_scope(Region ®ion) {
return RegionScope(*this, region);
}
+/// Creates a RAII object the lifetime of which corresponds to the new mapping
+/// for transform IR values defined in the given isolated-from-above region.
+/// Values defined in surrounding regions cannot be accessed.
+TransformState::RegionScope
+TransformState::make_isolated_region_scope(Region ®ion) {
+ return RegionScope(*this, region, RegionScope::Isolated());
+}
+
namespace detail {
/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
/// to either the list of operations associated with its operand or the root of
/// Verification hook for TransformOpInterface.
LogicalResult verifyTransformOpInterface(Operation *op);
+
+/// Populates `mappings` with mapped values associated with the given transform
+/// IR values in the given `state`.
+void prepareValueMappings(
+ SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
+ ValueRange values, const transform::TransformState &state);
} // namespace detail
/// This trait is supposed to be attached to Transform dialect operations that
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
+include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
"functional-type(operands, results)";
}
+def IncludeOp : TransformDialectOp<"include",
+ [CallOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Includes a named transform sequence";
+ let description = [{
+ The application of this transform operation is equivalent to applying the
+ operations contained in the named transform sequence with operands being
+ remapped to block arguments. The behavior of the operation when a
+ transformation in the included named sequence produces a silenceable error
+ is controlled by the `failure_propagation_mode` attribute. When set to
+ `propagate`, the failure of any nested transformation in the sequence
+ implies immediate failure of the entire sequence with a silenceable error,
+ and no further transformation is attempted. When set to `suppress`,
+ silenceable errors in nested operations are ignored and further
+ transformations are applied. Beware that even silenceable errors may leave
+ the payload IR in a state unsuitable for further transformations. It is the
+ responsibility of the user to ensure the following transformations are
+ robust enough when errors are suppressed. Definite errors are propagated
+ immediately regardless of the mode. The objects associated with the results
+ of this operation are the same as those associated with the operands of the
+ `transform.yield` in the referenced named sequence.
+ }];
+
+ let arguments = (ins SymbolRefAttr:$target,
+ FailurePropagationMode:$failure_propagation_mode,
+ Variadic<Transform_AnyHandleOrParamType>:$operands);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+
+ let assemblyFormat =
+ "$target `failures` `(` $failure_propagation_mode `)`"
+ "`(` $operands `)` attr-dict `:` functional-type($operands, $results)";
+
+ let extraClassDeclaration = [{
+ ::mlir::CallInterfaceCallable getCallableForCallee() {
+ return getTarget();
+ }
+
+ ::mlir::Operation::operand_range getArgOperands() {
+ return getOperands();
+ }
+ }];
+}
+
def MergeHandlesOp : TransformDialectOp<"merge_handles",
[DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
let hasFolder = 1;
}
+def NamedSequenceOp : TransformDialectOp<"named_sequence",
+ [CallableOpInterface,
+ FunctionOpInterface,
+ IsolatedFromAbove,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Named transform sequence that can be included elsewhere";
+ let description = [{
+ Defines a named (callable, function-like) sequence of other Transform
+ dialect operations that can be included using `transform.include` as part of
+ another Transform dialect construct. This sequence is not processed
+ immediately but rather dispatched to when the inclusion is processed. The
+ arguments and results can be used to communicate a subset of mapping into
+ the named sequence. The sequence must consist of a single block and end with
+ a `transform.yield` terminator. The operands of the terminator become the
+ results of the `transform.include`.
+
+ When dispatched to, the operations in the named sequence are executed one by
+ one, similarly to the regular unnamed sequence. The failure propagation mode
+ is specified on the `transform.include`. Different inclusions may use
+ different failure propagation modes. This transform operation always
+ succeeds by itself, but the inclusion may fail if any of the operations
+ fail.
+
+ Named sequences can only appear at the top-level of the Transform dialect
+ nesting structure. That is, they cannot be nested in other Transform dialect
+ operations. Furthermore, one of the ancestors must have the `SymbolTable`
+ trait and have the `transform.with_named_sequence` attribute attached.
+
+ Named sequences may include other named sequences via `transform.include`,
+ but recursion is *not* allowed.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$sym_name,
+ TypeAttrBase<"::mlir::FunctionType",
+ "function type attribute">:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs);
+ let regions = (region SizedRegion<1>:$body);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ ::llvm::ArrayRef<::mlir::Type> getArgumentTypes() {
+ return getFunctionType().getInputs();
+ }
+ ::llvm::ArrayRef<::mlir::Type> getResultTypes() {
+ return getFunctionType().getResults();
+ }
+
+ ::mlir::Region *getCallableRegion() {
+ return &getBody();
+ }
+ ::llvm::ArrayRef<::mlir::Type> getCallableResults() {
+ return getFunctionType().getResults();
+ }
+ }];
+}
+
def SplitHandlesOp : TransformDialectOp<"split_handles",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>,
let assemblyFormat = "$target attr-dict (`:` type($target)^)?";
}
-
def ReplicateOp : TransformDialectOp<"replicate",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
let description = [{
The transformations indicated by the sequence are applied in order of their
appearance. Each value produced by a transformation within the sequence
- corresponds to an operation or a group of operations in the payload IR.
- The behavior of the operation when a nested transformation produces a
- silenceable error is controlled by the `failure_propagation_mode` attribute.
- When set to `propagate`, the failure of any nested transformation in the
- sequence implies immediate failure of the entire sequence with a silenceable
- error, and no further transformation is attempted. When set to `suppress`,
+ corresponds to a group of operations or values in the payload IR, or to a
+ group of parameters, depending on the type of the value. The behavior of the
+ operation when a nested transformation produces a silenceable error is
+ controlled by the `failure_propagation_mode` attribute. When set to
+ `propagate`, the failure of any nested transformation in the sequence
+ implies immediate failure of the entire sequence with a silenceable error,
+ and no further transformation is attempted. When set to `suppress`,
silenceable errors in nested operations are ignored and further
transformations are applied. Beware that even silenceable errors may leave
- the payload IR in a state unsuitable for further transformations. It is
- the responsibility of the caller to ensure the following transformations
- are robust enough when errors are suppressed. Definite errors reported by
- nested transformations abort the sequence regardless of the propagation
- mode. The set of modes may be extended in the future, e.g., to collect
- silenceable errors and report them after attempting all transformations in
- the sequence.
+ the payload IR in a state unsuitable for further transformations. It is the
+ responsibility of the caller to ensure the following transformations are
+ robust enough when errors are suppressed. Definite errors reported by nested
+ transformations abort the sequence regardless of the propagation mode. The
+ set of modes may be extended in the future, e.g., to collect silenceable
+ errors and report them after attempting all transformations in the sequence.
The entry block of this operation has a single argument that maps to either
the operand if provided or the top-level container operation of the payload
}];
let arguments = (ins
- Arg<Variadic<TransformHandleTypeInterface>, "Operation handles yielded back to the parent"
+ Arg<Variadic<Transform_AnyHandleOrParamType>,
+ "Transform values yielded back to the parent"
>:$operands);
let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Analysis/CallGraph.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/SCCIterator.h"
using namespace mlir;
llvm::report_fatal_error(StringRef(buffer));
}
+LogicalResult transform::TransformDialect::verifyOperationAttribute(
+ Operation *op, NamedAttribute attribute) {
+ if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
+ if (!op->hasTrait<OpTrait::SymbolTable>()) {
+ return emitError(op->getLoc()) << attribute.getName()
+ << " attribute can only be attached to "
+ "operations with symbol tables";
+ }
+
+ const mlir::CallGraph callgraph(op);
+ for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
+ if (!scc.hasCycle())
+ continue;
+
+ // Need to check this here additionally because this verification may run
+ // before we check the nested operations.
+ if ((*scc->begin())->isExternal())
+ return op->emitOpError() << "contains a call to an external operation, "
+ "which is not allowed";
+
+ Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
+ InFlightDiagnostic diag = emitError(first->getLoc())
+ << "recursion not allowed in named sequences";
+ for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
+ // Need to check this here additionally because this verification may
+ // run before we check the nested operations.
+ if ((*it)->isExternal()) {
+ return op->emitOpError() << "contains a call to an external "
+ "operation, which is not allowed";
+ }
+
+ Operation *current = (*it)->getCallableRegion()->getParentOp();
+ diag.attachNote(current->getLoc()) << "operation on recursion stack";
+ }
+ return diag;
+ }
+ return success();
+ }
+ if (attribute.getName().getValue() == kTargetTagAttrName) {
+ if (!attribute.getValue().isa<StringAttr>()) {
+ return op->emitError()
+ << attribute.getName() << " attribute must be a string";
+ }
+ return success();
+ }
+ return emitError(op->getLoc())
+ << "unknown attribute: " << attribute.getName();
+}
+
#include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"
return success(found);
}
-LogicalResult
-transform::TransformState::mapBlockArgument(BlockArgument argument,
- ArrayRef<MappedValue> values) {
- if (argument.getType().isa<TransformHandleTypeInterface>()) {
+/// Given a list of MappedValues, cast them to the value kind implied by the
+/// interface of the handle type, and dispatch to one of the callbacks.
+static DiagnosedSilenceableFailure dispatchMappedValues(
+ Value handle, ArrayRef<transform::MappedValue> values,
+ function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
+ function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
+ function_ref<LogicalResult(ValueRange)> valuesFn) {
+ if (handle.getType().isa<transform::TransformHandleTypeInterface>()) {
SmallVector<Operation *> operations;
operations.reserve(values.size());
- for (MappedValue value : values) {
+ for (transform::MappedValue value : values) {
if (auto *op = value.dyn_cast<Operation *>()) {
operations.push_back(op);
continue;
}
- return emitError(argument.getLoc())
+ return emitSilenceableFailure(handle.getLoc())
<< "wrong kind of value provided for top-level operation handle";
}
- return setPayloadOps(argument, operations);
+ if (failed(operationsFn(operations)))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ return DiagnosedSilenceableFailure::success();
}
- if (argument.getType().isa<TransformValueHandleTypeInterface>()) {
+ if (handle.getType().isa<transform::TransformValueHandleTypeInterface>()) {
SmallVector<Value> payloadValues;
payloadValues.reserve(values.size());
- for (MappedValue value : values) {
+ for (transform::MappedValue value : values) {
if (auto v = value.dyn_cast<Value>()) {
payloadValues.push_back(v);
continue;
}
- return emitError(argument.getLoc())
+ return emitSilenceableFailure(handle.getLoc())
<< "wrong kind of value provided for the top-level value handle";
}
- return setPayloadValues(argument, payloadValues);
+ if (failed(valuesFn(payloadValues)))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ return DiagnosedSilenceableFailure::success();
}
- assert(argument.getType().isa<TransformParamTypeInterface>() &&
+ assert(handle.getType().isa<transform::TransformParamTypeInterface>() &&
"unsupported kind of block argument");
- SmallVector<Param> parameters;
+ SmallVector<transform::Param> parameters;
parameters.reserve(values.size());
- for (MappedValue value : values) {
+ for (transform::MappedValue value : values) {
if (auto attr = value.dyn_cast<Attribute>()) {
parameters.push_back(attr);
continue;
}
- return emitError(argument.getLoc())
+ return emitSilenceableFailure(handle.getLoc())
<< "wrong kind of value provided for top-level parameter";
}
- return setParams(argument, parameters);
+ if (failed(paramsFn(parameters)))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult
+transform::TransformState::mapBlockArgument(BlockArgument argument,
+ ArrayRef<MappedValue> values) {
+ return dispatchMappedValues(
+ argument, values,
+ [&](ArrayRef<Operation *> operations) {
+ return setPayloadOps(argument, operations);
+ },
+ [&](ArrayRef<Param> params) {
+ return setParams(argument, params);
+ },
+ [&](ValueRange payloadValues) {
+ return setPayloadValues(argument, payloadValues);
+ })
+ .checkAndReport();
}
LogicalResult
this->values.replace(position, values);
}
+void transform::TransformResults::setMappedValues(
+ OpResult handle, ArrayRef<MappedValue> values) {
+ DiagnosedSilenceableFailure diag = dispatchMappedValues(
+ handle, values,
+ [&](ArrayRef<Operation *> operations) {
+ return set(handle, operations), success();
+ },
+ [&](ArrayRef<Param> params) {
+ return setParams(handle, params), success();
+ },
+ [&](ValueRange payloadValues) {
+ return setValues(handle, payloadValues), success();
+ });
+#ifndef NDEBUG
+ if (!diag.succeeded())
+ llvm::dbgs() << diag.getStatusString() << "\n";
+ assert(diag.succeeded() && "incorrect mapping");
+#endif // NDEBUG
+ (void)diag.silence();
+}
+
ArrayRef<Operation *>
transform::TransformResults::get(unsigned resultNumber) const {
assert(resultNumber < operations.size() &&
// Utilities for PossibleTopLevelTransformOpTrait.
//===----------------------------------------------------------------------===//
+void transform::detail::prepareValueMappings(
+ SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
+ ValueRange values, const transform::TransformState &state) {
+ for (Value operand : values) {
+ SmallVector<MappedValue> &mapped = mappings.emplace_back();
+ if (operand.getType().isa<TransformHandleTypeInterface>()) {
+ llvm::append_range(mapped, state.getPayloadOps(operand));
+ } else if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
+ llvm::append_range(mapped, state.getPayloadValues(operand));
+ } else {
+ assert(operand.getType().isa<TransformParamTypeInterface>() &&
+ "unsupported kind of transform dialect value");
+ llvm::append_range(mapped, state.getParams(operand));
+ }
+ }
+}
+
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
TransformState &state, Operation *op, Region ®ion) {
SmallVector<Operation *> targets;
SmallVector<SmallVector<MappedValue>> extraMappings;
if (op->getNumOperands() != 0) {
llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
- for (Value operand : op->getOperands().drop_front()) {
- SmallVector<MappedValue> &mapped = extraMappings.emplace_back();
- if (operand.getType().isa<TransformHandleTypeInterface>()) {
- llvm::append_range(mapped, state.getPayloadOps(operand));
- } else if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
- llvm::append_range(mapped, state.getPayloadValues(operand));
- } else {
- assert(operand.getType().isa<TransformParamTypeInterface>() &&
- "unsupported kind of transform dialect value");
- llvm::append_range(mapped, state.getParams(operand));
- }
- }
+ prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
} else {
if (state.getNumTopLevelMappings() !=
region.front().getNumArguments() - 1) {
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
static void forwardTerminatorOperands(Block *block,
transform::TransformState &state,
transform::TransformResults &results) {
- for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
- block->getParentOp()->getOpResults())) {
- Value terminatorOperand = std::get<0>(pair);
- OpResult result = std::get<1>(pair);
- results.set(result, state.getPayloadOps(terminatorOperand));
+ for (auto &&[terminatorOperand, result] :
+ llvm::zip(block->getTerminator()->getOperands(),
+ block->getParentOp()->getOpResults())) {
+ if (result.getType().isa<transform::TransformHandleTypeInterface>()) {
+ results.set(result, state.getPayloadOps(terminatorOperand));
+ } else if (result.getType()
+ .isa<transform::TransformValueHandleTypeInterface>()) {
+ results.setValues(result, state.getPayloadValues(terminatorOperand));
+ } else {
+ assert(result.getType().isa<transform::TransformParamTypeInterface>() &&
+ "unhandled transform type interface");
+ results.setParams(result, state.getParams(terminatorOperand));
+ }
}
}
}
//===----------------------------------------------------------------------===//
+// IncludeOp
+//===----------------------------------------------------------------------===//
+
+/// Applies the transform ops contained in `block`. Maps `results` to the same
+/// values as the operands of the block terminator.
+static DiagnosedSilenceableFailure
+applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
+ transform::TransformState &state,
+ transform::TransformResults &results) {
+ // Apply the sequenced ops one by one.
+ for (Operation &transform : block.without_terminator()) {
+ DiagnosedSilenceableFailure result =
+ state.applyTransform(cast<transform::TransformOpInterface>(transform));
+ if (result.isDefiniteFailure())
+ return result;
+
+ if (result.isSilenceableFailure()) {
+ if (mode == transform::FailurePropagationMode::Propagate) {
+ // Propagate empty results in case of early exit.
+ forwardEmptyOperands(&block, state, results);
+ return result;
+ }
+ (void)result.silence();
+ }
+ }
+
+ // Forward the operation mapping for values yielded from the sequence to the
+ // values produced by the sequence op.
+ forwardTerminatorOperands(&block, state, results);
+ return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure
+transform::IncludeOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
+ getOperation(), getTarget());
+ assert(callee && "unverified reference to unknown symbol");
+
+ // Map operands to block arguments.
+ SmallVector<SmallVector<MappedValue>> mappings;
+ detail::prepareValueMappings(mappings, getOperands(), state);
+ auto scope = state.make_isolated_region_scope(callee.getBody());
+ for (auto &&[arg, map] :
+ llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
+ if (failed(state.mapBlockArgument(arg, map)))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ DiagnosedSilenceableFailure result = applySequenceBlock(
+ callee.getBody().front(), getFailurePropagationMode(), state, results);
+ mappings.clear();
+ detail::prepareValueMappings(
+ mappings, callee.getBody().front().getTerminator()->getOperands(), state);
+ for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
+ results.setMappedValues(result, mapping);
+ return result;
+}
+
+/// Appends to `effects` the memory effect instances on `target` with the same
+/// resource and effect as the ones the operation `iface` having on `source`.
+static void
+remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+ iface.getEffectsOnValue(source, nestedEffects);
+ for (const auto &effect : nestedEffects)
+ effects.emplace_back(effect.getEffect(), target, effect.getResource());
+}
+
+/// Appends to `effects` the same effects as the operations of `block` have on
+/// block arguments but associated with `operands.`
+static void
+remapArgumentEffects(Block &block, ValueRange operands,
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ for (Operation &op : block) {
+ auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+ if (!iface)
+ continue;
+
+ for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
+ remapEffects(iface, source, target, effects);
+ }
+
+ SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+ iface.getEffectsOnResource(transform::PayloadIRResource::get(),
+ nestedEffects);
+ llvm::append_range(effects, nestedEffects);
+ }
+}
+
+static DiagnosedSilenceableFailure
+verifyNamedSequenceOp(transform::NamedSequenceOp op);
+
+void transform::IncludeOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ // Bail if the callee is unknown. This may run as part of the verification
+ // process before we verified the validity of the callee or of this op.
+ auto target =
+ getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
+ if (!target)
+ return;
+ auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
+ getOperation(), getTarget());
+ if (!callee)
+ return;
+ DiagnosedSilenceableFailure earlyVerifierResult =
+ verifyNamedSequenceOp(callee);
+ if (!earlyVerifierResult.succeeded()) {
+ (void)earlyVerifierResult.silence();
+ return;
+ }
+
+ // Carry over effects from the callee.
+ remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
+
+ // Proper effects.
+ onlyReadsHandle(getOperands(), effects);
+ producesHandle(getResults(), effects);
+}
+
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+ return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+LogicalResult
+transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Access through indirection and do additional checking because this may be
+ // running before the main op verifier.
+ auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
+ if (!targetAttr)
+ return emitOpError() << "expects a 'target' symbol reference attribute";
+
+ auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
+ *this, targetAttr);
+ if (!target)
+ return emitOpError() << "does not reference a named transform sequence";
+
+ FunctionType fnType = target.getFunctionType();
+ if (fnType.getNumInputs() != getNumOperands())
+ return emitError("incorrect number of operands for callee");
+
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+ if (getOperand(i).getType() != fnType.getInput(i)) {
+ return emitOpError("operand type mismatch: expected operand type ")
+ << fnType.getInput(i) << ", but provided "
+ << getOperand(i).getType() << " for operand number " << i;
+ }
+ }
+
+ if (fnType.getNumResults() != getNumResults())
+ return emitError("incorrect number of results for callee");
+
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+ Type resultType = getResult(i).getType();
+ Type funcType = fnType.getResult(i);
+ if (!implementSameInterface<TransformHandleTypeInterface,
+ TransformValueHandleTypeInterface,
+ TransformParamTypeInterface>(resultType,
+ funcType)) {
+ return emitOpError() << "type of result #" << i
+ << " must implement the same transform dialect "
+ "interface as the corresponding callee result";
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// MergeHandlesOp
//===----------------------------------------------------------------------===//
}
//===----------------------------------------------------------------------===//
+// NamedSequenceOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::NamedSequenceOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ // Nothing to do here.
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::NamedSequenceOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
+
+ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name),
+ [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(inputs, results); },
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
+ function_interface_impl::printFunctionOp(
+ printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
+ getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
+ getResAttrsAttrName());
+}
+
+/// Verification of a NamedSequenceOp. This does not report the error
+/// immediately, so it can be used to check for op's well-formedness before the
+/// verifier runs, e.g., during trait verification.
+static DiagnosedSilenceableFailure
+verifyNamedSequenceOp(transform::NamedSequenceOp op) {
+ if (op.isExternal())
+ return emitSilenceableFailure(op) << "cannot be empty";
+
+ if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
+ if (!parent->getAttr(
+ transform::TransformDialect::kWithNamedSequenceAttrName)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableFailure(op)
+ << "expects the parent symbol table to have the '"
+ << transform::TransformDialect::kWithNamedSequenceAttrName
+ << "' attribute";
+ diag.attachNote(parent->getLoc()) << "symbol table operation";
+ return diag;
+ }
+ }
+
+ if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableFailure(op)
+ << "cannot be defined inside another transform op";
+ diag.attachNote(parent.getLoc()) << "ancestor transform op";
+ return diag;
+ }
+
+ if (op.getBody().front().empty())
+ return emitSilenceableFailure(op) << "expected a non-empty body block";
+
+ Operation *terminator = &op.getBody().front().back();
+ if (!isa<transform::YieldOp>(terminator)) {
+ DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
+ << "expected '"
+ << transform::YieldOp::getOperationName()
+ << "' as terminator";
+ diag.attachNote(terminator->getLoc()) << "terminator";
+ return diag;
+ }
+
+ if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
+ return emitSilenceableFailure(terminator)
+ << "expected terminator to have as many operands as the parent op "
+ "has results";
+ }
+ for (auto [i, operandType, resultType] :
+ llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
+ terminator->getOperands().getType(),
+ op.getFunctionType().getResults())) {
+ if (operandType == resultType)
+ continue;
+ return emitSilenceableFailure(terminator)
+ << "the type of the terminator operand #" << i
+ << " must match the type of the corresponding parent op result ("
+ << operandType << " vs " << resultType << ")";
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::NamedSequenceOp::verify() {
+ // Actual verification happens in a separate function for reusability.
+ return verifyNamedSequenceOp(*this).checkAndReport();
+}
+
+//===----------------------------------------------------------------------===//
// SplitHandlesOp
//===----------------------------------------------------------------------===//
if (failed(mapBlockArguments(state)))
return DiagnosedSilenceableFailure::definiteFailure();
- // Apply the sequenced ops one by one.
- for (Operation &transform : getBodyBlock()->without_terminator()) {
- DiagnosedSilenceableFailure result =
- state.applyTransform(cast<TransformOpInterface>(transform));
- if (result.isDefiniteFailure())
- return result;
-
- if (result.isSilenceableFailure()) {
- if (getFailurePropagationMode() == FailurePropagationMode::Propagate) {
- // Propagate empty results in case of early exit.
- forwardEmptyOperands(getBodyBlock(), state, results);
- return result;
- }
- (void)result.silence();
- }
- }
-
- // Forward the operation mapping for values yielded from the sequence to the
- // values produced by the sequence op.
- forwardTerminatorOperands(getBodyBlock(), state, results);
- return DiagnosedSilenceableFailure::success();
+ return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
+ results);
}
static ParseResult parseSequenceOpOperands(
return success();
}
-/// Appends to `effects` the memory effect instances on `target` with the same
-/// resource and effect as the ones the operation `iface` having on `source`.
-static void
-remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- SmallVector<MemoryEffects::EffectInstance> nestedEffects;
- iface.getEffectsOnValue(source, nestedEffects);
- for (const auto &effect : nestedEffects)
- effects.emplace_back(effect.getEffect(), target, effect.getResource());
-}
-
-namespace {
-template <typename T>
-using has_get_extra_bindings = decltype(std::declval<T &>().getExtraBindings());
-} // namespace
-
/// Populate `effects` with transform dialect memory effects for the potential
/// top-level operation. Such operations have recursive effects from nested
/// operations. When they have an operand, we can additionally remap effects on
// Carry over all effects on arguments of the entry block as those on the
// operands, this is the same value just remapped.
- for (Operation &op : *operation.getBodyBlock()) {
- auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
- if (!iface)
- continue;
-
- remapEffects(iface, operation.getBodyBlock()->getArgument(0),
- operation.getRoot(), effects);
- if constexpr (llvm::is_detected<has_get_extra_bindings, OpTy>::value) {
- for (auto [source, target] :
- llvm::zip(operation.getBodyBlock()->getArguments().drop_front(),
- operation.getExtraBindings())) {
- remapEffects(iface, source, target, effects);
- }
- }
-
- SmallVector<MemoryEffects::EffectInstance> nestedEffects;
- iface.getEffectsOnResource(transform::PayloadIRResource::get(),
- nestedEffects);
- llvm::append_range(effects, nestedEffects);
- }
+ remapArgumentEffects(*operation.getBodyBlock(), operation->getOperands(),
+ effects);
}
void transform::SequenceOp::getEffects(
::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
[&](::mlir::transform::TransformOpInterface transformOp) {
+ if (!transformOp
+ ->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
+ return WalkResult::skip();
if (!topLevelTransform) {
topLevelTransform = transformOp;
return WalkResult::skip();
// expected-note @below {{no 'allocate' effect specified for result #0}}
transform.test_required_memory_effects %arg0 {has_operand_effect, modifies_payload} : (!transform.any_op) -> !transform.any_op
}
+
+// -----
+
+// expected-error @below {{attribute can only be attached to operations with symbol tables}}
+"test.unknown_container"() { transform.with_named_sequence } : () -> ()
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{failed to verify constraint: region with 1 blocks}}
+ "transform.named_sequence"() ({}) { sym_name = "external_named_sequence", function_type = () -> () } : () -> ()
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @external_named_sequence failures(propagate) () : () -> ()
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{recursion not allowed in named sequences}}
+ transform.named_sequence @self_recursion() -> () {
+ transform.include @self_recursion failures(suppress) () : () -> ()
+ }
+}
+
+// -----
+
+module @mutual_recursion attributes { transform.with_named_sequence } {
+ // expected-note @below {{operation on recursion stack}}
+ transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ transform.include @bar failures(suppress) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+
+ // expected-error @below {{recursion not allowed in named sequences}}
+ transform.named_sequence @bar(%arg0: !transform.any_op) -> () {
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+// expected-error @below {{unknown attribute: "transform.unknown_container"}}
+module @unknown_attribute attributes { transform.unknown_container } {}
+
+// -----
+
+module {
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{op does not reference a named transform sequence}}
+ transform.include @non_existent failures(propagate) () : () -> ()
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{requires attribute 'target'}}
+ "transform.include"() {failure_propagation_mode = 0} : () -> ()
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ transform.yield
+ }
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg1: !transform.any_op):
+ // expected-error @below {{incorrect number of operands for callee}}
+ transform.include @foo failures(suppress) () : () -> ()
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ transform.yield
+ }
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg1: !transform.op<"builtin.module">):
+ // expected-error @below {{operand type mismatch: expected operand type '!transform.any_op', but provided '!transform.op<"builtin.module">' for operand number 0}}
+ transform.include @foo failures(suppress) (%arg1) : (!transform.op<"builtin.module">) -> ()
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) {
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg1: !transform.any_op):
+ // expected-error @below {{incorrect number of results for callee}}
+ transform.include @foo failures(suppress) (%arg1) : (!transform.any_op) -> ()
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) {
+ transform.yield %arg0 : !transform.any_op
+ }
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg1: !transform.any_op):
+ // expected-error @below {{type of result #0 must implement the same transform dialect interface as the corresponding callee result}}
+ transform.include @foo failures(suppress) (%arg1) : (!transform.any_op) -> (!transform.any_value)
+ }
+}
+
+// -----
+
+// expected-note @below {{symbol table operation}}
+module {
+ // expected-error @below {{expects the parent symbol table to have the 'transform.with_named_sequence' attribute}}
+ transform.named_sequence @parent_has_no_attributes() {
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence} {
+ // expected-note @below {{ancestor transform op}}
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{cannot be defined inside another transform op}}
+ transform.named_sequence @nested() {
+ transform.yield
+ }
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence} {
+ func.func private @foo()
+
+ // expected-error @below {{expected 'transform.yield' as terminator}}
+ transform.named_sequence @nested() {
+ // expected-note @below {{terminator}}
+ func.call @foo() : () -> ()
+ }
+}
+
+
+// -----
+
+module attributes { transform.with_named_sequence} {
+ func.func private @foo()
+
+ transform.named_sequence @nested(%arg0: !transform.any_op) {
+ // expected-error @below {{expected terminator to have as many operands as the parent op has results}}
+ transform.yield %arg0 : !transform.any_op
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence} {
+ func.func private @foo()
+
+ transform.named_sequence @nested(%arg0: !transform.any_op) -> !transform.op<"builtin.module"> {
+ // expected-error @below {{the type of the terminator operand #0 must match the type of the corresponding parent op result}}
+ transform.yield %arg0 : !transform.any_op
+ }
+}
%op = transform.get_defining_op %bbarg : (!transform.any_value) -> !transform.any_op
transform.test_print_remark_at_operand %op, "matched" : !transform.any_op
}
+
+// -----
+
+module @named_inclusion attributes { transform.with_named_sequence } {
+
+ transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ // expected-remark @below {{applying transformation "a"}}
+ transform.test_transform_op "a"
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+}
+
+// -----
+
+module @named_inclusion_in_named attributes { transform.with_named_sequence } {
+
+ transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ // expected-remark @below {{applying transformation "a"}}
+ transform.test_transform_op "a"
+ transform.yield
+ }
+
+ transform.named_sequence @bar(%arg0: !transform.any_op) -> () {
+ // expected-remark @below {{applying transformation "b"}}
+ transform.test_transform_op "b"
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @bar failures(suppress) (%arg0) : (!transform.any_op) -> ()
+ }
+}
+
+// -----
+
+// expected-remark @below {{operation}}
+module @named_operands attributes { transform.with_named_sequence } {
+
+ transform.named_sequence @foo(%arg0: !transform.any_op, %arg1: !transform.any_value) -> () {
+ transform.test_print_remark_at_operand %arg0, "operation" : !transform.any_op
+ transform.test_print_remark_at_operand_value %arg1, "value" : !transform.any_value
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ // expected-remark @below {{value}}
+ // expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}}
+ ^bb0(%arg0: !transform.any_op):
+ %0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value
+ include @foo failures(propagate) (%arg0, %0) : (!transform.any_op, !transform.any_value) -> ()
+ }
+}
+
+// -----
+
+// expected-remark @below {{operation}}
+module @named_return attributes { transform.with_named_sequence } {
+
+ // expected-remark @below {{value}}
+ // expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}}
+ transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op, !transform.any_value) {
+ %0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value
+ transform.yield %arg0, %0 : !transform.any_op, !transform.any_value
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ %0:2 = include @foo failures(propagate) (%arg0) : (!transform.any_op) -> (!transform.any_op, !transform.any_value)
+ transform.test_print_remark_at_operand %0#0, "operation" : !transform.any_op
+ transform.test_print_remark_at_operand_value %0#1, "value" : !transform.any_value
+ }
+}
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Transform/IR/TransformOps.td",
- deps = [":TransformDialectTdFiles"],
+ deps = [
+ ":CallInterfacesTdFiles",
+ ":TransformDialectTdFiles"
+ ],
)
gentbl_cc_library(
srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]),
hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]),
deps = [
+ ":CallInterfaces",
":ControlFlowInterfaces",
":IR",
":PDLDialect",