[mlir][vector] Clean up use of `llvm::zip` in `VectorOps.cpp`
authorJakub Kuderski <kubak@google.com>
Wed, 30 Nov 2022 22:11:35 +0000 (17:11 -0500)
committerJakub Kuderski <kubak@google.com>
Wed, 30 Nov 2022 22:13:04 +0000 (17:13 -0500)
- Use `zip_equal` where iteratees are supposted to have equal lenght.
- Use `zip_first` where the first iteratee is supposed to be the
  shortest.
- Use `llvm::enumerate` instead of calculating index manually.
- Use structured bindings to unpack tuples where appropriate.
- Fix a bug in a comparison in `intersectsWhereNonNegative`.

Both `zip_first` (after D138858) and `zip_equal` (introduced in D138865)
assert interatee lengths, which allows us to more precisely convey
whether we want to iterate over the common prefix (`zip`), or expect all
lengths to be the same (`zip_equal`).

Reviewed By: dcaballe, antiagainst

Differential Revision: https://reviews.llvm.org/D139022

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

index 328601c..f8c10bd 100644 (file)
@@ -87,10 +87,9 @@ static MaskFormat getMaskFormat(Value mask) {
     auto shape = m.getType().getShape();
     bool allTrue = true;
     bool allFalse = true;
-    for (auto pair : llvm::zip(masks, shape)) {
-      int64_t i = std::get<0>(pair).cast<IntegerAttr>().getInt();
-      int64_t u = std::get<1>(pair);
-      if (i < u)
+    for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
+      int64_t i = maskIdx.cast<IntegerAttr>().getInt();
+      if (i < dimSize)
         allTrue = false;
       if (i > 0)
         allFalse = false;
@@ -1178,10 +1177,10 @@ private:
   /// Comparison is on the common prefix (i.e. zip).
   template <typename ContainerA, typename ContainerB>
   bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
-    for (auto it : llvm::zip(a, b)) {
-      if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
+    for (auto [elemA, elemB] : llvm::zip(a, b)) {
+      if (elemA < 0 || elemB < 0)
         continue;
-      if (std::get<0>(it) != std::get<1>(it))
+      if (elemA != elemB)
         return false;
     }
     return true;
@@ -1729,7 +1728,8 @@ computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
   int64_t rankDiff = dstShape.size() - srcShape.size();
   int64_t dstDim = rankDiff;
   llvm::SetVector<int64_t> res;
-  for (auto [s1, s2] : llvm::zip(srcShape, dstShape.drop_front(rankDiff))) {
+  for (auto [s1, s2] :
+       llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
     if (s1 != s2) {
       assert(s1 == 1 && "expected dim-1 broadcasting");
       res.insert(dstDim);
@@ -2384,18 +2384,16 @@ static LogicalResult
 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
                                   ArrayRef<int64_t> shape, StringRef attrName,
                                   bool halfOpen = true, int64_t min = 0) {
-  assert(arrayAttr.size() <= shape.size());
-  unsigned index = 0;
-  for (auto it : llvm::zip(arrayAttr, shape)) {
-    auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
-    auto max = std::get<1>(it);
+  for (auto [index, attrDimPair] :
+       llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
+    int64_t val = std::get<0>(attrDimPair).cast<IntegerAttr>().getInt();
+    int64_t max = std::get<1>(attrDimPair);
     if (!halfOpen)
       max += 1;
     if (val < min || val >= max)
       return op.emitOpError("expected ")
              << attrName << " dimension " << index << " to be confined to ["
              << min << ", " << max << ")";
-    ++index;
   }
   return success();
 }
@@ -2410,8 +2408,8 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
     bool halfOpen = true, int64_t min = 1) {
   assert(arrayAttr1.size() <= shape.size());
   assert(arrayAttr2.size() <= shape.size());
-  unsigned index = 0;
-  for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
+  for (auto [index, it] :
+       llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
     auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
     auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
     auto max = std::get<2>(it);
@@ -2421,7 +2419,6 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
       return op.emitOpError("expected sum(")
              << attrName1 << ", " << attrName2 << ") dimension " << index
              << " to be confined to [" << min << ", " << max << ")";
-    ++index;
   }
   return success();
 }
@@ -2962,11 +2959,9 @@ public:
 
     // Compute slice of vector mask region.
     SmallVector<int64_t, 4> sliceMaskDimSizes;
-    assert(sliceOffsets.size() == maskDimSizes.size());
-    for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
-      int64_t maskDimSize = std::get<0>(it);
-      int64_t sliceOffset = std::get<1>(it);
-      int64_t sliceSize = std::get<2>(it);
+    sliceMaskDimSizes.reserve(maskDimSizes.size());
+    for (auto [maskDimSize, sliceOffset, sliceSize] :
+         llvm::zip_equal(maskDimSizes, sliceOffsets, sliceSizes)) {
       int64_t sliceMaskDimSize = std::max(
           static_cast<int64_t>(0),
           std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
@@ -4236,9 +4231,9 @@ public:
     }
 
     // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
-    for (const auto &it :
-         llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
-      if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) {
+    for (auto [insertSize, extractSize] :
+         llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
+      if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
         return rewriter.notifyMatchFailure(
             insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
       }
@@ -5208,10 +5203,10 @@ public:
 
     // Gather constant mask dimension sizes.
     SmallVector<int64_t, 4> maskDimSizes;
-    for (auto it : llvm::zip(createMaskOp.operands(),
-                             createMaskOp.getType().getShape())) {
-      auto *defOp = std::get<0>(it).getDefiningOp();
-      int64_t maxDimSize = std::get<1>(it);
+    maskDimSizes.reserve(createMaskOp->getNumOperands());
+    for (auto [operand, maxDimSize] : llvm::zip_equal(
+             createMaskOp.operands(), createMaskOp.getType().getShape())) {
+      Operation *defOp = operand.getDefiningOp();
       int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
       dimSize = std::min(dimSize, maxDimSize);
       // If one of dim sizes is zero, set all dims to zero.
@@ -5438,10 +5433,7 @@ LogicalResult ScanOp::verify() {
     if (i != reductionDim)
       expectedShape.push_back(srcShape[i]);
   }
-  if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
-                   [](std::tuple<int64_t, int64_t> s) {
-                     return std::get<0>(s) != std::get<1>(s);
-                   })) {
+  if (!llvm::equal(initialValueShapes, expectedShape)) {
     return emitOpError("incompatible input/initial value shapes");
   }
 
@@ -5588,8 +5580,8 @@ void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
   OpBuilder::InsertionGuard guard(builder);
   Region *warpRegion = result.addRegion();
   Block *block = builder.createBlock(warpRegion);
-  for (auto it : llvm::zip(blockArgTypes, args))
-    block->addArgument(std::get<0>(it), std::get<1>(it).getLoc());
+  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
+    block->addArgument(type, arg.getLoc());
 }
 
 /// Helper check if the distributed vector type is consistent with the expanded
@@ -5636,16 +5628,16 @@ LogicalResult WarpExecuteOnLane0Op::verify() {
     return emitOpError(
         "expected same number of yield operands and return values.");
   int64_t warpSize = getWarpSize();
-  for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) {
-    if (failed(verifyDistributedType(std::get<0>(it).getType(),
-                                     std::get<1>(it).getType(), warpSize,
-                                     getOperation())))
+  for (auto [regionArg, arg] :
+       llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
+    if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
+                                     warpSize, getOperation())))
       return failure();
   }
-  for (auto it : llvm::zip(yield.getOperands(), getResults())) {
-    if (failed(verifyDistributedType(std::get<0>(it).getType(),
-                                     std::get<1>(it).getType(), warpSize,
-                                     getOperation())))
+  for (auto [yieldOperand, result] :
+       llvm::zip_equal(yield.getOperands(), getResults())) {
+    if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
+                                     warpSize, getOperation())))
       return failure();
   }
   return success();