} // namespace func
namespace scf {
class ForOp;
+class IfOp;
} // namespace scf
} // namespace mlir
}];
}
+def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let description = [{
+ Given an scf.if conditional, inject user-defined information that it is
+ always safe to execute only the if or else branch.
+
+ This is achieved by just replacing the scf.if by the content of one of its
+ branches.
+
+ This is particularly useful for user-controlled rewriting of conditionals
+ that exist solely to guard against out-of-bounds behavior.
+
+ At the moment, no assume or assert operation is emitted as it is not always
+ desirable. In the future, this may be controlled by a dedicated attribute.
+
+ #### Return modes
+
+ The transform only consumes its operand and does not produce any result.
+ The transform definitely fails if `take_else_branch` is specified and the
+ `else` region is empty.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ OptionalAttr<UnitAttr>:$take_else_branch);
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $target
+ (`take_else_branch` $take_else_branch^)?
+ attr-dict
+ `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::scf::IfOp ifOp,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // SCF_TRANSFORM_OPS
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
}
//===----------------------------------------------------------------------===//
+// TakeAssumedBranchOp
+//===----------------------------------------------------------------------===//
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
+ Region ®ion) {
+ assert(llvm::hasSingleElement(region) && "expected single-region block");
+ Block *block = ®ion.front();
+ Operation *terminator = block->getTerminator();
+ ValueRange results = terminator->getOperands();
+ rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{});
+ rewriter.replaceOp(op, results);
+ rewriter.eraseOp(terminator);
+}
+
+DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
+ scf::IfOp ifOp, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ TrackingListener listener(state, *this);
+ IRRewriter rewriter(ifOp->getContext(), &listener);
+ rewriter.setInsertionPoint(ifOp);
+
+ Region ®ion =
+ getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
+ if (!llvm::hasSingleElement(region)) {
+ return emitDefiniteFailure()
+ << "requires an scf.if op with a single-block "
+ << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
+ }
+ replaceOpWithRegion(rewriter, ifOp, region);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::TakeAssumedBranchOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics --allow-unregistered-dialect | FileCheck %s
+
+func.func @if_no_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+ scf.if %cond {
+ "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
+ scf.yield
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %if = transform.structured.match ops{["scf.if"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+
+ // expected-error @+1 {{requires an scf.if op with a single-block `else` region}}
+ transform.scf.take_assumed_branch %if take_else_branch
+ : (!transform.any_op) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: tile_tensor_pad
+func.func @tile_tensor_pad(
+ %arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index)
+ -> tensor<20x40xf32>
+{
+ // CHECK: scf.forall
+ // CHECK-NOT: scf.if
+ // CHECK-NOT: tensor.generate
+ // CHECK-NOT: else
+ // CHECK: tensor.pad {{.*}} nofold
+ %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
+ ^bb0(%arg9: index, %arg10: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<20x40xf32>
+ return %0 : tensor<20x40xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ : (!transform.any_op) -> !pdl.operation
+ transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
+
+ %if = transform.structured.match ops{["scf.if"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.scf.take_assumed_branch %if take_else_branch
+ : (!transform.any_op) -> ()
+}