Get rid of separate getRank() on MemRef now that it subclasses ShapedType
authorGeoffrey Martin-Noble <gcmn@google.com>
Fri, 31 May 2019 20:28:37 +0000 (13:28 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:13:00 +0000 (20:13 -0700)
    We are moving towards int64_t for these shape/dimension -related values to avoid buginess with unsigned arithmetic

--

PiperOrigin-RevId: 250945322

mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp
mlir/include/mlir/IR/StandardTypes.h
mlir/lib/StandardOps/Ops.cpp
mlir/lib/VectorOps/VectorOps.cpp

index 26c298a..84234c3 100644 (file)
@@ -188,7 +188,7 @@ public:
     // dimensions, extracts the size from the memref descriptor.
     auto memrefSize = [int64Ty, pos, i64cst](MemRefType type, Value *memref,
                                              int dim) -> Value * {
-      assert(static_cast<size_t>(dim) < type.getRank());
+      assert(dim < type.getRank());
       if (type.getShape()[dim] != -1) {
         return i64cst(type.getShape()[dim]);
       }
index 1ce241a..59aad04 100644 (file)
@@ -40,7 +40,7 @@ void linalg::ViewOp::build(Builder *b, OperationState *result, Value *memRef,
                            ArrayRef<Value *> indexings) {
   MemRefType memRefType = memRef->getType().cast<MemRefType>();
   result->addOperands({memRef});
-  assert(indexings.size() == memRefType.getRank() &&
+  assert(indexings.size() == static_cast<size_t>(memRefType.getRank()) &&
          "unexpected number of indexings (must match the memref rank)");
 
   result->addOperands(indexings);
@@ -107,7 +107,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
   if (!memRefType)
     return parser->emitError(parser->getNameLoc(),
                              "memRef type expected for first type");
-  if (indexingsInfo.size() != memRefType.getRank())
+  if (indexingsInfo.size() != static_cast<size_t>(memRefType.getRank()))
     return parser->emitError(parser->getNameLoc(),
                              "expected " + Twine(memRefType.getRank()) +
                                  " indexings");
@@ -116,7 +116,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
     return parser->emitError(parser->getNameLoc(), "view type expected");
 
   ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back();
-  if (indexingTypes.size() != memRefType.getRank())
+  if (indexingTypes.size() != static_cast<size_t>(memRefType.getRank()))
     return parser->emitError(parser->getNameLoc(),
                              "expected " + Twine(memRefType.getRank()) +
                                  " indexing types");
index 713fc6c..730ceed 100644 (file)
@@ -397,9 +397,6 @@ public:
                    location);
   }
 
-  // TODO(b/132735995) Get rid of this unsigned override.
-  unsigned getRank() const { return getShape().size(); }
-
   ArrayRef<int64_t> getShape() const;
 
   /// Returns an array of affine map pointers representing the memref affine
index 66a1226..39791fd 100644 (file)
@@ -1258,11 +1258,11 @@ static LogicalResult verify(DimOp op) {
   auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
   if (!indexAttr)
     return op.emitOpError("requires an integer attribute named 'index'");
-  uint64_t index = indexAttr.getValue().getZExtValue();
+  int64_t index = indexAttr.getValue().getSExtValue();
 
   auto type = op.getOperand()->getType();
   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
-    if (index >= static_cast<uint64_t>(tensorType.getRank()))
+    if (index >= tensorType.getRank())
       return op.emitOpError("index is out of range");
   } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
     if (index >= memrefType.getRank())
@@ -1459,12 +1459,15 @@ ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
   }
 
   // Check that source/destination index list size matches associated rank.
-  if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
-      dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
+  if (srcIndexInfos.size() !=
+          static_cast<size_t>(types[0].cast<MemRefType>().getRank()) ||
+      dstIndexInfos.size() !=
+          static_cast<size_t>(types[1].cast<MemRefType>().getRank()))
     return parser->emitError(parser->getNameLoc(),
                              "memref rank not equal to indices count");
 
-  if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
+  if (tagIndexInfos.size() !=
+      static_cast<size_t>(types[2].cast<MemRefType>().getRank()))
     return parser->emitError(parser->getNameLoc(),
                              "tag memref rank not equal to indices count");
 
@@ -1543,7 +1546,8 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
     return parser->emitError(parser->getNameLoc(),
                              "expected tag to be of memref type");
 
-  if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
+  if (tagIndexInfos.size() !=
+      static_cast<size_t>(type.cast<MemRefType>().getRank()))
     return parser->emitError(parser->getNameLoc(),
                              "tag memref rank not equal to indices count");
 
index b40a1d9..f65961e 100644 (file)
@@ -148,7 +148,7 @@ ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
 
   // Extract optional paddingValue.
   // At this point, indexInfo may contain the optional paddingValue, pop it out.
-  if (indexInfo.size() != memrefType.getRank())
+  if (indexInfo.size() != static_cast<size_t>(memrefType.getRank()))
     return parser->emitError(parser->getNameLoc(),
                              "expected " + Twine(memrefType.getRank()) +
                                  " indices to the memref");