Add custom builder for AffineIfOp
authorNagy Mostafa <nagy.mostafa@gmail.com>
Fri, 6 Sep 2019 18:02:31 +0000 (11:02 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Sep 2019 18:03:03 +0000 (11:03 -0700)
Closes tensorflow/mlir#109

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/109 from nmostafa:nmostafa/AffineIfOp 7dbf2115f0092ffab26381ea8704aa05a0253971
PiperOrigin-RevId: 267633077

mlir/include/mlir/Dialect/AffineOps/AffineOps.td
mlir/include/mlir/EDSC/Intrinsics.h
mlir/lib/Dialect/AffineOps/AffineOps.cpp
mlir/test/EDSC/builder-api-test.cpp

index 4961ce8..370adfd 100644 (file)
@@ -212,9 +212,10 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
   let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
 
   let skipDefaultBuilders = 1;
+
   let builders = [
     OpBuilder<"Builder *builder, OperationState *result, "
-              "Value *cond, bool withElseRegion">
+              "IntegerSet set, ArrayRef<Value *> args, bool withElseRegion">
   ];
 
   let extraClassDeclaration = [{
index 98e9cea..a9a15df 100644 (file)
@@ -198,10 +198,11 @@ template <typename Op> struct OperationBuilder : public OperationHandle {
   OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
 };
 
-using alloc = ValueBuilder<AllocOp>;
 using affine_apply = ValueBuilder<AffineApplyOp>;
+using affine_if = OperationBuilder<AffineIfOp>;
 using affine_load = ValueBuilder<AffineLoadOp>;
 using affine_store = OperationBuilder<AffineStoreOp>;
+using alloc = ValueBuilder<AllocOp>;
 using call = OperationBuilder<mlir::CallOp>;
 using constant_float = ValueBuilder<ConstantFloatOp>;
 using constant_index = ValueBuilder<ConstantIndexOp>;
index 2161ae0..3a2e0b0 100644 (file)
@@ -1662,6 +1662,17 @@ void AffineIfOp::setConditional(IntegerSet set, ArrayRef<Value *> operands) {
   getOperation()->setOperands(operands);
 }
 
+void AffineIfOp::build(Builder *builder, OperationState *result, IntegerSet set,
+                       ArrayRef<Value *> args, bool withElseRegion) {
+  result->addOperands(args);
+  result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
+  Region *thenRegion = result->addRegion();
+  Region *elseRegion = result->addRegion();
+  AffineIfOp::ensureTerminator(*thenRegion, *builder, result->location);
+  if (withElseRegion)
+    AffineIfOp::ensureTerminator(*elseRegion, *builder, result->location);
+}
+
 namespace {
 // This is a pattern to canonicalize an affine if op's conditional (integer
 // set + operands).
index 367355c..e8dbc87 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/EDSC/Helpers.h"
 #include "mlir/EDSC/Intrinsics.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StandardTypes.h"
@@ -746,6 +747,41 @@ TEST_FUNC(empty_map_load_store) {
   f.erase();
 }
 
+// CHECK-LABEL: func @affine_if_op
+// CHECK:       affine.if ([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}}
+// CHECK-NOT:   else
+// CHECK:       affine.if ([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}}
+// CHECK-NEXT:  } else {
+TEST_FUNC(affine_if_op) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+  using namespace edsc::op;
+  auto f32Type = FloatType::getF32(&globalContext());
+  auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
+  auto f = makeFunction("affine_if_op", {}, {memrefType});
+
+  OpBuilder builder(f.getBody());
+  ScopedContext scope(builder, f.getLoc());
+
+  ValueHandle zero = constant_index(0), ten = constant_index(10);
+
+  SmallVector<bool, 4> isEq = {false, false, false, false};
+  SmallVector<AffineExpr, 4> affineExprs = {
+      builder.getAffineDimExpr(0),    // d0 >= 0
+      builder.getAffineDimExpr(1),    // d1 >= 0
+      builder.getAffineSymbolExpr(0), // s0 >= 0
+      builder.getAffineSymbolExpr(1)  // s1 >= 0
+  };
+  auto intSet = builder.getIntegerSet(2, 2, affineExprs, isEq);
+
+  SmallVector<Value *, 4> affineIfArgs = {zero, zero, ten, ten};
+  intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/false);
+  intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/true);
+
+  f.print(llvm::outs());
+  f.erase();
+}
+
 int main() {
   RUN_TESTS();
   return 0;