[mlir] add types to the transform dialect
authorAlex Zinenko <zinenko@google.com>
Tue, 4 Oct 2022 11:49:21 +0000 (11:49 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 11 Oct 2022 09:55:07 +0000 (09:55 +0000)
Introduce a type system for the transform dialect. A transform IR type
captures the expectations of the transform IR on the payload IR
operations that are being transformed, such as being of a certain kind
or implementing an interface that enables the transformation. This
provides stricter checking and better readability of the transform IR
than using the catch-all "handle" type.

This change implements the basic support for a type system amendable to
dialect extensions and adds a drop-in replacement for the unrestricted
"handle" type. The actual switch of transform dialect ops to that type
will happen in a separate commit.

See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135164

24 files changed:
mlir/docs/Dialects/Transform.md
mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformTypes.h [new file with mode: 0644]
mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td [new file with mode: 0644]
mlir/lib/Dialect/Transform/IR/CMakeLists.txt
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformTypes.cpp [new file with mode: 0644]
mlir/test/Dialect/Transform/ops.mlir
mlir/test/Dialect/Transform/test-dialect-injection.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/CMakeLists.txt
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

index b5f5eec..070bfb6 100644 (file)
@@ -2,6 +2,8 @@
 
 [TOC]
 
+[include "Dialects/TransformTypes.md"]
+
 [include "Dialects/TransformOps.md"]
 
 ## Bufferization Transform Operations
@@ -16,4 +18,6 @@
 
 [include "Dialects/LinalgStructuredTransformOps.md"]
 
+[include "Dialects/TransformTypeInterfaces.md"]
+
 [include "Dialects/TransformOpInterfaces.md"]
index 7039b23..4fea396 100644 (file)
@@ -8,6 +8,13 @@ mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform)
 add_public_tablegen_target(MLIRTransformDialectIncGen)
 add_dependencies(mlir-headers MLIRTransformDialectIncGen)
 
+set(LLVM_TARGET_DEFINITIONS TransformTypes.td)
+mlir_tablegen(TransformTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(TransformTypes.cpp.inc -gen-typedef-defs)
+add_public_tablegen_target(MLIRTransformTypesIncGen)
+add_dependencies(mlir-headers MLIRTransformTypesIncGen)
+add_mlir_doc(TransformTypes TransformTypes Dialects/ -gen-typedef-docs)
+
 set(LLVM_TARGET_DEFINITIONS TransformAttrs.td)
 mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls)
 mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs)
@@ -17,5 +24,13 @@ add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen)
 add_mlir_dialect(TransformOps transform)
 add_mlir_doc(TransformOps TransformOps Dialects/ -gen-dialect-doc -dialect=transform)
 
+# Contrary to what the name claims, this only produces the _op_ interface.
 add_mlir_interface(TransformInterfaces)
 add_mlir_doc(TransformInterfaces TransformOpInterfaces Dialects/ -gen-op-interface-docs)
+
+set(LLVM_TARGET_DEFINITIONS TransformInterfaces.td)
+mlir_tablegen(TransformTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(TransformTypeInterfaces.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRTransformDialectTypeInterfacesIncGen)
+add_dependencies(mlir-headers MLIRTransformDialectTypeInterfacesIncGen)
+add_mlir_doc(TransformInterfaces TransformTypeInterfaces Dialects/ -gen-type-interface-docs)
index aea8f4b..055cf71 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringMap.h"
 
 namespace mlir {
@@ -37,6 +38,11 @@ static inline void checkImplementsTransformInterface(MLIRContext *context) {
          "ops injected into the transform dialect must implement "
          "MemoryEffectsOpInterface");
 }
+
+/// Asserts that the type provided as template argument implements the
+/// TransformTypeInterface. This must be a dynamic assertion since interface
+/// implementations may be registered at runtime.
+void checkImplementsTransformTypeInterface(TypeID typeID, MLIRContext *context);
 } // namespace detail
 #endif // NDEBUG
 } // namespace transform
@@ -120,6 +126,18 @@ protected:
     });
   }
 
