Delete ActionManager and replace it with a simple callback on the Context
authorMehdi Amini <joker.eph@gmail.com>
Sat, 4 Feb 2023 22:16:35 +0000 (14:16 -0800)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 7 Mar 2023 07:25:34 +0000 (08:25 +0100)
The concept of the ActionManager acts as a sort of "Hub" that can receive
various types of action and dispatch them to a set of registered handlers.
One handler will handle the action or it'll cascade to other handlers.

This model does not really fit the current evolution of the Action tracing
and debugging: we can't foresee a good case where this behavior compose with
the use-case behind the handlers. Instead we simplify it with a single
callback installed on the Context.

Differential Revision: https://reviews.llvm.org/D144811

mlir/include/mlir/Debug/Counter.h
mlir/include/mlir/IR/Action.h
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/Debug/DebugCounter.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
mlir/unittests/Debug/DebugCounterTest.cpp
mlir/unittests/IR/ActionTest.cpp [deleted file]
mlir/unittests/IR/CMakeLists.txt

index 2515599..622a87c 100644 (file)
@@ -26,10 +26,10 @@ namespace tracing {
 /// a counter for an action with `skip=47` and `count=2`, would skip the
 /// first 47 executions, then execute twice, and finally prevent any further
 /// executions.
-class DebugCounter : public ActionManager::GenericHandler {
+class DebugCounter {
 public:
   DebugCounter();
-  ~DebugCounter() override;
+  ~DebugCounter();
 
   /// Add a counter for the given action tag. `countToSkip` is the number
   /// of counter executions to skip before enabling execution of the action.
@@ -38,9 +38,8 @@ public:
   void addCounter(StringRef actionTag, int64_t countToSkip,
                   int64_t countToStopAfter);
 
-  /// Register a counter with the specified name.
-  FailureOr<bool> execute(llvm::function_ref<void()> transform,
-                          const Action &action) final;
+  /// Entry point for handling actions.
+  void operator()(llvm::function_ref<void()> transform, const Action &action);
 
   /// Print the counters that have been registered with this instance to the
   /// provided output stream.
@@ -50,6 +49,9 @@ public:
   static void registerCLOptions();
 
 private:
+  // Returns true if the next action matching this tag should be executed.
+  bool shouldExecute(StringRef tag);
+
   /// Apply the registered CL options to this debug counter instance.
   void applyCLOptions();
 
index f53b5ae..569d428 100644 (file)
@@ -31,9 +31,7 @@ namespace tracing {
 /// An action is a specific action that is to be taken by the compiler,
 /// that can be toggled and controlled by an external user. There are no
 /// constraints on the granularity of an action, it could be as simple as
-/// "perform this fold" and as complex as "run this pass pipeline". Via template
-/// parameters `ParameterTs`, a user may provide the set of argument types that
-/// are provided when handling a query on this action.
+/// "perform this fold" and as complex as "run this pass pipeline".
 ///
 /// This class represents the base class of the ActionImpl class (see below).
 /// This holds the template-invariant elements of the Action class.
@@ -60,151 +58,13 @@ protected:
   TypeID actionID;
 };
 
-//===----------------------------------------------------------------------===//
-// ActionManager
-//===----------------------------------------------------------------------===//
-
-/// This class represents manages actions, and orchestrates the
-/// communication between action queries and action handlers. An action handler
-/// is either an action specific handler, i.e. a derived class of
-/// `MyActionType::Handler`, or a generic handler, i.e. a derived class of
-/// `ActionManager::GenericHandler`. For more details on action specific
-/// handlers, see the definition of `Action::Handler` below. For more
-/// details on generic handlers, see `ActionManager::GenericHandler` below.
-class ActionManager {
-public:
-  //===--------------------------------------------------------------------===//
-  // Handlers
-  //===--------------------------------------------------------------------===//
-
-  /// This class represents the base class of an action handler.
-  class HandlerBase {
-  public:
-    virtual ~HandlerBase() = default;
-
-    /// Return the unique handler id of this handler, use for casting
-    /// functionality.
-    TypeID getHandlerID() const { return handlerID; }
-
-  protected:
-    HandlerBase(TypeID handlerID) : handlerID(handlerID) {}
-
-    /// The type of the derived handler class. This allows for detecting if a
-    /// handler can handle a given action type.
-    TypeID handlerID;
-  };
-
-  /// This class represents a generic action handler. A generic handler allows
-  /// for handling any action type. Handlers of this type are useful for
-  /// implementing general functionality that doesn't necessarily need to
-  /// interpret the exact action parameters, or can rely on an external
-  /// interpreter (such as the user). Given that these handlers are generic,
-  /// they take a set of opaque parameters that try to map the context of the
-  /// action type in a generic way.
-  class GenericHandler : public HandlerBase {
-  public:
-    GenericHandler() : HandlerBase(TypeID::get<GenericHandler>()) {}
-
-    /// This hook allows for controlling the execution of an action. It should
-    /// return failure if the handler could not process the action, or whether
-    /// the `transform` was executed or not.
-    virtual FailureOr<bool> execute(function_ref<void()> transform,
-                                    const Action &action) {
-      return failure();
-    }
-
-    /// Provide classof to allow casting between handler types.
-    static bool classof(const ActionManager::HandlerBase *handler) {
-      return handler->getHandlerID() == TypeID::get<GenericHandler>();
-    }
-  };
-
-  /// Register the given action handler with the manager.
-  void registerActionHandler(std::unique_ptr<HandlerBase> handler) {
-    actionHandlers.emplace_back(std::move(handler));
-  }
-  template <typename T>
-  void registerActionHandler() {
-    registerActionHandler(std::make_unique<T>());
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Action Queries
-  //===--------------------------------------------------------------------===//
-
-  /// Dispatch an action represented by the `transform` callback. If no handler
-  /// is found, the `transform` callback is invoked directly.
-  /// Return true if the action was executed, false otherwise.
-  template <typename ActionType, typename... Args>
-  bool execute(function_ref<void()> transform, Args &&...args) {
-    if (actionHandlers.empty()) {
-      transform();
-      return true;
-    }
-
-    // Invoke the `execute` method on the provided handler.
-    auto executeFn = [&](auto *handler, auto &&...handlerParams) {
-      return handler->execute(
-          transform,
-          ActionType(std::forward<decltype(handlerParams)>(handlerParams)...));
-    };
-    FailureOr<bool> result = dispatchToHandler<ActionType, bool>(
-        executeFn, std::forward<Args>(args)...);
-    // no handler found, execute the transform directly.
-    if (failed(result)) {
-      transform();
-      return true;
-    }
-
-    // Return the result of the handler.
-    return *result;
-  }
-
-private:
-  //===--------------------------------------------------------------------===//
-  // Query to Handler Dispatch
-  //===--------------------------------------------------------------------===//
-
-  /// Dispath a given callback on any handlers that are able to process queries
-  /// on the given action type. This method returns failure if no handlers could
-  /// process the action, or success(with a result) if a handler processed the
-  /// action.
-  template <typename ActionType, typename ResultT, typename HandlerCallbackT,
-            typename... Args>
-  FailureOr<ResultT> dispatchToHandler(HandlerCallbackT &&handlerCallback,
-                                       Args &&...args) {
-    static_assert(ActionType::template canHandleWith<Args...>(),
-                  "cannot execute action with the given set of parameters");
-
-    // Process any generic or action specific handlers.
-    // The first handler that gives us a result is the one that we will return.
-    for (std::unique_ptr<HandlerBase> &it : reverse(actionHandlers)) {
-      FailureOr<ResultT> result = failure();
-      if (auto *handler = dyn_cast<typename ActionType::Handler>(&*it)) {
-        result = handlerCallback(handler, std::forward<Args>(args)...);
-      } else if (auto *genericHandler = dyn_cast<GenericHandler>(&*it)) {
-        result = handlerCallback(genericHandler, std::forward<Args>(args)...);
-      }
-
-      // If the handler succeeded, return the result. Otherwise, try a new
-      // handler.
-      if (succeeded(result))
-        return result;
-    }
-    return failure();
-  }
-
-  /// The set of action handlers that have been registered with the manager.
-  SmallVector<std::unique_ptr<HandlerBase>> actionHandlers;
-};
-
 /// CRTP Implementation of an action. This class provides a base class for
 /// implementing specific actions.
 ///  Derived classes are expected to provide the following:
 ///   * static constexpr StringLiteral tag = "...";
 ///     - This method returns a unique string identifier, similar to a command
 ///       line flag or DEBUG_TYPE.
-template <typename Derived, typename... ParameterTs>
+template <typename Derived>
 class ActionImpl : public Action {
 public:
   ActionImpl() : Action(TypeID::get<Derived>()) {}
@@ -216,37 +76,6 @@ public:
 
   /// Forward tag access to the derived class.
   StringRef getTag() const final { return Derived::tag; }
-
-  class Handler : public ActionManager::HandlerBase {
-  public:
-    Handler() : HandlerBase(TypeID::get<Derived>()) {}
-
-    /// This hook allows for controlling the execution of an action.
-    /// `parameters` correspond to the set of values provided by the
-    /// action as context. It should return failure if the handler could not
-    /// process the action, passing it to the next registered handler.
-    virtual FailureOr<bool> execute(function_ref<void()> transform,
-                                    const Derived &action) {
-      return failure();
-    }
-
-    /// Provide classof to allow casting between handler types.
-    static bool classof(const ActionManager::HandlerBase *handler) {
-      return handler->getHandlerID() == TypeID::get<Derived>();
-    }
-  };
-
-private:
-  /// Returns true if the action can be handled within the given set of
-  /// parameter types.
-  template <typename... CallerParameterTs>
-  static constexpr bool canHandleWith() {
-    return std::is_invocable_v<function_ref<void(ParameterTs...)>,
-                               CallerParameterTs...>;
-  }
-
-  /// Allow access to `canHandleWith`.
-  friend class ActionManager;
 };
 
 } // namespace tracing
index 317e83c..cc13447 100644 (file)
@@ -21,7 +21,7 @@ class ThreadPool;
 
 namespace mlir {
 namespace tracing {
-class ActionManager;
+class Action;
 }
 class DiagnosticEngine;
 class Dialect;
@@ -217,9 +217,6 @@ public:
   /// instances. This should not be used directly.
   StorageUniquer &getAttributeUniquer();
 
-  /// Returns the manager of debug actions within the context.
-  tracing::ActionManager &getActionManager();
-
   /// These APIs are tracking whether the context will be used in a
   /// multithreading environment: this has no effect other than enabling
   /// assertions on misuses of some APIs.
@@ -242,10 +239,57 @@ public:
   /// (attributes, operations, types, etc.).
   llvm::hash_code getRegistryHash();
 
+  //===--------------------------------------------------------------------===//
+  // Action API
+  //===--------------------------------------------------------------------===//
+
+  /// Signatures for the action handler that can be registered with the context.
+  using HandlerTy =
+      std::function<void(function_ref<void()>, const tracing::Action &)>;
+
+  /// Register a handler for handling actions that are dispatched through this
+  /// context. A nullptr handler can be set to disable a previously set handler.
+  void registerActionHandler(HandlerTy handler);
+
+  /// Return true if a valid ActionHandler is set.
+  bool hasActionHandler();
+
+  /// Dispatch the provided action to the handler if any, or just execute it.
+  void executeAction(function_ref<void()> actionFn,
+                     const tracing::Action &action) {
+    if (LLVM_UNLIKELY(hasActionHandler()))
+      executeActionInternal(actionFn, action);
+    else
+      actionFn();
+  }
+
+  /// Dispatch the provided action to the handler if any, or just execute it.
+  template <typename ActionTy, typename... Args>
+  void executeAction(function_ref<void()> actionFn, Args &&...args) {
+    if (LLVM_UNLIKELY(hasActionHandler()))
+      executeActionInternal<ActionTy, Args...>(actionFn,
+                                               std::forward<Args>(args)...);
+    else
+      actionFn();
+  }
+
 private:
   /// Return true if the given dialect is currently loading.
   bool isDialectLoading(StringRef dialectNamespace);
 
+  /// Internal helper for the dispatch method.
+  void executeActionInternal(function_ref<void()> actionFn,
+                             const tracing::Action &action);
+
+  /// Internal helper for the dispatch method. We get here after checking that
+  /// there is a handler, for the purpose of keeping this code out-of-line. and
+  /// avoid calling the ctor for the Action unnecessarily.
+  template <typename ActionTy, typename... Args>
+  LLVM_ATTRIBUTE_NOINLINE void
+  executeActionInternal(function_ref<void()> actionFn, Args &&...args) {
+    executeActionInternal(actionFn, ActionTy(std::forward<Args>(args)...));
+  }
+
   const std::unique_ptr<MLIRContextImpl> impl;
 
   MLIRContext(const MLIRContext &) = delete;
index 195ce95..a8747f9 100644 (file)
@@ -62,10 +62,14 @@ void DebugCounter::addCounter(StringRef actionTag, int64_t countToSkip,
   counters.try_emplace(actionTag, countToSkip, countToStopAfter);
 }
 
-// Register a counter with the specified name.
-FailureOr<bool> DebugCounter::execute(llvm::function_ref<void()> transform,
-                                      const Action &action) {
-  auto counterIt = counters.find(action.getTag());
+void DebugCounter::operator()(llvm::function_ref<void()> transform,
+                              const Action &action) {
+  if (shouldExecute(action.getTag()))
+    transform();
+}
+
+bool DebugCounter::shouldExecute(StringRef tag) {
+  auto counterIt = counters.find(tag);
   if (counterIt == counters.end())
     return true;
 
index 291b1cf..29d49ed 100644 (file)
@@ -31,6 +31,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Mutex.h"
 #include "llvm/Support/RWMutex.h"
@@ -123,8 +124,10 @@ public:
   // Debugging
   //===--------------------------------------------------------------------===//
 
-  /// An action manager for use within the context.
-  tracing::ActionManager actionManager;
+  /// An action handler for handling actions that are dispatched through this
+  /// context.
+  std::function<void(function_ref<void()>, const tracing::Action &)>
+      actionHandler;
 
   //===--------------------------------------------------------------------===//
   // Diagnostics
@@ -345,13 +348,22 @@ static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
 }
 
 //===----------------------------------------------------------------------===//
-// Debugging
+// Action Handling
 //===----------------------------------------------------------------------===//
 
-tracing::ActionManager &MLIRContext::getActionManager() {
-  return getImpl().actionManager;
+void MLIRContext::registerActionHandler(HandlerTy handler) {
+  getImpl().actionHandler = std::move(handler);
 }
 
+/// Dispatch the provided action to the handler if any, or just execute it.
+void MLIRContext::executeActionInternal(function_ref<void()> actionFn,
+                                        const tracing::Action &action) {
+  assert(getImpl().actionHandler);
+  getImpl().actionHandler(actionFn, action);
+}
+
+bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
+
 //===----------------------------------------------------------------------===//
 // Diagnostic Handlers
 //===----------------------------------------------------------------------===//
index 03fcd95..fb4fbee 100644 (file)
@@ -212,7 +212,7 @@ static LogicalResult processBuffer(raw_ostream &os,
   context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects());
   if (config.shouldVerifyDiagnostics())
     context.printOpOnDiagnostic(false);
-  context.getActionManager().registerActionHandler<tracing::DebugCounter>();
+  context.registerActionHandler(tracing::DebugCounter());
 
   // If we are in verify diagnostics mode then we have a lot of work to do,
   // otherwise just perform the actions without worrying about it.
index 4ba7d3c..eeeb538 100644 (file)
@@ -21,23 +21,29 @@ struct CounterAction : public ActionImpl<CounterAction> {
 };
 
 TEST(DebugCounterTest, CounterTest) {
-  std::unique_ptr<DebugCounter> counter = std::make_unique<DebugCounter>();
-  counter->addCounter(CounterAction::tag, /*countToSkip=*/1,
-                      /*countToStopAfter=*/3);
+  DebugCounter counter;
+  counter.addCounter(CounterAction::tag, /*countToSkip=*/1,
+                     /*countToStopAfter=*/3);
 
-  ActionManager manager;
-  manager.registerActionHandler(std::move(counter));
-
-  auto noOp = []() { return; };
+  int count = 0;
+  auto noOp = [&]() {
+    ++count;
+    return;
+  };
 
   // The first execution is skipped.
-  EXPECT_FALSE(manager.execute<CounterAction>(noOp));
+  counter(noOp, CounterAction{});
+  EXPECT_EQ(count, 0);
 
   // The counter stops after 3 successful executions.
-  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
-  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
-  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
-  EXPECT_FALSE(manager.execute<CounterAction>(noOp));
+  counter(noOp, CounterAction{});
+  EXPECT_EQ(count, 1);
+  counter(noOp, CounterAction{});
+  EXPECT_EQ(count, 2);
+  counter(noOp, CounterAction{});
+  EXPECT_EQ(count, 3);
+  counter(noOp, CounterAction{});
+  EXPECT_EQ(count, 3);
 }
 
 } // namespace
diff --git a/mlir/unittests/IR/ActionTest.cpp b/mlir/unittests/IR/ActionTest.cpp
deleted file mode 100644 (file)
index 5432f65..0000000
+++ /dev/null
@@ -1,134 +0,0 @@
-//===- ActionTest.cpp - Debug Action Tests ---------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/Action.h"
-#include "mlir/Support/TypeID.h"
-#include "gmock/gmock.h"
-
-using namespace mlir;
-using namespace mlir::tracing;
-
-namespace {
-struct SimpleAction : ActionImpl<SimpleAction> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleAction)
-  static constexpr StringLiteral tag = "simple-action";
-};
-struct OtherSimpleAction : ActionImpl<OtherSimpleAction> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherSimpleAction)
-  static constexpr StringLiteral tag = "other-simple-action";
-};
-struct ParametricAction : ActionImpl<ParametricAction, bool> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction)
-  ParametricAction(bool executeParam) : executeParam(executeParam) {}
-  bool executeParam;
-  static constexpr StringLiteral tag = "param-action";
-};
-
-TEST(ActionTest, GenericHandler) {
-  ActionManager manager;
-
-  // A generic handler that always executes the simple action, but not the
-  // parametric action.
-  struct GenericHandler : ActionManager::GenericHandler {
-    FailureOr<bool> execute(llvm::function_ref<void()> transform,
-                            const Action &action) final {
-      StringRef tag = action.getTag();
-      if (isa<SimpleAction>(action)) {
-        EXPECT_EQ(tag, SimpleAction::tag);
-        transform();
-        return true;
-      }
-
-      EXPECT_TRUE(isa<ParametricAction>(action));
-      return false;
-    }
-  };
-  manager.registerActionHandler<GenericHandler>();
-
-  auto noOp = []() { return; };
-  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
-  EXPECT_FALSE(manager.execute<ParametricAction>(noOp, true));
-}
-
-TEST(ActionTest, ActionSpecificHandler) {
-  ActionManager manager;
-
-  // Handler that simply uses the input as the decider.
-  struct ActionSpecificHandler : ParametricAction::Handler {
-    FailureOr<bool> execute(llvm::function_ref<void()> transform,
-                            const ParametricAction &action) final {
-      if (action.executeParam)
-        transform();
-      return action.executeParam;
-    }
-  };
-  manager.registerActionHandler<ActionSpecificHandler>();
-
-  int count = 0;
-  auto incCount = [&]() { count++; };
-  EXPECT_TRUE(manager.execute<ParametricAction>(incCount, true));
-  EXPECT_EQ(count, 1);
-  EXPECT_FALSE(manager.execute<ParametricAction>(incCount, false));
-  EXPECT_EQ(count, 1);
-
-  // There is no handler for the simple action, so it is always executed.
-  EXPECT_TRUE(manager.execute<SimpleAction>(incCount));
-  EXPECT_EQ(count, 2);
-}
-
-TEST(ActionTest, DebugCounterHandler) {
-  ActionManager manager;
-
-  // Handler that uses the number of action executions as the decider.
-  struct DebugCounterHandler : SimpleAction::Handler {
-    FailureOr<bool> execute(llvm::function_ref<void()> transform,
-                            const SimpleAction &action) final {
-      bool shouldExecute = numExecutions++ < 3;
-      if (shouldExecute)
-        transform();
-      return shouldExecute;
-    }
-    unsigned numExecutions = 0;
-  };
-  manager.registerActionHandler<DebugCounterHandler>();
-
-  // Check that the action is executed 3 times, but no more after.
-  auto noOp = []() { return; };
-  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
-  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
-  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
-  EXPECT_FALSE(manager.execute<SimpleAction>(noOp));
-  EXPECT_FALSE(manager.execute<SimpleAction>(noOp));
-}
-
-TEST(ActionTest, NonOverlappingActionSpecificHandlers) {
-  ActionManager manager;
-
-  // One handler returns true and another returns false
-  struct SimpleActionHandler : SimpleAction::Handler {
-    FailureOr<bool> execute(llvm::function_ref<void()> transform,
-                            const SimpleAction &action) final {
-      transform();
-      return true;
-    }
-  };
-  struct OtherSimpleActionHandler : OtherSimpleAction::Handler {
-    FailureOr<bool> execute(llvm::function_ref<void()> transform,
-                            const OtherSimpleAction &action) final {
-      transform();
-      return false;
-    }
-  };
-  manager.registerActionHandler<SimpleActionHandler>();
-  manager.registerActionHandler<OtherSimpleActionHandler>();
-  auto noOp = []() { return; };
-  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
-  EXPECT_FALSE(manager.execute<OtherSimpleAction>(noOp));
-}
-
-} // namespace
index 04a86bc..7d49283 100644 (file)
@@ -1,5 +1,4 @@
 add_mlir_unittest(MLIRIRTests
-  ActionTest.cpp
   AdaptorTest.cpp
   AttributeTest.cpp
   DialectTest.cpp