[mlir][gpu] Change ParalellLoopMappingAttr to AttrDef
authorMogball <jeffniu22@gmail.com>
Thu, 9 Jun 2022 21:33:41 +0000 (21:33 +0000)
committerMogball <jeffniu22@gmail.com>
Thu, 9 Jun 2022 22:23:21 +0000 (22:23 +0000)
It was a StructAttr. Also adds a FieldParser for AffineMap.

Depends on D127348

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127350

12 files changed:
mlir/include/mlir/Dialect/GPU/GPUBase.td
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h
mlir/include/mlir/Dialect/GPU/ParallelLoopMapperAttr.td
mlir/include/mlir/Dialect/GPU/Passes.td
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/DialectImplementation.h
mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
mlir/test/Dialect/GPU/mapping.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index d19050b..6b307a9 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef GPU_BASE
 #define GPU_BASE
 
+include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/OpBase.td"
 
 //===----------------------------------------------------------------------===//
@@ -117,4 +118,13 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// GPU Attributes.
+//===----------------------------------------------------------------------===//
+
+class GPU_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<GPU_Dialect, attrName, traits> {
+  let mnemonic = attrMnemonic;
+}
+
 #endif // GPU_BASE
index 10f9dbd..e1e818f 100644 (file)
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/Dialect/GPU/GPUBase.td"
+include "mlir/Dialect/GPU/ParallelLoopMapperAttr.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/FunctionInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
index 9ae3683..40798ea 100644 (file)
 #ifndef MLIR_DIALECT_GPU_PARALLELLOOPMAPPER_H
 #define MLIR_DIALECT_GPU_PARALLELLOOPMAPPER_H
 
-#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Support/LLVM.h"
-#include "llvm/ADT/DenseMap.h"
-
-#include "mlir/Dialect/GPU/ParallelLoopMapperEnums.h.inc"
+#include "llvm/ADT/StringRef.h"
 
 namespace mlir {
 
@@ -29,8 +27,6 @@ class Region;
 
 } // namespace mlir
 
-#include "mlir/Dialect/GPU/ParallelLoopMapperAttr.h.inc"
-
 namespace mlir {
 namespace scf {
 class ParallelOp;
@@ -41,24 +37,13 @@ namespace gpu {
 /// Name of the mapping attribute produced by loop mappers.
 StringRef getMappingAttrName();
 
-/// Get the value of the processor in the ParallelLoopDimMapping attribute.
-inline Processor getProcessor(ParallelLoopDimMapping attr) {
-  return static_cast<Processor>(attr.processor().getInt());
-}
-
-/// Helper function to create a ParallelDimMapperAttr.
-/// TODO: Replace its uses with an auto-gened method.
-ParallelLoopDimMapping getParallelLoopDimMappingAttr(Processor processor,
-                                                     AffineMap map,
-                                                     AffineMap bound);
-
 /// Sets the mapping attribute of a scf.parallel operation. Verifies that the
 /// mapping passed is valid.
 /// - the number of DimMapperAttr provided is same as the number of loops of
 ///   the `ploopOp`.
 /// - the mapping does not map multiple loops to the same processor.
 LogicalResult setMappingAttr(scf::ParallelOp ploopOp,
-                             ArrayRef<ParallelLoopDimMapping> mapping);
+                             ArrayRef<ParallelLoopDimMappingAttr> mapping);
 } // namespace gpu
 } // namespace mlir
 #endif // MLIR_DIALECT_GPU_PARALLELLOOPMAPPER_H
index 52ef8b5..9f16365 100644 (file)
 include "mlir/Dialect/GPU/GPUBase.td"
 include "mlir/IR/EnumAttr.td"
 
