Move the definitions of LoadOp and StoreOp to the ODG framework.
authorRiver Riddle <riverriddle@google.com>
Sat, 25 May 2019 01:01:38 +0000 (18:01 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:01:52 +0000 (20:01 -0700)
--

PiperOrigin-RevId: 249928980

mlir/include/mlir/StandardOps/Ops.h
mlir/include/mlir/StandardOps/Ops.td
mlir/lib/StandardOps/Ops.cpp

index 18008f2..b399fbe 100644 (file)
@@ -450,84 +450,6 @@ public:
                                           MLIRContext *context);
 };
 
-/// The "load" op reads an element from a memref specified by an index list. The
-/// output of load is a new value with the same type as the elements of the
-/// memref. The arity of indices is the rank of the memref (i.e., if the memref
-/// loaded from is of rank 3, then 3 indices are required for the load following
-/// the memref identifier).  For example:
-///
-///   %3 = load %0[%1, %1] : memref<4x4xi32>
-///
-class LoadOp
-    : public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
-public:
-  using Op::Op;
-
-  // Hooks to customize behavior of this op.
-  static void build(Builder *builder, OperationState *result, Value *memref,
-                    ArrayRef<Value *> indices = {});
-
-  Value *getMemRef() { return getOperand(0); }
-  void setMemRef(Value *value) { setOperand(0, value); }
-  MemRefType getMemRefType() {
-    return getMemRef()->getType().cast<MemRefType>();
-  }
-
-  operand_range getIndices() {
-    return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
-  }
-
-  static StringRef getOperationName() { return "std.load"; }
-
-  LogicalResult verify();
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context);
-};
-
-/// The "store" op writes an element to a memref specified by an index list.
-/// The arity of indices is the rank of the memref (i.e. if the memref being
-/// stored to is of rank 3, then 3 indices are required for the store following
-/// the memref identifier). The store operation does not produce a result.
-///
-/// In the following example, the ssa value '%v' is stored in memref '%A' at
-/// indices [%i, %j]:
-///
-///   store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
-///
-class StoreOp
-    : public Op<StoreOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
-public:
-  using Op::Op;
-
-  // Hooks to customize behavior of this op.
-  static void build(Builder *builder, OperationState *result,
-                    Value *valueToStore, Value *memref,
-                    ArrayRef<Value *> indices = {});
-
-  Value *getValueToStore() { return getOperand(0); }
-
-  Value *getMemRef() { return getOperand(1); }
-  void setMemRef(Value *value) { setOperand(1, value); }
-  MemRefType getMemRefType() {
-    return getMemRef()->getType().cast<MemRefType>();
-  }
-
-  operand_range getIndices() {
-    return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
-  }
-
-  static StringRef getOperationName() { return "std.store"; }
-
-  LogicalResult verify();
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-
-  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                          MLIRContext *context);
-};
-
 /// Prints dimension and symbol list.
 void printDimAndSymbolList(Operation::operand_iterator begin,
                            Operation::operand_iterator end, unsigned numDims,
index 817079a..f7b77a1 100644 (file)
@@ -487,6 +487,45 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
+def LoadOp : Std_Op<"load"> {
+  let summary = "load operation";
+  let description = [{
+    The "load" op reads an element from a memref specified by an index list. The
+    output of load is a new value with the same type as the elements of the
+    memref. The arity of indices is the rank of the memref (i.e., if the memref
+    loaded from is of rank 3, then 3 indices are required for the load following
+    the memref identifier). For example:
+
+      %3 = load %0[%1, %1] : memref<4x4xi32>
+  }];
+
+  let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
+  let results = (outs AnyType);
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Value *memref,"
+    "ArrayRef<Value *> indices = {}", [{
+      auto memrefType = memref->getType().cast<MemRefType>();
+      result->addOperands(memref);
+      result->addOperands(indices);
+      result->types.push_back(memrefType.getElementType());
+  }]>];
+
+  let extraClassDeclaration = [{
+    Value *getMemRef() { return getOperand(0); }
+    void setMemRef(Value *value) { setOperand(0, value); }
+    MemRefType getMemRefType() {
+      return getMemRef()->getType().cast<MemRefType>();
+    }
+
+    operand_range getIndices() {
+      return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
 def MemRefCastOp : CastOp<"memref_cast"> {
   let summary = "memref cast operation";
   let description = [{
@@ -616,6 +655,44 @@ def SubIOp : IntArithmeticOp<"subi"> {
   let hasFolder = 1;
 }
 
+def StoreOp : Std_Op<"store"> {
+  let summary = "store operation";
+  let description = [{
+    The "store" op writes an element to a memref specified by an index list.
+    The arity of indices is the rank of the memref (i.e. if the memref being
+    stored to is of rank 3, then 3 indices are required for the store following
+    the memref identifier). The store operation does not produce a result.
+
+    In the following example, the ssa value '%v' is stored in memref '%A' at
+    indices [%i, %j]:
+      store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
+  }];
+
+  let arguments = (ins AnyType:$value, AnyMemRef:$memref, Variadic<Index>:$indices);
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Value *valueToStore, Value *memref", [{
+      result->addOperands(valueToStore);
+      result->addOperands(memref);
+  }]>];
+
+  let extraClassDeclaration = [{
+      Value *getValueToStore() { return getOperand(0); }
+
+      Value *getMemRef() { return getOperand(1); }
+      void setMemRef(Value *value) { setOperand(1, value); }
+      MemRefType getMemRefType() {
+        return getMemRef()->getType().cast<MemRefType>();
+      }
+
+      operand_range getIndices() {
+        return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
+      }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
 def TensorCastOp : CastOp<"tensor_cast"> {
   let summary = "tensor cast operation";
   let description = [{
index 6b559b3..29f9b19 100644 (file)
@@ -81,7 +81,7 @@ template <typename T> static LogicalResult verifyCastOp(T op) {
 
 StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
     : Dialect(/*name=*/"std", context) {
-  addOperations<CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp, StoreOp,
+  addOperations<CondBranchOp, DmaStartOp, DmaWaitOp,
 #define GET_OP_LIST
 #include "mlir/StandardOps/Ops.cpp.inc"
                 >();
@@ -1698,23 +1698,15 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
 // LoadOp
 //===----------------------------------------------------------------------===//
 
-void LoadOp::build(Builder *builder, OperationState *result, Value *memref,
-                   ArrayRef<Value *> indices) {
-  auto memrefType = memref->getType().cast<MemRefType>();
-  result->addOperands(memref);
-  result->addOperands(indices);
-  result->types.push_back(memrefType.getElementType());
-}
-
-void LoadOp::print(OpAsmPrinter *p) {
-  *p << "load " << *getMemRef() << '[';
-  p->printOperands(getIndices());
+static void print(OpAsmPrinter *p, LoadOp op) {
+  *p << "load " << *op.getMemRef() << '[';
+  p->printOperands(op.getIndices());
   *p << ']';
-  p->printOptionalAttrDict(getAttrs());
-  *p << " : " << getMemRefType();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getMemRefType();
 }
 
-ParseResult LoadOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType memrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
   MemRefType type;
@@ -1730,23 +1722,16 @@ ParseResult LoadOp::parse(OpAsmParser *parser, OperationState *result) {
       parser->addTypeToList(type.getElementType(), result->types));
 }
 
-LogicalResult LoadOp::verify() {
-  if (getNumOperands() == 0)
-    return emitOpError("expected a memref to load from");
+static LogicalResult verify(LoadOp op) {
+  if (op.getType() != op.getMemRefType().getElementType())
+    return op.emitOpError("result type must match element type of memref");
 
-  auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
-  if (!memRefType)
-    return emitOpError("first operand must be a memref");
-
-  if (getType() != memRefType.getElementType())
-    return emitOpError("result type must match element type of memref");
-
-  if (memRefType.getRank() != getNumOperands() - 1)
-    return emitOpError("incorrect number of indices for load");
+  if (op.getMemRefType().getRank() != op.getNumOperands() - 1)
+    return op.emitOpError("incorrect number of indices for load");
 
-  for (auto *idx : getIndices())
+  for (auto *idx : op.getIndices())
     if (!idx->getType().isIndex())
-      return emitOpError("index to load must have 'index' type");
+      return op.emitOpError("index to load must have 'index' type");
 
   // TODO: Verify we have the right number of indices.
 
@@ -1982,24 +1967,16 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
 // StoreOp
 //===----------------------------------------------------------------------===//
 
-void StoreOp::build(Builder *builder, OperationState *result,
-                    Value *valueToStore, Value *memref,
-                    ArrayRef<Value *> indices) {
-  result->addOperands(valueToStore);
-  result->addOperands(memref);
-  result->addOperands(indices);
-}
-
-void StoreOp::print(OpAsmPrinter *p) {
-  *p << "store " << *getValueToStore();
-  *p << ", " << *getMemRef() << '[';
-  p->printOperands(getIndices());
+static void print(OpAsmPrinter *p, StoreOp op) {
+  *p << "store " << *op.getValueToStore();
+  *p << ", " << *op.getMemRef() << '[';
+  p->printOperands(op.getIndices());
   *p << ']';
-  p->printOptionalAttrDict(getAttrs());
-  *p << " : " << getMemRefType();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getMemRefType();
 }
 
-ParseResult StoreOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType storeValueInfo;
   OpAsmParser::OperandType memrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
@@ -2018,25 +1995,18 @@ ParseResult StoreOp::parse(OpAsmParser *parser, OperationState *result) {
       parser->resolveOperands(indexInfo, affineIntTy, result->operands));
 }
 
-LogicalResult StoreOp::verify() {
-  if (getNumOperands() < 2)
-    return emitOpError("expected a value to store and a memref");
-
-  // Second operand is a memref type.
-  auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
-  if (!memRefType)
-    return emitOpError("second operand must be a memref");
-
+static LogicalResult verify(StoreOp op) {
   // First operand must have same type as memref element type.
-  if (getValueToStore()->getType() != memRefType.getElementType())
-    return emitOpError("first operand must have same type memref element type");
+  if (op.getValueToStore()->getType() != op.getMemRefType().getElementType())
+    return op.emitOpError(
+        "first operand must have same type memref element type");
 
-  if (getNumOperands() != 2 + memRefType.getRank())
-    return emitOpError("store index operand count not equal to memref rank");
+  if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
+    return op.emitOpError("store index operand count not equal to memref rank");
 
-  for (auto *idx : getIndices())
+  for (auto *idx : op.getIndices())
     if (!idx->getType().isIndex())
-      return emitOpError("index to load must have 'index' type");
+      return op.emitOpError("index to load must have 'index' type");
 
   // TODO: Verify we have the right number of indices.