/// a counter for an action with `skip=47` and `count=2`, would skip the
/// first 47 executions, then execute twice, and finally prevent any further
/// executions.
-class DebugCounter : public ActionManager::GenericHandler {
+class DebugCounter {
public:
DebugCounter();
- ~DebugCounter() override;
+ ~DebugCounter();
/// Add a counter for the given action tag. `countToSkip` is the number
/// of counter executions to skip before enabling execution of the action.
void addCounter(StringRef actionTag, int64_t countToSkip,
int64_t countToStopAfter);
- /// Register a counter with the specified name.
- FailureOr<bool> execute(llvm::function_ref<void()> transform,
- const Action &action) final;
+ /// Entry point for handling actions.
+ void operator()(llvm::function_ref<void()> transform, const Action &action);
/// Print the counters that have been registered with this instance to the
/// provided output stream.
static void registerCLOptions();
private:
+ // Returns true if the next action matching this tag should be executed.
+ bool shouldExecute(StringRef tag);
+
/// Apply the registered CL options to this debug counter instance.
void applyCLOptions();
/// An action is a specific action that is to be taken by the compiler,
/// that can be toggled and controlled by an external user. There are no
/// constraints on the granularity of an action, it could be as simple as
-/// "perform this fold" and as complex as "run this pass pipeline". Via template
-/// parameters `ParameterTs`, a user may provide the set of argument types that
-/// are provided when handling a query on this action.
+/// "perform this fold" and as complex as "run this pass pipeline".
///
/// This class represents the base class of the ActionImpl class (see below).
/// This holds the template-invariant elements of the Action class.
TypeID actionID;
};
-//===----------------------------------------------------------------------===//
-// ActionManager
-//===----------------------------------------------------------------------===//
-
-/// This class represents manages actions, and orchestrates the
-/// communication between action queries and action handlers. An action handler
-/// is either an action specific handler, i.e. a derived class of
-/// `MyActionType::Handler`, or a generic handler, i.e. a derived class of
-/// `ActionManager::GenericHandler`. For more details on action specific
-/// handlers, see the definition of `Action::Handler` below. For more
-/// details on generic handlers, see `ActionManager::GenericHandler` below.
-class ActionManager {
-public:
- //===--------------------------------------------------------------------===//
- // Handlers
- //===--------------------------------------------------------------------===//
-
- /// This class represents the base class of an action handler.
- class HandlerBase {
- public:
- virtual ~HandlerBase() = default;
-
- /// Return the unique handler id of this handler, use for casting
- /// functionality.
- TypeID getHandlerID() const { return handlerID; }
-
- protected:
- HandlerBase(TypeID handlerID) : handlerID(handlerID) {}
-
- /// The type of the derived handler class. This allows for detecting if a
- /// handler can handle a given action type.
- TypeID handlerID;
- };
-
- /// This class represents a generic action handler. A generic handler allows
- /// for handling any action type. Handlers of this type are useful for
- /// implementing general functionality that doesn't necessarily need to
- /// interpret the exact action parameters, or can rely on an external
- /// interpreter (such as the user). Given that these handlers are generic,
- /// they take a set of opaque parameters that try to map the context of the
- /// action type in a generic way.
- class GenericHandler : public HandlerBase {
- public:
- GenericHandler() : HandlerBase(TypeID::get<GenericHandler>()) {}
-
- /// This hook allows for controlling the execution of an action. It should
- /// return failure if the handler could not process the action, or whether
- /// the `transform` was executed or not.
- virtual FailureOr<bool> execute(function_ref<void()> transform,
- const Action &action) {
- return failure();
- }
-
- /// Provide classof to allow casting between handler types.
- static bool classof(const ActionManager::HandlerBase *handler) {
- return handler->getHandlerID() == TypeID::get<GenericHandler>();
- }
- };
-
- /// Register the given action handler with the manager.
- void registerActionHandler(std::unique_ptr<HandlerBase> handler) {
- actionHandlers.emplace_back(std::move(handler));
- }
- template <typename T>
- void registerActionHandler() {
- registerActionHandler(std::make_unique<T>());
- }
-
- //===--------------------------------------------------------------------===//
- // Action Queries
- //===--------------------------------------------------------------------===//
-
- /// Dispatch an action represented by the `transform` callback. If no handler
- /// is found, the `transform` callback is invoked directly.
- /// Return true if the action was executed, false otherwise.
- template <typename ActionType, typename... Args>
- bool execute(function_ref<void()> transform, Args &&...args) {
- if (actionHandlers.empty()) {
- transform();
- return true;
- }
-
- // Invoke the `execute` method on the provided handler.
- auto executeFn = [&](auto *handler, auto &&...handlerParams) {
- return handler->execute(
- transform,
- ActionType(std::forward<decltype(handlerParams)>(handlerParams)...));
- };
- FailureOr<bool> result = dispatchToHandler<ActionType, bool>(
- executeFn, std::forward<Args>(args)...);
- // no handler found, execute the transform directly.
- if (failed(result)) {
- transform();
- return true;
- }
-
- // Return the result of the handler.
- return *result;
- }
-
-private:
- //===--------------------------------------------------------------------===//
- // Query to Handler Dispatch
- //===--------------------------------------------------------------------===//
-
- /// Dispath a given callback on any handlers that are able to process queries
- /// on the given action type. This method returns failure if no handlers could
- /// process the action, or success(with a result) if a handler processed the
- /// action.
- template <typename ActionType, typename ResultT, typename HandlerCallbackT,
- typename... Args>
- FailureOr<ResultT> dispatchToHandler(HandlerCallbackT &&handlerCallback,
- Args &&...args) {
- static_assert(ActionType::template canHandleWith<Args...>(),
- "cannot execute action with the given set of parameters");
-
- // Process any generic or action specific handlers.
- // The first handler that gives us a result is the one that we will return.
- for (std::unique_ptr<HandlerBase> &it : reverse(actionHandlers)) {
- FailureOr<ResultT> result = failure();
- if (auto *handler = dyn_cast<typename ActionType::Handler>(&*it)) {
- result = handlerCallback(handler, std::forward<Args>(args)...);
- } else if (auto *genericHandler = dyn_cast<GenericHandler>(&*it)) {
- result = handlerCallback(genericHandler, std::forward<Args>(args)...);
- }
-
- // If the handler succeeded, return the result. Otherwise, try a new
- // handler.
- if (succeeded(result))
- return result;
- }
- return failure();
- }
-
- /// The set of action handlers that have been registered with the manager.
- SmallVector<std::unique_ptr<HandlerBase>> actionHandlers;
-};
-
/// CRTP Implementation of an action. This class provides a base class for
/// implementing specific actions.
/// Derived classes are expected to provide the following:
/// * static constexpr StringLiteral tag = "...";
/// - This method returns a unique string identifier, similar to a command
/// line flag or DEBUG_TYPE.
-template <typename Derived, typename... ParameterTs>
+template <typename Derived>
class ActionImpl : public Action {
public:
ActionImpl() : Action(TypeID::get<Derived>()) {}
/// Forward tag access to the derived class.
StringRef getTag() const final { return Derived::tag; }
-
- class Handler : public ActionManager::HandlerBase {
- public:
- Handler() : HandlerBase(TypeID::get<Derived>()) {}
-
- /// This hook allows for controlling the execution of an action.
- /// `parameters` correspond to the set of values provided by the
- /// action as context. It should return failure if the handler could not
- /// process the action, passing it to the next registered handler.
- virtual FailureOr<bool> execute(function_ref<void()> transform,
- const Derived &action) {
- return failure();
- }
-
- /// Provide classof to allow casting between handler types.
- static bool classof(const ActionManager::HandlerBase *handler) {
- return handler->getHandlerID() == TypeID::get<Derived>();
- }
- };
-
-private:
- /// Returns true if the action can be handled within the given set of
- /// parameter types.
- template <typename... CallerParameterTs>
- static constexpr bool canHandleWith() {
- return std::is_invocable_v<function_ref<void(ParameterTs...)>,
- CallerParameterTs...>;
- }
-
- /// Allow access to `canHandleWith`.
- friend class ActionManager;
};
} // namespace tracing
namespace mlir {
namespace tracing {
-class ActionManager;
+class Action;
}
class DiagnosticEngine;
class Dialect;
/// instances. This should not be used directly.
StorageUniquer &getAttributeUniquer();
- /// Returns the manager of debug actions within the context.
- tracing::ActionManager &getActionManager();
-
/// These APIs are tracking whether the context will be used in a
/// multithreading environment: this has no effect other than enabling
/// assertions on misuses of some APIs.
/// (attributes, operations, types, etc.).
llvm::hash_code getRegistryHash();
+ //===--------------------------------------------------------------------===//
+ // Action API
+ //===--------------------------------------------------------------------===//
+
+ /// Signatures for the action handler that can be registered with the context.
+ using HandlerTy =
+ std::function<void(function_ref<void()>, const tracing::Action &)>;
+
+ /// Register a handler for handling actions that are dispatched through this
+ /// context. A nullptr handler can be set to disable a previously set handler.
+ void registerActionHandler(HandlerTy handler);
+
+ /// Return true if a valid ActionHandler is set.
+ bool hasActionHandler();
+
+ /// Dispatch the provided action to the handler if any, or just execute it.
+ void executeAction(function_ref<void()> actionFn,
+ const tracing::Action &action) {
+ if (LLVM_UNLIKELY(hasActionHandler()))
+ executeActionInternal(actionFn, action);
+ else
+ actionFn();
+ }
+
+ /// Dispatch the provided action to the handler if any, or just execute it.
+ template <typename ActionTy, typename... Args>
+ void executeAction(function_ref<void()> actionFn, Args &&...args) {
+ if (LLVM_UNLIKELY(hasActionHandler()))
+ executeActionInternal<ActionTy, Args...>(actionFn,
+ std::forward<Args>(args)...);
+ else
+ actionFn();
+ }
+
private:
/// Return true if the given dialect is currently loading.
bool isDialectLoading(StringRef dialectNamespace);
+ /// Internal helper for the dispatch method.
+ void executeActionInternal(function_ref<void()> actionFn,
+ const tracing::Action &action);
+
+ /// Internal helper for the dispatch method. We get here after checking that
+ /// there is a handler, for the purpose of keeping this code out-of-line. and
+ /// avoid calling the ctor for the Action unnecessarily.
+ template <typename ActionTy, typename... Args>
+ LLVM_ATTRIBUTE_NOINLINE void
+ executeActionInternal(function_ref<void()> actionFn, Args &&...args) {
+ executeActionInternal(actionFn, ActionTy(std::forward<Args>(args)...));
+ }
+
const std::unique_ptr<MLIRContextImpl> impl;
MLIRContext(const MLIRContext &) = delete;
counters.try_emplace(actionTag, countToSkip, countToStopAfter);
}
-// Register a counter with the specified name.
-FailureOr<bool> DebugCounter::execute(llvm::function_ref<void()> transform,
- const Action &action) {
- auto counterIt = counters.find(action.getTag());
+void DebugCounter::operator()(llvm::function_ref<void()> transform,
+ const Action &action) {
+ if (shouldExecute(action.getTag()))
+ transform();
+}
+
+bool DebugCounter::shouldExecute(StringRef tag) {
+ auto counterIt = counters.find(tag);
if (counterIt == counters.end())
return true;
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/RWMutex.h"
// Debugging
//===--------------------------------------------------------------------===//
- /// An action manager for use within the context.
- tracing::ActionManager actionManager;
+ /// An action handler for handling actions that are dispatched through this
+ /// context.
+ std::function<void(function_ref<void()>, const tracing::Action &)>
+ actionHandler;
//===--------------------------------------------------------------------===//
// Diagnostics
}
//===----------------------------------------------------------------------===//
-// Debugging
+// Action Handling
//===----------------------------------------------------------------------===//
-tracing::ActionManager &MLIRContext::getActionManager() {
- return getImpl().actionManager;
+void MLIRContext::registerActionHandler(HandlerTy handler) {
+ getImpl().actionHandler = std::move(handler);
}
+/// Dispatch the provided action to the handler if any, or just execute it.
+void MLIRContext::executeActionInternal(function_ref<void()> actionFn,
+ const tracing::Action &action) {
+ assert(getImpl().actionHandler);
+ getImpl().actionHandler(actionFn, action);
+}
+
+bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
+
//===----------------------------------------------------------------------===//
// Diagnostic Handlers
//===----------------------------------------------------------------------===//
context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects());
if (config.shouldVerifyDiagnostics())
context.printOpOnDiagnostic(false);
- context.getActionManager().registerActionHandler<tracing::DebugCounter>();
+ context.registerActionHandler(tracing::DebugCounter());
// If we are in verify diagnostics mode then we have a lot of work to do,
// otherwise just perform the actions without worrying about it.
};
TEST(DebugCounterTest, CounterTest) {
- std::unique_ptr<DebugCounter> counter = std::make_unique<DebugCounter>();
- counter->addCounter(CounterAction::tag, /*countToSkip=*/1,
- /*countToStopAfter=*/3);
+ DebugCounter counter;
+ counter.addCounter(CounterAction::tag, /*countToSkip=*/1,
+ /*countToStopAfter=*/3);
- ActionManager manager;
- manager.registerActionHandler(std::move(counter));
-
- auto noOp = []() { return; };
+ int count = 0;
+ auto noOp = [&]() {
+ ++count;
+ return;
+ };
// The first execution is skipped.
- EXPECT_FALSE(manager.execute<CounterAction>(noOp));
+ counter(noOp, CounterAction{});
+ EXPECT_EQ(count, 0);
// The counter stops after 3 successful executions.
- EXPECT_TRUE(manager.execute<CounterAction>(noOp));
- EXPECT_TRUE(manager.execute<CounterAction>(noOp));
- EXPECT_TRUE(manager.execute<CounterAction>(noOp));
- EXPECT_FALSE(manager.execute<CounterAction>(noOp));
+ counter(noOp, CounterAction{});
+ EXPECT_EQ(count, 1);
+ counter(noOp, CounterAction{});
+ EXPECT_EQ(count, 2);
+ counter(noOp, CounterAction{});
+ EXPECT_EQ(count, 3);
+ counter(noOp, CounterAction{});
+ EXPECT_EQ(count, 3);
}
} // namespace
+++ /dev/null
-//===- ActionTest.cpp - Debug Action Tests ---------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/Action.h"
-#include "mlir/Support/TypeID.h"
-#include "gmock/gmock.h"
-
-using namespace mlir;
-using namespace mlir::tracing;
-
-namespace {
-struct SimpleAction : ActionImpl<SimpleAction> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleAction)
- static constexpr StringLiteral tag = "simple-action";
-};
-struct OtherSimpleAction : ActionImpl<OtherSimpleAction> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherSimpleAction)
- static constexpr StringLiteral tag = "other-simple-action";
-};
-struct ParametricAction : ActionImpl<ParametricAction, bool> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction)
- ParametricAction(bool executeParam) : executeParam(executeParam) {}
- bool executeParam;
- static constexpr StringLiteral tag = "param-action";
-};
-
-TEST(ActionTest, GenericHandler) {
- ActionManager manager;
-
- // A generic handler that always executes the simple action, but not the
- // parametric action.
- struct GenericHandler : ActionManager::GenericHandler {
- FailureOr<bool> execute(llvm::function_ref<void()> transform,
- const Action &action) final {
- StringRef tag = action.getTag();
- if (isa<SimpleAction>(action)) {
- EXPECT_EQ(tag, SimpleAction::tag);
- transform();
- return true;
- }
-
- EXPECT_TRUE(isa<ParametricAction>(action));
- return false;
- }
- };
- manager.registerActionHandler<GenericHandler>();
-
- auto noOp = []() { return; };
- EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
- EXPECT_FALSE(manager.execute<ParametricAction>(noOp, true));
-}
-
-TEST(ActionTest, ActionSpecificHandler) {
- ActionManager manager;
-
- // Handler that simply uses the input as the decider.
- struct ActionSpecificHandler : ParametricAction::Handler {
- FailureOr<bool> execute(llvm::function_ref<void()> transform,
- const ParametricAction &action) final {
- if (action.executeParam)
- transform();
- return action.executeParam;
- }
- };
- manager.registerActionHandler<ActionSpecificHandler>();
-
- int count = 0;
- auto incCount = [&]() { count++; };
- EXPECT_TRUE(manager.execute<ParametricAction>(incCount, true));
- EXPECT_EQ(count, 1);
- EXPECT_FALSE(manager.execute<ParametricAction>(incCount, false));
- EXPECT_EQ(count, 1);
-
- // There is no handler for the simple action, so it is always executed.
- EXPECT_TRUE(manager.execute<SimpleAction>(incCount));
- EXPECT_EQ(count, 2);
-}
-
-TEST(ActionTest, DebugCounterHandler) {
- ActionManager manager;
-
- // Handler that uses the number of action executions as the decider.
- struct DebugCounterHandler : SimpleAction::Handler {
- FailureOr<bool> execute(llvm::function_ref<void()> transform,
- const SimpleAction &action) final {
- bool shouldExecute = numExecutions++ < 3;
- if (shouldExecute)
- transform();
- return shouldExecute;
- }
- unsigned numExecutions = 0;
- };
- manager.registerActionHandler<DebugCounterHandler>();
-
- // Check that the action is executed 3 times, but no more after.
- auto noOp = []() { return; };
- EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
- EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
- EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
- EXPECT_FALSE(manager.execute<SimpleAction>(noOp));
- EXPECT_FALSE(manager.execute<SimpleAction>(noOp));
-}
-
-TEST(ActionTest, NonOverlappingActionSpecificHandlers) {
- ActionManager manager;
-
- // One handler returns true and another returns false
- struct SimpleActionHandler : SimpleAction::Handler {
- FailureOr<bool> execute(llvm::function_ref<void()> transform,
- const SimpleAction &action) final {
- transform();
- return true;
- }
- };
- struct OtherSimpleActionHandler : OtherSimpleAction::Handler {
- FailureOr<bool> execute(llvm::function_ref<void()> transform,
- const OtherSimpleAction &action) final {
- transform();
- return false;
- }
- };
- manager.registerActionHandler<SimpleActionHandler>();
- manager.registerActionHandler<OtherSimpleActionHandler>();
- auto noOp = []() { return; };
- EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
- EXPECT_FALSE(manager.execute<OtherSimpleAction>(noOp));
-}
-
-} // namespace
add_mlir_unittest(MLIRIRTests
- ActionTest.cpp
AdaptorTest.cpp
AttributeTest.cpp
DialectTest.cpp