[mlir] make DiagnosedSilenceableError(LogicalResult) ctor private
authorAlex Zinenko <zinenko@google.com>
Wed, 2 Nov 2022 15:22:43 +0000 (15:22 +0000)
committerAlex Zinenko <zinenko@google.com>
Mon, 12 Dec 2022 12:52:06 +0000 (12:52 +0000)
Now we have more convenient functions to construct silenceable errors
while emitting diagnostics, and the constructor is ambiguous as it
doesn't tell whether the logical error is silencebale or definite.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D137257

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

index bbcfabe..99e12a1 100644 (file)
@@ -34,7 +34,6 @@ namespace mlir {
 /// failures as their diagnostics have been already reported to the user.
 class [[nodiscard]] DiagnosedSilenceableFailure {
 public:
-  explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
   DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete;
   DiagnosedSilenceableFailure &
   operator=(const DiagnosedSilenceableFailure &) = delete;
@@ -156,6 +155,7 @@ public:
   }
 
 private:
+  explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
   explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
       : result(failure()) {
     diagnostics.emplace_back(std::move(diagnostic));
index fe29f30..3378153 100644 (file)
@@ -51,23 +51,12 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
   ];
 
   let extraSharedClassDeclaration = [{
-    /// Emits a generic transform error for the current transform operation
-    /// targeting the given Payload IR operation and returns failure. Should
-    /// be only used as a last resort when the transformation itself provides
-    /// no further indication as to the reason of the failure.
-    ::mlir::LogicalResult reportUnknownTransformError(
-        ::mlir::Operation *target) {
-      ::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply";
-      diag.attachNote(target->getLoc()) << "attempted to apply to this op";
-      return diag;
-    }
-
     /// Creates the silenceable failure object with a diagnostic located at the
     /// current operation. Silenceable failure must be suppressed or reported
     /// explicitly at some later time.
     DiagnosedSilenceableFailure
     emitSilenceableError(const ::llvm::Twine &message = {}) {
-      return ::mlir::emitSilenceableFailure($_op);
+      return ::mlir::emitSilenceableFailure($_op, message);
     }
 
     /// Creates the definite failure object with a diagnostic located at the
@@ -78,6 +67,17 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
       return ::mlir::emitDefiniteFailure($_op, message);
     }
 
+    /// Emits a generic definite failure for the current transform operation
+    /// targeting the given Payload IR operation and returns failure. Should
+    /// be only used as a last resort when the transformation itself provides
+    /// no further indication as to the reason of the failure.
+    DiagnosedDefiniteFailure emitDefaultDefiniteFailure(
+        ::mlir::Operation *target) {
+      auto diag = ::mlir::emitDefiniteFailure($_op, "failed to apply");
+      diag.attachNote(target->getLoc()) << "attempted to apply to this op";
+      return diag;
+    }
+
     /// Creates the default silenceable failure for a transform op that failed
     /// to properly apply to a target.
     DiagnosedSilenceableFailure emitDefaultSilenceableFailure(
index baf18dc..57cce19 100644 (file)
@@ -119,7 +119,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
                                        blkSizeX, blkSizeY, blkSizeZ);
   rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
   rewriter.create<TerminatorOp>(loc);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 /// Alter kernel configuration of the given kernel.
index 3ae4163..c8dd269 100644 (file)
@@ -79,20 +79,20 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
                                                      Conv1DNwcWcfOp>>(target);
   if (succeeded(windowedNhwc)) {
     results.push_back(*windowedNhwc);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   FailureOr<LinalgOp> windowedNchw =
       tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
                                                      Conv1DNcwFcwOp>>(target);
   if (succeeded(windowedNchw)) {
     results.push_back(*windowedNchw);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   FailureOr<LinalgOp> depthwise =
       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
   if (succeeded(depthwise)) {
     results.push_back(*depthwise);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   results.assign(1, nullptr);
   return emitDefaultSilenceableFailure(target);
@@ -206,7 +206,8 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
         return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
             rewriter, tilingInterfaceOp, tileAndFuseOptions);
       });
-  return DiagnosedSilenceableFailure(result);
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
 }
 
 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
@@ -568,12 +569,12 @@ transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
   // Exit early if no transformation is needed.
   if (isa<GenericOp>(target)) {
     results.push_back(target);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
   if (succeeded(generic)) {
     results.push_back(generic->getOperation());
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   results.assign(1, nullptr);
   return emitDefaultSilenceableFailure(target);
@@ -592,7 +593,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
   // Exit early if no transformation is needed.
   if (interchangeVector.empty()) {
     results.push_back(target);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   TrivialPatternRewriter rewriter(target->getContext());
   FailureOr<GenericOp> res =
@@ -600,7 +601,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
   if (failed(res))
     return DiagnosedSilenceableFailure::definiteFailure();
   results.push_back(res->getOperation());
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 LogicalResult transform::InterchangeOp::verify() {
@@ -639,8 +640,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
   if (payloadOps.size() != 1) {
     results.set(getResult().cast<OpResult>(), {});
-    return DiagnosedSilenceableFailure(
-        this->emitOpError("requires exactly one target handle"));
+    return emitDefiniteFailure("requires exactly one target handle");
   }
 
   SmallVector<Operation *> res;
@@ -687,7 +687,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
 
   payloadOps.front()->walk(matchFun);
   results.set(getResult().cast<OpResult>(), res);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===---------------------------------------------------------------------===//
@@ -792,7 +792,7 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
       tryApply<LinalgPaddingPattern>(target, paddingOptions);
   if (succeeded(result)) {
     results.push_back(result->getOperation());
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
 
   results.assign(1, nullptr);
@@ -866,15 +866,15 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target,
     promotionOptions = promotionOptions.setAlignment(*getAlignment());
 
   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
 
   TrivialPatternRewriter rewriter(target->getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
   if (failed(res))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
   results.push_back(target);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -909,7 +909,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
     replacements.push_back(replacement);
   }
   transformResults.set(getReplacement().cast<OpResult>(), replacements);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 void transform::ReplaceOp::getEffects(
@@ -972,10 +972,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
   FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
   if (failed(maybeTilingResult))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
 
   results.append(maybeTilingResult->tiledOps);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1171,13 +1171,13 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
           : splitReduction(rewriter, target, splitFn, getUseAlloc());
   if (failed(splitResult))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
 
   results.push_back(splitResult->initOrAlloc);
   results.push_back(splitResult->fillOp);
   results.push_back(splitResult->splitLinalgOp);
   results.push_back(splitResult->resultCombiningLinalgOp);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1200,12 +1200,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
       sizes);
 
   if (failed(result))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultSilenceableFailure(target);
   results.push_back(result->loops.front());
   results.push_back(result->initialOp);
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1235,7 +1235,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
   results.push_back(result->initialOp);
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1523,7 +1523,7 @@ static DiagnosedSilenceableFailure unpackPDLOperations(
     }
   }
 
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
@@ -1533,7 +1533,7 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> mapping,
     SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
   if (targets.empty())
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
 
   // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
   // Convert to OpFoldResults[index attributes or payload op].
@@ -1577,7 +1577,7 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     tileOps.push_back(tilingResult->tileOp);
     tiledOps.push_back(tilingResult->tiledOp);
   }
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
@@ -1604,7 +1604,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
   transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
   transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
 
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 void transform::TileToForeachThreadOp::getEffects(
@@ -1852,10 +1852,10 @@ transform::VectorizeOp::applyToOne(Operation *target,
     linalg::populatePadOpVectorizationPatterns(patterns);
 
   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
-    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    return emitDefaultDefiniteFailure(target);
 
   results.push_back(target);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
index 1d7a8b7..391164c 100644 (file)
@@ -33,7 +33,7 @@ transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target,
   }
 
   results.push_back(newBuffer.value());
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
index 8777662..21deab6 100644 (file)
@@ -103,10 +103,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
     FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
         rewriter, location, exec.getRegion(), getFuncName(), &call);
 
-    if (failed(outlined)) {
-      (void)reportUnknownTransformError(target);
-      return DiagnosedSilenceableFailure::definiteFailure();
-    }
+    if (failed(outlined))
+      return emitDefaultDefiniteFailure(target);
 
     if (symbolTableOp) {
       SymbolTable &symbolTable =
@@ -139,7 +137,7 @@ transform::LoopPeelOp::applyToOne(scf::ForOp target,
       scf::peelAndCanonicalizeForLoop(rewriter, target, result);
   // TODO: Return both the peeled loop and the remainder loop.
   results.push_back(failed(status) ? target : result);
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -200,7 +198,7 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
       pattern.returningMatchAndRewrite(target, rewriter);
   if (succeeded(patternResult)) {
     results.push_back(*patternResult);
-    return DiagnosedSilenceableFailure(success());
+    return DiagnosedSilenceableFailure::success();
   }
   results.assign(1, nullptr);
   return emitDefaultSilenceableFailure(target);
@@ -225,7 +223,7 @@ transform::LoopUnrollOp::applyToOne(Operation *op,
     diag << "Op failed to unroll";
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
-  return DiagnosedSilenceableFailure(success());
+  return DiagnosedSilenceableFailure::success();
 }
 
 //===----------------------------------------------------------------------===//