[mlir][Transform] NFC - Refactor forall mapping to threads and blocks into one thing
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 14 Mar 2023 21:37:58 +0000 (14:37 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 15 Mar 2023 12:09:39 +0000 (05:09 -0700)
Differential Revision: https://reviews.llvm.org/D146095

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp

index 7c6aa7e..579922a 100644 (file)
@@ -39,11 +39,11 @@ namespace gpu {
 /// Dynamic, `scf.forall` trip counts are currently not supported.
 /// Dynamic block dim sizes are currently not supported.
 DiagnosedSilenceableFailure mapForallToBlocksImpl(
-    RewriterBase &rewriter, scf::ForallOp forallOp,
+    RewriterBase &rewriter, TransformOpInterface transformOp,
+    scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
+    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes,
     function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
-        blockIdGenerator,
-    SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
-    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);
+        blockIdGenerator);
 
 /// Search `scf.forall` ops nested under `target` and map each such op to GPU
 /// threads. Mapping is one-to-one and the induction variables of `scf.forall`
@@ -54,12 +54,12 @@ DiagnosedSilenceableFailure mapForallToBlocksImpl(
 /// Dynamic, `scf.forall` trip counts are currently not supported.
 /// Dynamic block dim sizes are currently not supported.
 DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(
-    RewriterBase &rewriter, Operation *target,
-    const SmallVectorImpl<int64_t> &kernelBlockDims,
+    RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
+    Operation *target, const SmallVectorImpl<int64_t> &kernelBlockDims,
+    bool syncAfterDistribute,
+    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes,
     function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
-        threadIdGenerator,
-    bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
-    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
+        threadIdGenerator);
 
 /// Find the unique top level scf::ForallOp within a given target op.
 DiagnosedSilenceableFailure
index 930bf46..27c2775 100644 (file)
@@ -124,6 +124,9 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
 SmallVector<OpFoldResult>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
                      llvm::function_ref<bool(Attribute, Attribute)> compare);
+SmallVector<int64_t>
+getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
+                     llvm::function_ref<bool(Attribute, Attribute)> compare);
 
 } // namespace mlir
 
index 93f00fe..6d87604 100644 (file)
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
 
 using namespace mlir;
 using namespace mlir::gpu;
 using namespace mlir::transform;
 
+#define DEBUG_TYPE "gpu-transforms"
+
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 
-/// Helper type forfunctions that generate ids for the mapping of a scf.forall.
+/// Helper type for functions that generate ids for the mapping of a scf.forall.
 using IdGeneratorFnType = llvm::function_ref<void(RewriterBase &, scf::ForallOp,
                                                   SmallVectorImpl<Value> &)>;
 
@@ -86,7 +95,7 @@ static DiagnosedSilenceableFailure
 failureHelper(std::optional<TransformOpInterface> transformOp,
               scf::ForallOp forallOp, const Twine &message) {
   if (transformOp.has_value())
-    return transformOp->emitSilenceableError() << message;
+    return emitDefiniteFailure(*transformOp, message);
   return emitDefiniteFailure(forallOp, message);
 }
 
@@ -273,30 +282,35 @@ alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch,
 // MapForallToBlocks
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
-    RewriterBase &rewriter, scf::ForallOp forallOp,
-    IdGeneratorFnType blockIdGenerator, SmallVectorImpl<int64_t> &gridDims,
-    TransformOpInterface transformOp,
-    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
+static FailureOr<SmallVector<int64_t>> rewriteOneForallCommonImpl(
+    RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
+    scf::ForallOp forallOp,
+    const SmallVectorImpl<int64_t> &availableMappingSizes,
+    const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
+    IdGeneratorFnType idGenerator) {
+  LDBG("Start rewriteOneForallCommonImpl");
 
   // Step 0. GPU-specific verifications. There is no better place to anchor
   // those right now: the ForallOp is target-independent and the transform op
   // does not apply to individual ForallOp.
   DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp);
   if (!diag.succeeded())
-    return diag;
-
-  SmallVector<Attribute> blockMapping =
+    return failure();
+
+  // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
+  SmallVector<int64_t> tmpMappingSizes = llvm::to_vector(
+      llvm::map_range(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) {
+        auto maybeStaticValue = getConstantIntValue(ofr);
+        assert(maybeStaticValue && "expected static value");
+        return maybeStaticValue.value();
+      }));
+  SmallVector<Attribute> forallMappings =
       llvm::to_vector(forallOp.getMapping()->getValue());
