[mlir][transform] Add PackedOrDynamicIndexList helper
authorMatthias Springer <springerm@google.com>
Mon, 19 Dec 2022 06:57:46 +0000 (07:57 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 19 Dec 2022 07:08:04 +0000 (08:08 +0100)
This customer parser/printer is similar to DynamicIndexList, but has special syntax for the case where one handle represents the entire list.

Example:
```
// Regular index list
[10, 20, %val]

// Packed handle (no square parentheses)
%val
```

Differential Revision: https://reviews.llvm.org/D138825

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Transform/Utils/Utils.h [new file with mode: 0644]
mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/CMakeLists.txt
mlir/lib/Dialect/Transform/Utils/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Dialect/Transform/Utils/Utils.cpp [new file with mode: 0644]
mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index cd330ec..b6a9bdf 100644 (file)
@@ -1054,6 +1054,8 @@ def TileToForeachThreadOp :
   let arguments = (ins PDL_Operation:$target,
                    Variadic<PDL_Operation>:$num_threads,
                    Variadic<PDL_Operation>:$tile_sizes,
+                   Optional<PDL_Operation>:$packed_num_threads,
+                   Optional<PDL_Operation>:$packed_tile_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
@@ -1085,10 +1087,12 @@ def TileToForeachThreadOp :
 
   let assemblyFormat = [{
     $target oilist(
-        `num_threads` custom<DynamicIndexList>($num_threads,
-                                               $static_num_threads) |
-         `tile_sizes` custom<DynamicIndexList>($tile_sizes,
-                                               $static_tile_sizes))
+        `num_threads` custom<PackedOrDynamicIndexList>($packed_num_threads,
+                                                       $num_threads,
+                                                       $static_num_threads) |
+         `tile_sizes` custom<PackedOrDynamicIndexList>($packed_tile_sizes,
+                                                       $tile_sizes,
+                                                       $static_tile_sizes))
     (`(` `mapping` `=` $mapping^ `)`)? attr-dict
   }];
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Transform/Utils/Utils.h b/mlir/include/mlir/Dialect/Transform/Utils/Utils.h
new file mode 100644 (file)
index 0000000..ced6f12
--- /dev/null
@@ -0,0 +1,52 @@
+//===- Utils.h - Transform dialect utilities --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H
+#define MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class OpAsmPrinter;
+
+namespace transform {
+class TransformState;
+
+/// Printer hook for custom directive in assemblyFormat.
+///
+///   custom<PackedOrDynamicIndexList>($packed, $values, $integers)
+///
+/// where `values` are variadic Index values, `integers` is an `I64ArrayAttr`
+/// and `packed` is a single transform dialect handle who's mapped payload ops
+/// have a single Index result and represent the index list. Either `packed`
+/// or the other two parameters may be specified.
+///
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list or with a single handle. E.g., `[%arg0, 7, 42, %arg42]` or just `%h`.
+void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                                   Value packed, OperandRange values,
+                                   ArrayRef<int64_t> integers);
+
+/// Pasrer hook for custom directive in assemblyFormat.
+///
+///   custom<PackedOrDynamicIndexList>($packed, $values, $integers)
+///
+/// See `printPackedOrDynamicIndexList` for details.
+ParseResult parsePackedOrDynamicIndexList(
+    OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORMS_UTILS_UTILS_H
index fd15ac8..7e72ebf 100644 (file)
@@ -18,5 +18,6 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
   MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRTransformDialect
+  MLIRTransformDialectUtils
   MLIRVectorDialect
   )
index dd0f2e5..5f1a4ec 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformUtils.h"
+#include "mlir/Dialect/Transform/Utils/Utils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
@@ -1523,6 +1524,8 @@ void transform::TileToForeachThreadOp::build(
         /*target=*/target,
         /*num_threads=*/ValueRange{},
         /*tile_sizes=*/dynamicTileSizes,
+        /*packed_num_threads=*/Value(),
+        /*packed_tile_sizes=*/Value(),
         /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
         /*static_tile_sizes=*/staticTileSizesAttr,
         /*mapping=*/mapping);
@@ -1558,38 +1561,70 @@ void transform::TileToForeachThreadOp::build(
         /*target=*/target,
         /*num_threads=*/dynamicNumThreads,
         /*tile_sizes=*/ValueRange{},
+        /*packed_num_threads=*/Value(),
+        /*packed_tile_sizes=*/Value(),
         /*static_num_threads=*/staticNumThreadsAttr,
         /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
         /*mapping=*/mapping);
 }
 
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
+/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
+/// to exactly one op with one index result, return that value.
 static DiagnosedSilenceableFailure unpackPDLOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
   for (OpFoldResult ofr : ofrs) {
-    // Don't try to unpack non-PDL operation.
-    if (ofr.is<Attribute>() ||
-        !ofr.get<Value>().getType().isa<pdl::OperationType>()) {
+    if (ofr.is<Attribute>()) {
+      if (!ofr.get<Attribute>().isa<IntegerAttr>())
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
       result.push_back(ofr);
       continue;
     }
     ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
-    for (Operation *op : payloadOps) {
-      if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
-        DiagnosedSilenceableFailure diag =
-            transformOp.emitSilenceableError()
-            << "payload op must have exactly 1 index result";
-        diag.attachNote(op->getLoc())
-            << "has " << op->getNumResults() << " results";
-        return diag;
-      }
-      result.push_back(op->getResult(0));
+    if (payloadOps.size() != 1) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "handle must be mapped to exactly one payload op";
+      diag.attachNote(ofr.get<Value>().getLoc())
+          << "mapped to " << payloadOps.size() << " payload ops";
+      return diag;
     }
+
+    Operation *op = payloadOps[0];
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+    result.push_back(op->getResult(0));
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+// Given a list of OpFoldResults that are either index attrs or op
+// handles, return a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op. (There
+// must be exactly one mapped payload op and it must have exactly one
+// index result.)
+static DiagnosedSilenceableFailure
+unpackPDLOperations(transform::TransformState &state,
+                    TransformOpInterface transformOp,
+                    SmallVector<OpFoldResult> &result, Value packedHandle) {
+  ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
+  for (Operation *op : payloadOps) {
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+    result.push_back(op->getResult(0));
   }
 
   return DiagnosedSilenceableFailure::success();
@@ -1604,21 +1639,6 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
   if (targets.empty())
     return DiagnosedSilenceableFailure::success();
 
-  // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
-  // Convert to OpFoldResults[index attributes or payload op].
-  SmallVector<OpFoldResult> numThreads;
-  DiagnosedSilenceableFailure status =
-      unpackPDLOperations(state, transformOp, numThreads, mixedNumThreads);
-  if (!status.succeeded())
-    return status;
-
-  // getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
-  // Convert to OpFoldResults[index attributes or payload op].
-  SmallVector<OpFoldResult> tileSizes;
-  status = unpackPDLOperations(state, transformOp, tileSizes, mixedTileSizes);
-  if (!status.succeeded())
-    return status;
-
   // Transform all targets one by one.
   for (Operation *target : targets) {
     auto tilableOp = dyn_cast<TilingInterface>(target);
@@ -1633,10 +1653,10 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
     if (!mixedNumThreads.empty()) {
       tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
-                                                   numThreads, mapping);
+                                                   mixedNumThreads, mapping);
     } else {
       tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
-          rewriter, tilableOp, tileSizes, mapping);
+          rewriter, tilableOp, mixedTileSizes, mapping);
     }
 
     if (failed(tilingResult))
