From: Guray Ozen Date: Thu, 10 Nov 2022 16:55:49 +0000 (+0100) Subject: [mlir] Introduce device mapper attribute for `thread_dim_map` and `mapped to dims` X-Git-Tag: upstream/17.0.6~27960 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6663f3470417523141ee87923840d17b35d6b4c7;p=platform%2Fupstream%2Fllvm.git [mlir] Introduce device mapper attribute for `thread_dim_map` and `mapped to dims` `scf.foreach_thread` defines mapping its loops to processors via an integer array, see an example below. A lowering can use this mapping. However, expressing mapping as an integer array is very confusing, especially when there are multiple levels of parallelism. In addition, the op does not verify the integer array. This change introduces device mapping attribute to make mapping descriptive and verifiable. Then it makes GPU transform dialect use it. ``` scf.foreach_thread (%i, %j) in (%c1, %c2) { scf.foreach_thread (%i2, %j2) in (%c1, %c2) {...} { thread_dim_mapping = [0, 1]} } { thread_dim_mapping = [0, 1]} ``` It first introduces a `DeviceMappingInterface` which is an attribute interface. `scf.foreach_thread` defines its mapping via this interface. A lowering must define its attributes and implement this interface as well. This way gives us a clear validation. The change also introduces two new attributes (`#gpu.thread` and `#gpu.block` ). After this change, the above code prints as below, as seen here, this way clarifies the loop mappings. The change also implements consuming of these two new attribute by the transform dialect. Transform dialect binds the outermost loops to the thread blocks and innermost loops to threads. ``` scf.foreach_thread (%i, %j) in (%c1, %c2) { scf.foreach_thread (%i2, %j2) in (%c1, %c2) {...} { thread_dim_mapping = [#gpu.thread, #gpu.thread]} } { thread_dim_mapping = [#gpu.block, #gpu.block]} ``` Reviewed By: ftynse, nicolasvasilache Differential Revision: https://reviews.llvm.org/D137413 --- diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h index 63eaa41..4d0208b 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -172,6 +172,8 @@ void addAsyncDependency(Operation *op, Value token); #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc" diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index aef31af..cadc685 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/Dialect/GPU/IR/GPUBase.td" include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td" +include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt index c99f3df..0cbcc3f 100644 --- a/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt @@ -4,3 +4,8 @@ mlir_tablegen(GPUTransformOps.cpp.inc -gen-op-defs) add_public_tablegen_target(MLIRGPUTransformOpsIncGen) add_mlir_doc(GPUTransformOps GPUTransformOps Dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS GPUDeviceMappingAttr.td) +mlir_tablegen(GPUDeviceMapperEnums.h.inc -gen-enum-decls) +mlir_tablegen(GPUDeviceMapperEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRGPUDeviceMapperEnumsGen) diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td new file mode 100644 index 0000000..a93353d --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td @@ -0,0 +1,65 @@ +//===-- GPUDeviceMappingAttr.td - Attribute definition -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the attribute used to map loops to gpu. +// +//===----------------------------------------------------------------------===// + +#ifndef GPU_DEVICE_MAPPING_ATTR +#define GPU_DEVICE_MAPPING_ATTR + +include "mlir/Dialect/GPU/IR/GPUBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" + +def DimX : I64EnumAttrCase<"DimX", 0, "x">; +def DimY : I64EnumAttrCase<"DimY", 1, "y">; +def DimZ : I64EnumAttrCase<"DimZ", 2, "z">; + +def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [ + DimX, DimY, DimZ]> { + let cppNamespace = "::mlir::gpu"; +} + +def GPUThreadMappingAttr + : GPU_Attr<"GPUThreadMapping", "thread", [ DeviceMappingAttrInterface ]> { + let parameters = (ins + EnumParameter:$thread + ); + let assemblyFormat = "`<` params `>`"; + let description = [{ + An attribute that allows defining thread parallelism for GPU devices. + + Thread (aka work item) are grouped into a thread blocks where block may be + described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates + that thread parallelism is desired. It can be consumed by lowering to + generate GPU. + }]; +} + +def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [ + DimX, DimY, DimZ]> { + let cppNamespace = "::mlir::gpu"; +} + +def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ DeviceMappingAttrInterface ] > { + let parameters = (ins + EnumParameter:$block + ); + let assemblyFormat = "`<` params `>`"; + let description = [{ + An attribute that allows defining thread block parallelism for GPU devices. + + Thread blocks (aka work-group) are grouped into a grid where grid may be + described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates + that thread block parallelism is desired. It can be consumed by lowering to + generate GPU code. + }]; +} + +#endif // GPU_DEVICE_MAPPING_ATTR diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td index 0dfda8d..c0b348d 100644 --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -29,7 +29,7 @@ def MapNestedForeachToThreads : The operation searches for `scf.foreach_thread` ops nested under `target` and maps each such op to GPU threads. Mapping is one-to-one and the induction variables of `scf.foreach_thread` are rewritten to - `gpu.thread_id` according to the `thread_dim_mapping` attribute. + `gpu.thread_id` according to the `mapping` attribute. Sibling `scf.foreach_thread` are supported in which case, the union of the number of threads is computed and may result in predication. @@ -73,10 +73,10 @@ def MapNestedForeachToThreads : threads(%tx, %ty, %tz) in (%tx = %3, %ty = %4, %tz = %5) { scf.foreach_thread (%i, %j) in (7, 9) { ... // body 1 - } {thread_dim_mapping = [1, 0, 2]} + } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} scf.foreach_thread (%i) in (12) { ... // body 2 - } + } {mapping = [#gpu.thread]} gpu.terminator } ``` diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h index 2583875..ac1af6c 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -47,7 +47,7 @@ DiagnosedSilenceableFailure tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, ArrayRef mixedNumThreads, - ArrayRef mixedTileSizes, Optional threadDimMapping, + ArrayRef mixedTileSizes, Optional mapping, SmallVector &tileOps, SmallVector &tiledOps); } // namespace transform diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 347def6..b8638f1 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" @@ -792,7 +793,7 @@ def TileToForeachThreadOp : a valid tiling specification (i.e. that only tiles parallel dimensions, e.g. in the Linalg case). - If non-empty, the `thread_dim_mapping` is added as an attribute to the + If non-empty, the `mapping` is added as an attribute to the resulting `scf.foreach_thread`. #### Return modes @@ -832,7 +833,7 @@ def TileToForeachThreadOp : Variadic:$tile_sizes, DefaultValuedAttr:$static_num_threads, DefaultValuedAttr:$static_tile_sizes, - OptionalAttr:$thread_dim_mapping); + OptionalAttr:$mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); @@ -841,22 +842,22 @@ def TileToForeachThreadOp : "ArrayRef":$staticTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef", "{}">:$threadDimMapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef", "{}">:$threadDimMapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$staticNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef", "{}">:$threadDimMapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef", "{}">:$threadDimMapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, ]; let assemblyFormat = [{ @@ -867,7 +868,7 @@ def TileToForeachThreadOp : `tile_sizes` custom($tile_sizes, $static_tile_sizes, "ShapedType::kDynamicSize")) - (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict + (`(` `mapping` `=` $mapping^ `)`)? attr-dict }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index e5e28e6c4..758386f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -423,7 +423,7 @@ computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying /// tiling by `numThreads`. -/// If non-empty, the `threadDimMapping` is added as an attribute to the +/// If non-empty, the `mapping` is added as an attribute to the /// resulting `scf.foreach_thread`. /// Zero tile sizes indicate that the dimension is not tiled, and can be /// thought of as tiling by the full size of data. It is the user's @@ -436,14 +436,14 @@ struct ForeachThreadTilingResult { FailureOr tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, ArrayRef numThreads, - ArrayRef threadDimMapping = {}); + Optional mapping); /// Same as `tileToForeachThreadOp`, but calculate the number of threads /// required using the given tileSizes. FailureOr tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef tileSizes, - ArrayRef threadDimMapping = {}); + Optional mapping); /// All indices returned by IndexOp should be invariant with respect to /// tiling. Therefore, if an operation is tiled, we have to transform the diff --git a/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt index 804d29d..1b6f45b 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt @@ -1,3 +1,10 @@ add_mlir_dialect(SCFOps scf Ops) add_mlir_doc(SCFOps SCFDialect Dialects/ -gen-dialect-doc) +set(LLVM_TARGET_DEFINITIONS DeviceMappingInterface.td) +mlir_tablegen(DeviceMappingAttrInterface.h.inc -gen-attr-interface-decls) +mlir_tablegen(DeviceMappingAttrInterface.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(DeviceMappingAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(DeviceMappingAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRDeviceMappingInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRDeviceMappingInterfacesIncGen) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h new file mode 100644 index 0000000..0a7fdbb --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h @@ -0,0 +1,22 @@ +//===- DeviceMappingInterface.h - -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the device mapping interface defined in +// `DeviceMappingInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DEVICEMAPPINGINTERFACE_H +#define MLIR_DEVICEMAPPINGINTERFACE_H + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.h.inc" + +#endif // MLIR_DEVICEMAPPINGINTERFACE_H diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td new file mode 100644 index 0000000..2d2cafc --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td @@ -0,0 +1,43 @@ +//===- DeviceMappingInterface.td - Device mapping interfaces*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interfaces for the device mapping specification for the loops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DEVICEMAPPINGINTERFACE +#define MLIR_DEVICEMAPPINGINTERFACE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Attribute interfaces +//===----------------------------------------------------------------------===// + +def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + Attribute interface describing how to map a region to a processing unit. + + It is intended to be a generic mechanism for binding regions to execution + units of an actual or virtual device. Each device first expresses its own + mappings, and those mappings must implement this interface. These mappings + can be used by the device-specific code generators and the desired regions + can be connected to the given processing unit. + + Currently, `scf.foreach_thread` uses this interface to express the mapping + of the loops it contains to the GPU's parallelism units such as threads and + thread blocks. + }]; +} + +def DeviceMappingArrayAttr : + TypedArrayAttrBase { } + +#endif // MLIR_DEVICEMAPPINGINTERFACE diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index 12675c8..84b0ad3 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SCF_SCF_H #define MLIR_DIALECT_SCF_SCF_H +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionKindInterface.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index d576ca2..3fa890b93 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/RegionKindInterface.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -378,14 +379,14 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ application per thread. Further lowerings are responsible for specifying how this is materialized on concrete hardware resources. - An optional thread_dim_mapping index array attribute specifies for each - virtual thread dimension, how it remaps 1-1 to a set of concrete processing + An optional `mapping` is an attribute array that specifies processing units + with their dimension, how it remaps 1-1 to a set of concrete processing element resources (e.g. a CUDA grid dimension or a level of concrete nested - async parallelism). At this time, the specification is backend-dependent and - is not verified by the op, beyond being an index array attribute. - It is the reponsibility of the lowering to interpret the index array in the - context of the concrete target the op is lowered to, or to ignore it when - the specification is ill-formed or unsupported for a particular target. + async parallelism). It is expressed via any attribute that implements the + device mapping interface. It is the reponsibility of the lowering mechanism + to interpret the `mapping` attributes in the context of the concrete target + the op is lowered to, or to ignore it when the specification is ill-formed + or unsupported for a particular target. The only allowed terminator is `scf.foreach_thread.perform_concurrently`. `scf.foreach_thread` returns one value per `shared_out` operand. The @@ -440,11 +441,12 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ // ``` - Example with thread_dim_mapping attribute: + Example with mapping attribute: ```mlir // - // Sequential context. + // Sequential context. Here `mapping` is expressed as GPU thread mapping + // attributes // %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in (%num_threads_1, %numthread_id_2) shared_outs(...) @@ -456,7 +458,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ scf.foreach_thread.perform_concurrently { ... } - } { thread_dim_mapping = [1, 0] } + } { mapping = [#gpu.thread, #gpu.thread] } // Implicit synchronization point. // Sequential context. // @@ -480,7 +482,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ }]; let arguments = (ins Variadic:$num_threads, Variadic:$outputs, - DefaultValuedAttr:$thread_dim_mapping); + OptionalAttr:$mapping); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -493,10 +495,10 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ let builders = [ // Bodyless builder, outputs must be specified. OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - CArg<"ArrayRef", "{}">:$thread_dim_mapping)>, + "Optional":$mapping)>, // Builder that takes a bodyBuilder lambda. OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - "ArrayRef":$thread_dim_mapping, + "ArrayRef":$mapping, "function_ref":$bodyBuilder)> ]; let extraClassDeclaration = [{ @@ -535,14 +537,14 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ } /// Return the thread indices in the order specified by the - /// thread_dim_mapping attribute. Return failure is - /// thread_dim_mapping is not a valid permutation. - FailureOr> getPermutedThreadIndices(); + /// given mapping argument. Return failure is + /// mapping is not a valid permutation. + FailureOr> getPermutedThreadIndices(ArrayRef mapping); /// Return the number of threads in the order specified by the - /// thread_dim_mapping attribute. - /// Return failure is thread_dim_mapping is not a valid permutation. - FailureOr> getPermutedNumThreads(OpBuilder &b); + /// given mapping argument. + /// Return failure is mapping is not a valid permutation. + FailureOr> getPermutedNumThreads(OpBuilder &b, ArrayRef mapping); // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be diff --git a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt index 1563afb..6ed0047 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt @@ -3,10 +3,13 @@ add_mlir_dialect_library(MLIRGPUTransformOps ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU/TransformOps + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces DEPENDS MLIRGPUTransformOpsIncGen - + MLIRDeviceMappingInterfacesIncGen + MLIRGPUDeviceMapperEnumsGen + LINK_LIBS PUBLIC MLIRIR MLIRGPUTransforms diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 280dbf0..23420e8 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -166,8 +166,20 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl( // Step 0. Outline the compute workload region and set up the workload // operands. + SmallVector mapping; + if (!foreachThreadOp.getMapping().has_value()) + return transformOp.emitSilenceableError() << "mapping must be present"; + for (DeviceMappingAttrInterface map : *foreachThreadOp.getMapping()) { + if (auto blockMap = map.dyn_cast()) { + mapping.push_back((int64_t)blockMap.getBlock()); + } else { + return transformOp.emitSilenceableError() + << "mapping must be #gpu.block"; + } + } + FailureOr> potentialGridDim = - foreachThreadOp.getPermutedNumThreads(rewriter); + foreachThreadOp.getPermutedNumThreads(rewriter, mapping); if (failed(potentialGridDim) || llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) { @@ -193,7 +205,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl( // Step 2. RAUW thread indices to thread ops. SmallVector threadIndices = - *foreachThreadOp.getPermutedThreadIndices(); + *foreachThreadOp.getPermutedThreadIndices(mapping); assert(blockOps.size() == 3 && "3 block id ops are required"); for (auto [blockIdx, blockOp] : llvm::zip(threadIndices, blockOps)) { Value val = blockIdx; @@ -230,7 +242,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::findTopLevelForeachThreadOp( } /// This is a helper that is only used in -/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id. +/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects +/// block_id. static void generateGpuBlockIds(RewriterBase &rewriter, scf::ForeachThreadOp foreachOp, SmallVectorImpl &blockOps) { @@ -335,7 +348,18 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads( return failureHelper( "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id"); - auto potentialBlockDim = foreachThreadOp.getPermutedNumThreads(rewriter); + SmallVector mapping; + if (!foreachThreadOp.getMapping().has_value()) + return failureHelper("mapping must be present"); + for (DeviceMappingAttrInterface map : *foreachThreadOp.getMapping()) { + if (auto threadMap = map.dyn_cast()) { + mapping.push_back((int64_t)threadMap.getThread()); + } else { + return failureHelper("mapping must be #gpu.thread"); + } + } + FailureOr> potentialBlockDim = + foreachThreadOp.getPermutedNumThreads(rewriter, mapping); if (failed(potentialBlockDim) || llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) { return !getConstantIntValue(ofr).has_value(); @@ -365,8 +389,8 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads( if (blockDim > globalBlockDim) { return failureHelper( "The requested GPU threads are fewer than the number of loop trip " - "counts. Try to tile scf.foreach_thread before mapping or set small " - "blockDim."); + "counts. Try to tile scf.foreach_thread before mapping or set " + "small blockDim."); } if (blockDim == globalBlockDim) continue; @@ -400,7 +424,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads( // Step 4. RAUW thread indices to thread ops. SmallVector threadIndices = - *foreachThreadOp.getPermutedThreadIndices(); + *foreachThreadOp.getPermutedThreadIndices(mapping); for (auto [threadIdx, threadOp] : llvm::zip(threadIndices, threadOps)) { Value val = threadIdx; Value op = threadOp; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 6b8ca91..7b720a7 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1321,19 +1321,21 @@ void transform::TileOp::getEffects( // TileToForeachThreadOp //===----------------------------------------------------------------------===// -void transform::TileToForeachThreadOp::build( - OpBuilder &builder, OperationState &result, Value target, - ArrayRef staticTileSizes, transform::TileSizesSpec, - ArrayRef threadDimMapping) { +void transform::TileToForeachThreadOp::build(OpBuilder &builder, + OperationState &result, + Value target, + ArrayRef staticTileSizes, + transform::TileSizesSpec, + ArrayRef mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), - TileSizesSpec(), threadDimMapping); + TileSizesSpec(), mapping); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, transform::TileSizesSpec, - ArrayRef threadDimMapping) { + ArrayRef mapping) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, @@ -1344,28 +1346,29 @@ void transform::TileToForeachThreadOp::build( MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); - ArrayAttr threadDimMappingAttr; - if (!threadDimMapping.empty()) - threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); + ArrayAttr mappingAttr; + if (!mapping.empty()) + mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, /*numThreads=*/ValueRange{}, dynamicTileSizes, - /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, - threadDimMappingAttr); + /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr); } -void transform::TileToForeachThreadOp::build( - OpBuilder &builder, OperationState &result, Value target, - ArrayRef staticNumThreads, transform::NumThreadsSpec, - ArrayRef threadDimMapping) { +void transform::TileToForeachThreadOp::build(OpBuilder &builder, + OperationState &result, + Value target, + ArrayRef staticNumThreads, + transform::NumThreadsSpec, + ArrayRef mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), - NumThreadsSpec(), threadDimMapping); + NumThreadsSpec(), mapping); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedNumThreads, transform::NumThreadsSpec, - ArrayRef threadDimMapping) { + ArrayRef mapping) { SmallVector staticNumThreads; SmallVector dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, @@ -1376,19 +1379,19 @@ void transform::TileToForeachThreadOp::build( MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); - ArrayAttr threadDimMappingAttr; - if (!threadDimMapping.empty()) - threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); + ArrayAttr mappingAttr; + if (!mapping.empty()) + mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, - /*staticTileSizes=*/ArrayAttr(), threadDimMappingAttr); + /*staticTileSizes=*/ArrayAttr(), mappingAttr); } DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, ArrayRef mixedNumThreads, - ArrayRef mixedTileSizes, Optional threadDimMapping, + ArrayRef mixedTileSizes, Optional mapping, SmallVector &tileOps, SmallVector &tiledOps) { if (targets.empty()) return DiagnosedSilenceableFailure(success()); @@ -1457,19 +1460,13 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( return diag; } rewriter.setInsertionPoint(tilableOp); - auto maybeThreadDimMappingAttr = threadDimMapping; - auto dimMapping = llvm::to_vector( - maybeThreadDimMappingAttr - ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) - : ArrayRef{}); - FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, - numThreads, dimMapping); + numThreads, mapping); } else { tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( - rewriter, tilableOp, tileSizes, dimMapping); + rewriter, tilableOp, tileSizes, mapping); } if (failed(tilingResult)) @@ -1494,7 +1491,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( rewriter, state, cast(getOperation()), targets, - getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps, + getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps, tiledOps); if (!diag.succeeded()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 5937da3..a32e9f7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -215,7 +215,7 @@ static OpFoldResult buildMin(OpBuilder &b, Location loc, /// tiling is specified by the number of tiles/threads `numThreads` and the /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is /// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i], -/// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an +/// numThreads[i])`. If non-empty, the `mapping` is added as an /// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate /// that the dimension is not tiled, and can be thought of as tiling by the full /// size of data. @@ -226,7 +226,7 @@ static OpFoldResult buildMin(OpBuilder &b, Location loc, static FailureOr tileToForeachThreadOpImpl( RewriterBase &b, TilingInterface op, ArrayRef numThreads, Optional> nominalTileSizes, - ArrayRef threadDimMapping, bool omitTileOffsetBoundsCheck) { + Optional mapping, bool omitTileOffsetBoundsCheck) { Location loc = op->getLoc(); OpBuilder::InsertionGuard g(b); SmallVector loopRanges = op.getIterationDomain(b); @@ -256,7 +256,7 @@ static FailureOr tileToForeachThreadOpImpl( // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. scf::ForeachThreadOp foreachThreadOp = b.create( - loc, dest, ValueRange(materializedNonZeroNumThreads), threadDimMapping); + loc, dest, ValueRange(materializedNonZeroNumThreads), mapping); // Fill out the ForeachThreadOp body. b.setInsertionPointToStart(foreachThreadOp.getBody(0)); @@ -363,16 +363,16 @@ static FailureOr tileToForeachThreadOpImpl( FailureOr linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op, ArrayRef numThreads, - ArrayRef threadDimMapping) { + Optional mapping) { return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None, - threadDimMapping, + mapping, /*omitTileOffsetBoundsCheck=*/false); } FailureOr -linalg::tileToForeachThreadOpUsingTileSizes( - RewriterBase &b, TilingInterface op, ArrayRef tileSizes, - ArrayRef threadDimMapping) { +linalg::tileToForeachThreadOpUsingTileSizes(RewriterBase &b, TilingInterface op, + ArrayRef tileSizes, + Optional mapping) { SmallVector loopRanges = op.getIterationDomain(b); unsigned nLoops = loopRanges.size(); SmallVector numThreads; @@ -388,8 +388,7 @@ linalg::tileToForeachThreadOpUsingTileSizes( numThreads.push_back(numTiles); } return tileToForeachThreadOpImpl(b, op, numThreads, - /*nominalTileSizes=*/tileSizes, - threadDimMapping, + /*nominalTileSizes=*/tileSizes, mapping, /*omitTileOffsetBoundsCheck=*/true); } diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt index 6255d9e..a043491 100644 --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFDialect SCF.cpp + DeviceMappingInterface.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/SCF diff --git a/mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp b/mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp new file mode 100644 index 0000000..a90c638 --- /dev/null +++ b/mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp @@ -0,0 +1,17 @@ +//===- DeviceMappingInterface.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Table-generated class definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.cpp.inc" diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 5f1a20c..e32f671 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -12,13 +12,16 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::scf; @@ -1111,6 +1114,12 @@ LogicalResult ForeachThreadOp::verify() { if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType()) return emitOpError("type mismatch between ") << i << "-th output and corresponding block argument"; + if (getMapping().has_value()) + for (auto map : getMapping().value()) { + if (!isa(map)) + return emitOpError() + << getMappingAttrName() << " is not device mapping attribute"; + } return success(); } @@ -1200,11 +1209,14 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser, void ForeachThreadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, ValueRange numThreads, - ArrayRef threadDimMapping) { + Optional mapping) { result.addOperands(numThreads); result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), - builder.getI64ArrayAttr(threadDimMapping)); + if (mapping.has_value()) { + result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), + mapping.value()); + } + result.addAttribute( "operand_segment_sizes", builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), @@ -1231,12 +1243,12 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder, // Builder that takes a bodyBuilder lambda. void ForeachThreadOp::build( mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, - ValueRange numThreads, ArrayRef threadDimMapping, + ValueRange numThreads, ArrayRef mapping, function_ref bodyBuilder) { result.addOperands(numThreads); result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), - builder.getI64ArrayAttr(threadDimMapping)); + result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), + builder.getArrayAttr(mapping)); result.addAttribute( "operand_segment_sizes", builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), @@ -1290,51 +1302,51 @@ static FailureOr> permute(const SmallVector &vals, SmallVector result(vals.size()); SmallVector seen(vals.size()); for (auto [idx, val] : llvm::zip(perm, vals)) { - // Already seen, invalid thread_dim_mapping. + // Already seen, invalid mapping. if (seen[idx]) return failure(); result[idx] = val; seen[idx] = true; } - // Some not seen, invalid thread_dim_mapping. + // Some not seen, invalid mapping. if (!llvm::all_of(seen, [](bool b) { return b; })) return failure(); return result; } -/// Helper to get apply the `thread_dim_mapping` permutation of a +/// Helper to get apply the `mapping` permutation of a /// `foreachThreadOp` to `values`. template static FailureOr> getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp, - const SmallVector &values) { + const SmallVector &values, + ArrayRef mapping) { // Apply mapping permutation if specified. - auto mapping = foreachThreadOp.getThreadDimMapping(); - if (mapping && !mapping.empty()) { - auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping)); - if (failed(maybePermuted)) - return foreachThreadOp->emitError("invalid permutation"); - return *maybePermuted; - } + FailureOr> maybePermuted = permute(values, mapping); + if (failed(maybePermuted)) + return foreachThreadOp->emitError("invalid permutation"); + return *maybePermuted; return values; } -/// Return the thread indices in the order specified by the thread_dim_mapping -/// attribute. Return failure is thread_dim_mapping is not a valid permutation. -FailureOr> ForeachThreadOp::getPermutedThreadIndices() { +/// Return the thread indices in the order specified by the mapping +/// attribute. Return failure is mapping is not a valid permutation. +FailureOr> +ForeachThreadOp::getPermutedThreadIndices(ArrayRef mapping) { SmallVector threadCountValues = this->getThreadIndices(); threadCountValues.resize(3, Value()); - return getValuesPermutedByThreadMapping(*this, threadCountValues); + return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping); } /// Return the number of threads in the order specified by the -/// thread_dim_mapping attribute. -/// Return failure is thread_dim_mapping is not a valid permutation. +/// mapping attribute. +/// Return failure is mapping is not a valid permutation. FailureOr> -ForeachThreadOp::getPermutedNumThreads(OpBuilder &b) { +ForeachThreadOp::getPermutedNumThreads(OpBuilder &b, + ArrayRef mapping) { SmallVector threadCountValues = this->getNumThreads(); threadCountValues.resize(3, b.getIndexAttr(1)); - return getValuesPermutedByThreadMapping(*this, threadCountValues); + return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping); } ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) { diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 2771857..dcf2fe1 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1141,10 +1141,11 @@ struct ForeachThreadOpInterface // Create new ForeachThreadOp without any results and drop the automatically // introduced terminator. rewriter.setInsertionPoint(foreachThreadOp); - auto newForeachThreadOp = rewriter.create( + ForeachThreadOp newForeachThreadOp; + newForeachThreadOp = rewriter.create( foreachThreadOp.getLoc(), /*outputs=*/ValueRange(), - foreachThreadOp.getNumThreads(), - extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping())); + foreachThreadOp.getNumThreads(), foreachThreadOp.getMapping()); + newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir index f61ed8f..de50c59 100644 --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -26,7 +26,7 @@ func.func @map_nested_foreach_to_threads_excessive_threads(%x: memref<2 x 32 x f %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -38,7 +38,7 @@ func.func @map_nested_foreach_to_threads_excessive_threads(%x: memref<2 x 32 x f %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -67,7 +67,7 @@ func.func @map_nested_foreach_to_threads_fewer_threads(%x: memref<2 x 32 x f32>, %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } @@ -79,7 +79,7 @@ func.func @map_nested_foreach_to_threads_fewer_threads(%x: memref<2 x 32 x f32>, %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } @@ -106,7 +106,7 @@ func.func @map_nested_foreach_to_threads_dynamic_trip_count(%x: memref<2 x 32 x %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } return %y : memref<2 x 32 x f32> @@ -131,7 +131,7 @@ func.func @map_nested_foreach_to_threads_4d_loop(%x: memref<2x32x32x32xf32>, %y: scf.foreach_thread (%i, %j, %k, %l) in (%c2, %c32,%c32,%c32) { %4 = memref.load %x[%i, %j, %k, %l] : memref<2x32x32x32xf32> memref.store %4, %y[%i, %j, %k, %l] : memref<2x32x32x32xf32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } return %y : memref<2x32x32x32xf32> @@ -197,14 +197,14 @@ func.func @map_foreach_to_blocks_not_unique(%x: memref<2 x 32 x f32>, %y: memref %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } @@ -232,14 +232,14 @@ func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } return %y : memref<2 x 32 x f32> } @@ -261,7 +261,7 @@ func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.block, #gpu.block, #gpu.block] } return %y : memref<2 x 32 x f32> } @@ -273,4 +273,3 @@ transform.sequence failures(propagate) { } // ----- - diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir index d4ff7ff..eb7208b 100644 --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -24,7 +24,7 @@ func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.block, #gpu.block, #gpu.block]} gpu.terminator } return %y : !type @@ -33,7 +33,7 @@ func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_foreach_to_blocks %funcop { blockDim = [12, 9, 1]} + transform.gpu.map_foreach_to_blocks %funcop { gridDim = [12, 9]} } // ----- @@ -73,12 +73,12 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} scf.foreach_thread (%i) in (%c12) { %7 = memref.load %t[%i] : !type1d %8 = arith.addf %alpha, %7 : f32 memref.store %8, %t[%i] : !type1d - } {thread_dim_mapping = [0, 1, 2]} + } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } return %y : !type @@ -87,7 +87,7 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1] } + transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] } } // ----- @@ -118,8 +118,8 @@ func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d { %5 = memref.load %y[%i, %j, %k, %l] : !type4d %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j, %k, %l] : !type4d - } {thread_dim_mapping = [1, 0, 2]} - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.block, #gpu.block, #gpu.block] } return %y : !type4d } @@ -151,7 +151,7 @@ func.func @saxpy2d_no_barrier(%x: !type, %y: !type, %t: !type1d, %alpha : f32, % %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } return %y : !type diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir index ce0afdf..f015cdf 100644 --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -26,7 +26,7 @@ module { // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C_BLK]]{{.*}} : // CHECK-SAME: tensor into tensor // CHECK-NEXT: } - // CHECK-NEXT: } {thread_dim_mapping = [1, 0]} + // CHECK-NEXT: } {mapping = [#gpu.thread, #gpu.thread]} %0 = linalg.matmul ins(%A, %B : tensor, tensor) outs(%C : tensor) -> (tensor) return %0 : tensor @@ -35,7 +35,7 @@ module { transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapping = [ #gpu.thread, #gpu.thread ] ) } } @@ -177,7 +177,7 @@ module { transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] ( mapping = [#gpu.thread]) } } // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)> diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir index 69c4ef4..65b1f09 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir @@ -628,7 +628,7 @@ func.func @same_enclosing_repetitive_region(%2: tensor<320xf32>, // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} tensor.parallel_insert_slice %8 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> } - } {thread_dim_mapping = []} + } return %4 : tensor<320xf32> } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir index a533783..17dca3f 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -128,7 +128,7 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>, tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } - // CHECK: } {thread_dim_mapping = [5]} - } {thread_dim_mapping = [5]} + // CHECK: } {mapping = [#gpu.thread]} + } {mapping = [#gpu.thread]} return } diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir index e563838..18413e8 100644 --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -338,12 +338,12 @@ func.func @elide_terminator() -> () { %num_threads = arith.constant 100 : index // CHECK: scf.foreach_thread - // CHECK-NEXT: } {thread_dim_mapping = [42]} + // CHECK-NEXT: } {mapping = [#gpu.thread]} // CHECK-NEXT: return scf.foreach_thread (%thread_idx) in (%num_threads) { scf.foreach_thread.perform_concurrently { } - } {thread_dim_mapping = [42]} + } {mapping = [#gpu.thread]} return } diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 461da29..82ea071 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -217,7 +217,7 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach Location loc = op.getLoc(); auto foreachOp = rewriter.create( loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(), - /*threadDimMapping=*/ArrayRef{}, + /*mapping=*/ArrayRef{}, [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { unsigned numThreadIdRegionArgs = helper.getIterationSpaceSizes().size(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 5260beb..3d77545 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1833,7 +1833,10 @@ gentbl_cc_library( td_library( name = "SCFTdFiles", - srcs = ["include/mlir/Dialect/SCF/IR/SCFOps.td"], + srcs = [ + "include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td", + "include/mlir/Dialect/SCF/IR/SCFOps.td", + ], includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", @@ -1871,6 +1874,32 @@ gentbl_cc_library( ) gentbl_cc_library( + name = "SCFDeviceMappingInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-attr-interface-decls"], + "include/mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.h.inc", + ), + ( + ["-gen-attr-interface-defs"], + "include/mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.cpp.inc", + ), + ( + ["-gen-attrdef-decls"], + "include/mlir/Dialect/SCF/IR/DeviceMappingAttributes.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "include/mlir/Dialect/SCF/IR/DeviceMappingAttributes.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td", + deps = [":SCFTdFiles"], +) + +gentbl_cc_library( name = "SCFPassIncGen", strip_include_prefix = "include", tbl_outs = [ @@ -2794,6 +2823,7 @@ cc_library( ":MemRefDialect", ":ParallelCombiningOpInterface", ":Pass", + ":SCFDeviceMappingInterfacesIncGen", ":SCFIncGen", ":SCFPassIncGen", ":Support", @@ -3623,6 +3653,7 @@ td_library( "include/mlir/Dialect/GPU/IR/GPUBase.td", "include/mlir/Dialect/GPU/IR/GPUOps.td", "include/mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td", + "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td", ], includes = ["include"], deps = [ @@ -3632,11 +3663,33 @@ td_library( ":InferIntRangeInterfaceTdFiles", ":LLVMOpsTdFiles", ":OpBaseTdFiles", + ":SCFTdFiles", ":SideEffectInterfacesTdFiles", ], ) gentbl_cc_library( + name = "GPUDeviceMapperEnumsGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMapperEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMapperEnums.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td", + deps = [ + ":GPUOpsTdFiles", + ":OpBaseTdFiles", + ], +) + +gentbl_cc_library( name = "GPUBaseIncGen", strip_include_prefix = "include", tbl_outs = [ @@ -3725,6 +3778,7 @@ cc_library( ":InferTypeOpInterface", ":LLVMDialect", ":MemRefDialect", + ":SCFDialect", ":SideEffectInterfaces", "//llvm:Support", ], @@ -7694,6 +7748,7 @@ td_library( includes = ["include"], deps = [ ":PDLDialectTdFiles", + ":SCFTdFiles", ":TransformDialectTdFiles", ], ) @@ -7781,6 +7836,7 @@ gentbl_cc_library( td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td", deps = [ ":LinalgTransformOpsTdFiles", + ":SCFDeviceMappingInterfacesIncGen", ], )