[mlir:PDL] Add support for DialectConversion with pattern configurations
authorRiver Riddle <riddleriver@gmail.com>
Thu, 8 Sep 2022 23:59:39 +0000 (16:59 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 8 Nov 2022 09:57:57 +0000 (01:57 -0800)
Up until now PDL(L) has not supported dialect conversion because we had no
way of remapping values or integrating with type conversions. This commit
rectifies that by adding a new "pattern configuration" concept to PDL. This
essentially allows for attaching external configurations to patterns, which
can hook into pattern events (for now just the scope of a rewrite, but we
could also pass configs to native rewrites as well). This allows for injecting
the type converter into the conversion pattern rewriter.

Differential Revision: https://reviews.llvm.org/D133142

19 files changed:
mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/include/mlir/Transforms/DialectConversion.pdll [new file with mode: 0644]
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Rewrite/ByteCode.h
mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
mlir/lib/Rewrite/PatternApplicator.cpp
mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/test-dialect-conversion-pdll.mlir [new file with mode: 0644]
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestDialectConversion.cpp [new file with mode: 0644]
mlir/test/lib/Transforms/TestDialectConversion.pdll [new file with mode: 0644]
mlir/test/lib/Transforms/lit.local.cfg [new file with mode: 0644]
mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
mlir/tools/mlir-opt/mlir-opt.cpp

index 8e8517e..54033ff 100644 (file)
 #ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
 #define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
 
-#include <memory>
+#include "mlir/Support/LLVM.h"
 
 namespace mlir {
 class ModuleOp;
+class Operation;
 template <typename OpT>
 class OperationPass;
+class PDLPatternConfigSet;
 
 #define GEN_PASS_DECL_CONVERTPDLTOPDLINTERP
 #include "mlir/Conversion/Passes.h.inc"
@@ -26,6 +28,12 @@ class OperationPass;
 /// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
 std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass();
 
+/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
+/// `configMap` holds a map of the configurations for each pattern being
+/// compiled.
+std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass(
+    DenseMap<Operation *, PDLPatternConfigSet *> &configMap);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
index 151fab8..e257b67 100644 (file)
@@ -600,10 +600,16 @@ public:
 class PatternRewriter : public RewriterBase {
 public:
   using RewriterBase::RewriterBase;
+
+  /// A hook used to indicate if the pattern rewriter can recover from failure
+  /// during the rewrite stage of a pattern. For example, if the pattern
+  /// rewriter supports rollback, it may progress smoothly even if IR was
+  /// changed during the rewrite.
+  virtual bool canRecoverFromRewriteFailure() const { return false; }
 };
 
 //===----------------------------------------------------------------------===//
-// PDLPatternModule
+// PDL Patterns
 //===----------------------------------------------------------------------===//
 
 //===----------------------------------------------------------------------===//
@@ -797,6 +803,108 @@ protected:
 };
 
 //===----------------------------------------------------------------------===//
+// PDLPatternConfig
+
+/// An individual configuration for a pattern, which can be accessed by native
+/// functions via the PDLPatternConfigSet. This allows for injecting additional
+/// configuration into PDL patterns that is specific to certain compilation
+/// flows.
+class PDLPatternConfig {
+public:
+  virtual ~PDLPatternConfig() = default;
+
+  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
+  /// pattern. These can be used to setup any specific state necessary for the
+  /// rewrite.
+  virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
+  virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
+
+  /// Return the TypeID that represents this configuration.
+  TypeID getTypeID() const { return id; }
+
+protected:
+  PDLPatternConfig(TypeID id) : id(id) {}
+
+private:
+  TypeID id;
+};
+
+/// This class provides a base class for users implementing a type of pattern
+/// configuration.
+template <typename T>
+class PDLPatternConfigBase : public PDLPatternConfig {
+public:
+  /// Support LLVM style casting.
+  static bool classof(const PDLPatternConfig *config) {
+    return config->getTypeID() == getConfigID();
+  }
+
+  /// Return the type id used for this configuration.
+  static TypeID getConfigID() { return TypeID::get<T>(); }
+
+protected:
+  PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
+};
+
+/// This class contains a set of configurations for a specific pattern.
+/// Configurations are uniqued by TypeID, meaning that only one configuration of
+/// each type is allowed.
+class PDLPatternConfigSet {
+public:
+  PDLPatternConfigSet() = default;
+
+  /// Construct a set with the given configurations.
+  template <typename... ConfigsT>
+  PDLPatternConfigSet(ConfigsT &&...configs) {
+    (addConfig(std::forward<ConfigsT>(configs)), ...);
+  }
+
+  /// Get the configuration defined by the given type. Asserts that the
+  /// configuration of the provided type exists.
+  template <typename T>
+  const T &get() const {
+    const T *config = tryGet<T>();
+    assert(config && "configuration not found");
+    return *config;
+  }
+
+  /// Get the configuration defined by the given type, returns nullptr if the
+  /// configuration does not exist.
+  template <typename T>
+  const T *tryGet() const {
+    for (const auto &configIt : configs)
+      if (const T *config = dyn_cast<T>(configIt.get()))
+        return config;
+    return nullptr;
+  }
+
+  /// Notify the configurations within this set at the beginning or end of a
+  /// rewrite of a matched pattern.
+  void notifyRewriteBegin(PatternRewriter &rewriter) {
+    for (const auto &config : configs)
+      config->notifyRewriteBegin(rewriter);
+  }
+  void notifyRewriteEnd(PatternRewriter &rewriter) {
+    for (const auto &config : configs)
+      config->notifyRewriteEnd(rewriter);
+  }
+
+protected:
+  /// Add a configuration to the set.
+  template <typename T>
+  void addConfig(T &&config) {
+    assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
+    configs.emplace_back(
+        std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
+  }
+
+  /// The set of configurations for this pattern. This uses a vector instead of
+  /// a map with the expectation that the number of configurations per set is
+  /// small (<= 1).
+  SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
+};
+
+//===----------------------------------------------------------------------===//
 // PDLPatternModule
 
 /// A generic PDL pattern constraint function. This function applies a
@@ -807,9 +915,11 @@ using PDLConstraintFunction =
 /// A native PDL rewrite function. This function performs a rewrite on the
 /// given set of values. Any results from this rewrite that should be passed
 /// back to PDL should be added to the provided result list. This method is only
-/// invoked when the corresponding match was successful.
-using PDLRewriteFunction =
-    std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+/// invoked when the corresponding match was successful. Returns failure if an
+/// invariant of the rewrite was broken (certain rewriters may recover from
+/// partial pattern application).
+using PDLRewriteFunction = std::function<LogicalResult(
+    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
 
 namespace detail {
 namespace pdl_function_builder {
@@ -1034,6 +1144,13 @@ struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
     results.push_back(types);
   }
 };
+template <unsigned N>
+struct ProcessPDLValue<SmallVector<Type, N>> {
+  static void processAsResult(PatternRewriter &, PDLResultList &results,
+                              SmallVector<Type, N> values) {
+    results.push_back(TypeRange(values));
+  }
+};
 
 //===----------------------------------------------------------------------===//
 // Value
@@ -1061,6 +1178,13 @@ struct ProcessPDLValue<ResultRange> {
     results.push_back(values);
   }
 };
+template <unsigned N>
+struct ProcessPDLValue<SmallVector<Value, N>> {
+  static void processAsResult(PatternRewriter &, PDLResultList &results,
+                              SmallVector<Value, N> values) {
+    results.push_back(ValueRange(values));
+  }
+};
 
 //===----------------------------------------------------------------------===//
 // PDL Function Builder: Argument Handling
@@ -1111,28 +1235,49 @@ void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
 
 /// Store a single result within the result list.
 template <typename T>
-static void processResults(PatternRewriter &rewriter, PDLResultList &results,
-                           T &&value) {
+static LogicalResult processResults(PatternRewriter &rewriter,
+                                    PDLResultList &results, T &&value) {
   ProcessPDLValue<T>::processAsResult(rewriter, results,
                                       std::forward<T>(value));
+  return success();
 }
 
 /// Store a std::pair<> as individual results within the result list.
 template <typename T1, typename T2>
-static void processResults(PatternRewriter &rewriter, PDLResultList &results,
-                           std::pair<T1, T2> &&pair) {
-  processResults(rewriter, results, std::move(pair.first));
-  processResults(rewriter, results, std::move(pair.second));
+static LogicalResult processResults(PatternRewriter &rewriter,
+                                    PDLResultList &results,
+                                    std::pair<T1, T2> &&pair) {
+  if (failed(processResults(rewriter, results, std::move(pair.first))) ||
+      failed(processResults(rewriter, results, std::move(pair.second))))
+    return failure();
+  return success();
 }
 
 /// Store a std::tuple<> as individual results within the result list.
 template <typename... Ts>
-static void processResults(PatternRewriter &rewriter, PDLResultList &results,
-                           std::tuple<Ts...> &&tuple) {
+static LogicalResult processResults(PatternRewriter &rewriter,
+                                    PDLResultList &results,
+                                    std::tuple<Ts...> &&tuple) {
   auto applyFn = [&](auto &&...args) {
-    (processResults(rewriter, results, std::move(args)), ...);
+    return (succeeded(processResults(rewriter, results, std::move(args))) &&
+            ...);
   };
-  std::apply(applyFn, std::move(tuple));
+  return success(std::apply(applyFn, std::move(tuple)));
+}
+
+/// Handle LogicalResult propagation.
+inline LogicalResult processResults(PatternRewriter &rewriter,
+                                    PDLResultList &results,
+                                    LogicalResult &&result) {
+  return result;
+}
+template <typename T>
+static LogicalResult processResults(PatternRewriter &rewriter,
+                                    PDLResultList &results,
+                                    FailureOr<T> &&result) {
+  if (failed(result))
+    return failure();
+  return processResults(rewriter, results, std::move(*result));
 }
 
 //===----------------------------------------------------------------------===//
@@ -1192,23 +1337,26 @@ buildConstraintFn(ConstraintFnT &&constraintFn) {
 /// This overload handles the case of no return values.
 template <typename PDLFnT, std::size_t... I,
           typename FnTraitsT = llvm::function_traits<PDLFnT>>
-std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
+std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
+                 LogicalResult>
 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
                             PDLResultList &, ArrayRef<PDLValue> values,
                             std::index_sequence<I...>) {
   fn(rewriter,
      (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
          values[I]))...);
+  return success();
 }
 /// This overload handles the case of return values, which need to be packaged
 /// into the result list.
 template <typename PDLFnT, std::size_t... I,
           typename FnTraitsT = llvm::function_traits<PDLFnT>>
-std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
+std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
+                 LogicalResult>
 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
                             PDLResultList &results, ArrayRef<PDLValue> values,
                             std::index_sequence<I...>) {
-  processResults(
+  return processResults(
       rewriter, results,
       fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
                         processAsArg(values[I]))...));
@@ -1240,14 +1388,17 @@ buildRewriteFn(RewriteFnT &&rewriteFn) {
         std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
                                  1>();
     assertArgs<RewriteFnT>(rewriter, values, argIndices);
-    processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
-                                argIndices);
+    return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
+                                       argIndices);
   };
 }
 
 } // namespace pdl_function_builder
 } // namespace detail
 
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+
 /// This class contains all of the necessary data for a set of PDL patterns, or
 /// pattern rewrites specified in the form of the PDL dialect. This PDL module
 /// contained by this pattern may contain any number of `pdl.pattern`
