}
};
-/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
-/// range abstraction on top of an underlying linalg.buffer. This gives an
-/// indexing structure to an otherwise non-indexable linalg.buffer.
-///
-/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
-/// a `view` of the same elemental type as the buffer and of rank the number of
-/// ranges:
-///
-/// ```{.mlir}
-/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
-/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
-/// ```
-class ViewOp : public Op<ViewOp, OpTrait::VariadicOperands, OpTrait::OneResult,
- OpTrait::HasNoSideEffect> {
- enum { FirstIndexingOperand = 1 };
-
-public:
- using Op::Op;
-
- // Hooks to customize the behavior of this op.
- static llvm::StringRef getOperationName() { return "linalg.view"; }
- static void build(Builder *b, OperationState *result, Value *buffer,
- llvm::ArrayRef<Value *> indexings);
- LogicalResult verify();
- static ParseResult parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p);
-
- // Op-specific functionality.
- unsigned getRank() { return getViewType().getRank(); }
- Type getElementType() { return getViewType().getElementType(); }
- ViewType getViewType() { return getType().cast<ViewType>(); }
- Value *getSupportingBuffer() { return getOperand(0); }
- // Get the underlying indexing at a given rank.
- Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
- // Get all the indexings in this view.
- Operation::operand_range getIndexings() {
- return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
- }
-};
-
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.h.inc"
}];
}
+def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
+ Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
+ Results<(outs View)> {
+ let summary = "view operation";
+ let description = [{
+ The "linalg.view" op produces a linalg.view which is a multi-dimensional
+ range abstraction on top of an underlying linalg.buffer. This gives an
+ indexing structure to an otherwise non-indexable linalg.buffer.
+
+ A "linalg.view" takes a buffer and a variadic number of ranges and produces
+ a `view` of rank the number of ranges. The elemental type may not match the
+ buffer element type:
+
+ Examples:
+ ```
+ %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+ %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
+ %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xvector<4xf32>>
+ ```
+ }];
+
+ let builders = [OpBuilder<
+ "Builder *b, OperationState *result, Value *buffer, "
+ "ArrayRef<Value *> ranges, Type resultType = Type(), "
+ "ArrayRef<NamedAttribute> attrs = {}">];
+
+ let verifier = [{
+ if (getViewType().getRank() != llvm::size(ranges()))
+ return emitOpError("the view rank must be the number of its ranges");
+ return success();
+ }];
+
+ let extraClassDeclaration = [{
+ enum { FirstIndexingOperand = 1 };
+ unsigned getRank() { return getViewType().getRank(); }
+ Type getElementType() { return getViewType().getElementType(); }
+ ViewType getViewType() { return getType().cast<ViewType>(); }
+ /// Get the underlying indexing at a given rank.
+ Value *getRange(unsigned rank) {
+ assert(rank < getRank() && "rank overflow");
+ return *(ranges().begin() + rank);
+ }
+ }];
+}
+
def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";
Value *min, *max, *step;
if (view) {
// Cannot traverse block arguments, fail.
- if (isa<BlockArgument>(view.getIndexing(dim)))
+ if (isa<BlockArgument>(view.getRange(dim)))
return matchFailure();
// Record min, max, step for further processing.
- auto range = cast<RangeOp>(view.getIndexing(dim)->getDefiningOp());
+ auto range = cast<RangeOp>(view.getRange(dim)->getDefiningOp());
std::tie(min, max, step) =
std::make_tuple(range.min(), range.max(), range.step());
} else if (subView) {
return success();
}
-//////////////////////////////////////////////////////////////////////////////
-// ViewOp
-//////////////////////////////////////////////////////////////////////////////
-void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
- Value *buffer, ArrayRef<Value *> indexings) {
- BufferType bufferType = buffer->getType().cast<BufferType>();
- result->addOperands({buffer});
- result->addOperands(indexings);
- assert(
- std::none_of(indexings.begin(), indexings.end(),
- [](Value *v) { return !v->getType().isa<RangeType>(); }) &&
- "linalg.view takes only arguments of type linalg.range");
-
- Type elementType = bufferType.getElementType();
- result->addTypes(
- {ViewType::get(b->getContext(), elementType, indexings.size())});
-}
-
-LogicalResult mlir::linalg::ViewOp::verify() {
- if (llvm::empty(getOperands()))
- return emitOpError(
- "requires at least a buffer operand followed by indexings");
- auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
- if (!bufferType)
- return emitOpError("first operand must be of BufferType");
- unsigned index = 0;
- for (auto indexing : getIndexings()) {
- if (!indexing->getType().isa<RangeType>()) {
- return emitOpError() << index << "^th index must be of range type";
- }
- ++index;
- }
- if (getViewType().getRank() != index)
- return emitOpError()
- << "the rank of the view must be the number of its indexings";
- return success();
-}
-
-ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
- OperationState *result) {
- OpAsmParser::OperandType bufferInfo;
- SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
- Type bType, type;
- if (parser->parseOperand(bufferInfo) ||
- parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColon() || parser->parseType(bType) ||
- parser->parseArrow() || parser->parseType(type)) {
- return failure();
- }
-
- BufferType bufferType = bType.dyn_cast<BufferType>();
- if (!bufferType) {
- return parser->emitError(parser->getNameLoc(), "buffer type expected");
- }
-
- ViewType viewType = type.dyn_cast<ViewType>();
- if (!viewType)
- return parser->emitError(parser->getNameLoc(), "view type expected");
- if (viewType.getRank() != indexingsInfo.size())
- return parser->emitError(parser->getNameLoc(), "expected")
- << viewType.getRank() << " range indexings";
- return failure(
- parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
- (!indexingsInfo.empty() &&
- parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
- result->operands)) ||
- parser->addTypeToList(viewType, result->types));
-}
-
-// A ViewOp prints as:
-//
-// ```{.mlir}
-// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
-// ```
-//
-// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
-// holding a range.
-void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
- *p << getOperationName() << " " << *getSupportingBuffer() << "[";
- interleave(
- getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
- [&]() { *p << ", "; });
- *p << "] : " << getSupportingBuffer()->getType() << " -> " << getType();
-}
-
///////////////////// Operations defined with Tablegen /////////////////////////
// For such operations that do not correspond to library calls (i.e. defined in
// LinalgOps.td), we define an overloaded `print` function and a
// parse`className` function.
+//===----------------------------------------------------------------------===//
+// BufferAllocOp
+//===----------------------------------------------------------------------===//
+
static void print(OpAsmPrinter *p, BufferAllocOp op) {
*p << op.getOperationName() << " ";
if (!llvm::empty(op.size()))
return success();
}
+//===----------------------------------------------------------------------===//
+// BufferDeallocOp
+//===----------------------------------------------------------------------===//
+
static void print(OpAsmPrinter *p, BufferDeallocOp op) {
*p << op.getOperationName() << " " << *op.buffer();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getOperand()->getType();
}
+//===----------------------------------------------------------------------===//
+// BufferSizeOp
+//===----------------------------------------------------------------------===//
+
static ParseResult parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType op;
}
//===----------------------------------------------------------------------===//
+// ViewOp
+//===----------------------------------------------------------------------===//
+void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
+ Value *buffer, ArrayRef<Value *> ranges,
+ Type resultType,
+ ArrayRef<NamedAttribute> attrs) {
+ if (!resultType) {
+ Type elementType = buffer->getType().cast<BufferType>().getElementType();
+ resultType = ViewType::get(b->getContext(), elementType, ranges.size());
+ }
+ build(b, result, resultType, buffer, ranges);
+ result->addAttributes(attrs);
+}
+
+static ParseResult parseViewOp(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType bufferInfo;
+ SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
+ Type bType, vType;
+ if (parser->parseOperand(bufferInfo) ||
+ parser->parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColon() || parser->parseType(bType) ||
+ parser->parseArrow() || parser->parseType(vType)) {
+ return failure();
+ }
+
+ BufferType bufferType = bType.dyn_cast<BufferType>();
+ if (!bufferType) {
+ return parser->emitError(parser->getNameLoc(), "buffer type expected");
+ }
+
+ ViewType viewType = vType.dyn_cast<ViewType>();
+ if (!viewType)
+ return parser->emitError(parser->getNameLoc(), "view type expected");
+ if (viewType.getRank() != rangesInfo.size())
+ return parser->emitError(parser->getNameLoc(), "expected")
+ << viewType.getRank() << " range ranges";
+ return failure(
+ parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
+ (!rangesInfo.empty() &&
+ parser->resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
+ result->operands)) ||
+ parser->addTypeToList(viewType, result->types));
+}
+
+// A ViewOp prints as:
+//
+// ```{.mlir}
+// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+// ```
+//
+// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
+// holding a range.
+static void print(OpAsmPrinter *p, ViewOp op) {
+ *p << op.getOperationName() << " " << *op.buffer() << "[";
+ interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; });
+ *p << "] : " << op.buffer()->getType() << " -> " << op.getType();
+}
+
+//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
*p << " : " << op.getViewType();
}
+//===----------------------------------------------------------------------===//
+// SubViewOp
+//===----------------------------------------------------------------------===//
+
static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType inputView, resultView;
Type viewType;