-
-  // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
-  SmallVector<OpFoldResult> numBlocks = forallOp.getMixedUpperBound();
-  // Ensure we have 3 block sizes, one for each id.
-  for (auto attr : mappingAttributes) {
-    if (!llvm::is_contained(blockMapping, attr)) {
-      blockMapping.push_back(attr);
-      numBlocks.push_back(rewriter.getIndexAttr(1));
-    }
+  for (auto attr : allMappingAttributes) {
+    if (llvm::is_contained(forallMappings, attr))
+      continue;
+    forallMappings.push_back(attr);
+    tmpMappingSizes.push_back(1);
   }
 
   // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
@@ -304,43 +318,116 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
                         DeviceMappingAttrInterface b) -> bool {
     return a.getMappingId() < b.getMappingId();
   };
-  SmallVector<OpFoldResult> gridDimValues =
-      getValuesSortedByKey(blockMapping, numBlocks, comparator);
-  gridDims =
-      llvm::to_vector(llvm::map_range(gridDimValues, [](OpFoldResult ofr) {
-        return getConstantIntValue(ofr).value();
-      }));
+  SmallVector<int64_t> mappingSizes =
+      getValuesSortedByKey(forallMappings, tmpMappingSizes, comparator);
+  LLVM_DEBUG(llvm::interleaveComma(mappingSizes, DBGS() << "mappingSizes: ");
+             llvm::dbgs() << "\n";
+             llvm::interleaveComma(forallMappings, DBGS() << "mappingAttrs: ");
+             llvm::dbgs() << "\n");
+
+  // Step 3. Generate the mappingIdOps using the provided generator and map the
+  // induction variables to the newly created ops. Replace ids of dimension
+  // known to be of size 1 by zero to simplify the IR.
+  SmallVector<Value> mappingIdOps;
+  Location loc = forallOp.getLoc();
+  idGenerator(rewriter, forallOp, mappingIdOps);
+  LLVM_DEBUG(llvm::interleaveComma(mappingIdOps, DBGS() << "mappingIdOps: ");
+             llvm::dbgs() << "\n");
+  assert(mappingIdOps.size() == mappingSizes.size() && "expect equal sizes");
+  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  if (!availableMappingSizes.empty()) {
+    for (size_t i : llvm::seq(size_t(0), availableMappingSizes.size())) {
+      if (availableMappingSizes[i] == 1)
+        mappingIdOps[i] = zero;
+    }
+  }
 
-  // Step 3. Generate the blockids using the provided generator and map the
-  // induction variables to the newly created ops.
-  SmallVector<Value> blockOps;
-  blockIdGenerator(rewriter, forallOp, blockOps);
   IRMapping bvm;
-  for (auto [blockIdx, blockDim] :
-       llvm::zip(forallOp.getInductionVars(), blockMapping)) {
-    bvm.map(blockIdx,
-            blockOps[static_cast<int64_t>(
-                blockDim.cast<DeviceMappingAttrInterface>().getMappingId())]);
+  for (auto [iv, dim] :
+       llvm::zip_equal(forallOp.getInductionVars(),
+                       ArrayRef<Attribute>{forallMappings}.take_front(
+                           forallOp.getInductionVars().size()))) {
+    Value peIdOp = mappingIdOps[static_cast<int64_t>(
+        dim.cast<DeviceMappingAttrInterface>().getMappingId())];
+    bvm.map(iv, peIdOp);
   }
 
-  // Step 4. Move the body of forallOp.
-  // Erase the terminator first, it will not be used since we are on buffers.
+  // Step 4. Maybe create conditionals to predicate the region.
+  // Skip this step when availableMappingSizes is empty.
+  Value predicate;
+  if (!availableMappingSizes.empty()) {
+    LLVM_DEBUG(llvm::interleaveComma(availableMappingSizes,
+                                     DBGS() << "availableMappingSizes: ");
+               llvm::dbgs() << "\n");
+    for (auto [id, mappingSize, availableMappingSize] :
+         llvm::zip_equal(mappingIdOps, mappingSizes, availableMappingSizes)) {
+      if (mappingSize > availableMappingSize) {
+        (void)failureHelper(
+            transformOp, forallOp,
+            "Trying to map to fewer GPU threads than loop iterations but "
+            "overprovisioning is not yet supported. "
+            "Try additional tiling of the before mapping or map to more "
+            "threads.");
+        return failure();
+      }
+      if (mappingSize == availableMappingSize)
+        continue;
+      Value idx = rewriter.create<arith::ConstantIndexOp>(loc, mappingSize);
+      Value tmpPredicate = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::ult, id, idx);
+      LDBG("predicate: " << tmpPredicate);
+      predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
+                                                             tmpPredicate)
+                            : tmpPredicate;
+    }
+  }
+
+  // Step 5. Move the body of forallOp.
+  // Erase the terminator first, it will not be used.
   rewriter.eraseOp(forallOp.getTerminator());