@@ -1256,9 +1407,17 @@ class PDLPatternModule {
 public:
   PDLPatternModule() = default;
 
-  /// Construct a PDL pattern with the given module.
-  PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
-      : pdlModule(std::move(pdlModule)) {}
+  /// Construct a PDL pattern with the given module and configurations.
+  PDLPatternModule(OwningOpRef<ModuleOp> module)
+      : pdlModule(std::move(module)) {}
+  template <typename... ConfigsT>
+  PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
+      : PDLPatternModule(std::move(module)) {
+    auto configSet = std::make_unique<PDLPatternConfigSet>(
+        std::forward<ConfigsT>(patternConfigs)...);
+    attachConfigToPatterns(*pdlModule, *configSet);
+    configs.emplace_back(std::move(configSet));
+  }
 
   /// Merge the state in `other` into this pattern module.
   void mergeIn(PDLPatternModule &&other);
@@ -1344,6 +1503,14 @@ public:
     return rewriteFunctions;
   }
 
+  /// Return the set of the registered pattern configs.
+  SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
+    return std::move(configs);
+  }
+  DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
+    return std::move(configMap);
+  }
+
   /// Clear out the patterns and functions within this module.
   void clear() {
     pdlModule = nullptr;
@@ -1352,9 +1519,17 @@ public:
   }
 
 private:
