Add a linalg.for operation to support non-affine loop constructs
authorNicolas Vasilache <ntv@google.com>
Tue, 28 May 2019 22:05:51 +0000 (15:05 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:06:21 +0000 (20:06 -0700)
    The affine.for operation has restrictions that make it suitable for dependence analysis. The Linalg dialect aims at being more general.
    This CL introduces linalg.for, and its associated terminator, along with a simple roundtripping test.
    A `linalg.for` only takes one value of index type for lower bound, upper bound and step.

    Example usage:
    ```
    linalg.for %iv = %lb to %ub step %step {
      ... // body
    }
    ```

--

PiperOrigin-RevId: 250369722

mlir/include/mlir/Linalg/IR/LinalgOps.h
mlir/include/mlir/Linalg/IR/LinalgOps.td
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/IR/LinalgTypes.cpp
mlir/test/Linalg/roundtrip.mlir

index 26a5cf3..92f2630 100644 (file)
@@ -77,6 +77,82 @@ public:
   }
 };
 
+/// The "linalg.for" operation represents a loop nest taking 3 SSA value as
+/// operands that represent the lower bound, upper bound and step respectively.
+/// The operation defines an SSA value for its induction variable. It has one
+/// region capturing the loop body. The induction variable is represented as an
+/// argument of this region. This SSA value always has type index, which is the
+/// size of the machine word. The step is a value of type index, required to be
+/// positive.
+/// The lower and upper bounds specify a half-open range: the range includes the
+/// lower bound but does not include the upper bound.
+///
+/// The body region must contain exactly one block that terminates with
+/// "linalg.terminator".  Calling linalg::ForOp::build will create such region
+/// and insert the terminator, so will the parsing even in cases if it is absent
+/// from the custom format. For example:
+///
+/// ```mlir
+///    linalg.for %iv = %lb to %ub step %step {
+///      ... // body
+///    }
+/// ```
+class ForOp
+    : public Op<ForOp, OpTrait::NOperands<3>::Impl, OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  // Hooks to customize behavior of this op.
+  static void build(Builder *builder, OperationState *result, Value *lb,
+                    Value *ub, Value *step);
+  LogicalResult verify();
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  static StringRef getOperationName() { return "linalg.for"; }
+
+  /// Return a Builder set up to insert operations immediately before the
+  /// terminator.
+  FuncBuilder getBodyBuilder() {
+    Block *body = getBody();
+    return FuncBuilder(body, std::prev(body->end()));
+  }
+
+  /// Get the body of the ForOp.
+  Block *getBody() { return &getRegion().front(); }
+
+  /// Get the body region of the ForOp.
+  Region &getRegion() { return getOperation()->getRegion(0); }
+
+  /// Returns the induction variable for this loop.
+  Value *getInductionVar() { return getBody()->getArgument(0); }
+
+  //===--------------------------------------------------------------------===//
+  // Bounds and step
+  //===--------------------------------------------------------------------===//
+  /// Returns the lower bound operand.
+  Value *getLowerBound() { return getOperand(0); }
+
+  /// Returns the upper bound operand.
+  Value *getUpperBound() { return getOperand(1); }
+
+  /// Returns loop step.
+  Value *getStep() { return getOperand(2); }
+
+  /// Set lower bound.
+  void setLowerBound(Value *lb) { setOperand(0, lb); }
+
+  /// Set upper bound.
+  void setUpperBound(Value *ub) { setOperand(1, ub); }
+
+  /// Set loop step.
+  void setStep(Value *step) { setOperand(2, step); }
+};
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+ForOp getForInductionVarOwner(Value *val);
+
 /// A linalg.LoadOp is the counterpart of load but operating on ViewType
 /// instead of MemRefType.
 ///
index 824db9a..d9b0f65 100644 (file)
@@ -112,4 +112,26 @@ def RangeIntersectOp : Linalg_Op<"range_intersect", [NoSideEffect]>,
     }]>];
 }
 