-  Block *targetBlock = forallOp->getBlock();
-  Block::iterator insertionPoint = Block::iterator(forallOp);
+  Block *targetBlock;
+  Block::iterator insertionPoint;
+  if (predicate) {
+    // Step 5.a. If predicated, move at the beginning.
+    auto ifOp =
+        rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
+    targetBlock = ifOp.thenBlock();
+    insertionPoint = ifOp.thenBlock()->begin();
+  } else {
+    // Step 5.b. Otherwise, move inline just at the rewriter insertion point.
+    targetBlock = forallOp->getBlock();
+    insertionPoint = rewriter.getInsertionPoint();
+  }
   Block &sourceBlock = forallOp.getRegion().front();
   targetBlock->getOperations().splice(insertionPoint,
                                       sourceBlock.getOperations());
 
-  // Step 5. RAUW thread indices to thread ops.
+  // Step 6. RAUW thread indices to thread ops.
   for (Value loopIndex : forallOp.getInductionVars()) {
-    Value blockIdx = bvm.lookup(loopIndex);
-    rewriter.replaceAllUsesWith(loopIndex, blockIdx);
+    Value threadIdx = bvm.lookup(loopIndex);
+    rewriter.replaceAllUsesWith(loopIndex, threadIdx);
   }
 
-  // Step 6. Erase old op.
+  // Step 7. Erase old op.
   rewriter.eraseOp(forallOp);
 
+  return mappingSizes;
+}
+
+DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
+    RewriterBase &rewriter, TransformOpInterface transformOp,
+    scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
+    const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
+    IdGeneratorFnType idGenerator) {
+  // Pass an empty anyAvailableMappingSizes.
+  SmallVector<int64_t> anyAvailableMappingSizes;
+  FailureOr<SmallVector<int64_t>> maybeMappingSizes =
+      rewriteOneForallCommonImpl(rewriter, transformOp, forallOp,
+                                 anyAvailableMappingSizes, allMappingAttributes,
+                                 idGenerator);
+  if (failed(maybeMappingSizes))
+    return DiagnosedSilenceableFailure::definiteFailure();
+  gridDims = *maybeMappingSizes;
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -389,8 +476,8 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
     return diag;
   }
 
-  SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
-  if (!getGenerateGpuLaunch() && gridDim.size() != 3)
+  SmallVector<int64_t> gridDims = extractFromI64ArrayAttr(getGridDim());
+  if (!getGenerateGpuLaunch() && gridDims.size() != 3)
     return transformOp.emitDefiniteFailure("transform require size-3 mapping");
 
   OpBuilder::InsertionGuard guard(rewriter);
@@ -415,14 +502,14 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
 
   MappingToGpuBlocksHelper helper(getContext());
   diag = mlir::transform::gpu::mapForallToBlocksImpl(
-      rewriter, topLevelForallOp, helper.idGenerator, gridDim, transformOp,
-      helper.mappingAttributes);
+      rewriter, transformOp, topLevelForallOp, gridDims,
+      helper.mappingAttributes, helper.idGenerator);
   if (!diag.succeeded())
     return diag;
 
   diag = alterGpuLaunch(rewriter, gpuLaunch,
-                        cast<TransformOpInterface>(getOperation()), gridDim[0],
-                        gridDim[1], gridDim[2]);
+                        cast<TransformOpInterface>(getOperation()), gridDims[0],
+                        gridDims[1], gridDims[2]);
 
   results.push_back(gpuLaunch);
   return diag;
@@ -432,147 +519,33 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
 // MapNestedForallToThreads
 //===----------------------------------------------------------------------===//
 
