Add support for providing a legality callback for dynamic legality in DialectConversion.
authorRiver Riddle <riverriddle@google.com>
Fri, 19 Jul 2019 01:20:03 +0000 (18:20 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 19 Jul 2019 18:40:19 +0000 (11:40 -0700)
This allows for providing specific handling for dynamically legal operations/dialects without overriding the general 'isDynamicallyLegal' hook. This also means that a derived ConversionTarget class need not always be defined when some operations are dynamically legal.

Example usage:

ConversionTarget target(...);
target.addDynamicallyLegalOp<ReturnOp>([](ReturnOp op) {
  return ...
};
target.addDynamicallyLegalDialect<StandardOpsDialect>([](Operation *op) {
  return ...
};

PiperOrigin-RevId: 258884753

mlir/include/mlir/Support/STLExtras.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp

index f038b0f..3448b08 100644 (file)
@@ -113,6 +113,16 @@ struct detector<void_t<Op<Args...>>, Op, Args...> {
 template <template <class...> class Op, class... Args>
 using is_detected = typename detail::detector<void, Op, Args...>::value_t;
 
+/// Check if a Callable type can be invoked with the given set of arg types.
+namespace detail {
+template <typename Callable, typename... Args>
+using is_invocable =
+    decltype(std::declval<Callable &>()(std::declval<Args>()...));
+} // namespace detail
+
+template <typename Callable, typename... Args>
+using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
+
 //===----------------------------------------------------------------------===//
 //     Extra additions to <iterator>
 //===----------------------------------------------------------------------===//
index 68c6f12..5543c21 100644 (file)
@@ -312,16 +312,13 @@ public:
   /// The type used to store operation legality information.
   using LegalityMapTy = llvm::MapVector<OperationName, LegalizationAction>;
 
+  /// The signature of the callback used to determine if an operation is
+  /// dynamically legal on the target.
+  using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
+
   ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
   virtual ~ConversionTarget() = default;
 
-  /// Runs a custom legalization query for the given operation. This should
-  /// return true if the given operation is legal, otherwise false.
-  virtual bool isDynamicallyLegal(Operation *op) const {
-    llvm_unreachable(
-        "targets with custom legalization must override 'isDynamicallyLegal'");
-  }
-
   //===--------------------------------------------------------------------===//
   // Legality Registration
   //===--------------------------------------------------------------------===//
@@ -352,6 +349,26 @@ public:
     addDynamicallyLegalOp<OpT2, OpTs...>();
   }
 
+  /// Register the given operation as dynamically legal and set the dynamic
+  /// legalization callback to the one provided.
+  template <typename OpT>
+  void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
+    OperationName opName(OpT::getOperationName(), &ctx);
+    setOpAction(opName, LegalizationAction::Dynamic);
+    setLegalityCallback(opName, callback);
+  }
+  template <typename OpT, typename OpT2, typename... OpTs>
+  void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
+    addDynamicallyLegalOp<OpT>(callback);
+    addDynamicallyLegalOp<OpT2, OpTs...>(callback);
+  }
+  template <typename OpT, class Callable>
+  typename std::enable_if<!is_invocable<Callable, Operation *>::value>::type
+  addDynamicallyLegalOp(Callable &&callback) {
+    addDynamicallyLegalOp<OpT>(
+        [=](Operation *op) { return callback(cast<OpT>(op)); });
+  }
+
   /// Register the given operation as illegal, i.e. this operation is known to
   /// not be supported by this target.
   template <typename OpT> void addIllegalOp() {
@@ -384,9 +401,13 @@ public:
     SmallVector<StringRef, 2> dialectNames({name, names...});
     setDialectAction(dialectNames, LegalizationAction::Dynamic);
   }
-  template <typename... Args> void addDynamicallyLegalDialect() {
+  template <typename... Args>
+  void addDynamicallyLegalDialect(
+      llvm::Optional<DynamicLegalityCallbackFn> callback = llvm::None) {
     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
     setDialectAction(dialectNames, LegalizationAction::Dynamic);
+    if (callback)
+      setLegalityCallback(dialectNames, *callback);
   }
 
   /// Register the operations of the given dialects as illegal, i.e.
@@ -408,15 +429,40 @@ public:
   /// Get the legality action for the given operation.
   llvm::Optional<LegalizationAction> getOpAction(OperationName op) const;
 
+  /// Return true if the given operation instance is legal on this target.
+  bool isLegal(Operation *op) const;
+
+protected:
+  /// Runs a custom legalization query for the given operation. This should
+  /// return true if the given operation is legal, otherwise false.
+  virtual bool isDynamicallyLegal(Operation *op) const {
+    llvm_unreachable(
+        "targets with custom legalization must override 'isDynamicallyLegal'");
+  }
+
 private:
+  /// Set the dynamic legality callback for the given operation.
+  void setLegalityCallback(OperationName name,
+                           const DynamicLegalityCallbackFn &callback);
+
+  /// Set the dynamic legality callback for the given dialects.
+  void setLegalityCallback(ArrayRef<StringRef> dialects,
+                           const DynamicLegalityCallbackFn &callback);
+
   /// A deterministic mapping of operation name to the specific legality action
   /// to take.
   LegalityMapTy legalOperations;
 
+  /// A set of dynamic legality callbacks for given operation names.
+  DenseMap<OperationName, DynamicLegalityCallbackFn> opLegalityFns;
+
   /// A deterministic mapping of dialect name to the specific legality action to
   /// take.
   llvm::StringMap<LegalizationAction> legalDialects;
 
+  /// A set of dynamic legality callbacks for given dialect names.
+  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
+
   /// The current context this target applies to.
   MLIRContext &ctx;
 };
index ed271b6..02ca31f 100644 (file)
@@ -782,20 +782,11 @@ OperationLegalizer::legalize(Operation *op,
   LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
                           << "\n");
 
-  // Check if this was marked legal by the target.
-  if (auto action = target.getOpAction(op->getName())) {
-    // Check if this operation is always legal.
-    if (*action == LegalizationAction::Legal)
-      return success();
-
-    // Otherwise, handle dynamic legalization.
-    if (*action == LegalizationAction::Dynamic) {
-      LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n");
-      if (target.isDynamicallyLegal(op))
-        return success();
-    }
-
-    // Fallthough to see if a pattern can convert this into a legal operation.
+  // Check if this operation is legal on the target.
+  if (target.isLegal(op)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "-- Success : Operation marked legal by the target\n");
+    return success();
   }
 
   // Otherwise, we need to apply a legalization pattern to this operation.
@@ -1293,6 +1284,43 @@ auto ConversionTarget::getOpAction(OperationName op) const
   return llvm::None;
 }
 
