From d93be483eaf5e22f4192325f9357821cbd2e934e Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Sat, 12 Nov 2022 13:02:02 +0100 Subject: [PATCH] [mlir][transform] Make `tile_to_foreach_thread_op` builder to use ArrayAttr D137413 clarified `scf_foreach_thread` thread mapping nicely. `tile_to_foreach_thread_op` is one of the op that generates `scf_foreach_thread`, however, its builders are still having integer array. This is bug fix of potential problem. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D137891 --- .../Dialect/Linalg/TransformOps/LinalgTransformOps.td | 8 ++++---- .../Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 18 ++++++------------ 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index b8638f1..b92ed19 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -842,22 +842,22 @@ def TileToForeachThreadOp : "ArrayRef":$staticTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$staticNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, ]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 7b720a7..cdd4e15 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1326,7 +1326,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder, Value target, ArrayRef staticTileSizes, transform::TileSizesSpec, - ArrayRef mapping) { + ArrayAttr mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), TileSizesSpec(), mapping); @@ -1335,7 +1335,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder, void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, transform::TileSizesSpec, - ArrayRef mapping) { + ArrayAttr mapping) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, @@ -1346,12 +1346,9 @@ void transform::TileToForeachThreadOp::build( MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); - ArrayAttr mappingAttr; - if (!mapping.empty()) - mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, /*numThreads=*/ValueRange{}, dynamicTileSizes, - /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr); + /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping); } void transform::TileToForeachThreadOp::build(OpBuilder &builder, @@ -1359,7 +1356,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder, Value target, ArrayRef staticNumThreads, transform::NumThreadsSpec, - ArrayRef mapping) { + ArrayAttr mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), NumThreadsSpec(), mapping); @@ -1368,7 +1365,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder, void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedNumThreads, transform::NumThreadsSpec, - ArrayRef mapping) { + ArrayAttr mapping) { SmallVector staticNumThreads; SmallVector dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, @@ -1379,12 +1376,9 @@ void transform::TileToForeachThreadOp::build( MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); - ArrayAttr mappingAttr; - if (!mapping.empty()) - mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, - /*staticTileSizes=*/ArrayAttr(), mappingAttr); + /*staticTileSizes=*/ArrayAttr(), mapping); } DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( -- 2.7.4