}
};
+/// The "linalg.for" operation represents a loop nest taking 3 SSA value as
+/// operands that represent the lower bound, upper bound and step respectively.
+/// The operation defines an SSA value for its induction variable. It has one
+/// region capturing the loop body. The induction variable is represented as an
+/// argument of this region. This SSA value always has type index, which is the
+/// size of the machine word. The step is a value of type index, required to be
+/// positive.
+/// The lower and upper bounds specify a half-open range: the range includes the
+/// lower bound but does not include the upper bound.
+///
+/// The body region must contain exactly one block that terminates with
+/// "linalg.terminator". Calling linalg::ForOp::build will create such region
+/// and insert the terminator, so will the parsing even in cases if it is absent
+/// from the custom format. For example:
+///
+/// ```mlir
+/// linalg.for %iv = %lb to %ub step %step {
+/// ... // body
+/// }
+/// ```
+class ForOp
+ : public Op<ForOp, OpTrait::NOperands<3>::Impl, OpTrait::ZeroResult> {
+public:
+ using Op::Op;
+
+ // Hooks to customize behavior of this op.
+ static void build(Builder *builder, OperationState *result, Value *lb,
+ Value *ub, Value *step);
+ LogicalResult verify();
+ static ParseResult parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+
+ static StringRef getOperationName() { return "linalg.for"; }
+
+ /// Return a Builder set up to insert operations immediately before the
+ /// terminator.
+ FuncBuilder getBodyBuilder() {
+ Block *body = getBody();
+ return FuncBuilder(body, std::prev(body->end()));
+ }
+
+ /// Get the body of the ForOp.
+ Block *getBody() { return &getRegion().front(); }
+
+ /// Get the body region of the ForOp.
+ Region &getRegion() { return getOperation()->getRegion(0); }
+
+ /// Returns the induction variable for this loop.
+ Value *getInductionVar() { return getBody()->getArgument(0); }
+
+ //===--------------------------------------------------------------------===//
+ // Bounds and step
+ //===--------------------------------------------------------------------===//
+ /// Returns the lower bound operand.
+ Value *getLowerBound() { return getOperand(0); }
+
+ /// Returns the upper bound operand.
+ Value *getUpperBound() { return getOperand(1); }
+
+ /// Returns loop step.
+ Value *getStep() { return getOperand(2); }
+
+ /// Set lower bound.
+ void setLowerBound(Value *lb) { setOperand(0, lb); }
+
+ /// Set upper bound.
+ void setUpperBound(Value *ub) { setOperand(1, ub); }
+
+ /// Set loop step.
+ void setStep(Value *step) { setOperand(2, step); }
+};
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+ForOp getForInductionVarOwner(Value *val);
+
/// A linalg.LoadOp is the counterpart of load but operating on ViewType
/// instead of MemRefType.
///
////////////////////////////////////////////////////////////////////////////////
// ForOp.
////////////////////////////////////////////////////////////////////////////////
+// Check that if a "block" has a terminator, it is an `TerminatorOp`.
+static LogicalResult checkHasTerminator(OpState &op, Block &block) {
+ if (block.empty() || isa<TerminatorOp>(block.back()))
+ return success();
+
+ op.emitOpError("expects regions to end with '" +
+ TerminatorOp::getOperationName() + "'")
+ .attachNote()
+ << "in custom textual format, the absence of terminator implies '"
+ << TerminatorOp::getOperationName() << "'";
+ return failure();
+}
+
+// Insert `linalg.terminator` at the end of the ForOp only region's only block
+// if it does not have a terminator already. If a new `linalg.terminator` is
+// inserted, the location is specified by `loc`. If the region is empty, insert
+// a new block first.
+static void ensureTerminator(Region ®ion, Builder &builder, Location loc) {
+ if (region.empty())
+ region.push_back(new Block);
+
+ Block &block = region.back();
+ if (!block.empty() && block.back().isKnownTerminator())
+ return;
+
+ OperationState terminatorState(builder.getContext(), loc,
+ TerminatorOp::getOperationName());
+ TerminatorOp::build(&builder, &terminatorState);
+ block.push_back(Operation::create(terminatorState));
+}
+
+void mlir::linalg::ForOp::build(Builder *builder, OperationState *result,
+ Value *lb, Value *ub, Value *step) {
+ result->addOperands({lb, ub, step});
+ Region *bodyRegion = result->addRegion();
+ Block *body = new Block();
+ body->addArgument(IndexType::get(builder->getContext()));
+ bodyRegion->push_back(body);
+ ensureTerminator(*bodyRegion, *builder, result->location);
+}
+
+LogicalResult mlir::linalg::ForOp::verify() {
+ if (!getLowerBound()->getType().isa<IndexType>())
+ return emitOpError("lower bound operand must be an index");
+ if (!getUpperBound()->getType().isa<IndexType>())
+ return emitOpError("upper bound operand must be an index");
+ if (!getLowerBound()->getType().dyn_cast<IndexType>())
+ return emitOpError("step operand must be an index");
+ if (auto cst =
+ dyn_cast_or_null<ConstantIndexOp>(getLowerBound()->getDefiningOp()))
+ if (cst.getValue() <= 0)
+ return emitOpError("constant step operand must be positive");
+
+ if (std::next(getOperation()->getRegions().begin()) !=
+ getOperation()->getRegions().end())
+ return emitOpError("operation expected to have exactly one region");
+
+ auto &bodyRegion = getOperation()->getRegion(0);
+ // The body region must contain a single basic block.
+ if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end())
+ return emitOpError("expected body region to have a single block");
+ // Check that the body defines as single block argument for the induction
+ // variable.
+ auto *body = getBody();
+ if (body->getNumArguments() != 1 ||
+ !body->getArgument(0)->getType().isIndex())
+ return emitOpError("expected body to have a single index argument for "
+ "the induction variable");
+ if (failed(checkHasTerminator(*this, *body)))
+ return failure();
+ return success();
+}
+
+void mlir::linalg::ForOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *getInductionVar() << " = "
+ << *getLowerBound() << " to " << *getUpperBound() << " step "
+ << *getStep();
+ p->printRegion(getRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+ p->printOptionalAttrDict(getAttrs());
+}
+
+ParseResult mlir::linalg::ForOp::parse(OpAsmParser *parser,
+ OperationState *result) {
+ auto &builder = parser->getBuilder();
+ OpAsmParser::OperandType inductionVariable, lb, ub, step;
+ // Parse the induction variable followed by '='.
+ if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
+ return failure();
+
+ // Parse loop bounds.
+ Type indexType = builder.getIndexType();
+ if (parser->parseOperand(lb) ||
+ parser->resolveOperand(lb, indexType, result->operands) ||
+ parser->parseKeyword("to") || parser->parseOperand(ub) ||
+ parser->resolveOperand(ub, indexType, result->operands) ||
+ parser->parseKeyword("step") || parser->parseOperand(step) ||
+ parser->resolveOperand(step, indexType, result->operands))
+ return failure();
+
+ // Parse the body region.
+ Region *body = result->addRegion();
+ if (parser->parseRegion(*body, inductionVariable, indexType))
+ return failure();
+
+ ensureTerminator(*body, builder, result->location);
+
+ // Parse the optional attribute list.
+ if (parser->parseOptionalAttributeDict(result->attributes))
+ return failure();
+
+ return success();
+}
+
+mlir::linalg::ForOp mlir::linalg::getForInductionVarOwner(Value *val) {
+ auto *ivArg = dyn_cast<BlockArgument>(val);
+ if (!ivArg)
+ return ForOp();
+ assert(ivArg->getOwner() && "unlinked block argument");
+ auto *containingInst = ivArg->getOwner()->getContainingOp();
+ return dyn_cast_or_null<ForOp>(containingInst);
+}
////////////////////////////////////////////////////////////////////////////////
// LoadOp.
// CHECK-LABEL: func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.range {
// CHECK-NEXT: %0 = linalg.range_intersect %arg0, %arg1 : !linalg.range
// CHECK-NEXT: return %0 : !linalg.range
+
+func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) {
+ linalg.for %i0 = %arg0 to %arg1 step %arg2 {
+ linalg.for %i1 = %arg0 to %arg1 step %arg2 {
+ %min_cmp = cmpi "slt", %i0, %i1 : index
+ %min = select %min_cmp, %i0, %i1 : index
+ %max_cmp = cmpi "sge", %i0, %i1 : index
+ %max = select %max_cmp, %i0, %i1 : index
+ linalg.for %i2 = %min to %max step %i1 {
+ }
+ }
+ }
+ return
+}
+// CHECK-LABEL: func @linalg_for(%arg0: index, %arg1: index, %arg2: index) {
+// CHECK-NEXT: linalg.for %i0 = %arg0 to %arg1 step %arg2 {
+// CHECK-NEXT: linalg.for %i1 = %arg0 to %arg1 step %arg2 {
+// CHECK-NEXT: %0 = cmpi "slt", %i0, %i1 : index
+// CHECK-NEXT: %1 = select %0, %i0, %i1 : index
+// CHECK-NEXT: %2 = cmpi "sge", %i0, %i1 : index
+// CHECK-NEXT: %3 = select %2, %i0, %i1 : index
+// CHECK-NEXT: linalg.for %i2 = %1 to %3 step %i1 {