[MLIR][Python] Add SCFIfOp Python binding
authorchhzh123 <hc676@cornell.edu>
Sun, 13 Mar 2022 05:24:00 +0000 (05:24 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 13 Mar 2022 05:24:10 +0000 (05:24 +0000)
Current generated Python binding for the SCF dialect does not allow
users to call IfOp to create if-else branches on their own.
This PR sets up the default binding generation for scf.if operation
to address this problem.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D121076

mlir/python/mlir/dialects/_scf_ops_ext.py
mlir/test/python/dialects/scf.py

index a8924a7..3c3e673 100644 (file)
@@ -64,3 +64,44 @@ class ForOp:
     To obtain the loop-carried operands, use `iter_args`.
     """
     return self.body.arguments[1:]
+
+
+class IfOp:
+  """Specialization for the SCF if op class."""
+
+  def __init__(self,
+               cond,
+               results_=[],
+               *,
+               hasElse=False,
+               loc=None,
+               ip=None):
+    """Creates an SCF `if` operation.
+
+    - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
+    - `hasElse` determines whether the if operation has the else branch.
+    """
+    operands = []
+    operands.append(cond)
+    results = []
+    results.extend(results_)
+    super().__init__(
+        self.build_generic(
+            regions=2,
+            results=results,
+            operands=operands,
+            loc=loc,
+            ip=ip))
+    self.regions[0].blocks.append(*[])
+    if hasElse:
+        self.regions[1].blocks.append(*[])
+
+  @property
+  def then_block(self):
+    """Returns the then block of the if operation."""
+    return self.regions[0].blocks[0]
+
+  @property
+  def else_block(self):
+    """Returns the else block of the if operation."""
+    return self.regions[1].blocks[0]
index f434e80..c45931c 100644 (file)
@@ -82,3 +82,58 @@ def testOpsAsArguments():
 # CHECK:   iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
 # CHECK:     scf.yield %{{.*}}, %{{.*}}
 # CHECK:   return
+
+
+@constructAndPrintInModule
+def testIfWithoutElse():
+  bool = IntegerType.get_signless(1)
+  i32 = IntegerType.get_signless(32)
+
+  @builtin.FuncOp.from_py_func(bool)
+  def simple_if(cond):
+    if_op = scf.IfOp(cond)
+    with InsertionPoint(if_op.then_block):
+      one = arith.ConstantOp(i32, 1)
+      add = arith.AddIOp(one, one)
+      scf.YieldOp([])
+    return
+
+
+# CHECK: func @simple_if(%[[ARG0:.*]]: i1)
+# CHECK: scf.if %[[ARG0:.*]]
+# CHECK:   %[[ONE:.*]] = arith.constant 1
+# CHECK:   %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
+# CHECK: return
+
+
+@constructAndPrintInModule
+def testIfWithElse():
+  bool = IntegerType.get_signless(1)
+  i32 = IntegerType.get_signless(32)
+
+  @builtin.FuncOp.from_py_func(bool)
+  def simple_if_else(cond):
+    if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
+    with InsertionPoint(if_op.then_block):
+      x_true = arith.ConstantOp(i32, 0)
+      y_true = arith.ConstantOp(i32, 1)
+      scf.YieldOp([x_true, y_true])
+    with InsertionPoint(if_op.else_block):
+      x_false = arith.ConstantOp(i32, 2)
+      y_false = arith.ConstantOp(i32, 3)
+      scf.YieldOp([x_false, y_false])
+    add = arith.AddIOp(if_op.results[0], if_op.results[1])
+    return
+
+
+# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
+# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
+# CHECK:   %[[ZERO:.*]] = arith.constant 0
+# CHECK:   %[[ONE:.*]] = arith.constant 1
+# CHECK:   scf.yield %[[ZERO]], %[[ONE]]
+# CHECK: } else {
+# CHECK:   %[[TWO:.*]] = arith.constant 2
+# CHECK:   %[[THREE:.*]] = arith.constant 3
+# CHECK:   scf.yield %[[TWO]], %[[THREE]]
+# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
+# CHECK: return