+  /// Attach the given pattern config set to the patterns defined within the
+  /// given module.
+  void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
+
   /// The module containing the `pdl.pattern` operations.
   OwningOpRef<ModuleOp> pdlModule;
 
+  /// The set of configuration sets referenced by patterns within `pdlModule`.
+  SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
+  DenseMap<Operation *, PDLPatternConfigSet *> configMap;
+
   /// The external functions referenced from within the PDL module.
   llvm::StringMap<PDLConstraintFunction> constraintFunctions;
   llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
index 6045b22..5980949 100644 (file)
@@ -574,6 +574,11 @@ public:
   // PatternRewriter Hooks
   //===--------------------------------------------------------------------===//
 
+  /// Indicate that the conversion rewriter can recover from rewrite failure.
+  /// Recovery is supported via rollback, allowing for continued processing of
+  /// patterns even if a failure is encountered during the rewrite step.
+  bool canRecoverFromRewriteFailure() const override { return true; }
+
   /// PatternRewriter hook for replacing the results of an operation when the
   /// given functor returns true.
   void replaceOpWithIf(
@@ -892,6 +897,35 @@ private:
 };
 
 //===----------------------------------------------------------------------===//
+// PDL Configuration
+//===----------------------------------------------------------------------===//
+
+/// A PDL configuration that is used to supported dialect conversion
+/// functionality.
+class PDLConversionConfig final
+    : public PDLPatternConfigBase<PDLConversionConfig> {
+public:
+  PDLConversionConfig(TypeConverter *converter) : converter(converter) {}
+  ~PDLConversionConfig() final = default;
+
+  /// Return the type converter used by this configuration, which may be nullptr
+  /// if no type conversions are expected.
+  TypeConverter *getTypeConverter() const { return converter; }
+
+  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
+  /// pattern.
+  void notifyRewriteBegin(PatternRewriter &rewriter) final;
+  void notifyRewriteEnd(PatternRewriter &rewriter) final;
+
+private:
+  /// An optional type converter to use for the pattern.
+  TypeConverter *converter;
+};
+
+/// Register the dialect conversion PDL functions with the given pattern set.
+void registerConversionPDLFunctions(RewritePatternSet &patterns);
+
+//===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.pdll b/mlir/include/mlir/Transforms/DialectConversion.pdll
new file mode 100644 (file)
index 0000000..9c6ce7a
--- /dev/null
@@ -0,0 +1,30 @@
+//===- DialectConversion.pdll - DialectConversion PDLL Support -*- PDLL -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines various utilities for interacting with dialect conversion
+// within PDLL.
+//
+//===----------------------------------------------------------------------===//
+
+/// This rewrite returns the converted value of `value`, whose type is defined 
+/// by the type converted specified in the `PDLConversionConfig` of the current
+/// pattern.
+Rewrite convertValue(value: Value) -> Value;
+
+/// This rewrite returns the converted values of `values`, whose type is defined 
+/// by the type converted specified in the `PDLConversionConfig` of the current
+/// pattern.
+Rewrite convertValues(values: ValueRange) -> ValueRange;
+
+/// This rewrite returns the converted type of `type` as defined by the type
+/// converted specified in the `PDLConversionConfig` of the current pattern.
+Rewrite convertType(type: Type) -> Type;
+
+/// This rewrite returns the converted types of `types` as defined by the type
+/// converted specified in the `PDLConversionConfig` of the current pattern.
+Rewrite convertTypes(types: TypeRange) -> TypeRange;
index 301fa68..987e7a3 100644 (file)
@@ -37,7 +37,8 @@ namespace {
 /// given module containing PDL pattern operations.
 struct PatternLowering {
 public:
-  PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);
+  PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
+                  DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
 
   /// Generate code for matching and rewriting based on the pattern operations
   /// within the module.
@@ -140,13 +141,19 @@ private:
   /// The set of operation values whose whose location will be used for newly
   /// generated operations.
   SetVector<Value> locOps;
+
+  /// A mapping between pattern operations and the corresponding configuration
+  /// set.
+  DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
 };
 } // namespace
 
-PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
-                                 ModuleOp rewriterModule)
+PatternLowering::PatternLowering(
+    pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
+    DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
-      rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
+      rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
+      configMap(configMap) {}
 
 void PatternLowering::lower(ModuleOp module) {
   PredicateUniquer predicateUniquer;
@@ -589,10 +596,14 @@ void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
       rootKindAttr = builder.getStringAttr(*rootKind);
 
   builder.setInsertionPointToEnd(currentBlock);
-  builder.create<pdl_interp::RecordMatchOp>(
+  auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
       failureBlockStack.back());
+
+  // Set the config of the lowered match to the parent pattern.
+  if (configMap)
+    configMap->try_emplace(matchOp, configMap->lookup(pattern));
 }
 
 SymbolRefAttr PatternLowering::generateRewriter(
@@ -922,7 +933,14 @@ void PatternLowering::generateOperationResultTypeRewriter(
 namespace {
 struct PDLToPDLInterpPass
     : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
+  PDLToPDLInterpPass() = default;
+  PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
+  PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
+      : configMap(&configMap) {}
   void runOnOperation() final;
+
+  /// A map containing the configuration for each pattern.
+  DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
 };
 } // namespace
 
@@ -946,15 +964,24 @@ void PDLToPDLInterpPass::runOnOperation() {
       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
 
   // Generate the code for the patterns within the module.
-  PatternLowering generator(matcherFunc, rewriterModule);
+  PatternLowering generator(matcherFunc, rewriterModule, configMap);
   generator.lower(module);
 
   // After generation, delete all of the pattern operations.
   for (pdl::PatternOp pattern :
-       llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
+       llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
+    // Drop the now dead config mappings.
+    if (configMap)
+      configMap->erase(pattern);
+
     pattern.erase();
+  }
 }
 
 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
   return std::make_unique<PDLToPDLInterpPass>();
 }
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
+    DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
+  return std::make_unique<PDLToPDLInterpPass>(configMap);
+}
index 494d90f..d2de65e 100644 (file)
@@ -158,11 +158,15 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
   if (!other.pdlModule)
     return;
 
-  // Steal the functions of the other module.
+  // Steal the functions and config of the other module.
   for (auto &it : other.constraintFunctions)
     registerConstraintFunction(it.first(), std::move(it.second));
   for (auto &it : other.rewriteFunctions)
     registerRewriteFunction(it.first(), std::move(it.second));
+  for (auto &it : other.configs)
+    configs.emplace_back(std::move(it));
+  for (auto &it : other.configMap)
+    configMap.insert(it);
 
   // Steal the other state if we have no patterns.
   if (!pdlModule) {
@@ -176,6 +180,18 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
                                 other.pdlModule->getBody()->getOperations());
 }
 
+void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
+                                              PDLPatternConfigSet &configSet) {
+  // Attach the configuration to the symbols within the module. We only add
+  // to symbols to avoid hardcoding any specific operation names here (given
+  // that we don't depend on any PDL dialect). We can't use
+  // cast<SymbolOpInterface> here because patterns may be optional symbols.
+  module->walk([&](Operation *op) {
+    if (op->hasTrait<SymbolOpInterface::Trait>())
+      configMap[op] = &configSet;
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // Function Registry
 
index 388d6dc..9cc51da 100644 (file)
@@ -34,21 +34,23 @@ using namespace mlir::detail;
 //===----------------------------------------------------------------------===//
 
 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
+                                              PDLPatternConfigSet *configSet,
                                               ByteCodeAddr rewriterAddr) {
+  PatternBenefit benefit = matchOp.getBenefit();
+  MLIRContext *ctx = matchOp.getContext();
+
+  // Collect the set of generated operations.
   SmallVector<StringRef, 8> generatedOps;
   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
     generatedOps =
         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
 
-  PatternBenefit benefit = matchOp.getBenefit();
-  MLIRContext *ctx = matchOp.getContext();
-
   // Check to see if this is pattern matches a specific operation type.
   if (Optional<StringRef> rootKind = matchOp.getRootKind())
-    return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
+    return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
                               generatedOps);
-  return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
-                            generatedOps);
+  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
+                            benefit, ctx, generatedOps);
 }
 
 //===----------------------------------------------------------------------===//
@@ -194,14 +196,15 @@ public:
             ByteCodeField &maxValueRangeMemoryIndex,
             ByteCodeField &maxLoopLevel,
             llvm::StringMap<PDLConstraintFunction> &constraintFns,
-            llvm::StringMap<PDLRewriteFunction> &rewriteFns)
+            llvm::StringMap<PDLRewriteFunction> &rewriteFns,
+            const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
         rewriterByteCode(rewriterByteCode), patterns(patterns),
         maxValueMemoryIndex(maxValueMemoryIndex),
         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
