[mlir][Linalg] Better builders for transform ops
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 14 Dec 2022 11:25:28 +0000 (03:25 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 14 Dec 2022 14:22:52 +0000 (06:22 -0800)
Also adopt DenseI64ArrayAttr in those transform ops.

Differential Revision: https://reviews.llvm.org/D140009

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

index 9c80f56..1cac6b8 100644 (file)
@@ -721,14 +721,24 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
     ```
   }];
 
+  // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
   let arguments = (ins PDL_Operation:$target,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
+                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
   let results = (outs PDL_Operation:$for_op,
                       PDL_Operation:$fill_op,
                       PDL_Operation:$split_linalg_op,
                       PDL_Operation:$combining_linalg_op);
 
-  let assemblyFormat = "$target attr-dict";
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$staticTileSizes)>
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `by` `tile_sizes` `=` $tile_sizes
+    attr-dict
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -808,16 +818,31 @@ def TileReductionUsingForeachThreadOp :
     ```
   }];
 
+  // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
   let arguments = (ins PDL_Operation:$target,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
+                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
   let results = (outs PDL_Operation:$foreach_thread_op,
                       PDL_Operation:$fill_op,
                       PDL_Operation:$split_linalg_op,
                       PDL_Operation:$combining_linalg_op);
 
-  let assemblyFormat = "$target attr-dict";
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$staticNumThreads,
+                   "ArrayRef<int64_t>":$staticTileSizes,
+                   CArg<"ArrayAttr", "{}">:$mapping)>
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `by`
+    (`num_threads` `=` $num_threads^)?
+    (`,` `tile_sizes` `=` $tile_sizes^)?
+    (`,` `mapping` `=` $mapping^)?
+    attr-dict
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -825,6 +850,7 @@ def TileReductionUsingForeachThreadOp :
         ::llvm::SmallVectorImpl<::mlir::Operation *> &results,
         ::mlir::transform::TransformState &state);
   }];
+
 }
 
 def TileOp : Op<Transform_Dialect, "structured.tile",
index 853321f..c8995e6 100644 (file)
@@ -1200,20 +1200,31 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
 // TileReductionUsingScfOp
 //===----------------------------------------------------------------------===//
 
+void transform::TileReductionUsingScfOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    ArrayRef<int64_t> staticTileSizes) {
+  // Call the default builder.
+  // This is future-proof re mixed static-dynamic and setting up the proper
+  // operands segment sizes attributes for multiple variadic operands.
+  // In the absence of this, horrible bugs ensue.
+  // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
+  MLIRContext *ctx = builder.getContext();
+  auto opTy = pdl::OperationType::get(ctx);
+  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+  build(builder, result,
+        /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
+        /*target=*/target,
+        /*tile_sizes=*/staticTileSizesAttr);
+}
+
 DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
     linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
     transform::TransformState &state) {
   TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
-  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
-  SmallVector<OpFoldResult> sizes;
-  for (int64_t size : tileSizes) {
-    sizes.push_back(rewriter.getIndexAttr(size));
-  }
-
   FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
       rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
-      sizes);
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
   if (failed(result))
     return emitDefaultSilenceableFailure(target);
@@ -1228,14 +1239,37 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
 // TileReductionUsingForeachThreadOp
 //===----------------------------------------------------------------------===//
 
+void transform::TileReductionUsingForeachThreadOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
+    ArrayAttr mapping) {
+  // Call the default builder.
+  // This is future-proof re mixed static-dynamic and setting up the proper
+  // operands segment sizes attributes for multiple variadic operands.
+  // In the absence of this, horrible bugs ensue.
+  // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
+  MLIRContext *ctx = builder.getContext();
+  auto opTy = pdl::OperationType::get(ctx);
+  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
+  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+  build(builder, result,
+        /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
+        /*target=*/target,
+        /*num_threads=*/staticNumThreadsAttr,
+        /*tile_sizes=*/staticTileSizesAttr,
+        /*mapping=*/mapping);
+}
+
 DiagnosedSilenceableFailure
 transform::TileReductionUsingForeachThreadOp::applyToOne(
     linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
     transform::TransformState &state) {
   TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
-  SmallVector<OpFoldResult> numThreads = getAsOpFoldResult(getNumThreads());
-  SmallVector<OpFoldResult> tileSizes = getAsOpFoldResult(getTileSizes());
+  SmallVector<OpFoldResult> numThreads =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
+  SmallVector<OpFoldResult> tileSizes =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
   FailureOr<linalg::ForeachThreadReductionTilingResult> result =
       linalg::tileReductionUsingForeachThread(
           rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
index 13aec82..3709306 100644 (file)
@@ -17,7 +17,8 @@ func.func @reduction_tile(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 5] }
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0
+    by tile_sizes = [0, 5]
 }
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -71,7 +72,8 @@ func.func @reduction_tile_transpose(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>)
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [5, 0] }
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 
+    by tile_sizes = [5, 0]
 }
 
 //     CHECK: func @reduction_tile_transpose
@@ -107,7 +109,8 @@ func.func @reduction_tile_parallel(
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5] }
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0
+    by num_threads = [0, 5], tile_sizes = []
 }
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
@@ -159,7 +162,8 @@ func.func @matmul_tile_parallel(
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 0, 5] }
+  %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0
+    by num_threads = [0, 0, 5], tile_sizes = []
 }
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
@@ -219,7 +223,7 @@ transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 
-    { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
+    by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>]
 }
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
@@ -285,7 +289,7 @@ transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 
-    { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
+    by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>]
   
   //      CHECK:     expecting fill
   // CHECK-NEXT:     linalg.fill