[mlir] address post-commit review for D127724
authorAlex Zinenko <zinenko@google.com>
Wed, 15 Jun 2022 16:38:14 +0000 (18:38 +0200)
committerAlex Zinenko <zinenko@google.com>
Wed, 15 Jun 2022 16:43:05 +0000 (18:43 +0200)
- make transform.alternatives op apply only to isolated-from-above payload IR
  scopes;
- fix potential leak;
- fix several typos.

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

index 1a36e2f..27a68b0 100644 (file)
@@ -170,7 +170,7 @@ def Transform_Dialect : Dialect {
     perform the required recovery steps thus succeeding themselves. On
     the other hand, they must propagate irrecoverable failures. For such
     failures, the diagnostics are emitted immediately whereas their emission is
-    postponed for recoverable faliures. Transformation container operations may
+    postponed for recoverable failures. Transformation container operations may
     also fail to recover from a theoretically recoverable failure, in which case
     they are expected to emit the diagnostic and turn the failure into an
     irrecoverable one. A recoverable failure produced by applying the top-level
index df33f95..acb6769 100644 (file)
@@ -20,42 +20,42 @@ namespace mlir {
 ///   - success;
 ///   - silencable (recoverable) failure with yet-unreported diagnostic;
 ///   - definite failure.
-/// Silencable failure is intended to communicate information about
+/// Silenceable failure is intended to communicate information about
 /// transformations that did not apply but in a way that supports recovery,
 /// for example, they did not modify the payload IR or modified it in some
 /// predictable way. They are associated with a Diagnostic that provides more
-/// details on the failure. Silencable failure can be discarded, turning the
+/// details on the failure. Silenceable failure can be discarded, turning the
 /// result into success, or "reported", emitting the diagnostic and turning the
 /// result into definite failure. Transform IR operations containing other
 /// operations are allowed to do either with the results of the nested
 /// transformations, but must propagate definite failures as their diagnostics
 /// have been already reported to the user.
-class LLVM_NODISCARD DiagnosedSilencableFailure {
+class LLVM_NODISCARD DiagnosedSilenceableFailure {
 public:
-  explicit DiagnosedSilencableFailure(LogicalResult result) : result(result) {}
-  DiagnosedSilencableFailure(const DiagnosedSilencableFailure &) = delete;
-  DiagnosedSilencableFailure &
-  operator=(const DiagnosedSilencableFailure &) = delete;
-  DiagnosedSilencableFailure(DiagnosedSilencableFailure &&) = default;
-  DiagnosedSilencableFailure &
-  operator=(DiagnosedSilencableFailure &&) = default;
-
-  /// Constructs a DiagnosedSilencableFailure in the success state.
-  static DiagnosedSilencableFailure success() {
-    return DiagnosedSilencableFailure(::mlir::success());
+  explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
+  DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete;
+  DiagnosedSilenceableFailure &
+  operator=(const DiagnosedSilenceableFailure &) = delete;
+  DiagnosedSilenceableFailure(DiagnosedSilenceableFailure &&) = default;
+  DiagnosedSilenceableFailure &
+  operator=(DiagnosedSilenceableFailure &&) = default;
+
+  /// Constructs a DiagnosedSilenceableFailure in the success state.
+  static DiagnosedSilenceableFailure success() {
+    return DiagnosedSilenceableFailure(::mlir::success());
   }
 
-  /// Constructs a DiagnosedSilencableFailure in the failure state. Typically,
+  /// Constructs a DiagnosedSilenceableFailure in the failure state. Typically,
   /// a diagnostic has been emitted before this.
-  static DiagnosedSilencableFailure definiteFailure() {
-    return DiagnosedSilencableFailure(::mlir::failure());
+  static DiagnosedSilenceableFailure definiteFailure() {
+    return DiagnosedSilenceableFailure(::mlir::failure());
   }
 
-  /// Constructs a DiagnosedSilencableFailure in the silencable failure state,
+  /// Constructs a DiagnosedSilenceableFailure in the silencable failure state,
   /// ready to emit the given diagnostic. This is considered a failure
   /// regardless of the diagnostic severity.
-  static DiagnosedSilencableFailure silencableFailure(Diagnostic &&diag) {
-    return DiagnosedSilencableFailure(std::forward<Diagnostic>(diag));
+  static DiagnosedSilenceableFailure silencableFailure(Diagnostic &&diag) {
+    return DiagnosedSilenceableFailure(std::forward<Diagnostic>(diag));
   }
 
   /// Converts all kinds of failure into a LogicalResult failure, emitting the
@@ -75,7 +75,7 @@ public:
   }
 
   /// Returns `true` if this is a silencable failure.
-  bool isSilencableFailure() const { return diagnostic.hasValue(); }
+  bool isSilenceableFailure() const { return diagnostic.hasValue(); }
 
   /// Returns `true` if this is a success.
   bool succeeded() const {
@@ -98,26 +98,26 @@ public:
 
   /// Streams the given values into the diagnotic. Expects this object to be a
   /// silencable failure.
-  template <typename T> DiagnosedSilencableFailure &operator<<(T &&value) & {
-    assert(isSilencableFailure() &&
+  template <typename T> DiagnosedSilenceableFailure &operator<<(T &&value) & {
+    assert(isSilenceableFailure() &&
            "can only append output in silencable failure state");
     *diagnostic << std::forward<T>(value);
     return *this;
   }
-  template <typename T> DiagnosedSilencableFailure &&operator<<(T &&value) && {
+  template <typename T> DiagnosedSilenceableFailure &&operator<<(T &&value) && {
     return std::move(this->operator<<(std::forward<T>(value)));
   }
 
   /// Attaches a note to the diagnostic. Expects this object to be a silencable
   /// failure.
   Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
-    assert(isSilencableFailure() &&
+    assert(isSilenceableFailure() &&
            "can only attach notes to silencable failures");
     return diagnostic->attachNote(loc);
   }
 
 private:
-  explicit DiagnosedSilencableFailure(Diagnostic &&diagnostic)
+  explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
       : diagnostic(std::move(diagnostic)), result(failure()) {}
 
   /// The diagnostic associated with this object. If present, the object is
@@ -226,7 +226,7 @@ public:
 
   /// Applies the transformation specified by the given transform op and updates
   /// the state accordingly.
-  DiagnosedSilencableFailure applyTransform(TransformOpInterface transform);
+  DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
 
   /// Records the mapping between a block argument in the transform IR and a
   /// list of operations in the payload IR. The arguments must be defined in
@@ -524,7 +524,7 @@ namespace detail {
 /// the payload IR, depending on what is available in the context.
 LogicalResult
 mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
-                                             Operation *op, unsigned region);
+                                             Operation *op, Region &region);
 
 /// Verification hook for PossibleTopLevelTransformOpTrait.
 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
@@ -562,10 +562,18 @@ public:
   /// and the relevant list of Payload IR operations in the given state. The
   /// state is expected to be already scoped at the region of this operation.
   /// Returns failure if the mapping failed, e.g., the value is already mapped.
-  LogicalResult mapBlockArguments(TransformState &state, unsigned region = 0) {
+  LogicalResult mapBlockArguments(TransformState &state, Region &region) {
+    assert(region.getParentOp() == this->getOperation() &&
+           "op comes from the wrong region");
     return detail::mapPossibleTopLevelTransformOpBlockArguments(
         state, this->getOperation(), region);
   }
+  LogicalResult mapBlockArguments(TransformState &state) {
+    assert(
+        this->getOperation()->getNumRegions() == 1 &&
+        "must indicate the region to map if the operation has more than one");
+    return mapBlockArguments(state, this->getOperation()->getRegion(0));
+  }
 };
 
 /// Trait implementing the TransformOpInterface for operations applying a
@@ -586,8 +594,8 @@ public:
   /// Calls `applyToOne` for every payload operation associated with the operand
   /// of this transform IR op. If `applyToOne` returns ops, associates them with
   /// the result of this transform op.
-  DiagnosedSilencableFailure apply(TransformResults &transformResults,
-                                   TransformState &state);
+  DiagnosedSilenceableFailure apply(TransformResults &transformResults,
+                                    TransformState &state);
 
   /// Checks that the op matches the expectations of this trait.
   static LogicalResult verifyTrait(Operation *op);
@@ -738,7 +746,7 @@ appendTransformResultToVector(Ty result,
 ///   `targets` contains operations of the same class and a silencable failure
 ///   is reported if it does not.
 template <typename FnTy>
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 applyTransformToEach(ArrayRef<Operation *> targets,
                      SmallVectorImpl<Operation *> &results, FnTy transform) {
   using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
@@ -753,21 +761,21 @@ applyTransformToEach(ArrayRef<Operation *> targets,
     if (!specificOp) {
       Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
       diag << "attempted to apply transform to the wrong op kind";
-      return DiagnosedSilencableFailure::silencableFailure(std::move(diag));
+      return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
     }
 
     auto result = transform(specificOp);
     if (failed(appendTransformResultToVector(result, results)))
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
   }
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 } // namespace detail
 } // namespace transform
 } // namespace mlir
 
 template <typename OpTy>
-mlir::DiagnosedSilencableFailure
+mlir::DiagnosedSilenceableFailure
 mlir::transform::TransformEachOpTrait<OpTy>::apply(
     TransformResults &transformResults, TransformState &state) {
   using TransformOpType = typename llvm::function_traits<
@@ -775,7 +783,7 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
   ArrayRef<Operation *> targets =
       state.getPayloadOps(this->getOperation()->getOperand(0));
   SmallVector<Operation *> results;
-  DiagnosedSilencableFailure result = detail::applyTransformToEach(
+  DiagnosedSilenceableFailure result = detail::applyTransformToEach(
       targets, results, [&](TransformOpType specificOp) {
         return static_cast<OpTy *>(this)->applyToOne(specificOp);
       });
@@ -786,7 +794,7 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
     transformResults.set(
         this->getOperation()->getResult(0).template cast<OpResult>(), results);
   }
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 template <typename OpTy>
index b8b6a0a..3ce99c4 100644 (file)
@@ -42,7 +42,7 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
         special status object indicating whether the transformation succeeded
         or failed, and, if it failed, whether the failure is recoverable or not.
       }],
-      /*returnType=*/"::mlir::DiagnosedSilencableFailure",
+      /*returnType=*/"::mlir::DiagnosedSilenceableFailure",
       /*name=*/"apply",
       /*arguments=*/(ins
           "::mlir::transform::TransformResults &":$transformResults,
@@ -64,9 +64,9 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
 
     /// Creates the silencable failure object with a diagnostic located at the
     /// current operation.
-    DiagnosedSilencableFailure emitSilencableError() {
+    DiagnosedSilenceableFailure emitSilenceableError() {
       Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
-      return DiagnosedSilencableFailure::silencableFailure(std::move(diag));
+      return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
     }
   }];
 }
index f415e03..dd0e82e 100644 (file)
@@ -44,13 +44,14 @@ def AlternativesOp : TransformDialectOp<"alternatives",
     The single operand of this operation is the scope in which the alternative
     transformation sequences are attempted, that is, an operation in the payload
     IR that contains all the other operations that may be modified by the
-    transformations. There is no check that the transforms are indeed scoped
-    as their "apply" methods can be arbitrarily complex. Therefore it is the
-    responsibility of the user to ensure that the transforms are scoped
-    correctly, or to produce an irrecoverable error and thus abort the execution
-    without attempting the remaining alternatives. Note that the payload IR
-    outside of the given scope is not necessarily in the valid state, or even
-    accessible to the tranfsormation.
+    transformations. The scope operation must be isolated from above. There is
+    no check that the transforms are indeed scoped as their "apply" methods can
+    be arbitrarily complex. Therefore it is the responsibility of the user to
+    ensure that the transforms are scoped correctly, or to produce an
+    irrecoverable error and thus abort the execution without attempting the
+    remaining alternatives. Note that the payload IR outside of the given scope
+    is not necessarily in the valid state, or even accessible to the
+    tranfsormation.
     
     The changes to the IR within the scope performed by transforms in the failed
     alternative region are reverted before attempting the next region.
@@ -72,8 +73,8 @@ def AlternativesOp : TransformDialectOp<"alternatives",
     ```mlir
     %result = transform.alternatives %scope {
     ^bb0(%arg0: !pdl.operation):
-      // Try a failible transformation.
-      %0 = transform.failible %arg0 // ...
+      // Try a fallible transformation.
+      %0 = transform.fallible %arg0 // ...
       // If succeeded, yield the the result of the transformation.
       transform.yield %0 : !pdl.operation
     }, {
index c1f532c..66144cf 100644 (file)
@@ -24,7 +24,7 @@ using namespace mlir::transform;
 // OneShotBufferizeOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
                                      TransformState &state) {
   OneShotBufferizationOptions options;
@@ -39,19 +39,19 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
   for (Operation *target : payloadOps) {
     auto moduleOp = dyn_cast<ModuleOp>(target);
     if (getTargetIsModule() && !moduleOp)
-      return emitSilencableError() << "expected ModuleOp target";
+      return emitSilenceableError() << "expected ModuleOp target";
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
-        return emitSilencableError() << "expected ModuleOp target";
+        return emitSilenceableError() << "expected ModuleOp target";
       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
-        return emitSilencableError() << "bufferization failed";
+        return emitSilenceableError() << "bufferization failed";
     } else {
       if (failed(bufferization::runOneShotBufferize(target, options)))
-        return emitSilencableError() << "bufferization failed";
+        return emitSilenceableError() << "bufferization failed";
     }
   }
 
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 void transform::OneShotBufferizeOp::getEffects(
index d239cad..a5e865b 100644 (file)
@@ -164,7 +164,7 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
   return success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
                          mlir::transform::TransformState &state) {
   LinalgTilingAndFusionOptions fusionOptions;
@@ -188,8 +188,8 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
                                tileLoopNest->getLoopOps().end()};
         return tiledLinalgOp;
       });
-  return failed(result) ? DiagnosedSilencableFailure::definiteFailure()
-                        : DiagnosedSilencableFailure::success();
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
 }
 
 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
@@ -398,7 +398,7 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
 // TileOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::TileOp::apply(TransformResults &transformResults,
                          TransformState &state) {
   LinalgTilingOptions tilingOptions;
@@ -415,7 +415,7 @@ transform::TileOp::apply(TransformResults &transformResults,
         SimpleRewriter rewriter(linalgOp.getContext());
         return pattern.returningMatchAndRewrite(linalgOp, rewriter);
       });
-  return DiagnosedSilencableFailure(result);
+  return DiagnosedSilenceableFailure(result);
 }
 
 ParseResult transform::TileOp::parse(OpAsmParser &parser,
index 6b39a35..f7821da 100644 (file)
@@ -31,7 +31,7 @@ public:
 // GetParentForOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::GetParentForOp::apply(transform::TransformResults &results,
                                  transform::TransformState &state) {
   SetVector<Operation *> parents;
@@ -41,10 +41,10 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
       loop = current->getParentOfType<scf::ForOp>();
       if (!loop) {
-        DiagnosedSilencableFailure diag = emitSilencableError()
-                                          << "could not find an '"
-                                          << scf::ForOp::getOperationName()
-                                          << "' parent";
+        DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                           << "could not find an '"
+                                           << scf::ForOp::getOperationName()
+                                           << "' parent";
         diag.attachNote(target->getLoc()) << "target op";
         return diag;
       }
@@ -53,7 +53,7 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
     parents.insert(loop);
   }
   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -85,7 +85,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
   return executeRegionOp;
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::LoopOutlineOp::apply(transform::TransformResults &results,
                                 transform::TransformState &state) {
   SmallVector<Operation *> transformed;
@@ -96,8 +96,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     SimpleRewriter rewriter(getContext());
     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
     if (!exec) {
-      DiagnosedSilencableFailure diag = emitSilencableError()
-                                        << "failed to outline";
+      DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                         << "failed to outline";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }
@@ -107,7 +107,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
 
     if (failed(outlined)) {
       (void)reportUnknownTransformError(target);
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
     }
 
     if (symbolTableOp) {
@@ -120,7 +120,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     transformed.push_back(*outlined);
   }
   results.set(getTransformed().cast<OpResult>(), transformed);
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
index ad6935f..ecf1cbe 100644 (file)
@@ -188,16 +188,16 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
   return success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::TransformState::applyTransform(TransformOpInterface transform) {
   LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
   if (options.getExpensiveChecksEnabled() &&
       failed(checkAndRecordHandleInvalidation(transform))) {
-    return DiagnosedSilencableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::definiteFailure();
   }
 
   transform::TransformResults results(transform->getNumResults());
-  DiagnosedSilencableFailure result(transform.apply(results, *this));
+  DiagnosedSilenceableFailure result(transform.apply(results, *this));
   if (!result.succeeded())
     return result;
 
@@ -223,10 +223,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
            "payload IR association for a value other than the result of the "
            "current transform op");
     if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
   }
 
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -277,15 +277,14 @@ transform::TransformResults::get(unsigned resultNumber) const {
 //===----------------------------------------------------------------------===//
 
 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
-    TransformState &state, Operation *op, unsigned region) {
+    TransformState &state, Operation *op, Region &region) {
   SmallVector<Operation *> targets;
   if (op->getNumOperands() != 0)
     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
   else
     targets.push_back(state.getTopLevel());
 
-  return state.mapBlockArguments(op->getRegion(region).front().getArgument(0),
-                                 targets);
+  return state.mapBlockArguments(region.front().getArgument(0), targets);
 }
 
 LogicalResult
index 3037cf6..d071c8e 100644 (file)
@@ -164,7 +164,7 @@ static void forwardTerminatorOperands(Block *block,
   }
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::AlternativesOp::apply(transform::TransformResults &results,
                                  transform::TransformState &state) {
   SmallVector<Operation *> originals;
@@ -178,7 +178,14 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
       InFlightDiagnostic diag =
           emitError() << "scope must not contain the transforms being applied";
       diag.attachNote(original->getLoc()) << "scope";
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      InFlightDiagnostic diag =
+          emitError()
+          << "only isolated-from-above ops can be alternative scopes";
+      diag.attachNote(original->getLoc()) << "scope";
+      return DiagnosedSilenceableFailure(std::move(diag));
     }
   }
 
@@ -190,18 +197,18 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
     auto scope = state.make_region_scope(reg);
     auto clones = llvm::to_vector(
         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
-    if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
-      return DiagnosedSilencableFailure::definiteFailure();
     auto deleteClones = llvm::make_scope_exit([&] {
       for (Operation *clone : clones)
         clone->erase();
     });
+    if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
+      return DiagnosedSilenceableFailure::definiteFailure();
 
     bool failed = false;
     for (Operation &transform : reg.front().without_terminator()) {
-      DiagnosedSilencableFailure result =
+      DiagnosedSilenceableFailure result =
           state.applyTransform(cast<TransformOpInterface>(transform));
-      if (result.isSilencableFailure()) {
+      if (result.isSilenceableFailure()) {
         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
                           << "\n");
         failed = true;
@@ -209,7 +216,7 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
       }
 
       if (::mlir::failed(result.silence()))
-        return DiagnosedSilencableFailure::definiteFailure();
+        return DiagnosedSilenceableFailure::definiteFailure();
     }
 
     // If all operations in the given alternative succeeded, no need to consider
@@ -227,10 +234,10 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
         rewriter.replaceOp(original, clone->getResults());
       }
       forwardTerminatorOperands(&reg.front(), state, results);
-      return DiagnosedSilencableFailure::success();
+      return DiagnosedSilenceableFailure::success();
     }
   }
-  return emitSilencableError() << "all alternatives failed";
+  return emitSilenceableError() << "all alternatives failed";
 }
 
 LogicalResult transform::AlternativesOp::verify() {
@@ -260,15 +267,15 @@ LogicalResult transform::AlternativesOp::verify() {
 // GetClosestIsolatedParentOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure transform::GetClosestIsolatedParentOp::apply(
+DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   SetVector<Operation *> parents;
   for (Operation *target : state.getPayloadOps(getTarget())) {
     Operation *parent =
         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
     if (!parent) {
-      DiagnosedSilencableFailure diag =
-          emitSilencableError()
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError()
           << "could not find an isolated-from-above parent op";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
@@ -276,14 +283,14 @@ DiagnosedSilencableFailure transform::GetClosestIsolatedParentOp::apply(
     parents.insert(parent);
   }
   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
 // PDLMatchOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::PDLMatchOp::apply(transform::TransformResults &results,
                              transform::TransformState &state) {
   auto *extension = state.getExtension<PatternApplicatorExtension>();
@@ -294,28 +301,28 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
     if (failed(extension->findAllMatches(
             getPatternName().getLeafReference().getValue(), root, targets))) {
       emitOpError() << "could not find pattern '" << getPatternName() << "'";
-      return DiagnosedSilencableFailure::definiteFailure();
+      return DiagnosedSilenceableFailure::definiteFailure();
     }
   }
   results.set(getResult().cast<OpResult>(), targets);
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
 // SequenceOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::SequenceOp::apply(transform::TransformResults &results,
                              transform::TransformState &state) {
   // Map the entry block argument to the list of operations.
   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
   if (failed(mapBlockArguments(state)))
-    return DiagnosedSilencableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::definiteFailure();
 
   // Apply the sequenced ops one by one.
   for (Operation &transform : getBodyBlock()->without_terminator()) {
-    DiagnosedSilencableFailure result =
+    DiagnosedSilenceableFailure result =
         state.applyTransform(cast<TransformOpInterface>(transform));
     if (!result.succeeded())
       return result;
@@ -324,7 +331,7 @@ transform::SequenceOp::apply(transform::TransformResults &results,
   // Forward the operation mapping for values yielded from the sequence to the
   // values produced by the sequence op.
   forwardTerminatorOperands(getBodyBlock(), state, results);
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 /// Returns `true` if the given op operand may be consuming the handle value in
@@ -486,7 +493,7 @@ void transform::SequenceOp::getRegionInvocationBounds(
 // WithPDLPatternsOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
                                     transform::TransformState &state) {
   OwningOpRef<ModuleOp> pdlModuleOp =
@@ -505,7 +512,7 @@ transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
 
   auto scope = state.make_region_scope(getBody());
   if (failed(mapBlockArguments(state)))
-    return DiagnosedSilencableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::definiteFailure();
   return state.applyTransform(transformOp);
 }
 
index 22cfc00..45e01e8 100644 (file)
@@ -348,3 +348,33 @@ module {
   }
 }
 
+// -----
+
+func.func @foo(%arg0: index, %arg1: index, %arg2: index) {
+  // expected-note @below {{scope}}
+  scf.for %i = %arg0 to %arg1 step %arg2 {
+    %0 = arith.constant 0 : i32
+  }
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_const : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "arith.constant"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @match_const in %arg1
+    %1 = transform.loop.get_parent_for %0
+    // expected-error @below {{only isolated-from-above ops can be alternative scopes}}
+    alternatives %1 {
+    ^bb2(%arg2: !pdl.operation):
+    }
+  }
+}
index 81cba58..5905218 100644 (file)
@@ -38,13 +38,13 @@ public:
     return llvm::StringLiteral("transform.test_transform_op");
   }
 
-  DiagnosedSilencableFailure apply(transform::TransformResults &results,
-                                   transform::TransformState &state) {
+  DiagnosedSilenceableFailure apply(transform::TransformResults &results,
+                                    transform::TransformState &state) {
     InFlightDiagnostic remark = emitRemark() << "applying transformation";
     if (Attribute message = getMessage())
       remark << " " << message;
 
-    return DiagnosedSilencableFailure::success();
+    return DiagnosedSilenceableFailure::success();
   }
 
   Attribute getMessage() { return getOperation()->getAttr("message"); }
@@ -91,9 +91,9 @@ public:
         "transform.test_transform_unrestricted_op_no_interface");
   }
 
-  DiagnosedSilencableFailure apply(transform::TransformResults &results,
-                                   transform::TransformState &state) {
-    return DiagnosedSilencableFailure::success();
+  DiagnosedSilenceableFailure apply(transform::TransformResults &results,
+                                    transform::TransformState &state) {
+    return DiagnosedSilenceableFailure::success();
   }
 
   // No side effects.
@@ -101,7 +101,7 @@ public:
 };
 } // namespace
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestProduceParamOrForwardOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   if (getOperation()->getNumOperands() != 0) {
@@ -111,7 +111,7 @@ mlir::test::TestProduceParamOrForwardOperandOp::apply(
     results.set(getResult().cast<OpResult>(),
                 reinterpret_cast<Operation *>(*getParameter()));
   }
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
@@ -120,50 +120,51 @@ LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
   return success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
                                       transform::TransformState &state) {
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
   assert(payload.size() == 1 && "expected a single target op");
   auto value = reinterpret_cast<intptr_t>(payload[0]);
   if (static_cast<uint64_t>(value) != getParameter()) {
-    return emitSilencableError()
+    return emitSilenceableError()
            << "op expected the operand to be associated with " << getParameter()
            << " got " << value;
   }
 
   emitRemark() << "succeeded";
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
+DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
   for (Operation *op : payload)
     op->emitRemark() << getMessage();
 
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
                                           transform::TransformState &state) {
   state.addExtension<TestTransformStateExtension>(getMessageAttr());
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply(
+DiagnosedSilenceableFailure
+mlir::test::TestCheckIfTestExtensionPresentOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   auto *extension = state.getExtension<TestTransformStateExtension>();
   if (!extension) {
     emitRemark() << "extension absent";
-    return DiagnosedSilencableFailure::success();
+    return DiagnosedSilenceableFailure::success();
   }
 
   InFlightDiagnostic diag = emitRemark()
@@ -175,54 +176,54 @@ DiagnosedSilencableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply(
            "operations");
   }
 
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
+DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   auto *extension = state.getExtension<TestTransformStateExtension>();
   if (!extension) {
     emitError() << "TestTransformStateExtension missing";
-    return DiagnosedSilencableFailure::definiteFailure();
+    return DiagnosedSilenceableFailure::definiteFailure();
   }
 
   if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
                                       getOperation())))
-    return DiagnosedSilencableFailure::definiteFailure();
-  return DiagnosedSilencableFailure::success();
+    return DiagnosedSilenceableFailure::definiteFailure();
+  return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilencableFailure mlir::test::TestRemoveTestExtensionOp::apply(
+DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   state.removeExtension<TestTransformStateExtension>();
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
-DiagnosedSilencableFailure mlir::test::TestTransformOpWithRegions::apply(
+DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
     transform::TransformResults &results, transform::TransformState &state) {
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 void mlir::test::TestTransformOpWithRegions::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 
-DiagnosedSilencableFailure
+DiagnosedSilenceableFailure
 mlir::test::TestBranchingTransformOpTerminator::apply(
     transform::TransformResults &results, transform::TransformState &state) {
-  return DiagnosedSilencableFailure::success();
+  return DiagnosedSilenceableFailure::success();
 }
 
 void mlir::test::TestBranchingTransformOpTerminator::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
 
-DiagnosedSilencableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
+DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   emitRemark() << getRemark();
   for (Operation *op : state.getPayloadOps(getTarget()))
     op->erase();
 
   if (getFailAfterErase())
-    return emitSilencableError() << "silencable error";
-  return DiagnosedSilencableFailure::success();
+    return emitSilenceableError() << "silencable error";
+  return DiagnosedSilenceableFailure::success();
 }
 
 namespace {