-        maxLoopLevel(maxLoopLevel) {
+        maxLoopLevel(maxLoopLevel), configMap(configMap) {
     for (const auto &it : llvm::enumerate(constraintFns))
       constraintToMemIndex.try_emplace(it.value().first(), it.index());
     for (const auto &it : llvm::enumerate(rewriteFns))
@@ -328,6 +331,9 @@ private:
   ByteCodeField &maxTypeRangeMemoryIndex;
   ByteCodeField &maxValueRangeMemoryIndex;
   ByteCodeField &maxLoopLevel;
+
+  /// A map of pattern configurations.
+  const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
 };
 
 /// This class provides utilities for writing a bytecode stream.
@@ -969,7 +975,8 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
   ByteCodeField patternIndex = patterns.size();
   patterns.emplace_back(PDLByteCodePattern::create(
-      op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
+      op, configMap.lookup(op),
+      rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
   writer.append(OpCode::RecordMatch, patternIndex,
                 SuccessorRange(op.getOperation()), op.getMatchedOps());
   writer.appendPDLValueList(op.getInputs());
@@ -1014,13 +1021,16 @@ void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
 // PDLByteCode
 //===----------------------------------------------------------------------===//
 
-PDLByteCode::PDLByteCode(ModuleOp module,
-                         llvm::StringMap<PDLConstraintFunction> constraintFns,
-                         llvm::StringMap<PDLRewriteFunction> rewriteFns) {
+PDLByteCode::PDLByteCode(
+    ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
+    const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
+    llvm::StringMap<PDLConstraintFunction> constraintFns,
+    llvm::StringMap<PDLRewriteFunction> rewriteFns)
+    : configs(std::move(configs)) {
   Generator generator(module.getContext(), uniquedData, matcherByteCode,
                       rewriterByteCode, patterns, maxValueMemoryIndex,
                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
-                      maxLoopLevel, constraintFns, rewriteFns);
+                      maxLoopLevel, constraintFns, rewriteFns, configMap);
   generator.generate(module);
 
   // Initialize the external functions.
@@ -1076,14 +1086,15 @@ public:
   /// Start executing the code at the current bytecode index. `matches` is an
   /// optional field provided when this function is executed in a matching
   /// context.
-  void execute(PatternRewriter &rewriter,
-               SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
-               Optional<Location> mainRewriteLoc = {});
+  LogicalResult
+  execute(PatternRewriter &rewriter,
+          SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
+          Optional<Location> mainRewriteLoc = {});
 
 private:
   /// Internal implementation of executing each of the bytecode commands.
   void executeApplyConstraint(PatternRewriter &rewriter);
-  void executeApplyRewrite(PatternRewriter &rewriter);
+  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
   void executeAreEqual();
   void executeAreRangesEqual();
   void executeBranch();
@@ -1345,7 +1356,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
   selectJump(succeeded(constraintFn(rewriter, args)));
 }
 
-void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
+LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
   SmallVector<PDLValue, 16> args;
@@ -1359,7 +1370,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
   // Execute the rewrite function.
   ByteCodeField numResults = read();
   ByteCodeRewriteResultList results(numResults);
-  rewriteFn(rewriter, results, args);
+  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
 
   assert(results.getResults().size() == numResults &&
          "native PDL rewrite function returned unexpected number of results");
@@ -1395,6 +1406,13 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
     allocatedTypeRangeMemory.push_back(std::move(it));
   for (auto &it : results.getAllocatedValueRanges())
     allocatedValueRangeMemory.push_back(std::move(it));
+
+  // Process the result of the rewrite.
+  if (failed(rewriteResult)) {
+    LLVM_DEBUG(llvm::dbgs() << "  - Failed");
+    return failure();
+  }
+  return success();
 }
 
 void ByteCodeExecutor::executeAreEqual() {
@@ -2017,10 +2035,10 @@ void ByteCodeExecutor::executeSwitchTypes() {
   });
 }
 
-void ByteCodeExecutor::execute(
-    PatternRewriter &rewriter,
-    SmallVectorImpl<PDLByteCode::MatchResult> *matches,
-    Optional<Location> mainRewriteLoc) {
+LogicalResult
+ByteCodeExecutor::execute(PatternRewriter &rewriter,
+                          SmallVectorImpl<PDLByteCode::MatchResult> *matches,
+                          Optional<Location> mainRewriteLoc) {
   while (true) {
     // Print the location of the operation being executed.
     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
@@ -2031,7 +2049,8 @@ void ByteCodeExecutor::execute(
       executeApplyConstraint(rewriter);
       break;
     case ApplyRewrite:
-      executeApplyRewrite(rewriter);
+      if (failed(executeApplyRewrite(rewriter)))
+        return failure();
       break;
     case AreEqual:
       executeAreEqual();
@@ -2078,7 +2097,7 @@ void ByteCodeExecutor::execute(
     case Finalize:
       executeFinalize();
       LLVM_DEBUG(llvm::dbgs() << "\n");
-      return;
+      return success();
     case ForEach:
       executeForEach();
       break;
@@ -2166,8 +2185,6 @@ void ByteCodeExecutor::execute(
   }
 }
 
-/// Run the pattern matcher on the given root operation, collecting the matched
-/// patterns in `matches`.
 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
                         SmallVectorImpl<MatchResult> &matches,
                         PDLByteCodeMutableState &state) const {
@@ -2181,7 +2198,8 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
       constraintFunctions, rewriteFunctions);
-  executor.execute(rewriter, &matches);
+  LogicalResult executeResult = executor.execute(rewriter, &matches);
+  assert(succeeded(executeResult) && "unexpected matcher execution failure");
 
   // Order the found matches by benefit.
   std::stable_sort(matches.begin(), matches.end(),
@@ -2190,9 +2208,13 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
                    });
 }
 
-/// Run the rewriter of the given pattern on the root operation `op`.
-void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
-                          PDLByteCodeMutableState &state) const {
+LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
+                                   const MatchResult &match,
+                                   PDLByteCodeMutableState &state) const {
+  auto *configSet = match.pattern->getConfigSet();
+  if (configSet)
+    configSet->notifyRewriteBegin(rewriter);
+
   // The arguments of the rewrite function are stored at the start of the
   // memory buffer.
   llvm::copy(match.values, state.memory.begin());
@@ -2204,5 +2226,24 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
       rewriterByteCode, state.currentPatternBenefits, patterns,
       constraintFunctions, rewriteFunctions);
-  executor.execute(rewriter, /*matches=*/nullptr, match.location);
+  LogicalResult result =
+      executor.execute(rewriter, /*matches=*/nullptr, match.location);
+
+  if (configSet)
+    configSet->notifyRewriteEnd(rewriter);
+
+  // If the rewrite failed, check if the pattern rewriter can recover. If it
+  // can, we can signal to the pattern applicator to keep trying patterns. If it
+  // doesn't, we need to bail. Bailing here should be fine, given that we have
+  // no means to propagate such a failure to the user, and it also indicates a
+  // bug in the user code (i.e. failable rewrites should not be used with
+  // pattern rewriters that don't support it).
+  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
+    LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
+    llvm::report_fatal_error(
+        "Native PDL Rewrite failed, but the pattern "
+        "rewriter doesn't support recovery. Failable pattern rewrites should "
+        "not be used with pattern rewriters that do not support them.");
+  }
+  return result;
 }