+  /// Injects the types into the Transform dialect. The types must implement
+  /// the TransformTypeInterface and the implementation must be already
+  /// available when the type is injected. Furthermore, the types must provide
+  /// a `getMnemonic` static method returning an object convertible to
+  /// `StringRef` that is unique across all injected types.
+  template <typename... TypeTys>
+  void registerTypes() {
+    opInitializers.push_back([](TransformDialect *transformDialect) {
+      transformDialect->addTypesChecked<TypeTys...>();
+    });
+  }
+
   /// Declares that this Transform dialect extension depends on the dialect
   /// provided as template parameter. When the Transform dialect is loaded,
   /// dependent dialects will be loaded as well. This is intended for dialects
@@ -182,6 +200,25 @@ private:
   bool buildOnly;
 };
 
+template <typename Type>
+void TransformDialect::addTypeIfNotRegistered() {
+  // Use the address of the parse method as a proxy for identifying whether we
+  // are registering the same type class for the same mnemonic.
+  StringRef mnemonic = Type::getMnemonic();
+  auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse);
+  if (!inserted) {
+    const ExtensionTypeParsingHook &parsingHook = it->getValue();
+    if (*parsingHook.target<mlir::Type (*)(AsmParser &)>() != &Type::parse)
+      reportDuplicateTypeRegistration(mnemonic);
+  }
+  typePrintingHooks.try_emplace(
+      TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
+        printer << Type::getMnemonic();
+        cast<Type>(type).print(printer);
+      });
+  addTypes<Type>();
+}
+
 /// A wrapper for transform dialect extensions that forces them to be
 /// constructed in the build-only mode.
 template <typename DerivedTy>
index a4f7276..dd2b55f 100644 (file)
@@ -315,6 +315,23 @@ def Transform_Dialect : Dialect {
       const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
       getPDLConstraintHooks() const;
 
+      /// Parses a type registered by this dialect or one of its extensions.
+      ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
+
+      /// Prints a type registered by this dialect or one of its extensions.
+      void printType(::mlir::Type type,
+                     ::mlir::DialectAsmPrinter &printer) const override;
+
+      /// Parser callback for an individual type registered by this dialect or
+      /// its extensions.
+      using ExtensionTypeParsingHook =
+          std::function<::mlir::Type (::mlir::AsmParser &)>;
+
+      /// Printer callback for an individual type registered by this dialect or
+      /// its extensions.
+      using ExtensionTypePrintingHook =
+          std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;
+
     private:
       template <typename OpTy>
       void addOperationIfNotRegistered() {
@@ -344,6 +361,28 @@ def Transform_Dialect : Dialect {
         #endif // NDEBUG
       }
 
+      /// Registers the types specified as template parameters with the
+      /// Transform dialect. Checks that they meet the requirements for
+      /// Transform IR types.
+      template <typename... TypeTys>
+      void addTypesChecked() {
+        (addTypeIfNotRegistered<TypeTys>(), ...);
+
+        #ifndef NDEBUG
+        (detail::checkImplementsTransformTypeInterface(
+            TypeID::get<TypeTys>(), getContext()), ...);
+        #endif // NDEBUG
+      }
+
+      /// Implementation of the type registration for a single type, should
+      /// not be called directly, use addTypesChecked instead.
+      template <typename Type>
+      void addTypeIfNotRegistered();
+
+      /// Reports a repeated registration error of a type with the given
+      /// mnemonic.
+      [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
+
       template <typename, typename...>
       friend class TransformDialectExtension;
 
@@ -352,9 +391,23 @@ def Transform_Dialect : Dialect {
       void mergeInPDLMatchHooks(
           ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns);
 
+      //===----------------------------------------------------------------===//
+      // Data fields
+      //===----------------------------------------------------------------===//
+
       /// A container for PDL constraint function that can be used by
       /// operations in this dialect.
-      PDLPatternModule pdlMatchHooks;
+      ::mlir::PDLPatternModule pdlMatchHooks;
+
+      /// A map from type mnemonic to its parsing function for the remainder of
+      /// the syntax. The parser has access to the mnemonic, so it is used for
+      /// further dispatch.
+      ::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;
+
+      /// A map from type TypeID to its printing function. No need to do string
+      /// lookups when the type is fully constructed.
+      ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
+      typePrintingHooks;
   }];
 }
 
index abd10fe..25f61d6 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/IR/OpDefinition.h"
 
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ScopeExit.h"
 
 namespace mlir {
@@ -279,13 +280,16 @@ public:
   /// list of operations in the payload IR. The arguments must be defined in
   /// blocks of the currently processed transform IR region, typically after a
   /// region scope is defined.
-  void mapBlockArguments(BlockArgument argument,
-                         ArrayRef<Operation *> operations) {
+  ///
+  /// Returns failure if the payload does not satisfy the conditions associated
+  /// with the type of the handle value.
+  LogicalResult mapBlockArguments(BlockArgument argument,
+                                  ArrayRef<Operation *> operations) {
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
     assert(argument.getParentRegion() == regionStack.back() &&
            "mapping block arguments from a region other than the active one");
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-    setPayloadOps(argument, operations);
+    return setPayloadOps(argument, operations);
   }
 
   // Forward declarations to support limited visibility.
@@ -478,7 +482,10 @@ private:
   /// is invalid given the transformation "consumes" the handle as expressed
   /// by side effects. Practically, a transformation consuming a handle means
   /// that the associated payload operation may no longer exist.
-  void setPayloadOps(Value value, ArrayRef<Operation *> targets);
+  ///
+  /// Returns failure if the payload does not satisfy the conditions associated
+  /// with the type of the handle value.
+  LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
 
   /// Forgets the payload IR ops associated with the given transform IR value.
   void removePayloadOps(Value value);
@@ -488,8 +495,12 @@ private:
   /// expected to return the modified operation or nullptr. In the latter case,
   /// the corresponding operation is no longer associated with the transform IR
   /// value.
-  void updatePayloadOps(Value value,
-                        function_ref<Operation *(Operation *)> callback);
+  ///
+  /// Returns failure if the payload does not satisfy the conditions associated
+  /// with the type of the handle value.
+  LogicalResult
+  updatePayloadOps(Value value,
+                   function_ref<Operation *(Operation *)> callback);
 
   /// If the operand is a handle consumed by the operation, i.e. has the "free"
   /// memory effect associated with it, identifies other handles that are
@@ -574,9 +585,9 @@ namespace detail {
 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
 /// to either the list of operations associated with its operand or the root of
 /// the payload IR, depending on what is available in the context.
-void mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
-                                                  Operation *op,
-                                                  Region &region);
+LogicalResult
+mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
+                                             Operation *op, Region &region);
 
 /// Verification hook for PossibleTopLevelTransformOpTrait.
 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
@@ -613,17 +624,17 @@ public:
   /// Sets up the mapping between the entry block of the given region of this op
   /// and the relevant list of Payload IR operations in the given state. The
   /// state is expected to be already scoped at the region of this operation.
-  void mapBlockArguments(TransformState &state, Region &region) {
+  LogicalResult mapBlockArguments(TransformState &state, Region &region) {
     assert(region.getParentOp() == this->getOperation() &&
            "op comes from the wrong region");
-    detail::mapPossibleTopLevelTransformOpBlockArguments(
+    return detail::mapPossibleTopLevelTransformOpBlockArguments(
         state, this->getOperation(), region);
   }
-  void mapBlockArguments(TransformState &state) {
+  LogicalResult mapBlockArguments(TransformState &state) {
     assert(
         this->getOperation()->getNumRegions() == 1 &&
         "must indicate the region to map if the operation has more than one");
-    mapBlockArguments(state, this->getOperation()->getRegion(0));
+    return mapBlockArguments(state, this->getOperation()->getRegion(0));
   }
 };
 
index 9d55e78..e2ebf1d 100644 (file)
@@ -81,6 +81,31 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
   }];
 }
 
+def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
+  let description = [{
+    Types that can be used for Transform dialect handle values. Such types
+    define the properties of Payload IR operations associated with the handle.
+    A user of such a handle can assume that these properties have been verified
+    for any Payload IR operation associated with it.
+  }];
+
+  let cppNamespace = "::mlir::transform";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Checks if the given list of associated Payload IR operations satisfy
+        the conditions defined by this type. If not, produces a silenceable
+        error at the specified location.
+      }],
+      /*returnType=*/"::mlir::DiagnosedSilenceableFailure",
+      /*name=*/"checkPayload",
+      /*arguments=*/(ins "::mlir::Location":$loc,
+                         "::mlir::ArrayRef<::mlir::Operation *>":$payload)
+    >
+  ];
+}
+
 def FunctionalStyleTransformOpTrait
     : NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
   let cppNamespace = "::mlir::transform";
index 29fa720..b9f4eed 100644 (file)
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
 
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
-#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 
 namespace mlir {
index 817c1c6..3fa6692 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
 
+include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
@@ -96,6 +97,23 @@ def AlternativesOp : TransformDialectOp<"alternatives",
   let hasVerifier = 1;
 }
 
+def CastOp : TransformDialectOp<"cast",
+    [TransformOpInterface, TransformEachOpTrait,
+     DeclareOpInterfaceMethods<CastOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  // TODO: temporarily fallback support for casting from PDL_Operation type.
+  let arguments = (ins AnyType:$input);
+  let results = (outs AnyType:$output);
+  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+      ::mlir::Operation *target,
+      ::llvm::SmallVectorImpl<::mlir::Operation *> &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 def ForeachOp : TransformDialectOp<"foreach",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.h b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.h
new file mode 100644 (file)
index 0000000..a8b6ee0
--- /dev/null
@@ -0,0 +1,26 @@
+//===- TransformTypes.h - Transform dialect types ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class DiagnosedSilenceableFailure;
+class Operation;
+class Type;
+} // namespace mlir
+
+#include "mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Transform/IR/TransformTypes.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
new file mode 100644 (file)
index 0000000..6ece316
--- /dev/null
@@ -0,0 +1,26 @@
+//===- TransformTypes.td - Transform dialect types ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+
+def Transform_AnyOpType : TypeDef<Transform_Dialect, "AnyOp",
+    [DeclareTypeInterfaceMethods<TransformTypeInterface>]> {
+  let description = [{
+    Transform IR handle that can be associated with a list of arbitrary
+    Payload IR operations.
+  }];
+  let mnemonic = "any_op";
+  let assemblyFormat = "";
+}
+
+#endif  // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
index 3041d20..fe4c827 100644 (file)
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialect
   TransformDialect.cpp
   TransformInterfaces.cpp
   TransformOps.cpp
+  TransformTypes.cpp
 
   DEPENDS
   MLIRTransformDialectIncGen
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRTransformDialect
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRParser
   MLIRPDLDialect
   MLIRPDLInterpDialect
   MLIRRewrite
index 42234ec..adf9f0f 100644 (file)
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/DialectImplementation.h"
 
 using namespace mlir;
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
 
+#ifndef NDEBUG
+void transform::detail::checkImplementsTransformTypeInterface(
+    TypeID typeID, MLIRContext *context) {
+  const auto &abstractType = AbstractType::lookup(typeID, context);
+  assert(abstractType.hasInterface(TransformTypeInterface::getInterfaceID()));
+}
+#endif // NDEBUG
+
 void transform::TransformDialect::initialize() {
-  // Using the checked version to enable the same assertions as for the ops from
-  // extensions.
+  // Using the checked versions to enable the same assertions as for the ops
+  // from extensions.
   addOperationsChecked<
 #define GET_OP_LIST
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
       >();
+  addTypesChecked<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
+      >();
 }
 
 void transform::TransformDialect::mergeInPDLMatchHooks(
@@ -36,4 +50,36 @@ transform::TransformDialect::getPDLConstraintHooks() const {
   return pdlMatchHooks.getConstraintFunctions();
 }
 
+Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
+  StringRef keyword;
+  SMLoc loc = parser.getCurrentLocation();
+  if (failed(parser.parseKeyword(&keyword)))
+    return nullptr;
+
+  auto it = typeParsingHooks.find(keyword);
+  if (it == typeParsingHooks.end()) {
+    parser.emitError(loc) << "unknown type mnemonic: " << keyword;
+    return nullptr;
+  }
+
+  return it->getValue()(parser);
+}
+
+void transform::TransformDialect::printType(Type type,
+                                            DialectAsmPrinter &printer) const {
+  auto it = typePrintingHooks.find(type.getTypeID());
+  assert(it != typePrintingHooks.end() && "printing unknown type");
+  it->getSecond()(type, printer);
+}
+
+void transform::TransformDialect::reportDuplicateTypeRegistration(
+    StringRef mnemonic) {
+  std::string buffer;
+  llvm::raw_string_ostream msg(buffer);
+  msg << "error: extensible dialect type '" << mnemonic
+      << "' is already registered with a different implementation";
+  msg.flush();
+  llvm::report_fatal_error(StringRef(buffer));
+}
+
 #include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"
index 176e93a..3de6206 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/STLExtras.h"
@@ -60,14 +61,22 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
   return success(found);
 }
 
-void transform::TransformState::setPayloadOps(Value value,
-                                              ArrayRef<Operation *> targets) {
+LogicalResult
+transform::TransformState::setPayloadOps(Value value,
+                                         ArrayRef<Operation *> targets) {
   assert(value != kTopLevelValue &&
          "attempting to reset the transformation root");
 
   // TODO: this may go now
   if (value.use_empty())
-    return;
+    return success();
+
+  if (auto iface = value.getType().dyn_cast<TransformTypeInterface>()) {
+    DiagnosedSilenceableFailure result =
+        iface.checkPayload(value.getLoc(), targets);
+    if (failed(result.checkAndReport()))
+      return failure();
+  }
 
   // Setting new payload for the value without cleaning it first is a misuse of
   // the API, assert here.
@@ -80,6 +89,8 @@ void transform::TransformState::setPayloadOps(Value value,
 
   for (Operation *op : targets)
     mappings.reverse[op].push_back(value);
+
+  return success();
 }
 
 void transform::TransformState::dropReverseMapping(Mappings &mappings,
@@ -100,7 +111,7 @@ void transform::TransformState::removePayloadOps(Value value) {
   mappings.direct.erase(value);
 }
 
-void transform::TransformState::updatePayloadOps(
+LogicalResult transform::TransformState::updatePayloadOps(
     Value value, function_ref<Operation *(Operation *)> callback) {
   Mappings &mappings = getMapping(value);
   auto it = mappings.direct.find(value);
@@ -117,7 +128,15 @@ void transform::TransformState::updatePayloadOps(
     }
   }
 
+  if (auto iface = value.getType().dyn_cast<TransformTypeInterface>()) {
+    DiagnosedSilenceableFailure result =
+        iface.checkPayload(value.getLoc(), updated);
+    if (failed(result.checkAndReport()))
+      return failure();
+  }
+
   std::swap(association, updated);
+  return success();
 }
 
 void transform::TransformState::recordHandleInvalidationOne(
@@ -253,7 +272,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     assert(result.getDefiningOp() == transform.getOperation() &&
            "payload IR association for a value other than the result of the "
            "current transform op");
-    setPayloadOps(result, results.get(result.getResultNumber()));
+    if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
+      return DiagnosedSilenceableFailure::definiteFailure();
   }
 
   printOnFailureRAII.release();
@@ -278,9 +298,12 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
     return failure();
 
   for (Value handle : handles) {
-    state.updatePayloadOps(handle, [&](Operation *current) {
-      return current == op ? replacement : current;
-    });
+    LogicalResult result =
+        state.updatePayloadOps(handle, [&](Operation *current) {
+          return current == op ? replacement : current;
+        });
+    if (failed(result))
+      return failure();
   }
   return success();
 }
@@ -317,7 +340,7 @@ transform::TransformResults::get(unsigned resultNumber) const {
 // Utilities for PossibleTopLevelTransformOpTrait.
 //===----------------------------------------------------------------------===//
 
-void transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
+LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
     TransformState &state, Operation *op, Region &region) {
   SmallVector<Operation *> targets;
   if (op->getNumOperands() != 0)
@@ -325,7 +348,7 @@ void transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
   else
     targets.push_back(state.getTopLevel());
 
-  state.mapBlockArguments(region.front().getArgument(0), targets);
+  return state.mapBlockArguments(region.front().getArgument(0), targets);
 }
 
 LogicalResult
index 126500f..f888a85 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 
@@ -226,7 +227,8 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
       for (Operation *clone : clones)
         clone->erase();
     });