-static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads(
-    RewriterBase &rewriter, scf::ForallOp forallOp,
-    const SmallVectorImpl<int64_t> &kernelBlockDims,
-    const SmallVectorImpl<Value> &threadOps, bool syncAfterDistribute,
-    std::optional<TransformOpInterface> transformOp,
-    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
-
-  // Step 0. GPU-specific verifications. There is no better place to anchor
-  // those right now: the ForallOp is target-independent and the transform op
-  // does not apply to individual ForallOp.
-  DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp);
-  if (!diag.succeeded())
-    return diag;
-
-  Location loc = forallOp->getLoc();
-
-  SmallVector<Attribute> mapping =
-      llvm::to_vector(forallOp.getMapping()->getValue());
-
-  // Step 1. Complete the mapping to a full mapping (with 1s) if
-  // necessary.
-  SmallVector<OpFoldResult> numThreads = forallOp.getMixedUpperBound();
-  Attribute one = rewriter.getIndexAttr(1);
-  for (auto attr : mappingAttributes) {
-    if (std::find(mapping.begin(), mapping.end(), attr) == mapping.end()) {
-      mapping.push_back(attr);
-      numThreads.push_back(one);
-    }
-  }
-
-  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
-  auto comparator = [&](DeviceMappingAttrInterface a,
-                        DeviceMappingAttrInterface b) -> bool {
-    return a.getMappingId() < b.getMappingId();
-  };
-  SmallVector<OpFoldResult> blockDimValues =
-      getValuesSortedByKey(mapping, numThreads, comparator);
-  SmallVector<int64_t> blockDims =
-      llvm::to_vector(llvm::map_range(blockDimValues, [](OpFoldResult ofr) {
-        return getConstantIntValue(ofr).value();
-      }));
-
-  // Step 3. Create the gpu.thread ops and map the induction variables to the
-  // newly created ops.
-  // Replace ids of dimension size 1 by zero to simplify the IR.
-  // TODO
-  SmallVector<Value> threadOpsUpdated(threadOps.begin(), threadOps.end());
-  assert(threadOps.size() == kernelBlockDims.size());
-  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  for (size_t i : llvm::seq(size_t(0), kernelBlockDims.size())) {
-    if (kernelBlockDims[i] == 1)
-      threadOpsUpdated[i] = zero;
-  }
-  IRMapping bvm;
-  for (auto [threadIdx, blockDim] :
-       llvm::zip(forallOp.getInductionVars(), mapping)) {
-    bvm.map(threadIdx,
-            threadOpsUpdated[blockDim.cast<DeviceMappingAttrInterface>()
-                                 .getMappingId()]);
-  }
-
-  // Step 4. Maybe create conditionals to predicate the region.
-  Value predicate;
-  for (auto [threadId, blockDim, globalBlockDim] :
-       llvm::zip(threadOpsUpdated, blockDims, kernelBlockDims)) {
-    if (blockDim > globalBlockDim) {
-      return failureHelper(
-          transformOp, forallOp,
-          "Trying to map to fewer GPU threads than loop iterations but "
-          "overprovisioning is not yet supported. "
-          "Try additional tiling of the before mapping or map to more "
-          "threads.");
-    }
-    if (blockDim == globalBlockDim)
-      continue;
-    Value threadIdx = rewriter.create<arith::ConstantIndexOp>(loc, blockDim);
-    Value tmpPredicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ult, threadId, threadIdx);
-    predicate =
-        predicate ? rewriter.create<arith::AndIOp>(loc, predicate, tmpPredicate)
-                  : tmpPredicate;
-  }
-
-  // Step 5. Move the body of forallOp.
-  // Erase the terminator first, it will not be used.
-  rewriter.eraseOp(forallOp.getTerminator());
-  Block *targetBlock;
-  Block::iterator insertionPoint;
-  if (predicate) {
-    // Step 5.a. If predicated, move at the beginning.
-    auto ifOp =
-        rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
-    targetBlock = ifOp.thenBlock();
-    insertionPoint = ifOp.thenBlock()->begin();
-  } else {
-    // Step 5.b. Otherwise, move inline just before forallOp.
-    targetBlock = forallOp->getBlock();
-    insertionPoint = Block::iterator(forallOp);
-  }
-  Block &sourceBlock = forallOp.getRegion().front();
-  targetBlock->getOperations().splice(insertionPoint,
-                                      sourceBlock.getOperations());
-
-  // Step 6. RAUW thread indices to thread ops.
-  for (Value loopIndex : forallOp.getInductionVars()) {
-    Value threadIdx = bvm.lookup(loopIndex);
-    rewriter.replaceAllUsesWith(loopIndex, threadIdx);
-  }
-
-  // Step 7. syncthreads.
-  // TODO: Need warpsync
-  if (syncAfterDistribute)
-    rewriter.create<BarrierOp>(loc);
-
-  // Step 8. Erase old op.
-  rewriter.eraseOp(forallOp);
-
-  return DiagnosedSilenceableFailure::success();
-}
-
 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
