From 8154370b49b3de5276622d0e65c73a5a0f17f56d Mon Sep 17 00:00:00 2001 From: Nagy Mostafa Date: Fri, 6 Sep 2019 11:02:31 -0700 Subject: [PATCH] Add custom builder for AffineIfOp Closes tensorflow/mlir#109 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/109 from nmostafa:nmostafa/AffineIfOp 7dbf2115f0092ffab26381ea8704aa05a0253971 PiperOrigin-RevId: 267633077 --- mlir/include/mlir/Dialect/AffineOps/AffineOps.td | 3 +- mlir/include/mlir/EDSC/Intrinsics.h | 3 +- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 11 ++++++++ mlir/test/EDSC/builder-api-test.cpp | 36 ++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td index 4961ce8..370adfd 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -212,9 +212,10 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> { let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); let skipDefaultBuilders = 1; + let builders = [ OpBuilder<"Builder *builder, OperationState *result, " - "Value *cond, bool withElseRegion"> + "IntegerSet set, ArrayRef args, bool withElseRegion"> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index 98e9cea..a9a15df 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -198,10 +198,11 @@ template struct OperationBuilder : public OperationHandle { OperationBuilder() : OperationHandle(OperationHandle::create()) {} }; -using alloc = ValueBuilder; using affine_apply = ValueBuilder; +using affine_if = OperationBuilder; using affine_load = ValueBuilder; using affine_store = OperationBuilder; +using alloc = ValueBuilder; using call = OperationBuilder; using constant_float = ValueBuilder; using constant_index = ValueBuilder; diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 2161ae0..3a2e0b0 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -1662,6 +1662,17 @@ void AffineIfOp::setConditional(IntegerSet set, ArrayRef operands) { getOperation()->setOperands(operands); } +void AffineIfOp::build(Builder *builder, OperationState *result, IntegerSet set, + ArrayRef args, bool withElseRegion) { + result->addOperands(args); + result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(set)); + Region *thenRegion = result->addRegion(); + Region *elseRegion = result->addRegion(); + AffineIfOp::ensureTerminator(*thenRegion, *builder, result->location); + if (withElseRegion) + AffineIfOp::ensureTerminator(*elseRegion, *builder, result->location); +} + namespace { // This is a pattern to canonicalize an affine if op's conditional (integer // set + operands). diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 367355c..e8dbc87 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -23,6 +23,7 @@ #include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" @@ -746,6 +747,41 @@ TEST_FUNC(empty_map_load_store) { f.erase(); } +// CHECK-LABEL: func @affine_if_op +// CHECK: affine.if ([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}} +// CHECK-NOT: else +// CHECK: affine.if ([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}} +// CHECK-NEXT: } else { +TEST_FUNC(affine_if_op) { + using namespace edsc; + using namespace edsc::intrinsics; + using namespace edsc::op; + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); + auto f = makeFunction("affine_if_op", {}, {memrefType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + + ValueHandle zero = constant_index(0), ten = constant_index(10); + + SmallVector isEq = {false, false, false, false}; + SmallVector affineExprs = { + builder.getAffineDimExpr(0), // d0 >= 0 + builder.getAffineDimExpr(1), // d1 >= 0 + builder.getAffineSymbolExpr(0), // s0 >= 0 + builder.getAffineSymbolExpr(1) // s1 >= 0 + }; + auto intSet = builder.getIntegerSet(2, 2, affineExprs, isEq); + + SmallVector affineIfArgs = {zero, zero, ten, ten}; + intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/false); + intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/true); + + f.print(llvm::outs()); + f.erase(); +} + int main() { RUN_TESTS(); return 0; -- 2.7.4