[mlir][SCF] Add an scf.take_assumed_branch transform op.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 12 Apr 2023 13:19:45 +0000 (06:19 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 12 Apr 2023 15:47:20 +0000 (08:47 -0700)
Given an scf.if conditional, using this transformation is akin to injecting
user-specified information that it is always safe to execute only the specified
`if` or `else` branch.

This is achieved by just replacing the scf.if by the content of one of its
branches.

This is particularly useful for user-controlled rewriting of conditionals
that exist solely to guard against out-of-bounds behavior.

At the moment, no assume or assert operation is emitted as it is not always
desirable. In the future, this may be controlled by a dedicated attribute.

Differential Revision: https://reviews.llvm.org/D148125

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir [new file with mode: 0644]

index 91e4214..c5cc2da 100644 (file)
@@ -19,6 +19,7 @@ class FuncOp;
 } // namespace func
 namespace scf {
 class ForOp;
+class IfOp;
 } // namespace scf
 } // namespace mlir
 
index b286850..0399a5a 100644 (file)
@@ -215,4 +215,45 @@ def LoopCoalesceOp : Op<Transform_Dialect, "loop.coalesce", [
   }];
 }
 
+def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Given an scf.if conditional, inject user-defined information that it is
+    always safe to execute only the if or else branch. 
+    
+    This is achieved by just replacing the scf.if by the content of one of its
+    branches.
+
+    This is particularly useful for user-controlled rewriting of conditionals
+    that exist solely to guard against out-of-bounds behavior.
+
+    At the moment, no assume or assert operation is emitted as it is not always
+    desirable. In the future, this may be controlled by a dedicated attribute.
+
+    #### Return modes
+
+    The transform only consumes its operand and does not produce any result.
+    The transform definitely fails if `take_else_branch` is specified and the
+    `else` region is empty.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       OptionalAttr<UnitAttr>:$take_else_branch);
+  let results = (outs);
+
+  let assemblyFormat = [{
+      $target
+      (`take_else_branch` $take_else_branch^)?
+      attr-dict
+       `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::scf::IfOp ifOp,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // SCF_TRANSFORM_OPS
index b35e104..b87fc77 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
 using namespace mlir;
@@ -246,6 +247,46 @@ transform::LoopCoalesceOp::applyToOne(Operation *op,
 }
 
 //===----------------------------------------------------------------------===//
+// TakeAssumedBranchOp
+//===----------------------------------------------------------------------===//
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
+                                Region &region) {
+  assert(llvm::hasSingleElement(region) && "expected single-region block");
+  Block *block = &region.front();
+  Operation *terminator = block->getTerminator();
+  ValueRange results = terminator->getOperands();
+  rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{});
+  rewriter.replaceOp(op, results);
+  rewriter.eraseOp(terminator);
+}
+
+DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
+    scf::IfOp ifOp, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(ifOp->getContext(), &listener);
+  rewriter.setInsertionPoint(ifOp);
+
+  Region &region =
+      getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
+  if (!llvm::hasSingleElement(region)) {
+    return emitDefiniteFailure()
+           << "requires an scf.if op with a single-block "
+           << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
+  }
+  replaceOpWithRegion(rewriter, ifOp, region);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::TakeAssumedBranchOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir b/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir
new file mode 100644 (file)
index 0000000..15d9e56
--- /dev/null
@@ -0,0 +1,50 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics --allow-unregistered-dialect | FileCheck %s
+
+func.func @if_no_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+  scf.if %cond {
+    "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
+    scf.yield
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+  %if = transform.structured.match ops{["scf.if"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+
+  // expected-error @+1 {{requires an scf.if op with a single-block `else` region}}
+  transform.scf.take_assumed_branch %if take_else_branch 
+    : (!transform.any_op) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: tile_tensor_pad
+func.func @tile_tensor_pad(
+  %arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index) 
+    -> tensor<20x40xf32>
+{
+  //     CHECK: scf.forall
+  // CHECK-NOT:   scf.if
+  // CHECK-NOT:     tensor.generate
+  // CHECK-NOT:   else
+  //     CHECK:     tensor.pad {{.*}} nofold 
+  %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
+        ^bb0(%arg9: index, %arg10: index):
+          tensor.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<20x40xf32>
+  return %0 : tensor<20x40xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 
+    : (!transform.any_op) -> !pdl.operation
+  transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
+
+  %if = transform.structured.match ops{["scf.if"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.scf.take_assumed_branch %if take_else_branch 
+    : (!transform.any_op) -> ()
+}