-    state.mapBlockArguments(reg.front().getArgument(0), clones);
+    if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
+      return DiagnosedSilenceableFailure::definiteFailure();
 
     bool failed = false;
     for (Operation &transform : reg.front().without_terminator()) {
@@ -292,6 +294,35 @@ LogicalResult transform::AlternativesOp::verify() {
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
+transform::CastOp::applyToOne(Operation *target,
+                              SmallVectorImpl<Operation *> &results,
+                              transform::TransformState &state) {
+  results.push_back(target);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::CastOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsPayload(effects);
+  consumesHandle(getInput(), effects);
+  producesHandle(getOutput(), effects);
+}
+
+bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+  assert(inputs.size() == 1 && "expected one input");
+  assert(outputs.size() == 1 && "expected one output");
+  return llvm::all_of(
+      std::initializer_list<Type>{inputs.front(), outputs.front()},
+      [](Type ty) {
+        return ty.isa<pdl::OperationType, transform::TransformTypeInterface>();
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// ForeachOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformResults &results,
                             transform::TransformState &state) {
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
@@ -299,7 +330,8 @@ transform::ForeachOp::apply(transform::TransformResults &results,
 
   for (Operation *op : payloadOps) {
     auto scope = state.make_region_scope(getBody());
-    state.mapBlockArguments(getIterationVariable(), {op});
+    if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
+      return DiagnosedSilenceableFailure::definiteFailure();
 
     // Execute loop body.
     for (Operation &transform : getBody().front().without_terminator()) {
@@ -572,7 +604,8 @@ transform::SequenceOp::apply(transform::TransformResults &results,
                              transform::TransformState &state) {
   // Map the entry block argument to the list of operations.
   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
-  mapBlockArguments(state);
+  if (failed(mapBlockArguments(state)))
+    return DiagnosedSilenceableFailure::definiteFailure();
 
   // Apply the sequenced ops one by one.
   for (Operation &transform : getBodyBlock()->without_terminator()) {
@@ -766,7 +799,8 @@ transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
 
   auto scope = state.make_region_scope(getBody());
-  mapBlockArguments(state);
+  if (failed(mapBlockArguments(state)))
+    return DiagnosedSilenceableFailure::definiteFailure();
   return state.applyTransform(transformOp);
 }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
new file mode 100644 (file)
index 0000000..ab72125
--- /dev/null
@@ -0,0 +1,37 @@
+//===- TransformTypes.cpp - Transform Dialect Type Definitions ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Compiler.h"
+
+using namespace mlir;
+
+#include "mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc"
+
+// These are automatically generated by ODS but are not used as the Transform
+// dialect uses a different dispatch mechanism to support dialect extensions.
+LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
+LLVM_ATTRIBUTE_UNUSED static LogicalResult
+generatedTypePrinter(Type def, AsmPrinter &printer);
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
+
+DiagnosedSilenceableFailure
+transform::AnyOpType::checkPayload(Location loc,
+                                   ArrayRef<Operation *> payload) const {
+  return DiagnosedSilenceableFailure::success();
+}
index d905a3e..31e0d14 100644 (file)
@@ -58,3 +58,10 @@ transform.sequence failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
   }
 }
+
+// CHECK: transform.sequence
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  // CHECK: cast %{{.*}} : !pdl.operation to !transform.any_op
+  %0 = cast %arg0: !pdl.operation to !transform.any_op
+}
index 4474666..f1c2762 100644 (file)
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s | FileCheck %s
 
-// These ops are defined by a test extension but should be okay to roundtrip.
+// These types and ops are defined by a test extension but should be okay to
+// roundtrip.
 
 // CHECK: transform.test_transform_op
 transform.test_transform_op
@@ -10,3 +11,7 @@ transform.test_transform_op
 
 // CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42]
 transform.test_consume_operand_if_matches_param_or_fail %0[42]
+
+// Ensure that the extension type is roundtripped correctly.
+// CHECK: transform.cast %{{.*}} : !pdl.operation to !transform.test_dialect_op
+%1 = transform.cast %0: !pdl.operation to !transform.test_dialect_op
index ddfd645..7f7b4c4 100644 (file)
@@ -798,5 +798,46 @@ transform.sequence failures(suppress) {
   // Silenceable failure and all handles are now empty.
   %h_2:3 = split_handles %muli_2 in [3]
   // expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h_2#0
+  transform.test_print_number_of_associated_payload_ir_ops %h_2#0
+}
+
+// -----
+
+"test.some_op"() : () -> ()
+"other_dialect.other_op"() : () -> ()
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operation "test.some_op"
+    pdl.rewrite %0 with "transform.dialect"
+  }
+
+  sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @some in %arg1
+    %2 = transform.cast %0 : !pdl.operation to !transform.test_dialect_op
+    transform.cast %2 : !transform.test_dialect_op to !pdl.operation
+  }
+}
+
+// -----
+
+"test.some_op"() : () -> ()
+"other_dialect.other_op"() : () -> ()
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @other : benefit(1) {
+    %0 = pdl.operation "other_dialect.other_op"
+    pdl.rewrite %0 with "transform.dialect"
+  }
+
+  sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @other in %arg1
+    // expected-error @below {{expected the payload operation to belong to the 'test' dialect}}
+    %2 = transform.cast %0 : !pdl.operation to !transform.test_dialect_op
+    transform.cast %2 : !transform.test_dialect_op to !pdl.operation
+  }
 }