+def TerminatorOp :
+    Linalg_Op<"terminator", [NativeOpTrait<"IsTerminator">]> {
+  let summary = "linalg terminator operation";
+  let description = [{
+    "linalg.terminator" is a special terminator operation for blocks inside
+    linalg loops and branches. It unconditionally transmits the control flow to
+    the successor of the operation enclosing the region.
+
+    This operation does _not_ have a custom syntax. However, linalg control
+    operations omit the terminator in their custom syntax for brevity.
+
+       linalg.terminator
+  }];
+
+  // No custom parsing/printing form.
+  let parser = ?;
+  let printer = ?;
+
+  // Fully specified by traits.
+  let verifier = ?;
+}
+
 #endif // LINALG_OPS
index 24a47da..f222fce 100644 (file)
@@ -112,6 +112,129 @@ ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
 ////////////////////////////////////////////////////////////////////////////////
 // ForOp.
 ////////////////////////////////////////////////////////////////////////////////
+// Check that if a "block" has a terminator, it is an `TerminatorOp`.
+static LogicalResult checkHasTerminator(OpState &op, Block &block) {
+  if (block.empty() || isa<TerminatorOp>(block.back()))
+    return success();
+
+  op.emitOpError("expects regions to end with '" +
+                 TerminatorOp::getOperationName() + "'")
+          .attachNote()
+      << "in custom textual format, the absence of terminator implies '"
+      << TerminatorOp::getOperationName() << "'";
+  return failure();
+}
+
+// Insert `linalg.terminator` at the end of the ForOp only region's only block
+// if it does not have a terminator already.  If a new `linalg.terminator` is
+// inserted, the location is specified by `loc`. If the region is empty, insert
+// a new block first.
+static void ensureTerminator(Region &region, Builder &builder, Location loc) {
+  if (region.empty())
+    region.push_back(new Block);
+
+  Block &block = region.back();
+  if (!block.empty() && block.back().isKnownTerminator())
+    return;
+
+  OperationState terminatorState(builder.getContext(), loc,
+                                 TerminatorOp::getOperationName());
+  TerminatorOp::build(&builder, &terminatorState);
+  block.push_back(Operation::create(terminatorState));
+}
+
+void mlir::linalg::ForOp::build(Builder *builder, OperationState *result,
+                                Value *lb, Value *ub, Value *step) {
+  result->addOperands({lb, ub, step});
+  Region *bodyRegion = result->addRegion();
+  Block *body = new Block();
+  body->addArgument(IndexType::get(builder->getContext()));
+  bodyRegion->push_back(body);
+  ensureTerminator(*bodyRegion, *builder, result->location);
+}
+
+LogicalResult mlir::linalg::ForOp::verify() {
+  if (!getLowerBound()->getType().isa<IndexType>())
+    return emitOpError("lower bound operand must be an index");
+  if (!getUpperBound()->getType().isa<IndexType>())
+    return emitOpError("upper bound operand must be an index");
+  if (!getLowerBound()->getType().dyn_cast<IndexType>())
+    return emitOpError("step operand must be an index");
+  if (auto cst =
+          dyn_cast_or_null<ConstantIndexOp>(getLowerBound()->getDefiningOp()))
+    if (cst.getValue() <= 0)
+      return emitOpError("constant step operand must be positive");
+
+  if (std::next(getOperation()->getRegions().begin()) !=
+      getOperation()->getRegions().end())
+    return emitOpError("operation expected to have exactly one region");
+
+  auto &bodyRegion = getOperation()->getRegion(0);
+  // The body region must contain a single basic block.
+  if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end())
+    return emitOpError("expected body region to have a single block");
+  // Check that the body defines as single block argument for the induction
+  // variable.
+  auto *body = getBody();
+  if (body->getNumArguments() != 1 ||
+      !body->getArgument(0)->getType().isIndex())
+    return emitOpError("expected body to have a single index argument for "
+                       "the induction variable");
+  if (failed(checkHasTerminator(*this, *body)))
+    return failure();
+  return success();
+}
+
+void mlir::linalg::ForOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getInductionVar() << " = "
+     << *getLowerBound() << " to " << *getUpperBound() << " step "
+     << *getStep();
+  p->printRegion(getRegion(),
+                 /*printEntryBlockArgs=*/false,
+                 /*printBlockTerminators=*/false);
+  p->printOptionalAttrDict(getAttrs());
+}
+
+ParseResult mlir::linalg::ForOp::parse(OpAsmParser *parser,
+                                       OperationState *result) {
+  auto &builder = parser->getBuilder();
+  OpAsmParser::OperandType inductionVariable, lb, ub, step;
+  // Parse the induction variable followed by '='.
+  if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
+    return failure();
+
+  // Parse loop bounds.
+  Type indexType = builder.getIndexType();
+  if (parser->parseOperand(lb) ||
+      parser->resolveOperand(lb, indexType, result->operands) ||
+      parser->parseKeyword("to") || parser->parseOperand(ub) ||
+      parser->resolveOperand(ub, indexType, result->operands) ||
+      parser->parseKeyword("step") || parser->parseOperand(step) ||
+      parser->resolveOperand(step, indexType, result->operands))
+    return failure();
+
+  // Parse the body region.
+  Region *body = result->addRegion();
+  if (parser->parseRegion(*body, inductionVariable, indexType))
+    return failure();
+
+  ensureTerminator(*body, builder, result->location);
+
+  // Parse the optional attribute list.
+  if (parser->parseOptionalAttributeDict(result->attributes))
+    return failure();
+
+  return success();
+}
+
+mlir::linalg::ForOp mlir::linalg::getForInductionVarOwner(Value *val) {
+  auto *ivArg = dyn_cast<BlockArgument>(val);
+  if (!ivArg)
+    return ForOp();
+  assert(ivArg->getOwner() && "unlinked block argument");
+  auto *containingInst = ivArg->getOwner()->getContainingOp();
+  return dyn_cast_or_null<ForOp>(containingInst);
+}
 
 ////////////////////////////////////////////////////////////////////////////////
 // LoadOp.
