From: Jakub Kuderski Date: Wed, 30 Nov 2022 22:11:35 +0000 (-0500) Subject: [mlir][vector] Clean up use of `llvm::zip` in `VectorOps.cpp` X-Git-Tag: upstream/17.0.6~25808 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f22a573b2b8afaee88001168eeeb70c77f28a03e;p=platform%2Fupstream%2Fllvm.git [mlir][vector] Clean up use of `llvm::zip` in `VectorOps.cpp` - Use `zip_equal` where iteratees are supposted to have equal lenght. - Use `zip_first` where the first iteratee is supposed to be the shortest. - Use `llvm::enumerate` instead of calculating index manually. - Use structured bindings to unpack tuples where appropriate. - Fix a bug in a comparison in `intersectsWhereNonNegative`. Both `zip_first` (after D138858) and `zip_equal` (introduced in D138865) assert interatee lengths, which allows us to more precisely convey whether we want to iterate over the common prefix (`zip`), or expect all lengths to be the same (`zip_equal`). Reviewed By: dcaballe, antiagainst Differential Revision: https://reviews.llvm.org/D139022 --- diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 328601c..f8c10bd 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -87,10 +87,9 @@ static MaskFormat getMaskFormat(Value mask) { auto shape = m.getType().getShape(); bool allTrue = true; bool allFalse = true; - for (auto pair : llvm::zip(masks, shape)) { - int64_t i = std::get<0>(pair).cast().getInt(); - int64_t u = std::get<1>(pair); - if (i < u) + for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) { + int64_t i = maskIdx.cast().getInt(); + if (i < dimSize) allTrue = false; if (i > 0) allFalse = false; @@ -1178,10 +1177,10 @@ private: /// Comparison is on the common prefix (i.e. zip). template bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) { - for (auto it : llvm::zip(a, b)) { - if (std::get<0>(it) < 0 || std::get<0>(it) < 0) + for (auto [elemA, elemB] : llvm::zip(a, b)) { + if (elemA < 0 || elemB < 0) continue; - if (std::get<0>(it) != std::get<1>(it)) + if (elemA != elemB) return false; } return true; @@ -1729,7 +1728,8 @@ computeBroadcastedUnitDims(ArrayRef srcShape, int64_t rankDiff = dstShape.size() - srcShape.size(); int64_t dstDim = rankDiff; llvm::SetVector res; - for (auto [s1, s2] : llvm::zip(srcShape, dstShape.drop_front(rankDiff))) { + for (auto [s1, s2] : + llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) { if (s1 != s2) { assert(s1 == 1 && "expected dim-1 broadcasting"); res.insert(dstDim); @@ -2384,18 +2384,16 @@ static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef shape, StringRef attrName, bool halfOpen = true, int64_t min = 0) { - assert(arrayAttr.size() <= shape.size()); - unsigned index = 0; - for (auto it : llvm::zip(arrayAttr, shape)) { - auto val = std::get<0>(it).cast().getInt(); - auto max = std::get<1>(it); + for (auto [index, attrDimPair] : + llvm::enumerate(llvm::zip_first(arrayAttr, shape))) { + int64_t val = std::get<0>(attrDimPair).cast().getInt(); + int64_t max = std::get<1>(attrDimPair); if (!halfOpen) max += 1; if (val < min || val >= max) return op.emitOpError("expected ") << attrName << " dimension " << index << " to be confined to [" << min << ", " << max << ")"; - ++index; } return success(); } @@ -2410,8 +2408,8 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( bool halfOpen = true, int64_t min = 1) { assert(arrayAttr1.size() <= shape.size()); assert(arrayAttr2.size() <= shape.size()); - unsigned index = 0; - for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) { + for (auto [index, it] : + llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) { auto val1 = std::get<0>(it).cast().getInt(); auto val2 = std::get<1>(it).cast().getInt(); auto max = std::get<2>(it); @@ -2421,7 +2419,6 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( return op.emitOpError("expected sum(") << attrName1 << ", " << attrName2 << ") dimension " << index << " to be confined to [" << min << ", " << max << ")"; - ++index; } return success(); } @@ -2962,11 +2959,9 @@ public: // Compute slice of vector mask region. SmallVector sliceMaskDimSizes; - assert(sliceOffsets.size() == maskDimSizes.size()); - for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { - int64_t maskDimSize = std::get<0>(it); - int64_t sliceOffset = std::get<1>(it); - int64_t sliceSize = std::get<2>(it); + sliceMaskDimSizes.reserve(maskDimSizes.size()); + for (auto [maskDimSize, sliceOffset, sliceSize] : + llvm::zip_equal(maskDimSizes, sliceOffsets, sliceSizes)) { int64_t sliceMaskDimSize = std::max( static_cast(0), std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); @@ -4236,9 +4231,9 @@ public: } // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ. - for (const auto &it : - llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) { - if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) { + for (auto [insertSize, extractSize] : + llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) { + if (!isEqualConstantIntOrValue(insertSize, extractSize)) { return rewriter.notifyMatchFailure( insertOp, "InsertSliceOp and ExtractSliceOp sizes differ"); } @@ -5208,10 +5203,10 @@ public: // Gather constant mask dimension sizes. SmallVector maskDimSizes; - for (auto it : llvm::zip(createMaskOp.operands(), - createMaskOp.getType().getShape())) { - auto *defOp = std::get<0>(it).getDefiningOp(); - int64_t maxDimSize = std::get<1>(it); + maskDimSizes.reserve(createMaskOp->getNumOperands()); + for (auto [operand, maxDimSize] : llvm::zip_equal( + createMaskOp.operands(), createMaskOp.getType().getShape())) { + Operation *defOp = operand.getDefiningOp(); int64_t dimSize = cast(defOp).value(); dimSize = std::min(dimSize, maxDimSize); // If one of dim sizes is zero, set all dims to zero. @@ -5438,10 +5433,7 @@ LogicalResult ScanOp::verify() { if (i != reductionDim) expectedShape.push_back(srcShape[i]); } - if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape), - [](std::tuple s) { - return std::get<0>(s) != std::get<1>(s); - })) { + if (!llvm::equal(initialValueShapes, expectedShape)) { return emitOpError("incompatible input/initial value shapes"); } @@ -5588,8 +5580,8 @@ void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, OpBuilder::InsertionGuard guard(builder); Region *warpRegion = result.addRegion(); Block *block = builder.createBlock(warpRegion); - for (auto it : llvm::zip(blockArgTypes, args)) - block->addArgument(std::get<0>(it), std::get<1>(it).getLoc()); + for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args)) + block->addArgument(type, arg.getLoc()); } /// Helper check if the distributed vector type is consistent with the expanded @@ -5636,16 +5628,16 @@ LogicalResult WarpExecuteOnLane0Op::verify() { return emitOpError( "expected same number of yield operands and return values."); int64_t warpSize = getWarpSize(); - for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) { - if (failed(verifyDistributedType(std::get<0>(it).getType(), - std::get<1>(it).getType(), warpSize, - getOperation()))) + for (auto [regionArg, arg] : + llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) { + if (failed(verifyDistributedType(regionArg.getType(), arg.getType(), + warpSize, getOperation()))) return failure(); } - for (auto it : llvm::zip(yield.getOperands(), getResults())) { - if (failed(verifyDistributedType(std::get<0>(it).getType(), - std::get<1>(it).getType(), warpSize, - getOperation()))) + for (auto [yieldOperand, result] : + llvm::zip_equal(yield.getOperands(), getResults())) { + if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(), + warpSize, getOperation()))) return failure(); } return success();