@@ -1653,16 +1673,35 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
     transform::TransformResults &transformResults,
     transform::TransformState &state) {
   IRRewriter rewriter(getContext());
+  auto transformOp = cast<TransformOpInterface>(getOperation());
   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
 
   // Result payload ops.
   SmallVector<Operation *> tileOps;
   SmallVector<Operation *> tiledOps;
 
+  // Unpack handles.
+  SmallVector<OpFoldResult> mixedNumThreads;
+  DiagnosedSilenceableFailure status =
+      getPackedNumThreads()
+          ? unpackPDLOperations(state, transformOp, mixedNumThreads,
+                                getPackedNumThreads())
+          : unpackPDLOperations(state, transformOp, mixedNumThreads,
+                                getMixedNumThreads());
+  if (!status.succeeded())
+    return status;
+  SmallVector<OpFoldResult> mixedTileSizes;
+  status = getPackedTileSizes()
+               ? unpackPDLOperations(state, transformOp, mixedTileSizes,
+                                     getPackedTileSizes())
+               : unpackPDLOperations(state, transformOp, mixedTileSizes,
+                                     getMixedTileSizes());
+  if (!status.succeeded())
+    return status;
+
   DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl(
-      rewriter, state, cast<TransformOpInterface>(getOperation()), targets,
-      getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps,
-      tiledOps);
+      rewriter, state, transformOp, targets, mixedNumThreads, mixedTileSizes,
+      getMapping(), tileOps, tiledOps);
 
   if (!diag.succeeded()) {
     transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
@@ -1695,8 +1734,19 @@ SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
 }
 
 LogicalResult TileToForeachThreadOp::verify() {
-  if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
-    return emitOpError("either num_threads or tile_sizes must be specified");
+  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
+                       static_cast<int>(getPackedNumThreads() != Value());
+  if (numThreadsSpec > 1)
+    return emitOpError(
+        "num_threads and packed_num_threads are mutually exclusive");
+  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
+                      static_cast<int>(getPackedTileSizes() != Value());
+  if (tileSizesSpec > 1)
+    return emitOpError(
+        "tile_sizes and packed_tile_sizes are mutually exclusive");
+  if (numThreadsSpec == 0 && tileSizesSpec == 0)
+    return emitOpError(
+        "either (packed_)num_threads or (packed_)tile_sizes must be specified");
   return success();
 }
 
