From c912981bbd6d94218d66546aa22620660894f0ae Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 31 May 2019 16:41:21 -0700 Subject: [PATCH] Static cast size_t -> int64_t instead of vice versa for equals comparisons These were just introduced by a previous CL moving MemRef getRank to return int64_t. size_t could be smaller than 64 bits and in equals comparisons, signed vs unsigned doesn't matter. In these cases, we know right now that the particular int64_t is not larger than max size_t (because it currently comes directly from a size() call), the alternative cast plus equals comparison is always safe, so we might as well do it that way and no longer require reasoning deeper into the callstack. We are already assuming that size() calls fit into int64_t in a number of other cases like the aforementioned getRank() (since exabytes of RAM are rare). If we want to avoid this assumption we will have to come up with a principled way to do it throughout. -- PiperOrigin-RevId: 250980297 --- mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp | 6 +++--- mlir/lib/StandardOps/Ops.cpp | 19 ++++++++++--------- mlir/lib/VectorOps/VectorOps.cpp | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp b/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp index 59aad04..0e9a4c3 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp @@ -40,7 +40,7 @@ void linalg::ViewOp::build(Builder *b, OperationState *result, Value *memRef, ArrayRef indexings) { MemRefType memRefType = memRef->getType().cast(); result->addOperands({memRef}); - assert(indexings.size() == static_cast(memRefType.getRank()) && + assert(static_cast(indexings.size()) == memRefType.getRank() && "unexpected number of indexings (must match the memref rank)"); result->addOperands(indexings); @@ -107,7 +107,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { if (!memRefType) return parser->emitError(parser->getNameLoc(), "memRef type expected for first type"); - if (indexingsInfo.size() != static_cast(memRefType.getRank())) + if (static_cast(indexingsInfo.size()) != memRefType.getRank()) return parser->emitError(parser->getNameLoc(), "expected " + Twine(memRefType.getRank()) + " indexings"); @@ -116,7 +116,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc(), "view type expected"); ArrayRef indexingTypes = ArrayRef(types).drop_front().drop_back(); - if (indexingTypes.size() != static_cast(memRefType.getRank())) + if (static_cast(indexingTypes.size()) != memRefType.getRank()) return parser->emitError(parser->getNameLoc(), "expected " + Twine(memRefType.getRank()) + " indexing types"); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 39791fd..b97c149 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -338,7 +338,8 @@ struct SimplifyAllocConst : public OpRewritePattern { auto newMemRefType = MemRefType::get( newShapeConstants, memrefType.getElementType(), memrefType.getAffineMaps(), memrefType.getMemorySpace()); - assert(newOperands.size() == newMemRefType.getNumDynamicDims()); + assert(static_cast(newOperands.size()) == + newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = @@ -1459,15 +1460,15 @@ ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { } // Check that source/destination index list size matches associated rank. - if (srcIndexInfos.size() != - static_cast(types[0].cast().getRank()) || - dstIndexInfos.size() != - static_cast(types[1].cast().getRank())) + if (static_cast(srcIndexInfos.size()) != + types[0].cast().getRank() || + static_cast(dstIndexInfos.size()) != + types[1].cast().getRank()) return parser->emitError(parser->getNameLoc(), "memref rank not equal to indices count"); - if (tagIndexInfos.size() != - static_cast(types[2].cast().getRank())) + if (static_cast(tagIndexInfos.size()) != + types[2].cast().getRank()) return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); @@ -1546,8 +1547,8 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc(), "expected tag to be of memref type"); - if (tagIndexInfos.size() != - static_cast(type.cast().getRank())) + if (static_cast(tagIndexInfos.size()) != + type.cast().getRank()) return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); diff --git a/mlir/lib/VectorOps/VectorOps.cpp b/mlir/lib/VectorOps/VectorOps.cpp index f65961e..23b2f99 100644 --- a/mlir/lib/VectorOps/VectorOps.cpp +++ b/mlir/lib/VectorOps/VectorOps.cpp @@ -148,7 +148,7 @@ ParseResult VectorTransferReadOp::parse(OpAsmParser *parser, // Extract optional paddingValue. // At this point, indexInfo may contain the optional paddingValue, pop it out. - if (indexInfo.size() != static_cast(memrefType.getRank())) + if (static_cast(indexInfo.size()) != memrefType.getRank()) return parser->emitError(parser->getNameLoc(), "expected " + Twine(memrefType.getRank()) + " indices to the memref"); -- 2.7.4