[mlir][Transform]Significantly cleanup scf.foreach_thread and GPU transform permutati...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Sun, 13 Nov 2022 13:28:32 +0000 (05:28 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 14 Nov 2022 17:19:49 +0000 (09:19 -0800)
Previously, the need for a dense permutation leaked into the thread_dim_mapping specification.
This revision allows to use a sparse specification of the thread_dim_mapping and the proper completion / sorting is applied automatically.

In the process, the sematics of scf.foreach_thread is tightened to require a matching number of thread dimensions and mappings.
The relevant negative test is added.

Differential Revision: https://reviews.llvm.org/D137906

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/GPU/transform-gpu-failing.mlir
mlir/test/Dialect/GPU/transform-gpu.mlir
mlir/test/Dialect/SCF/invalid.mlir

index 3fa890b..67a8e43 100644 (file)
@@ -536,15 +536,14 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
       return getBody()->getArguments().drop_front(getRank());
     }
 
-    /// Return the thread indices in the order specified by the
-    /// given mapping argument. Return failure is
-    /// mapping is not a valid permutation.
-    FailureOr<SmallVector<Value>> getPermutedThreadIndices(ArrayRef<int64_t> mapping);
-
-    /// Return the number of threads in the order specified by the
-    /// given mapping argument.
-    /// Return failure is mapping is not a valid permutation.
-    FailureOr<SmallVector<OpFoldResult>> getPermutedNumThreads(OpBuilder &b, ArrayRef<int64_t> mapping);
+    /// Helper to sort `values` according to matching `keys`.
+    /// Take a custom `compare` binary comparator which returns true if the first
+    /// element is smaller than the second (i.e. compatible with std::sort).
+    /// This is a helper typically used to sort numThreads values before they are
+    /// mapped to concrete physical dimensions of hardware.
+    static SmallVector<Value> getValuesSortedByKey(
+      ArrayRef<Attribute> keys, ValueRange values,
+      llvm::function_ref<bool(Attribute, Attribute)> compare);
 
     // The ensureTerminator method generated by SingleBlockImplicitTerminator is
     // unaware of the fact that our terminator also needs a region to be
index 460eb2f..5fb90ac 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/None.h"
@@ -157,45 +158,75 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
                       SmallVectorImpl<Value> &)>
         blockIdGenerator,
     SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp) {
+  // Step 0. Target-specific verifications. There is no good place to anchor
+  // those right now: the ForeachThreadOp is target-independent and the
+  // transform op does not apply to individual ForeachThreadOp.
+  MLIRContext *ctx = foreachThreadOp->getContext();
+  Location loc = foreachThreadOp->getLoc();
+  Attribute bX = GPUBlockMappingAttr::get(ctx, Blocks::DimX);
+  Attribute bY = GPUBlockMappingAttr::get(ctx, Blocks::DimY);
+  Attribute bZ = GPUBlockMappingAttr::get(ctx, Blocks::DimZ);
   if (foreachThreadOp.getNumResults() > 0)
     return transformOp.emitSilenceableError()
-           << "only bufferized scf.foreach_thread lowers to gpu.block_id";
+           << "only bufferized scf.foreach_thread lowers to "
+              "gpu.block_id";
   if (foreachThreadOp.getNumThreads().size() > 3)
     return transformOp.emitSilenceableError()
-           << "scf.foreach_thread with rank > 3 does not lower to gpu.block_id";
-
-  // Step 0. Outline the compute workload region and set up the workload
-  // operands.
-  SmallVector<int64_t> mapping;
+           << "scf.foreach_thread with rank > 3 does not lower to "
+              "gpu.block_id";
+  if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
+        return !v.getDefiningOp<arith::ConstantIndexOp>();
+      })) {
+    return transformOp.emitSilenceableError()
+           << "unsupported dynamic griddim size";
+  }
   if (!foreachThreadOp.getMapping().has_value())
     return transformOp.emitSilenceableError() << "mapping must be present";
