template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<void, Op, Args...>::value_t;
+/// Check if a Callable type can be invoked with the given set of arg types.
+namespace detail {
+template <typename Callable, typename... Args>
+using is_invocable =
+ decltype(std::declval<Callable &>()(std::declval<Args>()...));
+} // namespace detail
+
+template <typename Callable, typename... Args>
+using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
+
//===----------------------------------------------------------------------===//
// Extra additions to <iterator>
//===----------------------------------------------------------------------===//
/// The type used to store operation legality information.
using LegalityMapTy = llvm::MapVector<OperationName, LegalizationAction>;
+ /// The signature of the callback used to determine if an operation is
+ /// dynamically legal on the target.
+ using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
+
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
virtual ~ConversionTarget() = default;
- /// Runs a custom legalization query for the given operation. This should
- /// return true if the given operation is legal, otherwise false.
- virtual bool isDynamicallyLegal(Operation *op) const {
- llvm_unreachable(
- "targets with custom legalization must override 'isDynamicallyLegal'");
- }
-
//===--------------------------------------------------------------------===//
// Legality Registration
//===--------------------------------------------------------------------===//
addDynamicallyLegalOp<OpT2, OpTs...>();
}
+ /// Register the given operation as dynamically legal and set the dynamic
+ /// legalization callback to the one provided.
+ template <typename OpT>
+ void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
+ OperationName opName(OpT::getOperationName(), &ctx);
+ setOpAction(opName, LegalizationAction::Dynamic);
+ setLegalityCallback(opName, callback);
+ }
+ template <typename OpT, typename OpT2, typename... OpTs>
+ void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
+ addDynamicallyLegalOp<OpT>(callback);
+ addDynamicallyLegalOp<OpT2, OpTs...>(callback);
+ }
+ template <typename OpT, class Callable>
+ typename std::enable_if<!is_invocable<Callable, Operation *>::value>::type
+ addDynamicallyLegalOp(Callable &&callback) {
+ addDynamicallyLegalOp<OpT>(
+ [=](Operation *op) { return callback(cast<OpT>(op)); });
+ }
+
/// Register the given operation as illegal, i.e. this operation is known to
/// not be supported by this target.
template <typename OpT> void addIllegalOp() {
SmallVector<StringRef, 2> dialectNames({name, names...});
setDialectAction(dialectNames, LegalizationAction::Dynamic);
}
- template <typename... Args> void addDynamicallyLegalDialect() {
+ template <typename... Args>
+ void addDynamicallyLegalDialect(
+ llvm::Optional<DynamicLegalityCallbackFn> callback = llvm::None) {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Dynamic);
+ if (callback)
+ setLegalityCallback(dialectNames, *callback);
}
/// Register the operations of the given dialects as illegal, i.e.
/// Get the legality action for the given operation.
llvm::Optional<LegalizationAction> getOpAction(OperationName op) const;
+ /// Return true if the given operation instance is legal on this target.
+ bool isLegal(Operation *op) const;
+
+protected:
+ /// Runs a custom legalization query for the given operation. This should
+ /// return true if the given operation is legal, otherwise false.
+ virtual bool isDynamicallyLegal(Operation *op) const {
+ llvm_unreachable(
+ "targets with custom legalization must override 'isDynamicallyLegal'");
+ }
+
private:
+ /// Set the dynamic legality callback for the given operation.
+ void setLegalityCallback(OperationName name,
+ const DynamicLegalityCallbackFn &callback);
+
+ /// Set the dynamic legality callback for the given dialects.
+ void setLegalityCallback(ArrayRef<StringRef> dialects,
+ const DynamicLegalityCallbackFn &callback);
+
/// A deterministic mapping of operation name to the specific legality action
/// to take.
LegalityMapTy legalOperations;
+ /// A set of dynamic legality callbacks for given operation names.
+ DenseMap<OperationName, DynamicLegalityCallbackFn> opLegalityFns;
+
/// A deterministic mapping of dialect name to the specific legality action to
/// take.
llvm::StringMap<LegalizationAction> legalDialects;
+ /// A set of dynamic legality callbacks for given dialect names.
+ llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
+
/// The current context this target applies to.
MLIRContext &ctx;
};
LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
<< "\n");
- // Check if this was marked legal by the target.
- if (auto action = target.getOpAction(op->getName())) {
- // Check if this operation is always legal.
- if (*action == LegalizationAction::Legal)
- return success();
-
- // Otherwise, handle dynamic legalization.
- if (*action == LegalizationAction::Dynamic) {
- LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n");
- if (target.isDynamicallyLegal(op))
- return success();
- }
-
- // Fallthough to see if a pattern can convert this into a legal operation.
+ // Check if this operation is legal on the target.
+ if (target.isLegal(op)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "-- Success : Operation marked legal by the target\n");
+ return success();
}
// Otherwise, we need to apply a legalization pattern to this operation.
return llvm::None;
}
+/// Return if the given operation instance is legal on this target.
+bool ConversionTarget::isLegal(Operation *op) const {
+ auto action = getOpAction(op->getName());
+
+ // Handle dynamic legality.
+ if (action == LegalizationAction::Dynamic) {
+ // Check for callbacks on the operation or dialect.
+ auto opFn = opLegalityFns.find(op->getName());
+ if (opFn != opLegalityFns.end())
+ return opFn->second(op);
+ auto dialectFn = dialectLegalityFns.find(op->getName().getDialect());
+ if (dialectFn != dialectLegalityFns.end())
+ return dialectFn->second(op);
+
+ // Otherwise, invoke the hook on the derived instance.
+ return isDynamicallyLegal(op);
+ }
+
+ // Otherwise, the operation is only legal if it was marked 'Legal'.
+ return action == LegalizationAction::Legal;
+}
+
+/// Set the dynamic legality callback for the given operation.
+void ConversionTarget::setLegalityCallback(
+ OperationName name, const DynamicLegalityCallbackFn &callback) {
+ assert(callback && "expected valid legality callback");
+ opLegalityFns[name] = callback;
+}
+
+/// Set the dynamic legality callback for the given dialects.
+void ConversionTarget::setLegalityCallback(
+ ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
+ assert(callback && "expected valid legality callback");
+ for (StringRef dialect : dialects)
+ dialectLegalityFns[dialect] = callback;
+}
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
}
};
-struct TestConversionTarget : public ConversionTarget {
- TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
- addLegalOp<LegalOpA, TestValidOp>();
- addDynamicallyLegalOp<TestReturnOp>();
- addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
- }
- bool isDynamicallyLegal(Operation *op) const final {
- // Don't allow F32 operands.
- return llvm::none_of(op->getOperandTypes(),
- [](Type type) { return type.isF32(); });
- }
-};
-
struct TestLegalizePatternDriver
: public ModulePass<TestLegalizePatternDriver> {
void runOnModule() override {
TestDropOp, TestPassthroughInvalidOp,
TestSplitReturnType>::build(patterns, &getContext());
+ // Define the conversion target used for the test.
+ ConversionTarget target(getContext());
+ target.addLegalOp<LegalOpA, TestValidOp>();
+ target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
+ target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
+ // Don't allow F32 operands.
+ return llvm::none_of(op.getOperandTypes(),
+ [](Type type) { return type.isF32(); });
+ });
+
TestTypeConverter converter;
- TestConversionTarget target(getContext());
(void)applyPartialConversion(getModule(), target, std::move(patterns),
&converter);
}