index e423ff2..4d43fe6 100644 (file)
@@ -38,19 +38,27 @@ using OwningOpRange = llvm::OwningArrayRef<Operation *>;
 class PDLByteCodePattern : public Pattern {
 public:
   static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
+                                   PDLPatternConfigSet *configSet,
                                    ByteCodeAddr rewriterAddr);
 
   /// Return the bytecode address of the rewriter for this pattern.
   ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
 
+  /// Return the configuration set for this pattern, or null if there is none.
+  PDLPatternConfigSet *getConfigSet() const { return configSet; }
+
 private:
   template <typename... Args>
-  PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
-      : Pattern(std::forward<Args>(patternArgs)...),
-        rewriterAddr(rewriterAddr) {}
+  PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet,
+                     Args &&...patternArgs)
+      : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr),
+        configSet(configSet) {}
 
   /// The address of the rewriter for this pattern.
   ByteCodeAddr rewriterAddr;
+
+  /// The optional config set for this pattern.
+  PDLPatternConfigSet *configSet;
 };
 
 //===----------------------------------------------------------------------===//
@@ -148,6 +156,8 @@ public:
   /// Create a ByteCode instance from the given module containing operations in
   /// the PDL interpreter dialect.
   PDLByteCode(ModuleOp module,
+              SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
+              const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
               llvm::StringMap<PDLConstraintFunction> constraintFns,
               llvm::StringMap<PDLRewriteFunction> rewriteFns);
 
@@ -165,9 +175,9 @@ public:
              PDLByteCodeMutableState &state) const;
 
   /// Run the rewriter of the given pattern that was previously matched in
-  /// `match`.
-  void rewrite(PatternRewriter &rewriter, const MatchResult &match,
-               PDLByteCodeMutableState &state) const;
+  /// `match`. Returns if a failure was encountered during the rewrite.
+  LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
+                        PDLByteCodeMutableState &state) const;
 
 private:
   /// Execute the given byte code starting at the provided instruction `inst`.
@@ -177,6 +187,9 @@ private:
                        PDLByteCodeMutableState &state,
                        SmallVectorImpl<MatchResult> *matches) const;
 
+  /// The set of pattern configs referenced within the bytecode.
+  SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
+
   /// A vector containing pointers to uniqued data. The storage is intentionally
   /// opaque such that we can store a wide range of data types. The types of
   /// data stored here include:
index 7657825..7b83d10 100644 (file)
@@ -16,7 +16,9 @@
 
 using namespace mlir;
 
-static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
+static LogicalResult
+convertPDLToPDLInterp(ModuleOp pdlModule,
+                      DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
   // Skip the conversion if the module doesn't contain pdl.
   if (pdlModule.getOps<pdl::PatternOp>().empty())
     return success();
@@ -37,7 +39,7 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
   // mode.
   pdlPipeline.enableVerifier(false);
 #endif
-  pdlPipeline.addPass(createPDLToPDLInterpPass());
+  pdlPipeline.addPass(createPDLToPDLInterpPass(configMap));
   if (failed(pdlPipeline.run(pdlModule)))
     return failure();
 
@@ -123,13 +125,16 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
   ModuleOp pdlModule = pdlPatterns.getModule();
   if (!pdlModule)
     return;
-  if (failed(convertPDLToPDLInterp(pdlModule)))
+  DenseMap<Operation *, PDLPatternConfigSet *> configMap =
+      pdlPatterns.takeConfigMap();
+  if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
     llvm::report_fatal_error(
         "failed to lower PDL pattern module to the PDL Interpreter");
 
   // Generate the pdl bytecode.
   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
-      pdlModule, pdlPatterns.takeConstraintFunctions(),
+      pdlModule, pdlPatterns.takeConfigs(), configMap,
+      pdlPatterns.takeConstraintFunctions(),
       pdlPatterns.takeRewriteFunctions());
 }
 
index 686b8e2..499a850 100644 (file)
@@ -191,20 +191,21 @@ LogicalResult PatternApplicator::matchAndRewrite(
     Operation *dumpRootOp = getDumpRootOp(op);
 #endif
     if (pdlMatch) {
-      bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
-      result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
+      result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
     } else {
-      const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
+      LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
+                              << bestPattern->getDebugName() << "\"\n");
 
-      LLVM_DEBUG(llvm::dbgs()
-                 << "Trying to match \"" << pattern->getDebugName() << "\"\n");
+      const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
       result = pattern->matchAndRewrite(op, rewriter);
-      LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "
-                              << succeeded(result) << "\n");
 
-      if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
-        result = failure();
+      LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName()
+                              << "\" result " << succeeded(result) << "\n");
     }
