MLIRContext *context);
};
-/// The "load" op reads an element from a memref specified by an index list. The
-/// output of load is a new value with the same type as the elements of the
-/// memref. The arity of indices is the rank of the memref (i.e., if the memref
-/// loaded from is of rank 3, then 3 indices are required for the load following
-/// the memref identifier). For example:
-///
-/// %3 = load %0[%1, %1] : memref<4x4xi32>
-///
-class LoadOp
- : public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
-public:
- using Op::Op;
-
- // Hooks to customize behavior of this op.
- static void build(Builder *builder, OperationState *result, Value *memref,
- ArrayRef<Value *> indices = {});
-
- Value *getMemRef() { return getOperand(0); }
- void setMemRef(Value *value) { setOperand(0, value); }
- MemRefType getMemRefType() {
- return getMemRef()->getType().cast<MemRefType>();
- }
-
- operand_range getIndices() {
- return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
- }
-
- static StringRef getOperationName() { return "std.load"; }
-
- LogicalResult verify();
- static ParseResult parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p);
- static void getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context);
-};
-
-/// The "store" op writes an element to a memref specified by an index list.
-/// The arity of indices is the rank of the memref (i.e. if the memref being
-/// stored to is of rank 3, then 3 indices are required for the store following
-/// the memref identifier). The store operation does not produce a result.
-///
-/// In the following example, the ssa value '%v' is stored in memref '%A' at
-/// indices [%i, %j]:
-///
-/// store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
-///
-class StoreOp
- : public Op<StoreOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
-public:
- using Op::Op;
-
- // Hooks to customize behavior of this op.
- static void build(Builder *builder, OperationState *result,
- Value *valueToStore, Value *memref,
- ArrayRef<Value *> indices = {});
-
- Value *getValueToStore() { return getOperand(0); }
-
- Value *getMemRef() { return getOperand(1); }
- void setMemRef(Value *value) { setOperand(1, value); }
- MemRefType getMemRefType() {
- return getMemRef()->getType().cast<MemRefType>();
- }
-
- operand_range getIndices() {
- return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
- }
-
- static StringRef getOperationName() { return "std.store"; }
-
- LogicalResult verify();
- static ParseResult parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p);
-
- static void getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context);
-};
-
/// Prints dimension and symbol list.
void printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end, unsigned numDims,
let hasFolder = 1;
}
+def LoadOp : Std_Op<"load"> {
+ let summary = "load operation";
+ let description = [{
+ The "load" op reads an element from a memref specified by an index list. The
+ output of load is a new value with the same type as the elements of the
+ memref. The arity of indices is the rank of the memref (i.e., if the memref
+ loaded from is of rank 3, then 3 indices are required for the load following
+ the memref identifier). For example:
+
+ %3 = load %0[%1, %1] : memref<4x4xi32>
+ }];
+
+ let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
+ let results = (outs AnyType);
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState *result, Value *memref,"
+ "ArrayRef<Value *> indices = {}", [{
+ auto memrefType = memref->getType().cast<MemRefType>();
+ result->addOperands(memref);
+ result->addOperands(indices);
+ result->types.push_back(memrefType.getElementType());
+ }]>];
+
+ let extraClassDeclaration = [{
+ Value *getMemRef() { return getOperand(0); }
+ void setMemRef(Value *value) { setOperand(0, value); }
+ MemRefType getMemRefType() {
+ return getMemRef()->getType().cast<MemRefType>();
+ }
+
+ operand_range getIndices() {
+ return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+ }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
def MemRefCastOp : CastOp<"memref_cast"> {
let summary = "memref cast operation";
let description = [{
let hasFolder = 1;
}
+def StoreOp : Std_Op<"store"> {
+ let summary = "store operation";
+ let description = [{
+ The "store" op writes an element to a memref specified by an index list.
+ The arity of indices is the rank of the memref (i.e. if the memref being
+ stored to is of rank 3, then 3 indices are required for the store following
+ the memref identifier). The store operation does not produce a result.
+
+ In the following example, the ssa value '%v' is stored in memref '%A' at
+ indices [%i, %j]:
+ store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
+ }];
+
+ let arguments = (ins AnyType:$value, AnyMemRef:$memref, Variadic<Index>:$indices);
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState *result, Value *valueToStore, Value *memref", [{
+ result->addOperands(valueToStore);
+ result->addOperands(memref);
+ }]>];
+
+ let extraClassDeclaration = [{
+ Value *getValueToStore() { return getOperand(0); }
+
+ Value *getMemRef() { return getOperand(1); }
+ void setMemRef(Value *value) { setOperand(1, value); }
+ MemRefType getMemRefType() {
+ return getMemRef()->getType().cast<MemRefType>();
+ }
+
+ operand_range getIndices() {
+ return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
+ }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
def TensorCastOp : CastOp<"tensor_cast"> {
let summary = "tensor cast operation";
let description = [{
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*name=*/"std", context) {
- addOperations<CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp, StoreOp,
+ addOperations<CondBranchOp, DmaStartOp, DmaWaitOp,
#define GET_OP_LIST
#include "mlir/StandardOps/Ops.cpp.inc"
>();
// LoadOp
//===----------------------------------------------------------------------===//
-void LoadOp::build(Builder *builder, OperationState *result, Value *memref,
- ArrayRef<Value *> indices) {
- auto memrefType = memref->getType().cast<MemRefType>();
- result->addOperands(memref);
- result->addOperands(indices);
- result->types.push_back(memrefType.getElementType());
-}
-
-void LoadOp::print(OpAsmPrinter *p) {
- *p << "load " << *getMemRef() << '[';
- p->printOperands(getIndices());
+static void print(OpAsmPrinter *p, LoadOp op) {
+ *p << "load " << *op.getMemRef() << '[';
+ p->printOperands(op.getIndices());
*p << ']';
- p->printOptionalAttrDict(getAttrs());
- *p << " : " << getMemRefType();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.getMemRefType();
}
-ParseResult LoadOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType type;
parser->addTypeToList(type.getElementType(), result->types));
}
-LogicalResult LoadOp::verify() {
- if (getNumOperands() == 0)
- return emitOpError("expected a memref to load from");
+static LogicalResult verify(LoadOp op) {
+ if (op.getType() != op.getMemRefType().getElementType())
+ return op.emitOpError("result type must match element type of memref");
- auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
- if (!memRefType)
- return emitOpError("first operand must be a memref");
-
- if (getType() != memRefType.getElementType())
- return emitOpError("result type must match element type of memref");
-
- if (memRefType.getRank() != getNumOperands() - 1)
- return emitOpError("incorrect number of indices for load");
+ if (op.getMemRefType().getRank() != op.getNumOperands() - 1)
+ return op.emitOpError("incorrect number of indices for load");
- for (auto *idx : getIndices())
+ for (auto *idx : op.getIndices())
if (!idx->getType().isIndex())
- return emitOpError("index to load must have 'index' type");
+ return op.emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices.
// StoreOp
//===----------------------------------------------------------------------===//
-void StoreOp::build(Builder *builder, OperationState *result,
- Value *valueToStore, Value *memref,
- ArrayRef<Value *> indices) {
- result->addOperands(valueToStore);
- result->addOperands(memref);
- result->addOperands(indices);
-}
-
-void StoreOp::print(OpAsmPrinter *p) {
- *p << "store " << *getValueToStore();
- *p << ", " << *getMemRef() << '[';
- p->printOperands(getIndices());
+static void print(OpAsmPrinter *p, StoreOp op) {
+ *p << "store " << *op.getValueToStore();
+ *p << ", " << *op.getMemRef() << '[';
+ p->printOperands(op.getIndices());
*p << ']';
- p->printOptionalAttrDict(getAttrs());
- *p << " : " << getMemRefType();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.getMemRefType();
}
-ParseResult StoreOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
parser->resolveOperands(indexInfo, affineIntTy, result->operands));
}
-LogicalResult StoreOp::verify() {
- if (getNumOperands() < 2)
- return emitOpError("expected a value to store and a memref");
-
- // Second operand is a memref type.
- auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
- if (!memRefType)
- return emitOpError("second operand must be a memref");
-
+static LogicalResult verify(StoreOp op) {
// First operand must have same type as memref element type.
- if (getValueToStore()->getType() != memRefType.getElementType())
- return emitOpError("first operand must have same type memref element type");
+ if (op.getValueToStore()->getType() != op.getMemRefType().getElementType())
+ return op.emitOpError(
+ "first operand must have same type memref element type");
- if (getNumOperands() != 2 + memRefType.getRank())
- return emitOpError("store index operand count not equal to memref rank");
+ if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
+ return op.emitOpError("store index operand count not equal to memref rank");
- for (auto *idx : getIndices())
+ for (auto *idx : op.getIndices())
if (!idx->getType().isIndex())
- return emitOpError("index to load must have 'index' type");
+ return op.emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices.