[mlir] Introduce device mapper attribute for `thread_dim_map` and `mapped to dims`
authorGuray Ozen <guray.ozen@gmail.com>
Thu, 10 Nov 2022 16:55:49 +0000 (17:55 +0100)
committerGuray Ozen <guray.ozen@gmail.com>
Fri, 11 Nov 2022 07:44:57 +0000 (08:44 +0100)
`scf.foreach_thread` defines mapping its loops to processors via an integer array, see an example below. A lowering can use this mapping. However, expressing mapping as an integer array is very confusing, especially when there are multiple levels of parallelism. In addition, the op does not verify the integer array. This change introduces device mapping attribute to make mapping descriptive and verifiable. Then it makes GPU transform dialect use it.

```
scf.foreach_thread (%i, %j) in (%c1, %c2) {
scf.foreach_thread (%i2, %j2) in (%c1, %c2)
{...} { thread_dim_mapping = [0, 1]}
} { thread_dim_mapping = [0, 1]}
```

It first introduces a `DeviceMappingInterface` which is an attribute interface. `scf.foreach_thread` defines its mapping via this interface. A lowering must define its attributes and implement this interface as well. This way gives us a clear validation.

The change also introduces two new attributes (`#gpu.thread<x/y/z>` and `#gpu.block<x,y,z>` ). After this change, the above code prints as below, as seen here, this way clarifies the loop mappings. The change also implements consuming of these two new attribute by the transform dialect. Transform dialect binds the outermost loops to the thread blocks and innermost loops to threads.

```
scf.foreach_thread (%i, %j) in (%c1, %c2) {
scf.foreach_thread (%i2, %j2) in (%c1, %c2)
{...} { thread_dim_mapping = [#gpu.thread<x>, #gpu.thread<y>]}
} { thread_dim_mapping = [#gpu.block<x>, #gpu.block<y>]}
```

Reviewed By: ftynse, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D137413

29 files changed:
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt
mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td [new file with mode: 0644]
mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt
mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h [new file with mode: 0644]
mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td [new file with mode: 0644]
mlir/include/mlir/Dialect/SCF/IR/SCF.h
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/SCF/IR/CMakeLists.txt
mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp [new file with mode: 0644]
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/GPU/transform-gpu-failing.mlir
mlir/test/Dialect/GPU/transform-gpu.mlir
mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
mlir/test/Dialect/SCF/ops.mlir
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 63eaa41..4d0208b 100644 (file)
@@ -172,6 +172,8 @@ void addAsyncDependency(Operation *op, Value token);
 
 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc"
 
+#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc"
 
index aef31af..cadc685 100644 (file)
@@ -16,6 +16,7 @@
 include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/Dialect/GPU/IR/GPUBase.td"
 include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td"
+include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/FunctionInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
index c99f3df..0cbcc3f 100644 (file)
@@ -4,3 +4,8 @@ mlir_tablegen(GPUTransformOps.cpp.inc -gen-op-defs)
 add_public_tablegen_target(MLIRGPUTransformOpsIncGen)
 
 add_mlir_doc(GPUTransformOps GPUTransformOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS GPUDeviceMappingAttr.td)
+mlir_tablegen(GPUDeviceMapperEnums.h.inc -gen-enum-decls)
+mlir_tablegen(GPUDeviceMapperEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRGPUDeviceMapperEnumsGen)
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
new file mode 100644 (file)
index 0000000..a93353d
--- /dev/null
@@ -0,0 +1,65 @@
+//===-- GPUDeviceMappingAttr.td - Attribute definition -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the attribute used to map loops to gpu. 
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GPU_DEVICE_MAPPING_ATTR
+#define GPU_DEVICE_MAPPING_ATTR
+
+include "mlir/Dialect/GPU/IR/GPUBase.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
+
+def DimX : I64EnumAttrCase<"DimX", 0, "x">;
+def DimY : I64EnumAttrCase<"DimY", 1, "y">;
+def DimZ : I64EnumAttrCase<"DimZ", 2, "z">;
+
+def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [
+    DimX, DimY, DimZ]> {
+  let cppNamespace = "::mlir::gpu";
+}
+
+def GPUThreadMappingAttr 
+    : GPU_Attr<"GPUThreadMapping", "thread", [ DeviceMappingAttrInterface ]> {
+  let parameters = (ins
+    EnumParameter<ThreadsEnum>:$thread
+  );
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    An attribute that allows defining thread parallelism for GPU devices.
+
+    Thread (aka work item) are grouped into a thread blocks where block may be 
+    described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates 
+    that thread parallelism is desired. It can be consumed by lowering to 
+    generate GPU.
+  }];
+}
+
+def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
+    DimX, DimY, DimZ]> {
+  let cppNamespace = "::mlir::gpu";
+}
+
+def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ DeviceMappingAttrInterface ] >  {
+  let parameters = (ins
+    EnumParameter<BlocksEnum>:$block
+  );
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    An attribute that allows defining thread block parallelism for GPU devices.
+
+    Thread blocks (aka work-group) are grouped into a grid where grid may be 
+    described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates 
+    that thread block parallelism is desired. It can be consumed by lowering to
+    generate GPU code.
+  }];
+}
+
+#endif // GPU_DEVICE_MAPPING_ATTR
index 0dfda8d..c0b348d 100644 (file)
@@ -29,7 +29,7 @@ def MapNestedForeachToThreads :
       The operation searches for `scf.foreach_thread` ops nested under `target` 
       and maps each such op to GPU threads. Mapping is one-to-one and the 
       induction variables of `scf.foreach_thread` are rewritten to 
-      `gpu.thread_id` according to the `thread_dim_mapping` attribute. 
+      `gpu.thread_id` according to the `mapping` attribute. 
       
       Sibling `scf.foreach_thread` are supported in which case, the union of 
       the number of threads is computed and may result in predication.
@@ -73,10 +73,10 @@ def MapNestedForeachToThreads :
                  threads(%tx, %ty, %tz) in (%tx = %3, %ty = %4, %tz = %5) {
         scf.foreach_thread (%i, %j) in (7, 9) {
           ... // body 1
-        } {thread_dim_mapping = [1, 0, 2]}
+        } {mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>]}
         scf.foreach_thread (%i) in (12) {
           ... // body 2
-        }
+        } {mapping = [#gpu.thread<x>]}
         gpu.terminator
       }
       ```
index 2583875..ac1af6c 100644 (file)
@@ -47,7 +47,7 @@ DiagnosedSilenceableFailure tileToForeachThreadOpImpl(
     RewriterBase &rewriter, transform::TransformState &state,
     TransformOpInterface transformOp, ArrayRef<Operation *> targets,
     ArrayRef<OpFoldResult> mixedNumThreads,
-    ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
+    ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> mapping,
     SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
 } // namespace transform
 
index 347def6..b8638f1 100644 (file)
@@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
@@ -792,7 +793,7 @@ def TileToForeachThreadOp :
     a valid tiling specification (i.e. that only tiles parallel dimensions, 
     e.g. in the Linalg case).
 
-    If non-empty, the `thread_dim_mapping` is added as an attribute to the
+    If non-empty, the `mapping` is added as an attribute to the
     resulting `scf.foreach_thread`.
 
     #### Return modes
@@ -832,7 +833,7 @@ def TileToForeachThreadOp :
                    Variadic<PDL_Operation>:$tile_sizes,
                    DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
                    DefaultValuedAttr<I64ArrayAttr, "{}">:$static_tile_sizes,
-                   OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
+                   OptionalAttr<DeviceMappingArrayAttr>:$mapping);
   let results = (outs PDL_Operation:$foreach_thread_op,
                       PDL_Operation:$tiled_op);
 
@@ -841,22 +842,22 @@ def TileToForeachThreadOp :
                    "ArrayRef<int64_t>":$staticTileSizes,
                    CArg<"::mlir::transform::TileSizesSpec", 
                         "::mlir::transform::TileSizesSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedTileSizes,
                    CArg<"::mlir::transform::TileSizesSpec", 
                         "::mlir::transform::TileSizesSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<int64_t>":$staticNumThreads,
                    CArg<"::mlir::transform::NumThreadsSpec", 
                         "::mlir::transform::NumThreadsSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedNumThreads,
                    CArg<"::mlir::transform::NumThreadsSpec", 
                         "::mlir::transform::NumThreadsSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
   ];
 
   let assemblyFormat = [{
@@ -867,7 +868,7 @@ def TileToForeachThreadOp :
          `tile_sizes` custom<DynamicIndexList>($tile_sizes,
                                                $static_tile_sizes,
                                                "ShapedType::kDynamicSize"))
-    (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict
+    (`(` `mapping` `=` $mapping^ `)`)? attr-dict
   }];
   let hasVerifier = 1;
 
index e5e28e6..758386f 100644 (file)
@@ -423,7 +423,7 @@ computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
 
 /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
 /// tiling by `numThreads`.
-/// If non-empty, the `threadDimMapping` is added as an attribute to the
+/// If non-empty, the `mapping` is added as an attribute to the
 /// resulting `scf.foreach_thread`.
 /// Zero tile sizes indicate that the dimension is not tiled, and can be
 /// thought of as tiling by the full size of data. It is the user's
@@ -436,14 +436,14 @@ struct ForeachThreadTilingResult {
 FailureOr<ForeachThreadTilingResult>
 tileToForeachThreadOp(RewriterBase &builder, TilingInterface op,
                       ArrayRef<OpFoldResult> numThreads,
-                      ArrayRef<int64_t> threadDimMapping = {});
+                      Optional<ArrayAttr> mapping);
 
 /// Same as `tileToForeachThreadOp`, but calculate the number of threads
 /// required using the given tileSizes.
 FailureOr<ForeachThreadTilingResult>
 tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
                                     ArrayRef<OpFoldResult> tileSizes,
-                                    ArrayRef<int64_t> threadDimMapping = {});
+                                    Optional<ArrayAttr> mapping);
 
 /// All indices returned by IndexOp should be invariant with respect to
 /// tiling. Therefore, if an operation is tiled, we have to transform the
index 804d29d..1b6f45b 100644 (file)
@@ -1,3 +1,10 @@
 add_mlir_dialect(SCFOps scf Ops)
 add_mlir_doc(SCFOps SCFDialect Dialects/ -gen-dialect-doc)
 
+set(LLVM_TARGET_DEFINITIONS DeviceMappingInterface.td)
+mlir_tablegen(DeviceMappingAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(DeviceMappingAttrInterface.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(DeviceMappingAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(DeviceMappingAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRDeviceMappingInterfacesIncGen)
+add_dependencies(mlir-generic-headers MLIRDeviceMappingInterfacesIncGen)
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h
new file mode 100644 (file)
index 0000000..0a7fdbb
--- /dev/null
@@ -0,0 +1,22 @@
+//===- DeviceMappingInterface.h - -------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the device mapping interface defined in
+// `DeviceMappingInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DEVICEMAPPINGINTERFACE_H
+#define MLIR_DEVICEMAPPINGINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.h.inc"
+
+#endif // MLIR_DEVICEMAPPINGINTERFACE_H
diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
new file mode 100644 (file)
index 0000000..2d2cafc
--- /dev/null
@@ -0,0 +1,43 @@
+//===- DeviceMappingInterface.td - Device mapping interfaces*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interfaces for the device mapping specification for the loops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DEVICEMAPPINGINTERFACE
+#define MLIR_DEVICEMAPPINGINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Attribute interfaces
+//===----------------------------------------------------------------------===//
+
+def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Attribute interface describing how to map a region to a processing unit.
+    
+    It is intended to be a generic mechanism for binding regions to execution 
+    units of an actual or virtual device. Each device first expresses its own 
+    mappings, and those mappings must implement this interface. These mappings 
+    can be used by the device-specific code generators and the desired regions 
+    can be connected to the given processing unit.
+    
+    Currently, `scf.foreach_thread` uses this interface to express the mapping 
+    of the loops it contains to the GPU's parallelism units such as threads and 
+    thread blocks.
+  }];
+}
+
+def DeviceMappingArrayAttr : 
+  TypedArrayAttrBase<DeviceMappingAttrInterface, 
+  "Device Mapping array attribute"> { }
+
+#endif // MLIR_DEVICEMAPPINGINTERFACE
index 12675c8..84b0ad3 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SCF_SCF_H
 #define MLIR_DIALECT_SCF_SCF_H
 
+#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/RegionKindInterface.h"
index d576ca2..3fa890b 100644 (file)
@@ -16,6 +16,7 @@
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/RegionKindInterface.td"
+include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -378,14 +379,14 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     application per thread. Further lowerings are responsible for specifying
     how this is materialized on concrete hardware resources.
 
-    An optional thread_dim_mapping index array attribute specifies for each
-    virtual thread dimension, how it remaps 1-1 to a set of concrete processing
+    An optional `mapping` is an attribute array that specifies processing units
+    with their dimension, how it remaps 1-1 to a set of concrete processing
     element resources (e.g. a CUDA grid dimension or a level of concrete nested
-    async parallelism). At this time, the specification is backend-dependent and
-    is not verified by the op, beyond being an index array attribute.
-    It is the reponsibility of the lowering to interpret the index array in the
-    context of the concrete target the op is lowered to, or to ignore it when
-    the specification is ill-formed or unsupported for a particular target.
+    async parallelism). It is expressed via any attribute that implements the 
+    device mapping interface. It is the reponsibility of the lowering mechanism 
+    to interpret the `mapping` attributes in the context of the concrete target 
+    the op is lowered to, or to ignore it when the specification is ill-formed 
+    or unsupported for a particular target.
 
     The only allowed terminator is `scf.foreach_thread.perform_concurrently`.
     `scf.foreach_thread` returns one value per `shared_out` operand. The
@@ -440,11 +441,12 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     //
     ```
 
-    Example with thread_dim_mapping attribute:
+    Example with mapping attribute:
 
     ```mlir
     //
-    // Sequential context.
+    // Sequential context. Here `mapping` is expressed as GPU thread mapping 
+    // attributes
     //
     %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
         (%num_threads_1, %numthread_id_2) shared_outs(...)
@@ -456,7 +458,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
        scf.foreach_thread.perform_concurrently {
          ...
       }
-    } { thread_dim_mapping = [1, 0] }
+    } { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
     // Implicit synchronization point.
     // Sequential context.
     //
@@ -480,7 +482,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
   }];
   let arguments = (ins Variadic<Index>:$num_threads,
                        Variadic<AnyRankedTensor>:$outputs,
-                       DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
+                       OptionalAttr<DeviceMappingArrayAttr>:$mapping);
 
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -493,10 +495,10 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
   let builders = [
     // Bodyless builder, outputs must be specified.
     OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
-                   CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
+                   "Optional<ArrayAttr>":$mapping)>,
     // Builder that takes a bodyBuilder lambda.
     OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
-                   "ArrayRef<int64_t>":$thread_dim_mapping,
+                   "ArrayRef<Attribute>":$mapping,
                    "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
   ];
   let extraClassDeclaration = [{
@@ -535,14 +537,14 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     }
 
     /// Return the thread indices in the order specified by the
-    /// thread_dim_mapping attribute. Return failure is
-    /// thread_dim_mapping is not a valid permutation.
-    FailureOr<SmallVector<Value>> getPermutedThreadIndices();
+    /// given mapping argument. Return failure is
+    /// mapping is not a valid permutation.
+    FailureOr<SmallVector<Value>> getPermutedThreadIndices(ArrayRef<int64_t> mapping);
 
     /// Return the number of threads in the order specified by the
-    /// thread_dim_mapping attribute.
-    /// Return failure is thread_dim_mapping is not a valid permutation.
-    FailureOr<SmallVector<OpFoldResult>> getPermutedNumThreads(OpBuilder &b);
+    /// given mapping argument.
+    /// Return failure is mapping is not a valid permutation.
+    FailureOr<SmallVector<OpFoldResult>> getPermutedNumThreads(OpBuilder &b, ArrayRef<int64_t> mapping);
 
     // The ensureTerminator method generated by SingleBlockImplicitTerminator is
     // unaware of the fact that our terminator also needs a region to be
index 1563afb..6ed0047 100644 (file)
@@ -3,10 +3,13 @@ add_mlir_dialect_library(MLIRGPUTransformOps
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU/TransformOps
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
 
   DEPENDS
   MLIRGPUTransformOpsIncGen
-
+  MLIRDeviceMappingInterfacesIncGen
+  MLIRGPUDeviceMapperEnumsGen
+  
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRGPUTransforms
index 280dbf0..23420e8 100644 (file)
@@ -166,8 +166,20 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
 
   // Step 0. Outline the compute workload region and set up the workload
   // operands.
+  SmallVector<int64_t> mapping;
+  if (!foreachThreadOp.getMapping().has_value())
+    return transformOp.emitSilenceableError() << "mapping must be present";
+  for (DeviceMappingAttrInterface map : *foreachThreadOp.getMapping()) {
+    if (auto blockMap = map.dyn_cast<GPUBlockMappingAttr>()) {
+      mapping.push_back((int64_t)blockMap.getBlock());
+    } else {
+      return transformOp.emitSilenceableError()
+             << "mapping must be #gpu.block<x/y/z/>";
+    }
+  }
+
   FailureOr<SmallVector<OpFoldResult>> potentialGridDim =
-      foreachThreadOp.getPermutedNumThreads(rewriter);
+      foreachThreadOp.getPermutedNumThreads(rewriter, mapping);
 
   if (failed(potentialGridDim) ||
       llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) {
@@ -193,7 +205,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
 
   // Step 2. RAUW thread indices to thread ops.
   SmallVector<Value> threadIndices =
-      *foreachThreadOp.getPermutedThreadIndices();
+      *foreachThreadOp.getPermutedThreadIndices(mapping);
   assert(blockOps.size() == 3 && "3 block id ops are required");
   for (auto [blockIdx, blockOp] : llvm::zip(threadIndices, blockOps)) {
     Value val = blockIdx;
@@ -230,7 +242,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::findTopLevelForeachThreadOp(
 }
 
 /// This is a helper that is only used in
-/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id.
+/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects
+/// block_id.
 static void generateGpuBlockIds(RewriterBase &rewriter,
                                 scf::ForeachThreadOp foreachOp,
                                 SmallVectorImpl<Value> &blockOps) {
@@ -335,7 +348,18 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     return failureHelper(
         "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id");
 
-  auto potentialBlockDim = foreachThreadOp.getPermutedNumThreads(rewriter);
+  SmallVector<int64_t> mapping;
+  if (!foreachThreadOp.getMapping().has_value())
+    return failureHelper("mapping must be present");
+  for (DeviceMappingAttrInterface map : *foreachThreadOp.getMapping()) {
+    if (auto threadMap = map.dyn_cast<GPUThreadMappingAttr>()) {
+      mapping.push_back((int64_t)threadMap.getThread());
+    } else {
+      return failureHelper("mapping must be #gpu.thread<x/y/z/>");
+    }
+  }
+  FailureOr<SmallVector<OpFoldResult>> potentialBlockDim =
+      foreachThreadOp.getPermutedNumThreads(rewriter, mapping);
   if (failed(potentialBlockDim) ||
       llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) {
         return !getConstantIntValue(ofr).has_value();
@@ -365,8 +389,8 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     if (blockDim > globalBlockDim) {
       return failureHelper(
           "The requested GPU threads are fewer than the number of loop trip "
-          "counts. Try to tile scf.foreach_thread before mapping or set small "
-          "blockDim.");
+          "counts. Try to tile scf.foreach_thread before mapping or set "
+          "small blockDim.");
     }
     if (blockDim == globalBlockDim)
       continue;
@@ -400,7 +424,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
 
   // Step 4. RAUW thread indices to thread ops.
   SmallVector<Value> threadIndices =
-      *foreachThreadOp.getPermutedThreadIndices();
+      *foreachThreadOp.getPermutedThreadIndices(mapping);
   for (auto [threadIdx, threadOp] : llvm::zip(threadIndices, threadOps)) {
     Value val = threadIdx;
     Value op = threadOp;
index 6b8ca91..7b720a7 100644 (file)
@@ -1321,19 +1321,21 @@ void transform::TileOp::getEffects(
 // TileToForeachThreadOp
 //===----------------------------------------------------------------------===//
 
-void transform::TileToForeachThreadOp::build(
-    OpBuilder &builder, OperationState &result, Value target,
-    ArrayRef<int64_t> staticTileSizes, transform::TileSizesSpec,
-    ArrayRef<int64_t> threadDimMapping) {
+void transform::TileToForeachThreadOp::build(OpBuilder &builder,
+                                             OperationState &result,
+                                             Value target,
+                                             ArrayRef<int64_t> staticTileSizes,
+                                             transform::TileSizesSpec,
+                                             ArrayRef<int64_t> mapping) {
   return build(builder, result, target,
                getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
-               TileSizesSpec(), threadDimMapping);
+               TileSizesSpec(), mapping);
 }
 
 void transform::TileToForeachThreadOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec,
-    ArrayRef<int64_t> threadDimMapping) {
+    ArrayRef<int64_t> mapping) {
   SmallVector<int64_t> staticTileSizes;
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes,
@@ -1344,28 +1346,29 @@ void transform::TileToForeachThreadOp::build(
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
-  ArrayAttr threadDimMappingAttr;
-  if (!threadDimMapping.empty())
-    threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping);
+  ArrayAttr mappingAttr;
+  if (!mapping.empty())
+    mappingAttr = builder.getI64ArrayAttr(mapping);
   build(builder, result, TypeRange{operationType, operationType}, target,
         /*numThreads=*/ValueRange{}, dynamicTileSizes,
-        /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr,
-        threadDimMappingAttr);
+        /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr);
 }
 
-void transform::TileToForeachThreadOp::build(
-    OpBuilder &builder, OperationState &result, Value target,
-    ArrayRef<int64_t> staticNumThreads, transform::NumThreadsSpec,
-    ArrayRef<int64_t> threadDimMapping) {
+void transform::TileToForeachThreadOp::build(OpBuilder &builder,
+                                             OperationState &result,
+                                             Value target,
+                                             ArrayRef<int64_t> staticNumThreads,
+                                             transform::NumThreadsSpec,
+                                             ArrayRef<int64_t> mapping) {
   return build(builder, result, target,
                getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
-               NumThreadsSpec(), threadDimMapping);
+               NumThreadsSpec(), mapping);
 }
 
 void transform::TileToForeachThreadOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec,
-    ArrayRef<int64_t> threadDimMapping) {
+    ArrayRef<int64_t> mapping) {
   SmallVector<int64_t> staticNumThreads;
   SmallVector<Value> dynamicNumThreads;
   dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
@@ -1376,19 +1379,19 @@ void transform::TileToForeachThreadOp::build(
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
-  ArrayAttr threadDimMappingAttr;
-  if (!threadDimMapping.empty())
-    threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping);
+  ArrayAttr mappingAttr;
+  if (!mapping.empty())
+    mappingAttr = builder.getI64ArrayAttr(mapping);
   build(builder, result, TypeRange{operationType, operationType}, target,
         dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr,
-        /*staticTileSizes=*/ArrayAttr(), threadDimMappingAttr);
+        /*staticTileSizes=*/ArrayAttr(), mappingAttr);
 }
 
 DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     RewriterBase &rewriter, transform::TransformState &state,
     TransformOpInterface transformOp, ArrayRef<Operation *> targets,
     ArrayRef<OpFoldResult> mixedNumThreads,
-    ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
+    ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> mapping,
     SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
   if (targets.empty())
     return DiagnosedSilenceableFailure(success());
@@ -1457,19 +1460,13 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
       return diag;
     }
     rewriter.setInsertionPoint(tilableOp);
-    auto maybeThreadDimMappingAttr = threadDimMapping;
-    auto dimMapping = llvm::to_vector(
-        maybeThreadDimMappingAttr
-            ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
-            : ArrayRef<int64_t>{});
-
     FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
     if (!mixedNumThreads.empty()) {
       tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
-                                                   numThreads, dimMapping);
+                                                   numThreads, mapping);
     } else {
       tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
-          rewriter, tilableOp, tileSizes, dimMapping);
+          rewriter, tilableOp, tileSizes, mapping);
     }
 
     if (failed(tilingResult))
@@ -1494,7 +1491,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
 
   DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl(
       rewriter, state, cast<TransformOpInterface>(getOperation()), targets,
-      getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps,
+      getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps,
       tiledOps);
 
   if (!diag.succeeded())
index 5937da3..a32e9f7 100644 (file)
@@ -215,7 +215,7 @@ static OpFoldResult buildMin(OpBuilder &b, Location loc,
 /// tiling is specified by the number of tiles/threads `numThreads` and the
 /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
 /// not specified, then  it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an
+/// numThreads[i])`. If non-empty, the `mapping` is added as an
 /// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate
 /// that the dimension is not tiled, and can be thought of as tiling by the full
 /// size of data.
@@ -226,7 +226,7 @@ static OpFoldResult buildMin(OpBuilder &b, Location loc,
 static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
     RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
     Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
-    ArrayRef<int64_t> threadDimMapping, bool omitTileOffsetBoundsCheck) {
+    Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
   Location loc = op->getLoc();
   OpBuilder::InsertionGuard g(b);
   SmallVector<Range> loopRanges = op.getIterationDomain(b);
@@ -256,7 +256,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
   // version because we require the use of RewriterBase in the body, so we
   // manually move the insertion point to the body below.
   scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
-      loc, dest, ValueRange(materializedNonZeroNumThreads), threadDimMapping);
+      loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
 
   // Fill out the ForeachThreadOp body.
   b.setInsertionPointToStart(foreachThreadOp.getBody(0));
@@ -363,16 +363,16 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
 FailureOr<ForeachThreadTilingResult>
 linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op,
                               ArrayRef<OpFoldResult> numThreads,
-                              ArrayRef<int64_t> threadDimMapping) {
+                              Optional<ArrayAttr> mapping) {
   return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None,
-                                   threadDimMapping,
+                                   mapping,
                                    /*omitTileOffsetBoundsCheck=*/false);
 }
 
 FailureOr<ForeachThreadTilingResult>
-linalg::tileToForeachThreadOpUsingTileSizes(
-    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> tileSizes,
-    ArrayRef<int64_t> threadDimMapping) {
+linalg::tileToForeachThreadOpUsingTileSizes(RewriterBase &b, TilingInterface op,
+                                            ArrayRef<OpFoldResult> tileSizes,
+                                            Optional<ArrayAttr> mapping) {
   SmallVector<Range> loopRanges = op.getIterationDomain(b);
   unsigned nLoops = loopRanges.size();
   SmallVector<OpFoldResult> numThreads;
@@ -388,8 +388,7 @@ linalg::tileToForeachThreadOpUsingTileSizes(
     numThreads.push_back(numTiles);
   }
   return tileToForeachThreadOpImpl(b, op, numThreads,
-                                   /*nominalTileSizes=*/tileSizes,
-                                   threadDimMapping,
+                                   /*nominalTileSizes=*/tileSizes, mapping,
                                    /*omitTileOffsetBoundsCheck=*/true);
 }
 
index 6255d9e..a043491 100644 (file)
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRSCFDialect
   SCF.cpp
+  DeviceMappingInterface.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/SCF
diff --git a/mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp b/mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp
new file mode 100644 (file)
index 0000000..a90c638
--- /dev/null
@@ -0,0 +1,17 @@
+//===- DeviceMappingInterface.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Table-generated class definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.cpp.inc"
index 5f1a20c..e32f671 100644 (file)
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::scf;
@@ -1111,6 +1114,12 @@ LogicalResult ForeachThreadOp::verify() {
     if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType())
       return emitOpError("type mismatch between ")
              << i << "-th output and corresponding block argument";
+  if (getMapping().has_value())
+    for (auto map : getMapping().value()) {
+      if (!isa<DeviceMappingAttrInterface>(map))
+        return emitOpError()
+               << getMappingAttrName() << " is not device mapping attribute";
+    }
 
   return success();
 }
@@ -1200,11 +1209,14 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
 void ForeachThreadOp::build(mlir::OpBuilder &builder,
                             mlir::OperationState &result, ValueRange outputs,
                             ValueRange numThreads,
-                            ArrayRef<int64_t> threadDimMapping) {
+                            Optional<ArrayAttr> mapping) {
   result.addOperands(numThreads);
   result.addOperands(outputs);
-  result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name),
-                      builder.getI64ArrayAttr(threadDimMapping));
+  if (mapping.has_value()) {
+    result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name),
+                        mapping.value());
+  }
+
   result.addAttribute(
       "operand_segment_sizes",
       builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
@@ -1231,12 +1243,12 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
 // Builder that takes a bodyBuilder lambda.
 void ForeachThreadOp::build(
     mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs,
-    ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
+    ValueRange numThreads, ArrayRef<Attribute> mapping,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
   result.addOperands(numThreads);
   result.addOperands(outputs);
-  result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name),
-                      builder.getI64ArrayAttr(threadDimMapping));
+  result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name),
+                      builder.getArrayAttr(mapping));
   result.addAttribute(
       "operand_segment_sizes",
       builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
@@ -1290,51 +1302,51 @@ static FailureOr<SmallVector<T>> permute(const SmallVector<T> &vals,
   SmallVector<T> result(vals.size());
   SmallVector<bool> seen(vals.size());
   for (auto [idx, val] : llvm::zip(perm, vals)) {
-    // Already seen, invalid thread_dim_mapping.
+    // Already seen, invalid mapping.
     if (seen[idx])
       return failure();
     result[idx] = val;
     seen[idx] = true;
   }
-  // Some not seen, invalid thread_dim_mapping.
+  // Some not seen, invalid mapping.
   if (!llvm::all_of(seen, [](bool b) { return b; }))
     return failure();
   return result;
 }
 
-/// Helper to get apply the `thread_dim_mapping` permutation of a
+/// Helper to get apply the `mapping` permutation of a
 /// `foreachThreadOp` to `values`.
 template <typename T>
 static FailureOr<SmallVector<T>>
 getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp,
-                                 const SmallVector<T> &values) {
+                                 const SmallVector<T> &values,
+                                 ArrayRef<int64_t> mapping) {
   // Apply mapping permutation if specified.
-  auto mapping = foreachThreadOp.getThreadDimMapping();
-  if (mapping && !mapping.empty()) {
-    auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping));
-    if (failed(maybePermuted))
-      return foreachThreadOp->emitError("invalid permutation");
-    return *maybePermuted;
-  }
+  FailureOr<SmallVector<T>> maybePermuted = permute(values, mapping);
+  if (failed(maybePermuted))
+    return foreachThreadOp->emitError("invalid permutation");
+  return *maybePermuted;
   return values;
 }
 
-/// Return the thread indices in the order specified by the thread_dim_mapping
-/// attribute. Return failure is thread_dim_mapping is not a valid permutation.
-FailureOr<SmallVector<Value>> ForeachThreadOp::getPermutedThreadIndices() {
+/// Return the thread indices in the order specified by the mapping
+/// attribute. Return failure is mapping is not a valid permutation.
+FailureOr<SmallVector<Value>>
+ForeachThreadOp::getPermutedThreadIndices(ArrayRef<int64_t> mapping) {
   SmallVector<Value> threadCountValues = this->getThreadIndices();
   threadCountValues.resize(3, Value());
-  return getValuesPermutedByThreadMapping(*this, threadCountValues);
+  return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping);
 }
 
 /// Return the number of threads in the order specified by the
-/// thread_dim_mapping attribute.
-/// Return failure is thread_dim_mapping is not a valid permutation.
+/// mapping attribute.
+/// Return failure is mapping is not a valid permutation.
 FailureOr<SmallVector<OpFoldResult>>
-ForeachThreadOp::getPermutedNumThreads(OpBuilder &b) {
+ForeachThreadOp::getPermutedNumThreads(OpBuilder &b,
+                                       ArrayRef<int64_t> mapping) {
   SmallVector<OpFoldResult> threadCountValues = this->getNumThreads();
   threadCountValues.resize(3, b.getIndexAttr(1));
-  return getValuesPermutedByThreadMapping(*this, threadCountValues);
+  return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping);
 }
 
 ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
index 2771857..dcf2fe1 100644 (file)
@@ -1141,10 +1141,11 @@ struct ForeachThreadOpInterface
     // Create new ForeachThreadOp without any results and drop the automatically
     // introduced terminator.
     rewriter.setInsertionPoint(foreachThreadOp);
-    auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
+    ForeachThreadOp newForeachThreadOp;
+    newForeachThreadOp = rewriter.create<ForeachThreadOp>(
         foreachThreadOp.getLoc(), /*outputs=*/ValueRange(),
-        foreachThreadOp.getNumThreads(),
-        extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
+        foreachThreadOp.getNumThreads(), foreachThreadOp.getMapping());
+
     newForeachThreadOp.getBody()->getTerminator()->erase();
 
     // Move over block contents of the old op.
index f61ed8f..de50c59 100644 (file)
@@ -26,7 +26,7 @@ func.func @map_nested_foreach_to_threads_excessive_threads(%x: memref<2 x 32 x f
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
     gpu.terminator
   }
 
@@ -38,7 +38,7 @@ func.func @map_nested_foreach_to_threads_excessive_threads(%x: memref<2 x 32 x f
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
     gpu.terminator
   }
 
@@ -67,7 +67,7 @@ func.func @map_nested_foreach_to_threads_fewer_threads(%x: memref<2 x 32 x f32>,
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
     gpu.terminator
   }
 
@@ -79,7 +79,7 @@ func.func @map_nested_foreach_to_threads_fewer_threads(%x: memref<2 x 32 x f32>,
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
     gpu.terminator
   }
 
@@ -106,7 +106,7 @@ func.func @map_nested_foreach_to_threads_dynamic_trip_count(%x: memref<2 x 32 x
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
     gpu.terminator
   }
   return %y : memref<2 x 32 x f32>
@@ -131,7 +131,7 @@ func.func @map_nested_foreach_to_threads_4d_loop(%x: memref<2x32x32x32xf32>, %y:
     scf.foreach_thread (%i, %j, %k, %l) in (%c2, %c32,%c32,%c32) {
         %4 = memref.load %x[%i, %j, %k, %l] : memref<2x32x32x32xf32>        
         memref.store %4, %y[%i, %j, %k, %l] : memref<2x32x32x32xf32>
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
     gpu.terminator
   }
   return %y : memref<2x32x32x32xf32>
@@ -197,14 +197,14 @@ func.func @map_foreach_to_blocks_not_unique(%x: memref<2 x 32 x f32>, %y: memref
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
 
     scf.foreach_thread (%i, %j) in (%c7, %c9) {
         %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32>
         %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
     gpu.terminator
   }
 
@@ -232,14 +232,14 @@ func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref
       %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
       %6 = math.fma %alpha, %4, %5 : f32
       memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-  }  {thread_dim_mapping = [0, 1, 2]}
+  }  { mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>] }
 
   scf.foreach_thread (%i, %j) in (%c7, %c9) {
       %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32>
       %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
       %6 = math.fma %alpha, %4, %5 : f32
       memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-  }  {thread_dim_mapping = [1, 0, 2]}
+  }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
   
   return %y : memref<2 x 32 x f32>
 }
@@ -261,7 +261,7 @@ func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref
       %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
       %6 = math.fma %alpha, %4, %5 : f32
       memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
-  }  {thread_dim_mapping = [0, 1, 2]}
+  }  { mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>] }
   return %y : memref<2 x 32 x f32>
 }
 
@@ -273,4 +273,3 @@ transform.sequence failures(propagate) {
 }
 
 // -----
-
index d4ff7ff..eb7208b 100644 (file)
@@ -24,7 +24,7 @@ func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
         %5 = memref.load %y[%i, %j] : !type
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : !type
-     }  {thread_dim_mapping = [0, 1, 2]}
+     }  { mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>]}
     gpu.terminator
   }
   return %y : !type
@@ -33,7 +33,7 @@ func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0
-  transform.gpu.map_foreach_to_blocks %funcop { blockDim = [12, 9, 1]}
+  transform.gpu.map_foreach_to_blocks %funcop { gridDim = [12, 9]}
 }
 
 // -----
@@ -73,12 +73,12 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
         %5 = memref.load %y[%i, %j] : !type
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : !type
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>]}
      scf.foreach_thread (%i) in (%c12) {
         %7 = memref.load %t[%i] : !type1d
         %8 = arith.addf %alpha, %7 : f32
         memref.store %8, %t[%i] : !type1d
-     }  {thread_dim_mapping = [0, 1, 2]}
+     }  {mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>] }
     gpu.terminator
   }
   return %y : !type
@@ -87,7 +87,7 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1] }
+  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] }
 }
 
 // -----
@@ -118,8 +118,8 @@ func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d {
       %5 = memref.load %y[%i, %j, %k, %l] : !type4d
       %6 = math.fma %alpha, %4, %5 : f32
       memref.store %6, %y[%i, %j, %k, %l] : !type4d
-    }  {thread_dim_mapping = [1, 0, 2]}
-  }  {thread_dim_mapping = [0, 1, 2]}
+    }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+  }  { mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>] }
   return %y : !type4d
 }
 
@@ -151,7 +151,7 @@ func.func @saxpy2d_no_barrier(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %
         %5 = memref.load %y[%i, %j] : !type
         %6 = math.fma %alpha, %4, %5 : f32
         memref.store %6, %y[%i, %j] : !type
-     }  {thread_dim_mapping = [1, 0, 2]}
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
     gpu.terminator
   }
   return %y : !type
index ce0afdf..f015cdf 100644 (file)
@@ -26,7 +26,7 @@ module {
   // CHECK-NEXT:     tensor.parallel_insert_slice %[[RES]] into %[[C_BLK]]{{.*}} :
   // CHECK-SAME:       tensor<?x?xf32> into tensor<?x?xf32>
   // CHECK-NEXT:   }
-  // CHECK-NEXT: } {thread_dim_mapping = [1, 0]}
+  // CHECK-NEXT: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
     %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
                       outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
     return %0 : tensor<?x?xf32>
@@ -35,7 +35,7 @@ module {
   transform.sequence failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-    %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0])
+    %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapping = [ #gpu.thread<y>, #gpu.thread<x> ] )
   }
 }
 
@@ -177,7 +177,7 @@ module {
   transform.sequence failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-    %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0])
+    %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] ( mapping = [#gpu.thread<x>])
   }
 }
 // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)>
index 69c4ef4..65b1f09 100644 (file)
@@ -628,7 +628,7 @@ func.func @same_enclosing_repetitive_region(%2: tensor<320xf32>,
       // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]}
       tensor.parallel_insert_slice %8 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32>
     }
-  } {thread_dim_mapping = []}
+  } 
   return %4 : tensor<320xf32>
 }
 
index a533783..17dca3f 100644 (file)
@@ -128,7 +128,7 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
         tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
           tensor<1xf32> into tensor<100xf32>
       }
-  // CHECK: } {thread_dim_mapping = [5]}
-  } {thread_dim_mapping = [5]}
+  // CHECK: } {mapping = [#gpu.thread<x>]}
+  } {mapping = [#gpu.thread<x>]}
   return
 }
index e563838..18413e8 100644 (file)
@@ -338,12 +338,12 @@ func.func @elide_terminator() -> () {
   %num_threads = arith.constant 100 : index
 
   //      CHECK:    scf.foreach_thread
-  // CHECK-NEXT:  } {thread_dim_mapping = [42]}
+  // CHECK-NEXT:  } {mapping = [#gpu.thread<x>]}
   // CHECK-NEXT:  return
   scf.foreach_thread (%thread_idx) in (%num_threads) {
     scf.foreach_thread.perform_concurrently {
     }
-  } {thread_dim_mapping = [42]}
+  } {mapping = [#gpu.thread<x>]}
   return
 }
 
index 461da29..82ea071 100644 (file)
@@ -217,7 +217,7 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
     Location loc = op.getLoc();
     auto foreachOp = rewriter.create<scf::ForeachThreadOp>(
         loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(),
-        /*threadDimMapping=*/ArrayRef<int64_t>{},
+        /*mapping=*/ArrayRef<Attribute>{},
         [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
           unsigned numThreadIdRegionArgs =
               helper.getIterationSpaceSizes().size();
index 5260beb..3d77545 100644 (file)
@@ -1833,7 +1833,10 @@ gentbl_cc_library(
 
 td_library(
     name = "SCFTdFiles",
-    srcs = ["include/mlir/Dialect/SCF/IR/SCFOps.td"],
+    srcs = [
+        "include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td",
+        "include/mlir/Dialect/SCF/IR/SCFOps.td",
+    ],
     includes = ["include"],
     deps = [
         ":ControlFlowInterfacesTdFiles",
@@ -1871,6 +1874,32 @@ gentbl_cc_library(
 )
 
 gentbl_cc_library(
+    name = "SCFDeviceMappingInterfacesIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-attr-interface-decls"],
+            "include/mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.h.inc",
+        ),
+        (
+            ["-gen-attr-interface-defs"],
+            "include/mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.cpp.inc",
+        ),
+        (
+            ["-gen-attrdef-decls"],
+            "include/mlir/Dialect/SCF/IR/DeviceMappingAttributes.h.inc",
+        ),
+        (
+            ["-gen-attrdef-defs"],
+            "include/mlir/Dialect/SCF/IR/DeviceMappingAttributes.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td",
+    deps = [":SCFTdFiles"],
+)
+
+gentbl_cc_library(
     name = "SCFPassIncGen",
     strip_include_prefix = "include",
     tbl_outs = [
@@ -2794,6 +2823,7 @@ cc_library(
         ":MemRefDialect",
         ":ParallelCombiningOpInterface",
         ":Pass",
+        ":SCFDeviceMappingInterfacesIncGen",
         ":SCFIncGen",
         ":SCFPassIncGen",
         ":Support",
@@ -3623,6 +3653,7 @@ td_library(
         "include/mlir/Dialect/GPU/IR/GPUBase.td",
         "include/mlir/Dialect/GPU/IR/GPUOps.td",
         "include/mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td",
+        "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td",
     ],
     includes = ["include"],
     deps = [
@@ -3632,11 +3663,33 @@ td_library(
         ":InferIntRangeInterfaceTdFiles",
         ":LLVMOpsTdFiles",
         ":OpBaseTdFiles",
+        ":SCFTdFiles",
         ":SideEffectInterfacesTdFiles",
     ],
 )
 
 gentbl_cc_library(
+    name = "GPUDeviceMapperEnumsGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-enum-decls"],
+            "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMapperEnums.h.inc",
+        ),
+        (
+            ["-gen-enum-defs"],
+            "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMapperEnums.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td",
+    deps = [
+        ":GPUOpsTdFiles",
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl_cc_library(
     name = "GPUBaseIncGen",
     strip_include_prefix = "include",
     tbl_outs = [
@@ -3725,6 +3778,7 @@ cc_library(
         ":InferTypeOpInterface",
         ":LLVMDialect",
         ":MemRefDialect",
+        ":SCFDialect",
         ":SideEffectInterfaces",
         "//llvm:Support",
     ],
@@ -7694,6 +7748,7 @@ td_library(
     includes = ["include"],
     deps = [
         ":PDLDialectTdFiles",
+        ":SCFTdFiles",
         ":TransformDialectTdFiles",
     ],
 )
@@ -7781,6 +7836,7 @@ gentbl_cc_library(
     td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td",
     deps = [
         ":LinalgTransformOpsTdFiles",
+        ":SCFDeviceMappingInterfacesIncGen",
     ],
 )