From: River Riddle Date: Sat, 25 May 2019 01:01:38 +0000 (-0700) Subject: Move the definitions of LoadOp and StoreOp to the ODG framework. X-Git-Tag: llvmorg-11-init~1466^2~1603 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5a5cdb94fe73d4e492e67e46d96b7d793f388def;p=platform%2Fupstream%2Fllvm.git Move the definitions of LoadOp and StoreOp to the ODG framework. -- PiperOrigin-RevId: 249928980 --- diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 18008f2..b399fbe 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -450,84 +450,6 @@ public: MLIRContext *context); }; -/// The "load" op reads an element from a memref specified by an index list. The -/// output of load is a new value with the same type as the elements of the -/// memref. The arity of indices is the rank of the memref (i.e., if the memref -/// loaded from is of rank 3, then 3 indices are required for the load following -/// the memref identifier). For example: -/// -/// %3 = load %0[%1, %1] : memref<4x4xi32> -/// -class LoadOp - : public Op { -public: - using Op::Op; - - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, Value *memref, - ArrayRef indices = {}); - - Value *getMemRef() { return getOperand(0); } - void setMemRef(Value *value) { setOperand(0, value); } - MemRefType getMemRefType() { - return getMemRef()->getType().cast(); - } - - operand_range getIndices() { - return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; - } - - static StringRef getOperationName() { return "std.load"; } - - LogicalResult verify(); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); -}; - -/// The "store" op writes an element to a memref specified by an index list. -/// The arity of indices is the rank of the memref (i.e. if the memref being -/// stored to is of rank 3, then 3 indices are required for the store following -/// the memref identifier). The store operation does not produce a result. -/// -/// In the following example, the ssa value '%v' is stored in memref '%A' at -/// indices [%i, %j]: -/// -/// store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0> -/// -class StoreOp - : public Op { -public: - using Op::Op; - - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, - Value *valueToStore, Value *memref, - ArrayRef indices = {}); - - Value *getValueToStore() { return getOperand(0); } - - Value *getMemRef() { return getOperand(1); } - void setMemRef(Value *value) { setOperand(1, value); } - MemRefType getMemRefType() { - return getMemRef()->getType().cast(); - } - - operand_range getIndices() { - return {getOperation()->operand_begin() + 2, getOperation()->operand_end()}; - } - - static StringRef getOperationName() { return "std.store"; } - - LogicalResult verify(); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); -}; - /// Prints dimension and symbol list. void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 817079a..f7b77a1 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -487,6 +487,45 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { let hasFolder = 1; } +def LoadOp : Std_Op<"load"> { + let summary = "load operation"; + let description = [{ + The "load" op reads an element from a memref specified by an index list. The + output of load is a new value with the same type as the elements of the + memref. The arity of indices is the rank of the memref (i.e., if the memref + loaded from is of rank 3, then 3 indices are required for the load following + the memref identifier). For example: + + %3 = load %0[%1, %1] : memref<4x4xi32> + }]; + + let arguments = (ins AnyMemRef:$memref, Variadic:$indices); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *, OperationState *result, Value *memref," + "ArrayRef indices = {}", [{ + auto memrefType = memref->getType().cast(); + result->addOperands(memref); + result->addOperands(indices); + result->types.push_back(memrefType.getElementType()); + }]>]; + + let extraClassDeclaration = [{ + Value *getMemRef() { return getOperand(0); } + void setMemRef(Value *value) { setOperand(0, value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + operand_range getIndices() { + return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; + } + }]; + + let hasCanonicalizer = 1; +} + def MemRefCastOp : CastOp<"memref_cast"> { let summary = "memref cast operation"; let description = [{ @@ -616,6 +655,44 @@ def SubIOp : IntArithmeticOp<"subi"> { let hasFolder = 1; } +def StoreOp : Std_Op<"store"> { + let summary = "store operation"; + let description = [{ + The "store" op writes an element to a memref specified by an index list. + The arity of indices is the rank of the memref (i.e. if the memref being + stored to is of rank 3, then 3 indices are required for the store following + the memref identifier). The store operation does not produce a result. + + In the following example, the ssa value '%v' is stored in memref '%A' at + indices [%i, %j]: + store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0> + }]; + + let arguments = (ins AnyType:$value, AnyMemRef:$memref, Variadic:$indices); + + let builders = [OpBuilder< + "Builder *, OperationState *result, Value *valueToStore, Value *memref", [{ + result->addOperands(valueToStore); + result->addOperands(memref); + }]>]; + + let extraClassDeclaration = [{ + Value *getValueToStore() { return getOperand(0); } + + Value *getMemRef() { return getOperand(1); } + void setMemRef(Value *value) { setOperand(1, value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + operand_range getIndices() { + return {getOperation()->operand_begin() + 2, getOperation()->operand_end()}; + } + }]; + + let hasCanonicalizer = 1; +} + def TensorCastOp : CastOp<"tensor_cast"> { let summary = "tensor cast operation"; let description = [{ diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 6b559b3..29f9b19 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -81,7 +81,7 @@ template static LogicalResult verifyCastOp(T op) { StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(/*name=*/"std", context) { - addOperations(); @@ -1698,23 +1698,15 @@ OpFoldResult ExtractElementOp::fold(ArrayRef operands) { // LoadOp //===----------------------------------------------------------------------===// -void LoadOp::build(Builder *builder, OperationState *result, Value *memref, - ArrayRef indices) { - auto memrefType = memref->getType().cast(); - result->addOperands(memref); - result->addOperands(indices); - result->types.push_back(memrefType.getElementType()); -} - -void LoadOp::print(OpAsmPrinter *p) { - *p << "load " << *getMemRef() << '['; - p->printOperands(getIndices()); +static void print(OpAsmPrinter *p, LoadOp op) { + *p << "load " << *op.getMemRef() << '['; + p->printOperands(op.getIndices()); *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getMemRefType(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getMemRefType(); } -ParseResult LoadOp::parse(OpAsmParser *parser, OperationState *result) { +static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; MemRefType type; @@ -1730,23 +1722,16 @@ ParseResult LoadOp::parse(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(type.getElementType(), result->types)); } -LogicalResult LoadOp::verify() { - if (getNumOperands() == 0) - return emitOpError("expected a memref to load from"); +static LogicalResult verify(LoadOp op) { + if (op.getType() != op.getMemRefType().getElementType()) + return op.emitOpError("result type must match element type of memref"); - auto memRefType = getMemRef()->getType().dyn_cast(); - if (!memRefType) - return emitOpError("first operand must be a memref"); - - if (getType() != memRefType.getElementType()) - return emitOpError("result type must match element type of memref"); - - if (memRefType.getRank() != getNumOperands() - 1) - return emitOpError("incorrect number of indices for load"); + if (op.getMemRefType().getRank() != op.getNumOperands() - 1) + return op.emitOpError("incorrect number of indices for load"); - for (auto *idx : getIndices()) + for (auto *idx : op.getIndices()) if (!idx->getType().isIndex()) - return emitOpError("index to load must have 'index' type"); + return op.emitOpError("index to load must have 'index' type"); // TODO: Verify we have the right number of indices. @@ -1982,24 +1967,16 @@ OpFoldResult SelectOp::fold(ArrayRef operands) { // StoreOp //===----------------------------------------------------------------------===// -void StoreOp::build(Builder *builder, OperationState *result, - Value *valueToStore, Value *memref, - ArrayRef indices) { - result->addOperands(valueToStore); - result->addOperands(memref); - result->addOperands(indices); -} - -void StoreOp::print(OpAsmPrinter *p) { - *p << "store " << *getValueToStore(); - *p << ", " << *getMemRef() << '['; - p->printOperands(getIndices()); +static void print(OpAsmPrinter *p, StoreOp op) { + *p << "store " << *op.getValueToStore(); + *p << ", " << *op.getMemRef() << '['; + p->printOperands(op.getIndices()); *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getMemRefType(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getMemRefType(); } -ParseResult StoreOp::parse(OpAsmParser *parser, OperationState *result) { +static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; @@ -2018,25 +1995,18 @@ ParseResult StoreOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperands(indexInfo, affineIntTy, result->operands)); } -LogicalResult StoreOp::verify() { - if (getNumOperands() < 2) - return emitOpError("expected a value to store and a memref"); - - // Second operand is a memref type. - auto memRefType = getMemRef()->getType().dyn_cast(); - if (!memRefType) - return emitOpError("second operand must be a memref"); - +static LogicalResult verify(StoreOp op) { // First operand must have same type as memref element type. - if (getValueToStore()->getType() != memRefType.getElementType()) - return emitOpError("first operand must have same type memref element type"); + if (op.getValueToStore()->getType() != op.getMemRefType().getElementType()) + return op.emitOpError( + "first operand must have same type memref element type"); - if (getNumOperands() != 2 + memRefType.getRank()) - return emitOpError("store index operand count not equal to memref rank"); + if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) + return op.emitOpError("store index operand count not equal to memref rank"); - for (auto *idx : getIndices()) + for (auto *idx : op.getIndices()) if (!idx->getType().isIndex()) - return emitOpError("index to load must have 'index' type"); + return op.emitOpError("index to load must have 'index' type"); // TODO: Verify we have the right number of indices.