[mlir][transform] Make `tile_to_foreach_thread_op` builder to use ArrayAttr
authorGuray Ozen <guray.ozen@gmail.com>
Sat, 12 Nov 2022 12:02:02 +0000 (13:02 +0100)
committerGuray Ozen <guray.ozen@gmail.com>
Sat, 12 Nov 2022 18:27:25 +0000 (19:27 +0100)
D137413 clarified `scf_foreach_thread` thread mapping nicely. `tile_to_foreach_thread_op` is one of the op that generates `scf_foreach_thread`, however, its builders are still having integer array.

This is bug fix of potential problem.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D137891

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

index b8638f1..b92ed19 100644 (file)
@@ -842,22 +842,22 @@ def TileToForeachThreadOp :
                    "ArrayRef<int64_t>":$staticTileSizes,
                    CArg<"::mlir::transform::TileSizesSpec", 
                         "::mlir::transform::TileSizesSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedTileSizes,
                    CArg<"::mlir::transform::TileSizesSpec", 
                         "::mlir::transform::TileSizesSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<int64_t>":$staticNumThreads,
                    CArg<"::mlir::transform::NumThreadsSpec", 
                         "::mlir::transform::NumThreadsSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedNumThreads,
                    CArg<"::mlir::transform::NumThreadsSpec", 
                         "::mlir::transform::NumThreadsSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>,
   ];
 
   let assemblyFormat = [{
index 7b720a7..cdd4e15 100644 (file)
@@ -1326,7 +1326,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
                                              Value target,
                                              ArrayRef<int64_t> staticTileSizes,
                                              transform::TileSizesSpec,
-                                             ArrayRef<int64_t> mapping) {
+                                             ArrayAttr mapping) {
   return build(builder, result, target,
                getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
                TileSizesSpec(), mapping);
@@ -1335,7 +1335,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
 void transform::TileToForeachThreadOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec,
-    ArrayRef<int64_t> mapping) {
+    ArrayAttr mapping) {
   SmallVector<int64_t> staticTileSizes;
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes,
@@ -1346,12 +1346,9 @@ void transform::TileToForeachThreadOp::build(
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
-  ArrayAttr mappingAttr;
-  if (!mapping.empty())
-    mappingAttr = builder.getI64ArrayAttr(mapping);
   build(builder, result, TypeRange{operationType, operationType}, target,
         /*numThreads=*/ValueRange{}, dynamicTileSizes,
-        /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr);
+        /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping);
 }
 
 void transform::TileToForeachThreadOp::build(OpBuilder &builder,
@@ -1359,7 +1356,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
                                              Value target,
                                              ArrayRef<int64_t> staticNumThreads,
                                              transform::NumThreadsSpec,
-                                             ArrayRef<int64_t> mapping) {
+                                             ArrayAttr mapping) {
   return build(builder, result, target,
                getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
                NumThreadsSpec(), mapping);
@@ -1368,7 +1365,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
 void transform::TileToForeachThreadOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec,
-    ArrayRef<int64_t> mapping) {
+    ArrayAttr mapping) {
   SmallVector<int64_t> staticNumThreads;
   SmallVector<Value> dynamicNumThreads;
   dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
@@ -1379,12 +1376,9 @@ void transform::TileToForeachThreadOp::build(
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
-  ArrayAttr mappingAttr;
-  if (!mapping.empty())
-    mappingAttr = builder.getI64ArrayAttr(mapping);
   build(builder, result, TypeRange{operationType, operationType}, target,
         dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr,
-        /*staticTileSizes=*/ArrayAttr(), mappingAttr);
+        /*staticTileSizes=*/ArrayAttr(), mapping);
 }
 
 DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(