+
+    // Process the result of the pattern application.
+    if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
+      result = failure();
     if (succeeded(result)) {
       LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
       break;
index 4632533..cf6770a 100644 (file)
@@ -93,10 +93,12 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
   os << "} // end namespace\n\n";
 
   // Emit function to add the generated matchers to the pattern list.
-  os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
-        "::mlir::RewritePatternSet &patterns) {\n";
+  os << "template <typename... ConfigsT>\n"
+        "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
+        "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
   for (const auto &name : patternNames)
-    os << "  patterns.add<" << name << ">(patterns.getContext());\n";
+    os << "  patterns.add<" << name
+       << ">(patterns.getContext(), configs...);\n";
   os << "}\n";
 }
 
@@ -104,14 +106,15 @@ void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
                        StringSet<> &nativeFunctions) {
   const char *patternClassStartStr = R"(
 struct {0} : ::mlir::PDLPatternModule {{
-  {0}(::mlir::MLIRContext *context)
+  template <typename... ConfigsT>
+  {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
     : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
 )";
   os << llvm::formatv(patternClassStartStr, patternName);
 
   os << "R\"mlir(";
   pattern->print(os, OpPrintingFlags().enableDebugInfo());
-  os << "\n    )mlir\", context)) {\n";
+  os << "\n    )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
 
   // Register any native functions used within the pattern.
   StringSet<> registeredNativeFunctions;
index 61bc4ff..616e437 100644 (file)
@@ -3273,6 +3273,76 @@ auto ConversionTarget::getOpInfo(OperationName op) const
 }
 
 //===----------------------------------------------------------------------===//
+// PDL Configuration
+//===----------------------------------------------------------------------===//
+
+void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
+  auto &rewriterImpl =
+      static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
+  rewriterImpl.currentTypeConverter = getTypeConverter();
+}
+
+void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
+  auto &rewriterImpl =
+      static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
+  rewriterImpl.currentTypeConverter = nullptr;
+}
+
+/// Remap the given value using the rewriter and the type converter in the
+/// provided config.
+static FailureOr<SmallVector<Value>>
+pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
+  SmallVector<Value> mappedValues;
+  if (failed(rewriter.getRemappedValues(values, mappedValues)))
+    return failure();
+  return std::move(mappedValues);
+}
+
+void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
+  patterns.getPDLPatterns().registerRewriteFunction(
+      "convertValue",
+      [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
+        auto results = pdllConvertValues(
+            static_cast<ConversionPatternRewriter &>(rewriter), value);
+        if (failed(results))
+          return failure();
+        return results->front();
+      });
+  patterns.getPDLPatterns().registerRewriteFunction(
+      "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
+        return pdllConvertValues(
+            static_cast<ConversionPatternRewriter &>(rewriter), values);
+      });
+  patterns.getPDLPatterns().registerRewriteFunction(
+      "convertType",
+      [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
+        auto &rewriterImpl =
+            static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
+        if (TypeConverter *converter = rewriterImpl.currentTypeConverter) {
+          if (Type newType = converter->convertType(type))
+            return newType;
+          return failure();
+        }
+        return type;
+      });
+  patterns.getPDLPatterns().registerRewriteFunction(
+      "convertTypes",
+      [](PatternRewriter &rewriter,
+         TypeRange types) -> FailureOr<SmallVector<Type>> {
+        auto &rewriterImpl =
+            static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
+        TypeConverter *converter = rewriterImpl.currentTypeConverter;
+        if (!converter)
+          return SmallVector<Type>(types);
+
+        SmallVector<Type> remappedTypes;
+        if (failed(converter->convertTypes(types, remappedTypes)))
+          return failure();
+        return std::move(remappedTypes);
+      });
+}
+
+//===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/test/Transforms/test-dialect-conversion-pdll.mlir b/mlir/test/Transforms/test-dialect-conversion-pdll.mlir
new file mode 100644 (file)
index 0000000..97c8dfc
--- /dev/null
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -test-dialect-conversion-pdll | FileCheck %s
+
+// CHECK-LABEL: @TestSingleConversion
+func.func @TestSingleConversion() {
+  // CHECK: %[[CAST:.*]] = "test.cast"() : () -> f64
+  // CHECK-NEXT: "test.return"(%[[CAST]]) : (f64) -> ()
+  %result = "test.cast"() : () -> (i64)
+  "test.return"(%result) : (i64) -> ()
+}
+
+// CHECK-LABEL: @TestLingeringConversion
+func.func @TestLingeringConversion() -> i64 {
+  // CHECK: %[[ORIG_CAST:.*]] = "test.cast"() : () -> f64
+  // CHECK: %[[MATERIALIZE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ORIG_CAST]] : f64 to i64
+  // CHECK-NEXT: return %[[MATERIALIZE_CAST]] : i64
+  %result = "test.cast"() : () -> (i64)
+  return %result : i64
+}
index 8672e87..0379dcd 100644 (file)
@@ -1,8 +1,18 @@
+add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen
+  TestDialectConversion.pdll
+  TestDialectConversionPDLLPatterns.h.inc
+
+  EXTRA_INCLUDES
+  ${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
+  )
+
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
   TestCommutativityUtils.cpp
   TestConstantFold.cpp
   TestControlFlowSink.cpp
+  TestDialectConversion.cpp
   TestInlining.cpp
   TestIntRangeInference.cpp
   TestTopologicalSort.cpp