-    RewriterBase &rewriter, Operation *target,
-    const SmallVectorImpl<int64_t> &blockDim, IdGeneratorFnType idGenerator,
-    bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
-    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
+    RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
+    Operation *target, const SmallVectorImpl<int64_t> &kernelBlockDims,
+    bool syncAfterDistribute,
+    const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
+    IdGeneratorFnType idGenerator) {
   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
   target->walk([&](scf::ForallOp forallOp) {
     // Ignore cases with different attributes.
     for (Attribute map : forallOp.getMapping()->getValue()) {
-      if (!llvm::is_contained(mappingAttributes, map)) {
+      if (!llvm::is_contained(allMappingAttributes, map)) {
         return WalkResult::skip();
       }
     }
     diag = verifyGpuMapping(transformOp, forallOp);
     if (diag.succeeded()) {
-      rewriter.setInsertionPoint(forallOp);
-      SmallVector<Value> threadOps;
-      idGenerator(rewriter, forallOp, threadOps);
-      diag = rewriteOneForallToGpuThreads(rewriter, forallOp, blockDim,
-                                          threadOps, syncAfterDistribute,
-                                          transformOp, mappingAttributes);
+      // Take the loc ahead of time
+      Location loc = forallOp.getLoc();
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPointAfter(forallOp);
+      if (failed(rewriteOneForallCommonImpl(rewriter, transformOp, forallOp,
+                                            kernelBlockDims,
+                                            allMappingAttributes, idGenerator)))
+        diag = DiagnosedSilenceableFailure::definiteFailure();
+      // Add a syncthreads if needed. TODO: warpsync
+      if (syncAfterDistribute)
+        rewriter.create<BarrierOp>(loc);
     }
     return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
   });
@@ -588,13 +561,13 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
   if (!gpuLaunch)
     return emitSilenceableError() << "Given target is not a gpu.launch";
 
-  SmallVector<int64_t> blockDim = extractFromI64ArrayAttr(getBlockDim());
-  if (blockDim.size() != 3)
+  SmallVector<int64_t> blockDims = extractFromI64ArrayAttr(getBlockDim());
+  if (blockDims.size() != 3)
     return transformOp.emitDefiniteFailure("transform require size-3 mapping");
 
   DiagnosedSilenceableFailure diag =
       checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
-                     blockDim[0], blockDim[1], blockDim[2]);
+                     blockDims[0], blockDims[1], blockDims[2]);
   if (diag.isSilenceableFailure()) {
     diag.attachNote(getLoc()) << getBlockDimAttrName() << " is too large";
     return diag;
@@ -602,18 +575,17 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
 
   MLIRContext *ctx = getContext();
   IRRewriter rewriter(ctx);
-  rewriter.setInsertionPoint(target);
   MappingToGpuThreadsHelper helper(ctx);
   diag = mlir::transform::gpu::mapNestedForallToThreadsImpl(
-      rewriter, target, blockDim, helper.idGenerator, getSyncAfterDistribute(),
-      transformOp, helper.mappingAttributes);
+      rewriter, transformOp, target, blockDims, getSyncAfterDistribute(),
+      helper.mappingAttributes, helper.idGenerator);
 
   if (!diag.succeeded())
     return diag;
 
   diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
-                        std::nullopt, std::nullopt, blockDim[0], blockDim[1],
-                        blockDim[2]);
+                        std::nullopt, std::nullopt, blockDims[0], blockDims[1],
+                        blockDims[2]);
 
   results.push_back(gpuLaunch.getOperation());
   return diag;
index 907a8c1..e646de9 100644 (file)
@@ -222,4 +222,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
   return getValuesSortedByKeyImpl(keys, values, compare);
 }
 
+SmallVector<int64_t>
+getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
+                     llvm::function_ref<bool(Attribute, Attribute)> compare) {
+  return getValuesSortedByKeyImpl(keys, values, compare);
+}
+
 } // namespace mlir