index 6d74721..0119fb0 100644 (file)
@@ -1,6 +1,8 @@
 set(LLVM_TARGET_DEFINITIONS TestTransformDialectExtension.td)
 mlir_tablegen(TestTransformDialectExtension.h.inc -gen-op-decls)
 mlir_tablegen(TestTransformDialectExtension.cpp.inc -gen-op-defs)
+mlir_tablegen(TestTransformDialectExtensionTypes.h.inc -gen-typedef-decls -typedefs-dialect=transform)
+mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=transform)
 add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
 
 add_mlir_library(MLIRTestTransformDialect
index 3994f0e..6a39a2e 100644 (file)
@@ -17,6 +17,8 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Compiler.h"
 
 using namespace mlir;
 
@@ -310,6 +312,22 @@ mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
+    Location loc, ArrayRef<Operation *> payload) const {
+  if (payload.empty())
+    return DiagnosedSilenceableFailure::success();
+
+  for (Operation *op : payload) {
+    if (op->getName().getDialectNamespace() != "test") {
+      Diagnostic diag(loc, DiagnosticSeverity::Error);
+      diag << "expected the payload operation to belong to the 'test' dialect";
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    }
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL
@@ -327,6 +345,10 @@ public:
 #define GET_OP_LIST
 #include "TestTransformDialectExtension.cpp.inc"
                          >();
+    registerTypes<
+#define GET_TYPEDEF_LIST
+#include "TestTransformDialectExtensionTypes.cpp.inc"
+        >();
   }
 };
 } // namespace
@@ -334,6 +356,16 @@ public:
 #define GET_OP_CLASSES
 #include "TestTransformDialectExtension.cpp.inc"
 
+// These are automatically generated by ODS but are not used as the Transform
+// dialect uses a different dispatch mechanism to support dialect extensions.
+LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
+LLVM_ATTRIBUTE_UNUSED static LogicalResult
+generatedTypePrinter(Type def, AsmPrinter &printer);
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTransformDialectExtensionTypes.cpp.inc"
+
 void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
   registry.addExtensions<TestTransformDialectExtension>();
 }