-  for (DeviceMappingAttrInterface map :
-       foreachThreadOp.getMapping()->getValue()) {
-    if (auto blockMap = map.dyn_cast<GPUBlockMappingAttr>()) {
-      mapping.push_back((int64_t)blockMap.getBlock());
-    } else {
-      return transformOp.emitSilenceableError()
-             << "mapping must be #gpu.block<x/y/z/>";
-    }
+  SmallVector<Attribute> blockMapping =
+      llvm::to_vector(foreachThreadOp.getMapping()->getValue());
+  if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) {
+        return !map.isa<GPUBlockMappingAttr>();
+      })) {
+    return transformOp.emitSilenceableError()
+           << "mapping must be #gpu.block<x/y/z/>";
   }
 
-  FailureOr<SmallVector<OpFoldResult>> potentialGridDim =
-      foreachThreadOp.getPermutedNumThreads(rewriter, mapping);
-
-  if (failed(potentialGridDim) ||
-      llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) {
-        return !getConstantIntValue(ofr).has_value();
-      })) {
-    return transformOp.emitSilenceableError() << "unsupported dynamic gridDim";
+  // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
+  SmallVector<Value> numBlocks =
+      llvm::to_vector(foreachThreadOp.getNumThreads());
+  // Ensure we have 3 block sizes, one for each id.
+  Value one;
+  for (auto attr : {bX, bY, bZ}) {
+    if (std::find(blockMapping.begin(), blockMapping.end(), attr) ==
+        blockMapping.end()) {
+      blockMapping.push_back(attr);
+      one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
+      numBlocks.push_back(one);
+    }
   }
 
-  for (OpFoldResult ofr : *potentialGridDim)
-    gridDims.push_back(getConstantIntValue(ofr).value());
+  // Step 2. sort the values by the corresponding GPUBlockMappingAttr.
+  auto comparator = [](Attribute a, Attribute b) -> bool {
+    return static_cast<int64_t>(a.cast<GPUBlockMappingAttr>().getBlock()) <
+           static_cast<int64_t>(b.cast<GPUBlockMappingAttr>().getBlock());
+  };
+  SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey(
+      blockMapping, numBlocks, comparator);
+  for (Value v : gridDimValues)
+    gridDims.push_back(v.getDefiningOp<arith::ConstantIndexOp>().value());
 
+  // Step 3. Generate the blockIds using the provided generator and map the
+  // induction variables to the newly created ops.
   SmallVector<Value> blockOps;
   blockIdGenerator(rewriter, foreachThreadOp, blockOps);
+  BlockAndValueMapping bvm;
+  for (auto [blockIdx, blockDim] :
+       llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
+    bvm.map(blockIdx, blockOps[static_cast<int64_t>(
+                          blockDim.cast<GPUBlockMappingAttr>().getBlock())]);
+  }
 
-  // Step 1. Move the body of foreachThreadOp.
+  // Step 4. Move the body of foreachThreadOp.
   // Erase the terminator first, it will not be used since we are on buffers.
   rewriter.eraseOp(foreachThreadOp.getTerminator());
   Block *targetBlock = foreachThreadOp->getBlock();
@@ -204,20 +235,16 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
   targetBlock->getOperations().splice(insertionPoint,
                                       sourceBlock.getOperations());
 
-  // Step 2. RAUW thread indices to thread ops.
-  SmallVector<Value> threadIndices =
-      *foreachThreadOp.getPermutedThreadIndices(mapping);
-  assert(blockOps.size() == 3 && "3 block id ops are required");
-  for (auto [blockIdx, blockOp] : llvm::zip(threadIndices, blockOps)) {
-    Value val = blockIdx;
-    Value blkOp = blockOp;
-    if (!val)
-      continue;
-    for (Operation *user : llvm::make_early_inc_range(val.getUsers()))
-      user->replaceUsesOfWith(val, blkOp);
+  // Step 5. RAUW thread indices to thread ops.
+  for (Value blockIdx : foreachThreadOp.getThreadIndices()) {
+    for (Operation *user : llvm::make_early_inc_range(blockIdx.getUsers())) {
+      rewriter.updateRootInPlace(user, [&]() {
+        user->replaceUsesOfWith(blockIdx, bvm.lookup(blockIdx));
+      });
+    }
   }
 
-  // Step 3. Erase old op.
+  // Step 6. Erase old op.
   rewriter.eraseOp(foreachThreadOp);
 
   return DiagnosedSilenceableFailure::success();
@@ -252,11 +279,10 @@ static void generateGpuBlockIds(RewriterBase &rewriter,
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(foreachOp);
   IndexType indexType = rewriter.getIndexType();
-  SmallVector<Dimension> gpuDims{Dimension::x, Dimension::y, Dimension::z};
-  for (int64_t idx : llvm::seq<int64_t>(0, gpuDims.size())) {
-    blockOps.push_back(
-        rewriter.create<BlockIdOp>(loc, indexType, gpuDims[idx]));
-  }
+  blockOps = SmallVector<Value>{
+      rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
+      rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
+      rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)};
 }
 
 DiagnosedSilenceableFailure
@@ -333,6 +359,9 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
     const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
     llvm::Optional<TransformOpInterface> transformOp) {
+  // Step 0. Target-specific verifications. There is no good place to anchor
+  // those right now: the ForeachThreadOp is target-independent and the
+  // transform op does not apply to individual ForeachThreadOp.
   auto failureHelper =
       [&](const Twine &message) -> DiagnosedSilenceableFailure {
     if (transformOp.has_value()) {
@@ -340,54 +369,79 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     }
     return emitDefiniteFailure(foreachThreadOp, message);
   };
-
+  MLIRContext *ctx = foreachThreadOp->getContext();
+  Location loc = foreachThreadOp->getLoc();
+  Attribute tX = GPUThreadMappingAttr::get(ctx, Threads::DimX);
+  Attribute tY = GPUThreadMappingAttr::get(ctx, Threads::DimY);
+  Attribute tZ = GPUThreadMappingAttr::get(ctx, Threads::DimZ);
   if (foreachThreadOp.getNumResults() > 0)
     return failureHelper(
         "only bufferized scf.foreach_thread lowers to gpu.thread_id");
-
   if (foreachThreadOp.getNumThreads().size() > 3)
     return failureHelper(
         "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id");
-
-  SmallVector<int64_t> mapping;
+  if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
+        return !v.getDefiningOp<arith::ConstantIndexOp>();
+      })) {
+    return failureHelper("unsupported dynamic blockdim size");
+  }
   if (!foreachThreadOp.getMapping().has_value())
     return failureHelper("mapping must be present");