+/// Return if the given operation instance is legal on this target.
+bool ConversionTarget::isLegal(Operation *op) const {
+  auto action = getOpAction(op->getName());
+
+  // Handle dynamic legality.
+  if (action == LegalizationAction::Dynamic) {
+    // Check for callbacks on the operation or dialect.
+    auto opFn = opLegalityFns.find(op->getName());
+    if (opFn != opLegalityFns.end())
+      return opFn->second(op);
+    auto dialectFn = dialectLegalityFns.find(op->getName().getDialect());
+    if (dialectFn != dialectLegalityFns.end())
+      return dialectFn->second(op);
+
+    // Otherwise, invoke the hook on the derived instance.
+    return isDynamicallyLegal(op);
+  }
+
+  // Otherwise, the operation is only legal if it was marked 'Legal'.
+  return action == LegalizationAction::Legal;
+}
+
+/// Set the dynamic legality callback for the given operation.
+void ConversionTarget::setLegalityCallback(
+    OperationName name, const DynamicLegalityCallbackFn &callback) {
+  assert(callback && "expected valid legality callback");
+  opLegalityFns[name] = callback;
+}
+
+/// Set the dynamic legality callback for the given dialects.
+void ConversionTarget::setLegalityCallback(
+    ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
+  assert(callback && "expected valid legality callback");
+  for (StringRef dialect : dialects)
+    dialectLegalityFns[dialect] = callback;
+}
+
 //===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
index 410536c..d452edb 100644 (file)
@@ -182,19 +182,6 @@ struct TestTypeConverter : public TypeConverter {
   }
 };
 
-struct TestConversionTarget : public ConversionTarget {
-  TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
-    addLegalOp<LegalOpA, TestValidOp>();
-    addDynamicallyLegalOp<TestReturnOp>();
-    addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
-  }
-  bool isDynamicallyLegal(Operation *op) const final {
-    // Don't allow F32 operands.
-    return llvm::none_of(op->getOperandTypes(),
-                         [](Type type) { return type.isF32(); });
-  }
-};
-
 struct TestLegalizePatternDriver
     : public ModulePass<TestLegalizePatternDriver> {
   void runOnModule() override {
@@ -204,8 +191,17 @@ struct TestLegalizePatternDriver
                        TestDropOp, TestPassthroughInvalidOp,
                        TestSplitReturnType>::build(patterns, &getContext());
 
+    // Define the conversion target used for the test.
+    ConversionTarget target(getContext());
+    target.addLegalOp<LegalOpA, TestValidOp>();
+    target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
+    target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
+      // Don't allow F32 operands.
+      return llvm::none_of(op.getOperandTypes(),
+                           [](Type type) { return type.isF32(); });
+    });
+
     TestTypeConverter converter;
-    TestConversionTarget target(getContext());
     (void)applyPartialConversion(getModule(), target, std::move(patterns),
                                  &converter);
   }