//===----------------------------------------------------------------------===//
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait]> {
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait]> {
let description = [{
Decomposes named complex operations, such as higher-dimensional
(depthwise) convolutions, into combinations of lower-dimensional equivalents
def RewriteInDestinationPassingStyleOp : Op<
Transform_Dialect, "structured.rewrite_in_destination_passing_style",
- [MemoryEffectsOpInterface,
- NavigationTransformOpTrait,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait]> {
let description = [{
Rewrite a supported tensor operation that is not in destination-passing style
into a form that is in destination-passing style.
$target attr-dict
`:` functional-type($target, results)
}];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
}
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
-transform::RewriteInDestinationPassingStyleOp::apply(
- transform::TransformResults &results, transform::TransformState &state) {
+transform::RewriteInDestinationPassingStyleOp::applyToOne(
+ Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
SmallVector<Operation *> res;
- ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
- for (Operation *target : targetOps) {
- IRRewriter rewriter(target->getContext());
- rewriter.setInsertionPoint(target);
- FailureOr<Operation *> maybeResult =
- TypeSwitch<Operation *, FailureOr<Operation *>>(target)
- .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
- [&rewriter](auto op) {
- return rewriteInDestinationPassingStyle(rewriter, op);
- });
- if (failed(maybeResult))
- return emitDefaultSilenceableFailure(target);
- res.push_back(*maybeResult);
- }
- results.set(getResult().cast<OpResult>(), res);
+ IRRewriter rewriter(target->getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<Operation *> maybeResult =
+ TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+ .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+ [&rewriter](auto op) {
+ return rewriteInDestinationPassingStyle(rewriter, op);
+ });
+ if (failed(maybeResult))
+ return emitDefaultSilenceableFailure(target);
+ results.push_back(*maybeResult);
return DiagnosedSilenceableFailure::success();
}