index c38693f..e5785fd 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/OpImplementation.h"
 
 namespace mlir {
@@ -25,6 +26,9 @@ class DialectRegistry;
 #define GET_OP_CLASSES
 #include "TestTransformDialectExtension.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "TestTransformDialectExtensionTypes.h.inc"
+
 namespace test {
 /// Registers the test extension to the Transform dialect.
 void registerTestTransformDialectExtension(::mlir::DialectRegistry &registry);
index fda71bc..af296f3 100644 (file)
 #ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
 #define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
 
+include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 
+def TestTransformTestDialectHandleType
+  : TypeDef<Transform_Dialect, "TestDialectOp",
+      [DeclareTypeInterfaceMethods<TransformTypeInterface>]> {
+  let description = [{Handle pointing to an op from the Test dialect.}];
+  let mnemonic = "test_dialect_op";
+  let assemblyFormat = "";
+}
+
 def TestProduceParamOrForwardOperandOp
   : Op<Transform_Dialect, "test_produce_param_or_forward_operand",
        [DeclareOpInterfaceMethods<TransformOpInterface>]> {
index 87c4061..fd21506 100644 (file)
@@ -8396,6 +8396,7 @@ td_library(
     name = "TransformDialectTdFiles",
     srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]),
     deps = [
+        ":CastInterfacesTdFiles",
         ":ControlFlowInterfacesTdFiles",
         ":OpBaseTdFiles",
         ":PDLDialectTdFiles",
@@ -8441,6 +8442,18 @@ gentbl_cc_library(
             ],
             "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc",
         ),
+        (
+            [
+                "-gen-type-interface-decls",
+            ],
+            "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc",
+        ),
+        (
+            [
+                "-gen-type-interface-defs",
+            ],
+            "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc",
+        ),
     ],
     tblgen = ":mlir-tblgen",
     td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td",
@@ -8487,6 +8500,24 @@ gentbl_cc_library(
     deps = [":TransformDialectTdFiles"],
 )
 
+gentbl_cc_library(
+    name = "TransformTypesIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-typedef-decls"],
+            "include/mlir/Dialect/Transform/IR/TransformTypes.h.inc",
+        ),
+        (
+            ["-gen-typedef-defs"],
+            "include/mlir/Dialect/Transform/IR/TransformTypes.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Transform/IR/TransformTypes.td",
+    deps = [":TransformDialectTdFiles"],
+)
+
 cc_library(
     name = "TransformDialect",
     srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]),
@@ -8503,6 +8534,7 @@ cc_library(
         ":TransformDialectIncGen",
         ":TransformDialectInterfacesIncGen",
         ":TransformOpsIncGen",
+        ":TransformTypesIncGen",
         "//llvm:Support",
     ],
 )
index 841123e..02ae5f1 100644 (file)
@@ -266,6 +266,20 @@ gentbl_cc_library(
             ["-gen-op-defs"],
             "lib/Dialect/Transform/TestTransformDialectExtension.cpp.inc",
         ),
+        (
+            [
+                "-gen-typedef-decls",
+                "-typedefs-dialect=transform",
+            ],
+            "lib/Dialect/Transform/TestTransformDialectExtensionTypes.h.inc",
+        ),
+        (
+            [
+                "-gen-typedef-defs",
+                "-typedefs-dialect=transform",
+            ],
+            "lib/Dialect/Transform/TestTransformDialectExtensionTypes.cpp.inc",
+        ),
     ],
     tblgen = "//mlir:mlir-tblgen",
     td_file = "lib/Dialect/Transform/TestTransformDialectExtension.td",
@@ -284,6 +298,7 @@ cc_library(
     includes = ["lib/Dialect/Transform"],
     deps = [
         ":TestTransformDialectExtensionIncGen",
+        "//llvm:Support",
         "//mlir:IR",
         "//mlir:PDLDialect",
         "//mlir:Pass",