-  for (DeviceMappingAttrInterface map :
-       foreachThreadOp.getMapping()->getValue()) {
-    if (auto threadMap = map.dyn_cast<GPUThreadMappingAttr>()) {
-      mapping.push_back((int64_t)threadMap.getThread());
-    } else {
-      return failureHelper("mapping must be #gpu.thread<x/y/z/>");
-    }
-  }
-  FailureOr<SmallVector<OpFoldResult>> potentialBlockDim =
-      foreachThreadOp.getPermutedNumThreads(rewriter, mapping);
-  if (failed(potentialBlockDim) ||
-      llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) {
-        return !getConstantIntValue(ofr).has_value();
+  SmallVector<Attribute> threadMapping =
+      llvm::to_vector(foreachThreadOp.getMapping()->getValue());
+  if (llvm::any_of(threadMapping, [](DeviceMappingAttrInterface map) {
+        return !map.isa<GPUThreadMappingAttr>();
       })) {
-    return failureHelper("unsupported dynamic blockdim size");
+    return transformOp->emitSilenceableError()
+           << "mapping must be #gpu.thread<x/y/z/>";
   }
 
-  SmallVector<int64_t> blockDim =
-      llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) {
-        return getConstantIntValue(ofr).value();
+  // Step 1. Complete the threadMapping to a full mapping (with 1s) if
+  // necessary.
+  SmallVector<Value> numThreads =
+      llvm::to_vector(foreachThreadOp.getNumThreads());
+  // Ensure we have 3 block sizes, one for each id.
+  Value one;
+  for (auto attr : {tX, tY, tZ}) {
+    if (std::find(threadMapping.begin(), threadMapping.end(), attr) ==
+        threadMapping.end()) {
+      threadMapping.push_back(attr);
+      one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
+      numThreads.push_back(one);
+    }
+  }
+
+  // Step 2. sort the values by the corresponding GPUThreadMappingAttr.
+  auto comparator = [](Attribute a, Attribute b) -> bool {
+    return static_cast<int64_t>(a.cast<GPUThreadMappingAttr>().getThread()) <
+           static_cast<int64_t>(b.cast<GPUThreadMappingAttr>().getThread());
+  };
+  SmallVector<Value> blockDimValues =
+      scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads,
+                                                 comparator);
+  SmallVector<int64_t> blockDims =
+      llvm::to_vector(llvm::map_range(blockDimValues, [](Value v) {
+        return v.getDefiningOp<arith::ConstantIndexOp>().value();
       }));
 
