From 2b81d3c6c6fd7b3fcffba626c5df3a9a66a3deb1 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 16 Jan 2020 09:30:17 -0500 Subject: [PATCH] [mlir][Linalg] Fix Linalg EDSC builders Summary: This diff fixes the fact that the method `mlir::edsc::makeGenericLinalgOp` incorrectly adds 2 blocks to Linalg ops. Tests are updated accordingly. Reviewers: ftynse, hanchung, herhut, pifon2a, asaadaldien Reviewed By: asaadaldien Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72780 --- mlir/include/mlir/EDSC/Builders.h | 13 ++++++++++++ mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 12 ++++++----- mlir/lib/EDSC/Builders.cpp | 33 +++++++++++++++++++++++++++++++ mlir/test/EDSC/builder-api-test.cpp | 6 +++--- 4 files changed, 56 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 14a4e5a..14ce342 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -251,6 +251,16 @@ public: /// not yet bound to mlir::Value. BlockBuilder(BlockHandle *bh, ArrayRef args); + /// Constructs a new mlir::Block with argument types derived from `args` and + /// appends it as the last block in the region. + /// Captures the new block in `bh` and its arguments into `args`. + /// Enters the new mlir::Block* and sets the insertion point to its end. + /// + /// Prerequisites: + /// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are + /// not yet bound to mlir::Value. + BlockBuilder(BlockHandle *bh, Region ®ion, ArrayRef args); + /// The only purpose of this operator is to serve as a sequence point so that /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is /// scoped within a BlockBuilder. @@ -450,6 +460,9 @@ public: /// Delegates block creation to MLIR and wrap the resulting mlir::Block. static BlockHandle create(ArrayRef argTypes); + /// Delegates block creation to MLIR and wrap the resulting mlir::Block. + static BlockHandle createInRegion(Region ®ion, ArrayRef argTypes); + operator bool() { return block != nullptr; } operator mlir::Block *() { return block; } mlir::Block *getBlock() { return block; } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 8d58125..0940f56 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -184,14 +184,16 @@ Operation *mlir::edsc::makeGenericLinalgOp( ? getElementTypeOrSelf(it.value()) : it.value().getType()); - assert(op->getRegions().front().empty()); - op->getRegions().front().push_front(new Block); - OpBuilder bb(op->getRegions().front()); - ScopedContext scope(bb, op->getLoc()); + assert(op->getNumRegions() == 1); + assert(op->getRegion(0).empty()); + OpBuilder opBuilder(op); + ScopedContext scope(opBuilder, op->getLoc()); BlockHandle b; auto handles = makeValueHandles(blockTypes); - BlockBuilder(&b, makeHandlePointers(MutableArrayRef(handles)))( + BlockBuilder(&b, op->getRegion(0), + makeHandlePointers(MutableArrayRef(handles)))( [&] { regionBuilder(b.getBlock()->getArguments()); }); + assert(op->getRegion(0).getBlocks().size() == 1); return op; } diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index b966003..33aed8e 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -133,6 +133,22 @@ BlockHandle mlir::edsc::BlockHandle::create(ArrayRef argTypes) { return res; } +BlockHandle mlir::edsc::BlockHandle::createInRegion(Region ®ion, + ArrayRef argTypes) { + auto ¤tB = ScopedContext::getBuilder(); + BlockHandle res; + region.push_back(new Block); + res.block = ®ion.back(); + // createBlock sets the insertion point inside the block. + // We do not want this behavior when using declarative builders with nesting. + OpBuilder::InsertionGuard g(currentB); + currentB.setInsertionPoint(res.block, res.block->begin()); + for (auto t : argTypes) { + res.block->addArgument(t); + } + return res; +} + static Optional emitStaticFor(ArrayRef lbs, ArrayRef ubs, int64_t step) { @@ -285,6 +301,23 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, enter(bh->getBlock()); } +mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, Region ®ion, + ArrayRef args) { + assert(!*bh && "BlockHandle already captures a block, use " + "the explicit BockBuilder(bh, Append())({}) syntax instead."); + SmallVector types; + for (auto *a : args) { + assert(!a->hasValue() && + "Expected delayed ValueHandle that has not yet captured."); + types.push_back(a->getType()); + } + *bh = BlockHandle::createInRegion(region, types); + for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) { + *(std::get<0>(it)) = ValueHandle(std::get<1>(it)); + } + enter(bh->getBlock()); +} + /// Only serves as an ordering point between entering nested block and creating /// stmts. void mlir::edsc::BlockBuilder::operator()(function_ref fun) { diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index d991188..27c5d4d 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -876,7 +876,7 @@ TEST_FUNC(linalg_pointwise_test) { // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} -/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: linalg.yield %[[a4]] : f32 @@ -906,7 +906,7 @@ TEST_FUNC(linalg_matmul_test) { // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} -/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: linalg.yield %[[a4]] : f32 @@ -937,7 +937,7 @@ TEST_FUNC(linalg_conv_nhwc) { // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} -// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: linalg.yield %[[a4]] : f32 -- 2.7.4