From: Nicolas Vasilache Date: Tue, 28 May 2019 22:05:51 +0000 (-0700) Subject: Add a linalg.for operation to support non-affine loop constructs X-Git-Tag: llvmorg-11-init~1466^2~1576 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3ad0fa95d1eab30c517817a6bcc191ecdf77a1e8;p=platform%2Fupstream%2Fllvm.git Add a linalg.for operation to support non-affine loop constructs The affine.for operation has restrictions that make it suitable for dependence analysis. The Linalg dialect aims at being more general. This CL introduces linalg.for, and its associated terminator, along with a simple roundtripping test. A `linalg.for` only takes one value of index type for lower bound, upper bound and step. Example usage: ``` linalg.for %iv = %lb to %ub step %step { ... // body } ``` -- PiperOrigin-RevId: 250369722 --- diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 26a5cf3..92f2630 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -77,6 +77,82 @@ public: } }; +/// The "linalg.for" operation represents a loop nest taking 3 SSA value as +/// operands that represent the lower bound, upper bound and step respectively. +/// The operation defines an SSA value for its induction variable. It has one +/// region capturing the loop body. The induction variable is represented as an +/// argument of this region. This SSA value always has type index, which is the +/// size of the machine word. The step is a value of type index, required to be +/// positive. +/// The lower and upper bounds specify a half-open range: the range includes the +/// lower bound but does not include the upper bound. +/// +/// The body region must contain exactly one block that terminates with +/// "linalg.terminator". Calling linalg::ForOp::build will create such region +/// and insert the terminator, so will the parsing even in cases if it is absent +/// from the custom format. For example: +/// +/// ```mlir +/// linalg.for %iv = %lb to %ub step %step { +/// ... // body +/// } +/// ``` +class ForOp + : public Op::Impl, OpTrait::ZeroResult> { +public: + using Op::Op; + + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, Value *lb, + Value *ub, Value *step); + LogicalResult verify(); + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + + static StringRef getOperationName() { return "linalg.for"; } + + /// Return a Builder set up to insert operations immediately before the + /// terminator. + FuncBuilder getBodyBuilder() { + Block *body = getBody(); + return FuncBuilder(body, std::prev(body->end())); + } + + /// Get the body of the ForOp. + Block *getBody() { return &getRegion().front(); } + + /// Get the body region of the ForOp. + Region &getRegion() { return getOperation()->getRegion(0); } + + /// Returns the induction variable for this loop. + Value *getInductionVar() { return getBody()->getArgument(0); } + + //===--------------------------------------------------------------------===// + // Bounds and step + //===--------------------------------------------------------------------===// + /// Returns the lower bound operand. + Value *getLowerBound() { return getOperand(0); } + + /// Returns the upper bound operand. + Value *getUpperBound() { return getOperand(1); } + + /// Returns loop step. + Value *getStep() { return getOperand(2); } + + /// Set lower bound. + void setLowerBound(Value *lb) { setOperand(0, lb); } + + /// Set upper bound. + void setUpperBound(Value *ub) { setOperand(1, ub); } + + /// Set loop step. + void setStep(Value *step) { setOperand(2, step); } +}; + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +ForOp getForInductionVarOwner(Value *val); + /// A linalg.LoadOp is the counterpart of load but operating on ViewType /// instead of MemRefType. /// diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Linalg/IR/LinalgOps.td index 824db9a..d9b0f655 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.td @@ -112,4 +112,26 @@ def RangeIntersectOp : Linalg_Op<"range_intersect", [NoSideEffect]>, }]>]; } +def TerminatorOp : + Linalg_Op<"terminator", [NativeOpTrait<"IsTerminator">]> { + let summary = "linalg terminator operation"; + let description = [{ + "linalg.terminator" is a special terminator operation for blocks inside + linalg loops and branches. It unconditionally transmits the control flow to + the successor of the operation enclosing the region. + + This operation does _not_ have a custom syntax. However, linalg control + operations omit the terminator in their custom syntax for brevity. + + linalg.terminator + }]; + + // No custom parsing/printing form. + let parser = ?; + let printer = ?; + + // Fully specified by traits. + let verifier = ?; +} + #endif // LINALG_OPS diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 24a47da..f222fce 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -112,6 +112,129 @@ ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser, //////////////////////////////////////////////////////////////////////////////// // ForOp. //////////////////////////////////////////////////////////////////////////////// +// Check that if a "block" has a terminator, it is an `TerminatorOp`. +static LogicalResult checkHasTerminator(OpState &op, Block &block) { + if (block.empty() || isa(block.back())) + return success(); + + op.emitOpError("expects regions to end with '" + + TerminatorOp::getOperationName() + "'") + .attachNote() + << "in custom textual format, the absence of terminator implies '" + << TerminatorOp::getOperationName() << "'"; + return failure(); +} + +// Insert `linalg.terminator` at the end of the ForOp only region's only block +// if it does not have a terminator already. If a new `linalg.terminator` is +// inserted, the location is specified by `loc`. If the region is empty, insert +// a new block first. +static void ensureTerminator(Region ®ion, Builder &builder, Location loc) { + if (region.empty()) + region.push_back(new Block); + + Block &block = region.back(); + if (!block.empty() && block.back().isKnownTerminator()) + return; + + OperationState terminatorState(builder.getContext(), loc, + TerminatorOp::getOperationName()); + TerminatorOp::build(&builder, &terminatorState); + block.push_back(Operation::create(terminatorState)); +} + +void mlir::linalg::ForOp::build(Builder *builder, OperationState *result, + Value *lb, Value *ub, Value *step) { + result->addOperands({lb, ub, step}); + Region *bodyRegion = result->addRegion(); + Block *body = new Block(); + body->addArgument(IndexType::get(builder->getContext())); + bodyRegion->push_back(body); + ensureTerminator(*bodyRegion, *builder, result->location); +} + +LogicalResult mlir::linalg::ForOp::verify() { + if (!getLowerBound()->getType().isa()) + return emitOpError("lower bound operand must be an index"); + if (!getUpperBound()->getType().isa()) + return emitOpError("upper bound operand must be an index"); + if (!getLowerBound()->getType().dyn_cast()) + return emitOpError("step operand must be an index"); + if (auto cst = + dyn_cast_or_null(getLowerBound()->getDefiningOp())) + if (cst.getValue() <= 0) + return emitOpError("constant step operand must be positive"); + + if (std::next(getOperation()->getRegions().begin()) != + getOperation()->getRegions().end()) + return emitOpError("operation expected to have exactly one region"); + + auto &bodyRegion = getOperation()->getRegion(0); + // The body region must contain a single basic block. + if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end()) + return emitOpError("expected body region to have a single block"); + // Check that the body defines as single block argument for the induction + // variable. + auto *body = getBody(); + if (body->getNumArguments() != 1 || + !body->getArgument(0)->getType().isIndex()) + return emitOpError("expected body to have a single index argument for " + "the induction variable"); + if (failed(checkHasTerminator(*this, *body))) + return failure(); + return success(); +} + +void mlir::linalg::ForOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getInductionVar() << " = " + << *getLowerBound() << " to " << *getUpperBound() << " step " + << *getStep(); + p->printRegion(getRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + p->printOptionalAttrDict(getAttrs()); +} + +ParseResult mlir::linalg::ForOp::parse(OpAsmParser *parser, + OperationState *result) { + auto &builder = parser->getBuilder(); + OpAsmParser::OperandType inductionVariable, lb, ub, step; + // Parse the induction variable followed by '='. + if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) + return failure(); + + // Parse loop bounds. + Type indexType = builder.getIndexType(); + if (parser->parseOperand(lb) || + parser->resolveOperand(lb, indexType, result->operands) || + parser->parseKeyword("to") || parser->parseOperand(ub) || + parser->resolveOperand(ub, indexType, result->operands) || + parser->parseKeyword("step") || parser->parseOperand(step) || + parser->resolveOperand(step, indexType, result->operands)) + return failure(); + + // Parse the body region. + Region *body = result->addRegion(); + if (parser->parseRegion(*body, inductionVariable, indexType)) + return failure(); + + ensureTerminator(*body, builder, result->location); + + // Parse the optional attribute list. + if (parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + + return success(); +} + +mlir::linalg::ForOp mlir::linalg::getForInductionVarOwner(Value *val) { + auto *ivArg = dyn_cast(val); + if (!ivArg) + return ForOp(); + assert(ivArg->getOwner() && "unlinked block argument"); + auto *containingInst = ivArg->getOwner()->getContainingOp(); + return dyn_cast_or_null(containingInst); +} //////////////////////////////////////////////////////////////////////////////// // LoadOp. diff --git a/mlir/lib/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Linalg/IR/LinalgTypes.cpp index 0e20eb8..453463b 100644 --- a/mlir/lib/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Linalg/IR/LinalgTypes.cpp @@ -31,7 +31,7 @@ using namespace mlir::linalg; mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect("linalg", context) { addTypes(); - addOperations(); addOperations< #define GET_OP_LIST diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index c42b8d7..d730a20 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -70,3 +70,25 @@ func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.ran // CHECK-LABEL: func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.range { // CHECK-NEXT: %0 = linalg.range_intersect %arg0, %arg1 : !linalg.range // CHECK-NEXT: return %0 : !linalg.range + +func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) { + linalg.for %i0 = %arg0 to %arg1 step %arg2 { + linalg.for %i1 = %arg0 to %arg1 step %arg2 { + %min_cmp = cmpi "slt", %i0, %i1 : index + %min = select %min_cmp, %i0, %i1 : index + %max_cmp = cmpi "sge", %i0, %i1 : index + %max = select %max_cmp, %i0, %i1 : index + linalg.for %i2 = %min to %max step %i1 { + } + } + } + return +} +// CHECK-LABEL: func @linalg_for(%arg0: index, %arg1: index, %arg2: index) { +// CHECK-NEXT: linalg.for %i0 = %arg0 to %arg1 step %arg2 { +// CHECK-NEXT: linalg.for %i1 = %arg0 to %arg1 step %arg2 { +// CHECK-NEXT: %0 = cmpi "slt", %i0, %i1 : index +// CHECK-NEXT: %1 = select %0, %i0, %i1 : index +// CHECK-NEXT: %2 = cmpi "sge", %i0, %i1 : index +// CHECK-NEXT: %3 = select %2, %i0, %i1 : index +// CHECK-NEXT: linalg.for %i2 = %1 to %3 step %i1 {