-  // Step 1. Create the gpu.thread ops
-  Location loc = foreachThreadOp.getLoc();
+  // Step 3. Create the gpu.thread ops and map the induction variables to the
+  // newly created ops.
   IndexType indexType = rewriter.getIndexType();
-
-  SmallVector<Dimension> gpuDims{Dimension::x, Dimension::y, Dimension::z};
-  SmallVector<Value> threadOps;
-  for (int64_t idx : llvm::seq<int64_t>(0, blockDim.size())) {
-    threadOps.push_back(
-        rewriter.create<ThreadIdOp>(loc, indexType, gpuDims[idx]));
+  SmallVector<Value> threadOps{
+      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
+      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
+      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
+  BlockAndValueMapping bvm;
+  for (auto [blockIdx, blockDim] :
+       llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
+    bvm.map(blockIdx, threadOps[static_cast<int64_t>(
+                          blockDim.cast<GPUThreadMappingAttr>().getThread())]);
   }
-  // Step 2. Maybe create conditionals to predicate the region.
+
+  // Step 4. Maybe create conditionals to predicate the region.
   Value predicate;
   for (auto [threadId, blockDim, globalBlockDim] :
-       llvm::zip(threadOps, blockDim, globalBlockDims)) {
+       llvm::zip(threadOps, blockDims, globalBlockDims)) {
     if (blockDim > globalBlockDim) {
       return failureHelper(
           "The requested GPU threads are fewer than the number of loop trip "
@@ -404,19 +458,19 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
                   : tmpPredicate;
   }
 
-  // Step 3. Move the body of foreachThreadOp.
+  // Step 5. Move the body of foreachThreadOp.
   // Erase the terminator first, it will not be used.
   rewriter.eraseOp(foreachThreadOp.getTerminator());
   Block *targetBlock;
   Block::iterator insertionPoint;
   if (predicate) {
-    // Step 3.a. If predicated, move at the beginning.
+    // Step 5.a. If predicated, move at the beginning.
     auto ifOp =
         rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
     targetBlock = ifOp.thenBlock();
     insertionPoint = ifOp.thenBlock()->begin();
   } else {
-    // Step 3.a. Otherwise, move inline just before foreachThreadOp.
+    // Step 5.b. Otherwise, move inline just before foreachThreadOp.
     targetBlock = foreachThreadOp->getBlock();
     insertionPoint = Block::iterator(foreachThreadOp);
   }
@@ -424,25 +478,21 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
   targetBlock->getOperations().splice(insertionPoint,
                                       sourceBlock.getOperations());
 
-  // Step 4. RAUW thread indices to thread ops.
-  SmallVector<Value> threadIndices =
-      *foreachThreadOp.getPermutedThreadIndices(mapping);
-  for (auto [threadIdx, threadOp] : llvm::zip(threadIndices, threadOps)) {
-    Value val = threadIdx;
-    Value op = threadOp;
-    if (!val)
-      continue;
-    for (Operation *user : llvm::make_early_inc_range(val.getUsers())) {
-      user->replaceUsesOfWith(val, op);
+  // Step 6. RAUW thread indices to thread ops.
+  for (Value threadIdx : foreachThreadOp.getThreadIndices()) {
+    for (Operation *user : llvm::make_early_inc_range(threadIdx.getUsers())) {
+      rewriter.updateRootInPlace(user, [&]() {
+        user->replaceUsesOfWith(threadIdx, bvm.lookup(threadIdx));
+      });
     }
   }
 
-  // Step 5. syncthreads.
+  // Step 7. syncthreads.
   // TODO: Need warpsync
   if (syncAfterDistribute)
     rewriter.create<BarrierOp>(loc);
 
-  // Step 6. Erase old op.
+  // Step 8. Erase old op.
   rewriter.eraseOp(foreachThreadOp);
 
   return DiagnosedSilenceableFailure::success();
index 6a85fb6..b39edc0 100644 (file)
@@ -1114,12 +1114,15 @@ LogicalResult ForeachThreadOp::verify() {
     if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType())
       return emitOpError("type mismatch between ")
              << i << "-th output and corresponding block argument";
-   if (getMapping().has_value())
+  if (getMapping().has_value() && !getMapping()->empty()) {
+    if (static_cast<int64_t>(getMapping()->size()) != getRank())
+      return emitOpError() << "mapping attribute size must match op rank";
     for (auto map : getMapping()->getValue()) {
       if (!isa<DeviceMappingAttrInterface>(map))
         return emitOpError()
                << getMappingAttrName() << " is not device mapping attribute";
     }
+  }
 
   return success();
 }
@@ -1294,59 +1297,21 @@ PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
   return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
 }
 
-template <typename T>
-static FailureOr<SmallVector<T>> permute(const SmallVector<T> &vals,
-                                         ArrayRef<int64_t> perm) {
-  if (vals.size() != perm.size())
-    return failure();
-  SmallVector<T> result(vals.size());
-  SmallVector<bool> seen(vals.size());
-  for (auto [idx, val] : llvm::zip(perm, vals)) {
-    // Already seen, invalid mapping.
-    if (seen[idx])
-      return failure();
-    result[idx] = val;
-    seen[idx] = true;
-  }
-  // Some not seen, invalid mapping.
-  if (!llvm::all_of(seen, [](bool b) { return b; }))
-    return failure();
-  return result;
-}
-
-/// Helper to get apply the `mapping` permutation of a
-/// `foreachThreadOp` to `values`.
-template <typename T>
-static FailureOr<SmallVector<T>>
-getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp,
-                                 const SmallVector<T> &values,
-                                 ArrayRef<int64_t> mapping) {
-  // Apply mapping permutation if specified.
-  FailureOr<SmallVector<T>> maybePermuted = permute(values, mapping);
-  if (failed(maybePermuted))
-    return foreachThreadOp->emitError("invalid permutation");
-  return *maybePermuted;
-  return values;
-}
-
-/// Return the thread indices in the order specified by the mapping
-/// attribute. Return failure is mapping is not a valid permutation.
-FailureOr<SmallVector<Value>>
-ForeachThreadOp::getPermutedThreadIndices(ArrayRef<int64_t> mapping) {
-  SmallVector<Value> threadCountValues = this->getThreadIndices();
-  threadCountValues.resize(3, Value());
-  return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping);
-}
-
-/// Return the number of threads in the order specified by the
-/// mapping attribute.
-/// Return failure is mapping is not a valid permutation.
-FailureOr<SmallVector<OpFoldResult>>
-ForeachThreadOp::getPermutedNumThreads(OpBuilder &b,
-                                       ArrayRef<int64_t> mapping) {
-  SmallVector<OpFoldResult> threadCountValues = this->getNumThreads();
-  threadCountValues.resize(3, b.getIndexAttr(1));
-  return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping);
+/// Helper to sort `values` according to matching `keys`.
+SmallVector<Value> ForeachThreadOp::getValuesSortedByKey(
+    ArrayRef<Attribute> keys, ValueRange values,
+    llvm::function_ref<bool(Attribute, Attribute)> compare) {
+  if (keys.empty())
+    return values;
+  assert(keys.size() == values.size() && "unexpected mismatching sizes");
+  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
+  std::sort(indices.begin(), indices.end(),
+            [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
+  SmallVector<Value> res;
+  res.reserve(values.size());
+  for (int64_t i = 0, e = indices.size(); i < e; ++i)
+    res.push_back(values[indices[i]]);
+  return res;
 }
 
 ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
index de50c59..c45d9c0 100644 (file)
@@ -67,7 +67,7 @@ func.func @map_nested_foreach_to_threads_fewer_threads(%x: memref<2 x 32 x f32>,
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
     gpu.terminator
   }
 
@@ -79,7 +79,7 @@ func.func @map_nested_foreach_to_threads_fewer_threads(%x: memref<2 x 32 x f32>,
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
     gpu.terminator
   }
 
@@ -106,7 +106,7 @@ func.func @map_nested_foreach_to_threads_dynamic_trip_count(%x: memref<2 x 32 x
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
     gpu.terminator
   }
   return %y : memref<2 x 32 x f32>
@@ -131,7 +131,7 @@ func.func @map_nested_foreach_to_threads_4d_loop(%x: memref<2x32x32x32xf32>, %y:
     scf.foreach_thread (%i, %j, %k, %l) in (%c2, %c32,%c32,%c32) {
         %4 = memref.load %x[%i, %j, %k, %l] : memref<2x32x32x32xf32>        
         memref.store %4, %y[%i, %j, %k, %l] : memref<2x32x32x32xf32>
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>, #gpu.thread<z>] }
     gpu.terminator
   }
   return %y : memref<2x32x32x32xf32>
@@ -197,14 +197,14 @@ func.func @map_foreach_to_blocks_not_unique(%x: memref<2 x 32 x f32>, %y: memref
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
 
     scf.foreach_thread (%i, %j) in (%c7, %c9) {
         %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32>
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
     gpu.terminator
   }
 
@@ -232,14 +232,14 @@ func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref
       %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
       %6 = math.fma %alpha, %4, %5 : f32
       memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-  }  { mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>] }
+  }  { mapping = [#gpu.thread<x>, #gpu.thread<y>] }
 
   scf.foreach_thread (%i, %j) in (%c7, %c9) {
       %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32>
       %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
       %6 = math.fma %alpha, %4, %5 : f32
       memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-  }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+  }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
   
   return %y : memref<2 x 32 x f32>
 }
@@ -261,7 +261,7 @@ func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref
       %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
       %6 = math.fma %alpha, %4, %5 : f32
       memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-  }  { mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>] }
+  }  { mapping = [#gpu.block<x>, #gpu.block<y>] }
   return %y : memref<2 x 32 x f32>
 }
 
index eb7208b..a5e1303 100644 (file)
@@ -24,7 +24,7 @@ func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
         %5 = memref.load %y[%i, %j] : !type
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : !type
-     }  { mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>]}
+     }  { mapping = [#gpu.block<x>, #gpu.block<y>]}
     gpu.terminator
   }
   return %y : !type
@@ -73,12 +73,12 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
         %5 = memref.load %y[%i, %j] : !type
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : !type
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
      scf.foreach_thread (%i) in (%c12) {
         %7 = memref.load %t[%i] : !type1d
         %8 = arith.addf %alpha, %7 : f32
         memref.store %8, %t[%i] : !type1d
-     }  {mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>] }
+     }  {mapping = [#gpu.thread<x>] }
     gpu.terminator
   }
   return %y : !type
@@ -118,8 +118,8 @@ func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d {
       %5 = memref.load %y[%i, %j, %k, %l] : !type4d
       %6 = math.fma %alpha, %4, %5 : f32
       memref.store %6, %y[%i, %j, %k, %l] : !type4d
-    }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
-  }  { mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>] }
+    }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
+  }  { mapping = [#gpu.block<x>, #gpu.block<y>] }
   return %y : !type4d
 }
 
@@ -151,7 +151,7 @@ func.func @saxpy2d_no_barrier(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %
         %5 = memref.load %y[%i, %j] : !type
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : !type
-     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
     gpu.terminator
   }
   return %y : !type
index fa91ba0..5fbb0f7 100644 (file)
@@ -575,6 +575,21 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
 
 // -----
 
+func.func @mismatched_mapping(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
+  %one = arith.constant 1 : index
+  %c65535 = arith.constant 65535 : index
+  // expected-error @below {{'scf.foreach_thread' op mapping attribute size must match op rank}}
+  scf.foreach_thread (%i, %j) in (%c65535, %c65535) {
+      %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32>
+      %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
+      %6 = math.fma %alpha, %4, %5 : f32
+      memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
+  }  { mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>] }
+  return %y : memref<2 x 32 x f32>
+}
+
+// -----
+
 func.func @switch_wrong_case_count(%arg0: index) {
   // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}}
   "scf.index_switch"(%arg0) ({