index 0e20eb8..453463b 100644 (file)
@@ -31,7 +31,7 @@ using namespace mlir::linalg;
 mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
     : Dialect("linalg", context) {
   addTypes<BufferType, RangeType, ViewType>();
-  addOperations<BufferAllocOp, BufferDeallocOp, LoadOp, RangeOp, StoreOp,
+  addOperations<BufferAllocOp, BufferDeallocOp, ForOp, LoadOp, RangeOp, StoreOp,
                 SliceOp, ViewOp>();
   addOperations<
 #define GET_OP_LIST
index c42b8d7..d730a20 100644 (file)
@@ -70,3 +70,25 @@ func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.ran
 // CHECK-LABEL: func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.range {
 //  CHECK-NEXT:   %0 = linalg.range_intersect %arg0, %arg1 : !linalg.range
 //  CHECK-NEXT:   return %0 : !linalg.range
+
+func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) {
+  linalg.for %i0 = %arg0 to %arg1 step %arg2 {
+    linalg.for %i1 = %arg0 to %arg1 step %arg2 {
+      %min_cmp = cmpi "slt", %i0, %i1 : index
+      %min = select %min_cmp, %i0, %i1 : index
+      %max_cmp = cmpi "sge", %i0, %i1 : index
+      %max = select %max_cmp, %i0, %i1 : index
+      linalg.for %i2 = %min to %max step %i1 {
+      }
+    }
+  }
+  return
+}
+// CHECK-LABEL: func @linalg_for(%arg0: index, %arg1: index, %arg2: index) {
+//  CHECK-NEXT:   linalg.for %i0 = %arg0 to %arg1 step %arg2 {
+//  CHECK-NEXT:     linalg.for %i1 = %arg0 to %arg1 step %arg2 {
+//  CHECK-NEXT:       %0 = cmpi "slt", %i0, %i1 : index
+//  CHECK-NEXT:       %1 = select %0, %i0, %i1 : index
+//  CHECK-NEXT:       %2 = cmpi "sge", %i0, %i1 : index
+//  CHECK-NEXT:       %3 = select %2, %i0, %i1 : index
+//  CHECK-NEXT:       linalg.for %i2 = %1 to %3 step %i1 {