@@ -12,8 +22,12 @@ add_mlir_library(MLIRTestTransforms
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
 
+  DEPENDS
+  MLIRTestDialectConversionPDLLPatternsIncGen
+
   LINK_LIBS PUBLIC
   MLIRAnalysis
+  MLIRFuncDialect
   MLIRInferIntRangeInterface
   MLIRTestDialect
   MLIRTransforms
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
new file mode 100644 (file)
index 0000000..996b7b9
--- /dev/null
@@ -0,0 +1,96 @@
+//===- TestDialectConversion.cpp - Test DialectConversion functionality ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// Test PDLL Support
+//===----------------------------------------------------------------------===//
+
+#include "TestDialectConversionPDLLPatterns.h.inc"
+
+namespace {
+struct PDLLTypeConverter : public TypeConverter {
+  PDLLTypeConverter() {
+    addConversion(convertType);
+    addArgumentMaterialization(materializeCast);
+    addSourceMaterialization(materializeCast);
+  }
+
+  static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
+    // Convert I64 to F64.
+    if (t.isSignlessInteger(64)) {
+      results.push_back(FloatType::getF64(t.getContext()));
+      return success();
+    }
+
+    // Otherwise, convert the type directly.
+    results.push_back(t);
+    return success();
+  }
+  /// Hook for materializing a conversion.
+  static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
+                                         ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  }
+};
+
+struct TestDialectConversionPDLLPass
+    : public PassWrapper<TestDialectConversionPDLLPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass)
+
+  StringRef getArgument() const final { return "test-dialect-conversion-pdll"; }
+  StringRef getDescription() const final {
+    return "Test DialectConversion PDLL functionality";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
+  }
+  LogicalResult initialize(MLIRContext *ctx) override {
+    // Build the pattern set within the `initialize` to avoid recompiling PDL
+    // patterns during each `runOnOperation` invocation.
+    RewritePatternSet patternList(ctx);
+    registerConversionPDLFunctions(patternList);
+    populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter));
+    patterns = std::move(patternList);
+    return success();
+  }
+
+  void runOnOperation() final {
+    mlir::ConversionTarget target(getContext());
+    target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
+    target.addDynamicallyLegalDialect<TestDialect>(
+        [this](Operation *op) { return converter.isLegal(op); });
+
+    if (failed(mlir::applyFullConversion(getOperation(), target, patterns)))
+      signalPassFailure();
+  }
+
+  FrozenRewritePatternSet patterns;
+  PDLLTypeConverter converter;
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestDialectConversionPasses() {
+  PassRegistration<TestDialectConversionPDLLPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.pdll b/mlir/test/lib/Transforms/TestDialectConversion.pdll
new file mode 100644 (file)
index 0000000..c29e852
--- /dev/null
@@ -0,0 +1,19 @@
+//===- TestPDLL.pdll - Test PDLL functionality ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestOps.td"
+#include "mlir/Transforms/DialectConversion.pdll"
+
+/// Change the result type of a producer.
+// FIXME: We shouldn't need to specify arguments for the result cast.
+Pattern => replace op<test.cast>(args: ValueRange) -> (results: TypeRange)
+  with op<test.cast>(args) -> (convertTypes(results));
+
+/// Pass through test.return conversion.
+Pattern => replace op<test.return>(args: ValueRange)
+  with op<test.return>(convertValues(args));
diff --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg
new file mode 100644 (file)
index 0000000..8cfe5cd
--- /dev/null
@@ -0,0 +1 @@
+config.suffixes.remove('.pdll')
index 4dae177..f975307 100644 (file)
@@ -5,18 +5,19 @@
 // check that we handle overlap.
 
 // CHECK: struct GeneratedPDLLPattern0 : ::mlir::PDLPatternModule {
+// CHECK:  template <typename... ConfigsT>
 // CHECK:  : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
 // CHECK:  R"mlir(
 // CHECK:    pdl.pattern
 // CHECK:      operation "test.op"
-// CHECK:  )mlir", context))
+// CHECK:  )mlir", context), std::forward<ConfigsT>(configs)...)
 
 // CHECK: struct NamedPattern : ::mlir::PDLPatternModule {
 // CHECK:  : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
 // CHECK:  R"mlir(
 // CHECK:    pdl.pattern
 // CHECK:      operation "test.op2"
-// CHECK:  )mlir", context))
+// CHECK:  )mlir", context), std::forward<ConfigsT>(configs)...)
 
 // CHECK: struct GeneratedPDLLPattern1 : ::mlir::PDLPatternModule {
 
 // CHECK:  R"mlir(
 // CHECK:    pdl.pattern
 // CHECK:      operation "test.op3"
-// CHECK:  )mlir", context))
+// CHECK:  )mlir", context), std::forward<ConfigsT>(configs)...)
 
-// CHECK:      static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns) {
-// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern0>(patterns.getContext());
-// CHECK-NEXT:   patterns.add<NamedPattern>(patterns.getContext());
-// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern1>(patterns.getContext());
-// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern2>(patterns.getContext());
+// CHECK:      static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {
+// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern0>(patterns.getContext(), configs...);
+// CHECK-NEXT:   patterns.add<NamedPattern>(patterns.getContext(), configs...);
+// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern1>(patterns.getContext(), configs...);
+// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern2>(patterns.getContext(), configs...);
 // CHECK-NEXT: }
 
 Pattern => erase op<test.op>;
index 9eb0a47..1e53d85 100644 (file)
@@ -76,6 +76,7 @@ void registerTestDataLayoutQuery();
 void registerTestDeadCodeAnalysisPass();
 void registerTestDecomposeCallGraphTypes();
 void registerTestDiagnosticsPass();
+void registerTestDialectConversionPasses();
 void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
 void registerTestExpandMathPass();
@@ -170,6 +171,7 @@ void registerTestPasses() {
   mlir::test::registerTestConstantFold();
   mlir::test::registerTestControlFlowSink();
   mlir::test::registerTestDiagnosticsPass();
+  mlir::test::registerTestDialectConversionPasses();
 #if MLIR_CUDA_CONVERSIONS_ENABLED
   mlir::test::registerTestGpuSerializeToCubinPass();
 #endif