Static cast size_t -> int64_t instead of vice versa for equals comparisons
authorGeoffrey Martin-Noble <gcmn@google.com>
Fri, 31 May 2019 23:41:21 +0000 (16:41 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:13:49 +0000 (20:13 -0700)
    These were just introduced by a previous CL moving MemRef getRank to return int64_t. size_t could be smaller than 64 bits and in equals comparisons, signed vs unsigned doesn't matter. In these cases, we know right now that the particular int64_t is not larger than max size_t (because it currently comes directly from a size() call), the alternative cast plus equals comparison is always safe, so we might as well do it that way and no longer require reasoning deeper into the callstack.

    We are already assuming that size() calls fit into int64_t in a number of other cases like the aforementioned getRank() (since exabytes of RAM are rare). If we want to avoid this assumption we will have to come up with a principled way to do it throughout.

--

PiperOrigin-RevId: 250980297

mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp
mlir/lib/StandardOps/Ops.cpp
mlir/lib/VectorOps/VectorOps.cpp

index 59aad04..0e9a4c3 100644 (file)
@@ -40,7 +40,7 @@ void linalg::ViewOp::build(Builder *b, OperationState *result, Value *memRef,
                            ArrayRef<Value *> indexings) {
   MemRefType memRefType = memRef->getType().cast<MemRefType>();
   result->addOperands({memRef});
-  assert(indexings.size() == static_cast<size_t>(memRefType.getRank()) &&
+  assert(static_cast<int64_t>(indexings.size()) == memRefType.getRank() &&
          "unexpected number of indexings (must match the memref rank)");
 
   result->addOperands(indexings);
@@ -107,7 +107,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
   if (!memRefType)
     return parser->emitError(parser->getNameLoc(),
                              "memRef type expected for first type");
-  if (indexingsInfo.size() != static_cast<size_t>(memRefType.getRank()))
+  if (static_cast<int64_t>(indexingsInfo.size()) != memRefType.getRank())
     return parser->emitError(parser->getNameLoc(),
                              "expected " + Twine(memRefType.getRank()) +
                                  " indexings");
@@ -116,7 +116,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
     return parser->emitError(parser->getNameLoc(), "view type expected");
 
   ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back();
-  if (indexingTypes.size() != static_cast<size_t>(memRefType.getRank()))
+  if (static_cast<int64_t>(indexingTypes.size()) != memRefType.getRank())
     return parser->emitError(parser->getNameLoc(),
                              "expected " + Twine(memRefType.getRank()) +
                                  " indexing types");
index 39791fd..b97c149 100644 (file)
@@ -338,7 +338,8 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
     auto newMemRefType = MemRefType::get(
         newShapeConstants, memrefType.getElementType(),
         memrefType.getAffineMaps(), memrefType.getMemorySpace());
-    assert(newOperands.size() == newMemRefType.getNumDynamicDims());
+    assert(static_cast<int64_t>(newOperands.size()) ==
+           newMemRefType.getNumDynamicDims());
 
     // Create and insert the alloc op for the new memref.
     auto newAlloc =
@@ -1459,15 +1460,15 @@ ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
   }
 
   // Check that source/destination index list size matches associated rank.
-  if (srcIndexInfos.size() !=
-          static_cast<size_t>(types[0].cast<MemRefType>().getRank()) ||
-      dstIndexInfos.size() !=
-          static_cast<size_t>(types[1].cast<MemRefType>().getRank()))
+  if (static_cast<int64_t>(srcIndexInfos.size()) !=
+          types[0].cast<MemRefType>().getRank() ||
+      static_cast<int64_t>(dstIndexInfos.size()) !=
+          types[1].cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "memref rank not equal to indices count");
 
-  if (tagIndexInfos.size() !=
-      static_cast<size_t>(types[2].cast<MemRefType>().getRank()))
+  if (static_cast<int64_t>(tagIndexInfos.size()) !=
+      types[2].cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "tag memref rank not equal to indices count");
 
@@ -1546,8 +1547,8 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
     return parser->emitError(parser->getNameLoc(),
                              "expected tag to be of memref type");
 
-  if (tagIndexInfos.size() !=
-      static_cast<size_t>(type.cast<MemRefType>().getRank()))
+  if (static_cast<int64_t>(tagIndexInfos.size()) !=
+      type.cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "tag memref rank not equal to indices count");
 
index f65961e..23b2f99 100644 (file)
@@ -148,7 +148,7 @@ ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
 
   // Extract optional paddingValue.
   // At this point, indexInfo may contain the optional paddingValue, pop it out.
-  if (indexInfo.size() != static_cast<size_t>(memrefType.getRank()))
+  if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
     return parser->emitError(parser->getNameLoc(),
                              "expected " + Twine(memrefType.getRank()) +
                                  " indices to the memref");