-def BlockX : I64EnumAttrCase<"BlockX", 0>;
-def BlockY : I64EnumAttrCase<"BlockY", 1>;
-def BlockZ : I64EnumAttrCase<"BlockZ", 2>;
-def ThreadX : I64EnumAttrCase<"ThreadX", 3>;
-def ThreadY : I64EnumAttrCase<"ThreadY", 4>;
-def ThreadZ : I64EnumAttrCase<"ThreadZ", 5>;
-def Sequential : I64EnumAttrCase<"Sequential", 6>;
-
-def ProcessorAttr : I64EnumAttr<"Processor", "processor for loop mapping", [
+def BlockX : I64EnumAttrCase<"BlockX", 0, "block_x">;
+def BlockY : I64EnumAttrCase<"BlockY", 1, "block_y">;
+def BlockZ : I64EnumAttrCase<"BlockZ", 2, "block_z">;
+def ThreadX : I64EnumAttrCase<"ThreadX", 3, "thread_x">;
+def ThreadY : I64EnumAttrCase<"ThreadY", 4, "thread_y">;
+def ThreadZ : I64EnumAttrCase<"ThreadZ", 5, "thread_z">;
+def Sequential : I64EnumAttrCase<"Sequential", 6, "sequential">;
+
+def ProcessorEnum : I64EnumAttr<"Processor", "processor for loop mapping", [
     BlockX, BlockY, BlockZ, ThreadX, ThreadY, ThreadZ, Sequential]> {
   let cppNamespace = "::mlir::gpu";
 }
@@ -37,12 +37,15 @@ def ProcessorAttr : I64EnumAttr<"Processor", "processor for loop mapping", [
 //       substitution.
 // bound : An affine map that is used to compute the bound of the hardware
 //         id based on an upper bound of the number of iterations.
-def ParallelLoopDimMappingAttr :
-    StructAttr<"ParallelLoopDimMapping", GPU_Dialect,
-               [StructFieldAttr<"processor", ProcessorAttr>,
-                StructFieldAttr<"map", AffineMapAttr>,
-                StructFieldAttr<"bound", AffineMapAttr>]>;
-
+def ParallelLoopDimMappingAttr 
+    : GPU_Attr<"ParallelLoopDimMapping", "loop_dim_map"> {
+  let parameters = (ins
+    EnumParameter<ProcessorEnum>:$processor,
+    "AffineMap":$map,
+    "AffineMap":$bound
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
 
 def ParallelLoopMappingAttr :
     TypedArrayAttrBase<ParallelLoopDimMappingAttr,
index f5786e8..a144fa4 100644 (file)
@@ -34,6 +34,7 @@ def GpuMapParallelLoopsPass
   let summary = "Greedily maps loops to GPU hardware dimensions.";
   let constructor = "mlir::createGpuMapParallelLoopsPass()";
   let description = "Greedily maps loops to GPU hardware dimensions.";
+  let dependentDialects = ["mlir::gpu::GPUDialect"];
 }
 
 #endif // MLIR_DIALECT_GPU_PASSES
index 205a062..4e214af 100644 (file)
@@ -517,7 +517,8 @@ public:
   Operation *cloneWithoutRegions(Operation &op) {
     return insert(op.cloneWithoutRegions());
   }
-  template <typename OpT> OpT cloneWithoutRegions(OpT op) {
+  template <typename OpT>
+  OpT cloneWithoutRegions(OpT op) {
     return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
   }
 
index bbe4da6..e9ec9d2 100644 (file)
@@ -127,6 +127,17 @@ struct FieldParser<
   }
 };
 
+/// Parse an affine map.
+template <>
+struct FieldParser<AffineMap> {
+  static FailureOr<AffineMap> parse(AsmParser &parser) {
+    AffineMap map;
+    if (failed(parser.parseAffineMap(map)))
+      return failure();
+    return map;
+  }
+};
+
 } // namespace mlir
 
 #endif // MLIR_IR_DIALECTIMPLEMENTATION_H
index 5dc6c2d..901810e 100644 (file)
@@ -429,12 +429,13 @@ static LogicalResult processParallelLoop(
     Attribute mappingAttribute;
     Value iv, lowerBound, upperBound, step;
     std::tie(mappingAttribute, iv, lowerBound, upperBound, step) = config;
-    auto annotation = mappingAttribute.dyn_cast<gpu::ParallelLoopDimMapping>();
+    auto annotation =
+        mappingAttribute.dyn_cast<gpu::ParallelLoopDimMappingAttr>();
     if (!annotation)
       return parallelOp.emitOpError()
              << "expected mapping attribute for lowering to GPU";
     Value newIndex;
-    gpu::Processor processor = gpu::getProcessor(annotation);
+    gpu::Processor processor = annotation.getProcessor();
 
     if (isMappedToProcessor(processor)) {
       // Use the corresponding thread/grid index as replacement for the loop iv.
@@ -449,11 +450,11 @@ static LogicalResult processParallelLoop(
           rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) +
               rewriter.getAffineSymbolExpr(1));
       newIndex = rewriter.create<AffineApplyOp>(
-          loc, annotation.map().getValue().compose(lowerAndStep),
+          loc, annotation.getMap().compose(lowerAndStep),
           ValueRange{operand, step, lowerBound});
       // If there was also a bound, insert that, too.
       // TODO: Check that we do not assign bounds twice.
-      if (annotation.bound().getValue()) {
+      if (annotation.getBound()) {
         // We pass as the single operand to the bound-map the number of
         // iterations, which is (upperBound - lowerBound) ceilDiv step. To
         // support inner loops with dynamic upper bounds (as generated by e.g.
@@ -493,7 +494,7 @@ static LogicalResult processParallelLoop(
               ((rewriter.getAffineDimExpr(0) - rewriter.getAffineSymbolExpr(0))
                    .ceilDiv(rewriter.getAffineSymbolExpr(1))));
           Value launchBound = rewriter.create<AffineApplyOp>(
-              loc, annotation.bound().getValue().compose(stepMap),
+              loc, annotation.getBound().compose(stepMap),
               ValueRange{
                   ensureLaunchIndependent(
                       cloningMap.lookupOrDefault(upperBound)),
index c7a1ef3..d2e8ed6 100644 (file)
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/AffineMap.h"
-#include "mlir/Pass/Pass.h"
-
-#include "mlir/Dialect/GPU/ParallelLoopMapperAttr.cpp.inc"
-#include "mlir/Dialect/GPU/ParallelLoopMapperEnums.cpp.inc"
 
 namespace mlir {
 
@@ -29,22 +25,13 @@ using scf::ParallelOp;
 
 StringRef gpu::getMappingAttrName() { return "mapping"; }
 
-gpu::ParallelLoopDimMapping
-gpu::getParallelLoopDimMappingAttr(Processor processor, AffineMap map,
-                                   AffineMap bound) {
-  MLIRContext *context = map.getContext();
-  OpBuilder builder(context);
-  return ParallelLoopDimMapping::get(
-      ProcessorAttr::get(builder.getContext(), processor),
-      AffineMapAttr::get(map), AffineMapAttr::get(bound), context);
-}
-
-LogicalResult gpu::setMappingAttr(ParallelOp ploopOp,
-                                  ArrayRef<ParallelLoopDimMapping> mapping) {
+LogicalResult
+gpu::setMappingAttr(ParallelOp ploopOp,
+                    ArrayRef<ParallelLoopDimMappingAttr> mapping) {
   // Verify that each processor is mapped to only once.
   llvm::DenseSet<gpu::Processor> specifiedMappings;
   for (auto dimAttr : mapping) {
-    gpu::Processor processor = getProcessor(dimAttr);
+    gpu::Processor processor = dimAttr.getProcessor();
     if (processor != gpu::Processor::Sequential &&
         specifiedMappings.count(processor))
       return ploopOp.emitError(
@@ -123,10 +110,10 @@ static void mapParallelOp(ParallelOp parallelOp,
 
   MLIRContext *ctx = parallelOp.getContext();
   Builder b(ctx);
-  SmallVector<ParallelLoopDimMapping, 4> attrs;
+  SmallVector<ParallelLoopDimMappingAttr, 4> attrs;
   attrs.reserve(parallelOp.getNumLoops());
   for (int i = 0, e = parallelOp.getNumLoops(); i < e; ++i) {
-    attrs.push_back(getParallelLoopDimMappingAttr(
+    attrs.push_back(b.getAttr<ParallelLoopDimMappingAttr>(
         getHardwareIdForMapping(mappingLevel, i), b.getDimIdentityMap(),
         b.getDimIdentityMap()));
   }
index 9beb6ed..e6966a8 100644 (file)
@@ -11,7 +11,7 @@ func.func @parallel_loop_bidy_bidx(%arg0 : index, %arg1 : index, %arg2 : index,
                                           step (%arg4, %step)  {
     %val = memref.load %buf[%i0, %i1] : memref<?x?xf32>
     memref.store %val, %res[%i1, %i0] : memref<?x?xf32>
-  } { mapping = [{processor = 1, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}, {processor = 0, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}] }
+  } { mapping = [#gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>, #gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>] }
   return
 }
 
@@ -56,12 +56,12 @@ func.func @parallel_loop_tiled(%arg0 : index, %arg1 : index, %arg2 : index,
       %val = memref.load %buf[%idx0, %idx1] : memref<?x?xf32>
       memref.store %val, %res[%idx1, %idx0] : memref<?x?xf32>
     } { mapping = [
-        {processor = 4, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>},
-        {processor = 3, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}
+        #gpu.loop_dim_map<processor = thread_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+        #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>
      ] }
   } { mapping = [
-      {processor = 1, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>},
-      {processor = 0, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}
+      #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+      #gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>
     ] }
   return
 }
@@ -109,8 +109,8 @@ func.func @parallel_loop_bidy_seq(%arg0 : index, %arg1 : index, %arg2 : index,
     %val = memref.load %buf[%i0, %i1] : memref<?x?xf32>
     memref.store %val, %res[%i1, %i0] : memref<?x?xf32>
   } { mapping = [
-      {processor = 1, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>},
-      {processor = 6, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}
+      #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+      #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>
     ] }
   return
 }
@@ -156,12 +156,12 @@ func.func @parallel_loop_tiled_seq(%arg0 : index, %arg1 : index, %arg2 : index,
       %val = memref.load %buf[%idx0, %idx1] : memref<?x?xf32>
       memref.store %val, %res[%idx1, %idx0] : memref<?x?xf32>
     } { mapping = [
-        {processor = 4, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>},
-        {processor = 6, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}
+        #gpu.loop_dim_map<processor = thread_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+        #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>
       ] }
   } { mapping = [
-      {processor = 1, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>},
-      {processor = 6, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}
+      #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+      #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>
     ] }
   return
 }
@@ -234,9 +234,9 @@ module {
         %20 = arith.addf %17, %18 : f32
         memref.store %20, %16[%arg5, %arg6] : memref<?x?xf32, #map3>
         scf.yield
-      } {mapping = [{bound = affine_map<(d0) -> (d0)>, map = affine_map<(d0) -> (d0)>, processor = 3 : i64}, {bound = affine_map<(d0) -> (d0)>, map = affine_map<(d0) -> (d0)>, processor = 4 : i64}]}
+      } {mapping = [#gpu.loop_dim_map<bound = (d0) -> (d0), map = (d0) -> (d0), processor = thread_x>, #gpu.loop_dim_map<bound = (d0) -> (d0), map = (d0) -> (d0), processor = thread_y>]}
       scf.yield
-    } {mapping = [{bound = affine_map<(d0) -> (d0)>, map = affine_map<(d0) -> (d0)>, processor = 0 : i64}, {bound = affine_map<(d0) -> (d0)>, map = affine_map<(d0) -> (d0)>, processor = 1 : i64}]}
+    } {mapping = [#gpu.loop_dim_map<bound = (d0) -> (d0), map = (d0) -> (d0), processor = block_x>, #gpu.loop_dim_map<bound = (d0) -> (d0), map = (d0) -> (d0), processor = block_y>]}
     return
   }
 }
@@ -310,7 +310,7 @@ func.func @parallel_loop_optional_attr() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   scf.parallel (%i0) = (%c0) to (%c1) step (%c1) {
-  } { mapping = [{processor = 0, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}], optional_attr = 1 }
+  } { mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>], optional_attr = 1 }
   // CHECK: optional_attr = 1
   return
 }
@@ -327,8 +327,8 @@ func.func @parallel_double_map(%arg0 : index, %arg1 : index, %arg2 : index,
   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                                           step (%four, %four)  {
   } { mapping = [
-      {processor = 1, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>},
-      {processor = 1, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}
+      #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+      #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>
     ] }
   return
 }
@@ -356,12 +356,12 @@ func.func @parallel_loop_loop_variant_bound(%arg0 : index, %arg1 : index, %arg2
       %val = memref.load %buf[%idx0, %idx1] : memref<?x?xf32>
       memref.store %val, %res[%idx1, %idx0] : memref<?x?xf32>
     } { mapping = [
-        {processor = 4, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>},
-        {processor = 6, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}
+        #gpu.loop_dim_map<processor = thread_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+        #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>
       ] }
   } { mapping = [
-      {processor = 1, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>},
-      {processor = 6, map = affine_map<(d0) -> (d0)>, bound = affine_map<(d0) -> (d0)>}
+      #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+      #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>
     ] }
   return
 }
index 8c23364..3959873 100644 (file)
@@ -14,14 +14,13 @@ func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
   return
 }
 
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL:   func @parallel_loop(
 // CHECK:           scf.parallel
 // CHECK:             scf.parallel
-// CHECK:      {mapping = [{bound = #[[$MAP]], map = #[[$MAP]], processor = 3 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 4 : i64}]}
-// CHECK:      {mapping = [{bound = #[[$MAP]], map = #[[$MAP]], processor = 0 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 1 : i64}]}
+// CHECK:      {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = thread_y, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
+// CHECK:      {mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
 // CHECK-NOT: mapping
 
 // -----
@@ -43,21 +42,20 @@ func.func @parallel_loop_4d(%arg0 : index, %arg1 : index, %arg2 : index,
   return
 }
 
-// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL:   func @parallel_loop_4d(
 // CHECK:           scf.parallel
 // CHECK:             scf.parallel
 // CHECK:               scf.parallel
-// CHECK:      {mapping = [{bound = #[[$MAP]], map = #[[$MAP]], processor = 6 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 6 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 6 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 6 : i64}]}
-// CHECK:      {mapping = [{bound = #[[$MAP]], map = #[[$MAP]], processor = 3 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 4 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 5 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 6 : i64}]}
-// CHECK:      {mapping = [{bound = #[[$MAP]], map = #[[$MAP]], processor = 0 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 1 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 2 : i64},
-// CHECK-SAME:             {bound = #[[$MAP]], map = #[[$MAP]], processor = 6 : i64}]}
+// CHECK:      {mapping = [#gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
+// CHECK:      {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = thread_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = thread_z, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
+// CHECK:      {mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = block_z, map = (d0) -> (d0), bound = (d0) -> (d0)>,
+// CHECK-SAME:             #gpu.loop_dim_map<processor = sequential, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
 // CHECK-NOT: mapping
index a0d7c8b..6c36e20 100644 (file)
@@ -3453,6 +3453,7 @@ td_library(
     srcs = [
         "include/mlir/Dialect/GPU/GPUBase.td",
         "include/mlir/Dialect/GPU/GPUOps.td",
+        "include/mlir/Dialect/GPU/ParallelLoopMapperAttr.td",
     ],
     includes = ["include"],
     deps = [
@@ -3466,35 +3467,6 @@ td_library(
 )
 
 gentbl_cc_library(
-    name = "ParallelLoopMapperAttrGen",
-    strip_include_prefix = "include",
-    tbl_outs = [
-        (
-            ["-gen-struct-attr-decls"],
-            "include/mlir/Dialect/GPU/ParallelLoopMapperAttr.h.inc",
-        ),
-        (
-            ["-gen-struct-attr-defs"],
-            "include/mlir/Dialect/GPU/ParallelLoopMapperAttr.cpp.inc",
-        ),
-        (
-            ["-gen-enum-decls"],
-            "include/mlir/Dialect/GPU/ParallelLoopMapperEnums.h.inc",
-        ),
-        (
-            ["-gen-enum-defs"],
-            "include/mlir/Dialect/GPU/ParallelLoopMapperEnums.cpp.inc",
-        ),
-    ],
-    tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/GPU/ParallelLoopMapperAttr.td",
-    deps = [
-        ":AttrTdFiles",
-        ":GPUOpsTdFiles",
-    ],
-)
-
-gentbl_cc_library(
     name = "GPUBaseIncGen",
     strip_include_prefix = "include",
     tbl_outs = [
@@ -3571,7 +3543,9 @@ cc_library(
             "lib/Dialect/GPU/IR/*.h",
         ],
     ),
-    hdrs = ["include/mlir/Dialect/GPU/GPUDialect.h"],
+    hdrs = [
+        "include/mlir/Dialect/GPU/GPUDialect.h",
+    ],
     includes = ["include"],
     deps = [
         ":ArithmeticDialect",
@@ -3644,7 +3618,6 @@ cc_library(
         ":GPUPassIncGen",
         ":MemRefDialect",
         ":IR",
-        ":ParallelLoopMapperAttrGen",
         ":Parser",
         ":Pass",
         ":ROCDLToLLVMIRTranslation",
@@ -5068,7 +5041,6 @@ cc_library(
         ":FuncDialect",
         ":IR",
         ":MemRefDialect",
-        ":ParallelLoopMapperAttrGen",
         ":Pass",
         ":SCFDialect",
         ":TensorDialect",