index 9f57627..31167e6 100644 (file)
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Transform/Utils/CMakeLists.txt b/mlir/lib/Dialect/Transform/Utils/CMakeLists.txt
new file mode 100644 (file)
index 0000000..eadcbab
--- /dev/null
@@ -0,0 +1,10 @@
+add_mlir_dialect_library(MLIRTransformDialectUtils
+  Utils.cpp
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Transform
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRIR
+  MLIRSupport
+  MLIRTransformDialect
+)
diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
new file mode 100644 (file)
index 0000000..6740790
--- /dev/null
@@ -0,0 +1,44 @@
+//===- Utils.cpp - Transform dialect utilities ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Utils/Utils.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+void transform::printPackedOrDynamicIndexList(OpAsmPrinter &printer,
+                                              Operation *op, Value packed,
+                                              OperandRange values,
+                                              ArrayRef<int64_t> integers) {
+  if (packed) {
+    assert(values.empty() && integers.empty() && "expected no values/integers");
+    printer << packed;
+    return;
+  }
+  printDynamicIndexList(printer, op, values, integers);
+}
+
+ParseResult transform::parsePackedOrDynamicIndexList(
+    OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers) {
+  OpAsmParser::UnresolvedOperand packedOperand;
+  if (parser.parseOptionalOperand(packedOperand).has_value()) {
+    packed.emplace(std::move(packedOperand));
+    integers = parser.getBuilder().getDenseI64ArrayAttr({});
+    return success();
+  }
+  return parseDynamicIndexList(parser, values, integers);
+}
index 4dda150..53a7828 100644 (file)
@@ -78,7 +78,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
   %sz = transform.structured.match ops{["test.dummy"]} in %arg1
-  %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz]
+  %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes %sz
 }
 
 // -----
index 3a35393..e3c8c93 100644 (file)
@@ -8219,6 +8219,7 @@ cc_library(
         ":SideEffectInterfaces",
         ":TilingInterface",
         ":TransformDialect",
+        ":TransformDialectUtils",
         ":TransformUtils",
         "//llvm:Support",
     ],
@@ -9096,6 +9097,21 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TransformDialectUtils",
+    srcs = ["lib/Dialect/Transform/Utils/Utils.cpp"],
+    hdrs = ["include/mlir/Dialect/Transform/Utils/Utils.h"],
+    includes = ["include"],
+    deps = [
+        ":DialectUtils",
+        ":IR",
+        ":Support",
+        ":TransformDialect",
+        ":ViewLikeInterface",
+        "//llvm:Support",
+    ],
+)
+
 td_library(
     name = "ComplexOpsTdFiles",
     srcs = [