Change the DebugAction paradigm to delegate the control to the handler
authorMehdi Amini <joker.eph@gmail.com>
Tue, 18 Oct 2022 23:03:48 +0000 (23:03 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 Mar 2023 14:58:26 +0000 (15:58 +0100)
At the moment, we invoke `shouldExecute()` that way:

```
if (manager.shouldExecute<DebugAction>(currentOp) {
  // apply a transformation
  …
}
```

In this sequence, the manager isn’t involved in the actual execution
of the action and can’t develop rich instrumentations. Instead the API
could let the control to the handler itself:

```
// Execute the action under the control of the manager
manager.execute<DebugAction>(currentOp, [&]() {
  // apply the transformation in this callback
  …
});
```

This inversion of control (by injecting a callback) allows handlers to
implement potentially new interesting features: for example, snapshot
the IR before and after the action, or record an action execution time.
More importantly, it will allow to capture the nesting execution of
actions.

On the other side: handlers receives now a DebugAction object that wraps
generic information (tag and description especially) as well as
action-specific data.

Finally, the DebugActionManager is now enabled in release builds as
well.

Differential Revision: https://reviews.llvm.org/D144808

mlir/include/mlir/Support/DebugAction.h
mlir/include/mlir/Support/DebugCounter.h
mlir/lib/Support/DebugCounter.cpp
mlir/unittests/Support/DebugActionTest.cpp
mlir/unittests/Support/DebugCounterTest.cpp

index e1dc25e..aa44028 100644 (file)
@@ -8,9 +8,7 @@
 //
 // This file contains definitions for the debug action framework. This framework
 // allows for external entities to control certain actions taken by the compiler
-// by registering handler functions. A debug action handler provides the
-// internal implementation for the various queries on a debug action, such as
-// whether it should execute or not.
+// by registering handler functions.
 //
 //===----------------------------------------------------------------------===//
 
 
 namespace mlir {
 
+/// This class represents the base class of a debug action.
+class DebugActionBase {
+public:
+  virtual ~DebugActionBase() = default;
+
+  /// Return the unique action id of this action, use for casting
+  /// functionality.
+  TypeID getActionID() const { return actionID; }
+
+  StringRef getTag() const { return tag; }
+
+  StringRef getDescription() const { return desc; }
+
+  virtual void print(raw_ostream &os) const {
+    os << "Action \"" << tag << "\" : " << desc << "\n";
+  }
+
+protected:
+  DebugActionBase(TypeID actionID, StringRef tag, StringRef desc)
+      : actionID(actionID), tag(tag), desc(desc) {}
+
+  /// The type of the derived action class. This allows for detecting the
+  /// specific handler of a given action type.
+  TypeID actionID;
+  StringRef tag;
+  StringRef desc;
+};
+
 //===----------------------------------------------------------------------===//
 // DebugActionManager
 //===----------------------------------------------------------------------===//
@@ -74,11 +100,11 @@ public:
   public:
     GenericHandler() : HandlerBase(TypeID::get<GenericHandler>()) {}
 
-    /// This hook allows for controlling whether an action should execute or
-    /// not. It should return failure if the handler could not process the
-    /// action, passing it to the next registered handler.
-    virtual FailureOr<bool> shouldExecute(StringRef actionTag,
-                                          StringRef description) {
+    /// This hook allows for controlling the execution of an action. It should
+    /// return failure if the handler could not process the action, or whether
+    /// the `transform` was executed or not.
+    virtual FailureOr<bool> execute(function_ref<void()> transform,
+                                    const DebugActionBase &action) {
       return failure();
     }
 
@@ -90,10 +116,7 @@ public:
 
   /// Register the given action handler with the manager.
   void registerActionHandler(std::unique_ptr<HandlerBase> handler) {
-    // The manager is always disabled if built without debug.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
     actionHandlers.emplace_back(std::move(handler));
-#endif
   }
   template <typename T>
   void registerActionHandler() {
@@ -104,31 +127,35 @@ public:
   // Action Queries
   //===--------------------------------------------------------------------===//
 
-  /// Returns true if the given action type should be executed, false otherwise.
-  /// `Args` are a set of parameters used by handlers of `ActionType` to
-  /// determine if the action should be executed.
+  /// Dispatch an action represented by the `transform` callback. If no handler
+  /// is found, the `transform` callback is invoked directly.
+  /// Return true if the action was executed, false otherwise.
   template <typename ActionType, typename... Args>
-  bool shouldExecute(Args &&...args) {
-    // The manager is always disabled if built without debug.
-#if !LLVM_ENABLE_ABI_BREAKING_CHECKS
-    return true;
-#else
-    // Invoke the `shouldExecute` method on the provided handler.
-    auto shouldExecuteFn = [&](auto *handler, auto &&...handlerParams) {
-      return handler->shouldExecute(
-          std::forward<decltype(handlerParams)>(handlerParams)...);
+  bool execute(function_ref<void()> transform, Args &&...args) {
+    if (actionHandlers.empty()) {
+      transform();
+      return true;
+    }
+
+    // Invoke the `execute` method on the provided handler.
+    auto executeFn = [&](auto *handler, auto &&...handlerParams) {
+      return handler->execute(
+          transform,
+          ActionType(std::forward<decltype(handlerParams)>(handlerParams)...));
     };
     FailureOr<bool> result = dispatchToHandler<ActionType, bool>(
-        shouldExecuteFn, std::forward<Args>(args)...);
+        executeFn, std::forward<Args>(args)...);
+    // no handler found, execute the transform directly.
+    if (failed(result)) {
+      transform();
+      return true;
+    }
 
-    // If the action wasn't handled, execute the action by default.
-    return succeeded(result) ? *result : true;
-#endif
+    // Return the result of the handler.
+    return *result;
   }
 
 private:
-// The manager is always disabled if built without debug.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
   //===--------------------------------------------------------------------===//
   // Query to Handler Dispatch
   //===--------------------------------------------------------------------===//
@@ -145,16 +172,13 @@ private:
                   "cannot execute action with the given set of parameters");
 
     // Process any generic or action specific handlers.
-    // TODO: We currently just pick the first handler that gives us a result,
-    // but in the future we may want to employ a reduction over all of the
-    // values returned.
-    for (std::unique_ptr<HandlerBase> &it : llvm::reverse(actionHandlers)) {
+    // The first handler that gives us a result is the one that we will return.
+    for (std::unique_ptr<HandlerBase> &it : reverse(actionHandlers)) {
       FailureOr<ResultT> result = failure();
       if (auto *handler = dyn_cast<typename ActionType::Handler>(&*it)) {
         result = handlerCallback(handler, std::forward<Args>(args)...);
       } else if (auto *genericHandler = dyn_cast<GenericHandler>(&*it)) {
-        result = handlerCallback(genericHandler, ActionType::getTag(),
-                                 ActionType::getDescription());
+        result = handlerCallback(genericHandler, std::forward<Args>(args)...);
       }
 
       // If the handler succeeded, return the result. Otherwise, try a new
@@ -167,7 +191,6 @@ private:
 
   /// The set of action handlers that have been registered with the manager.
   SmallVector<std::unique_ptr<HandlerBase>> actionHandlers;
-#endif
 };
 
 //===----------------------------------------------------------------------===//
@@ -191,17 +214,27 @@ private:
 /// instances of this action. The parameters to its query methods map 1-1 to the
 /// types on the action type.
 template <typename Derived, typename... ParameterTs>
-class DebugAction {
+class DebugAction : public DebugActionBase {
 public:
+  DebugAction()
+      : DebugActionBase(TypeID::get<Derived>(), Derived::getTag(),
+                        Derived::getDescription()) {}
+
+  /// Provide classof to allow casting between action types.
+  static bool classof(const DebugActionBase *action) {
+    return action->getActionID() == TypeID::get<Derived>();
+  }
+
   class Handler : public DebugActionManager::HandlerBase {
   public:
     Handler() : HandlerBase(TypeID::get<Derived>()) {}
 
-    /// This hook allows for controlling whether an action should execute or
-    /// not. `parameters` correspond to the set of values provided by the
+    /// This hook allows for controlling the execution of an action.
+    /// `parameters` correspond to the set of values provided by the
     /// action as context. It should return failure if the handler could not
     /// process the action, passing it to the next registered handler.
-    virtual FailureOr<bool> shouldExecute(ParameterTs... parameters) {
+    virtual FailureOr<bool> execute(function_ref<void()> transform,
+                                    const Derived &action) {
       return failure();
     }
 
index 5b6c015..83fd69d 100644 (file)
@@ -38,7 +38,8 @@ public:
                   int64_t countToStopAfter);
 
   /// Register a counter with the specified name.
-  FailureOr<bool> shouldExecute(StringRef tag, StringRef description) final;
+  FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                          const DebugActionBase &action) final;
 
   /// Print the counters that have been registered with this instance to the
   /// provided output stream.
index 44bcdf4..a587b43 100644 (file)
@@ -62,9 +62,9 @@ void DebugCounter::addCounter(StringRef actionTag, int64_t countToSkip,
 }
 
 // Register a counter with the specified name.
-FailureOr<bool> DebugCounter::shouldExecute(StringRef tag,
-                                            StringRef description) {
-  auto counterIt = counters.find(tag);
+FailureOr<bool> DebugCounter::execute(llvm::function_ref<void()> transform,
+                                      const DebugActionBase &action) {
+  auto counterIt = counters.find(action.getTag());
   if (counterIt == counters.end())
     return true;
 
index be2ca2e..1e80a0e 100644 (file)
@@ -10,9 +10,6 @@
 #include "mlir/Support/TypeID.h"
 #include "gmock/gmock.h"
 
-// DebugActionManager is only enabled in DEBUG mode.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-
 using namespace mlir;
 
 namespace {
@@ -30,6 +27,8 @@ struct OtherSimpleAction : DebugAction<OtherSimpleAction> {
 };
 struct ParametricAction : DebugAction<ParametricAction, bool> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction)
+  ParametricAction(bool executeParam) : executeParam(executeParam) {}
+  bool executeParam;
   static StringRef getTag() { return "param-action"; }
   static StringRef getDescription() { return "param-action-description"; }
 };
@@ -40,21 +39,25 @@ TEST(DebugActionTest, GenericHandler) {
   // A generic handler that always executes the simple action, but not the
   // parametric action.
   struct GenericHandler : DebugActionManager::GenericHandler {
-    FailureOr<bool> shouldExecute(StringRef tag, StringRef desc) final {
-      if (tag == SimpleAction::getTag()) {
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const DebugActionBase &action) final {
+      StringRef desc = action.getDescription();
+      if (isa<SimpleAction>(action)) {
         EXPECT_EQ(desc, SimpleAction::getDescription());
+        transform();
         return true;
       }
 
-      EXPECT_EQ(tag, ParametricAction::getTag());
+      EXPECT_TRUE(isa<ParametricAction>(action));
       EXPECT_EQ(desc, ParametricAction::getDescription());
       return false;
     }
   };
   manager.registerActionHandler<GenericHandler>();
 
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_FALSE(manager.shouldExecute<ParametricAction>(true));
+  auto noOp = []() { return; };
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_FALSE(manager.execute<ParametricAction>(noOp, true));
 }
 
 TEST(DebugActionTest, ActionSpecificHandler) {
@@ -62,17 +65,25 @@ TEST(DebugActionTest, ActionSpecificHandler) {
 
   // Handler that simply uses the input as the decider.
   struct ActionSpecificHandler : ParametricAction::Handler {
-    FailureOr<bool> shouldExecute(bool shouldExecuteParam) final {
-      return shouldExecuteParam;
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const ParametricAction &action) final {
+      if (action.executeParam)
+        transform();
+      return action.executeParam;
     }
   };
   manager.registerActionHandler<ActionSpecificHandler>();
 
-  EXPECT_TRUE(manager.shouldExecute<ParametricAction>(true));
-  EXPECT_FALSE(manager.shouldExecute<ParametricAction>(false));
+  int count = 0;
+  auto incCount = [&]() { count++; };
+  EXPECT_TRUE(manager.execute<ParametricAction>(incCount, true));
+  EXPECT_EQ(count, 1);
+  EXPECT_FALSE(manager.execute<ParametricAction>(incCount, false));
+  EXPECT_EQ(count, 1);
 
   // There is no handler for the simple action, so it is always executed.
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
+  EXPECT_TRUE(manager.execute<SimpleAction>(incCount));
+  EXPECT_EQ(count, 2);
 }
 
 TEST(DebugActionTest, DebugCounterHandler) {
@@ -80,17 +91,24 @@ TEST(DebugActionTest, DebugCounterHandler) {
 
   // Handler that uses the number of action executions as the decider.
   struct DebugCounterHandler : SimpleAction::Handler {
-    FailureOr<bool> shouldExecute() final { return numExecutions++ < 3; }
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const SimpleAction &action) final {
+      bool shouldExecute = numExecutions++ < 3;
+      if (shouldExecute)
+        transform();
+      return shouldExecute;
+    }
     unsigned numExecutions = 0;
   };
   manager.registerActionHandler<DebugCounterHandler>();
 
   // Check that the action is executed 3 times, but no more after.
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_FALSE(manager.shouldExecute<SimpleAction>());
-  EXPECT_FALSE(manager.shouldExecute<SimpleAction>());
+  auto noOp = []() { return; };
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_FALSE(manager.execute<SimpleAction>(noOp));
+  EXPECT_FALSE(manager.execute<SimpleAction>(noOp));
 }
 
 TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) {
@@ -98,17 +116,24 @@ TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) {
 
   // One handler returns true and another returns false
   struct SimpleActionHandler : SimpleAction::Handler {
-    FailureOr<bool> shouldExecute() final { return true; }
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const SimpleAction &action) final {
+      transform();
+      return true;
+    }
   };
   struct OtherSimpleActionHandler : OtherSimpleAction::Handler {
-    FailureOr<bool> shouldExecute() final { return false; }
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const OtherSimpleAction &action) final {
+      transform();
+      return false;
+    }
   };
   manager.registerActionHandler<SimpleActionHandler>();
   manager.registerActionHandler<OtherSimpleActionHandler>();
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_FALSE(manager.shouldExecute<OtherSimpleAction>());
+  auto noOp = []() { return; };
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_FALSE(manager.execute<OtherSimpleAction>(noOp));
 }
 
 } // namespace
-
-#endif
index bf8d027..c46550b 100644 (file)
@@ -12,9 +12,6 @@
 
 using namespace mlir;
 
-// DebugActionManager is only enabled in DEBUG mode.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-
 namespace {
 
 struct CounterAction : public DebugAction<CounterAction> {
@@ -31,16 +28,16 @@ TEST(DebugCounterTest, CounterTest) {
   DebugActionManager manager;
   manager.registerActionHandler(std::move(counter));
 
+  auto noOp = []() { return; };
+
   // The first execution is skipped.
-  EXPECT_FALSE(manager.shouldExecute<CounterAction>());
+  EXPECT_FALSE(manager.execute<CounterAction>(noOp));
 
   // The counter stops after 3 successful executions.
-  EXPECT_TRUE(manager.shouldExecute<CounterAction>());
-  EXPECT_TRUE(manager.shouldExecute<CounterAction>());
-  EXPECT_TRUE(manager.shouldExecute<CounterAction>());
-  EXPECT_FALSE(manager.shouldExecute<CounterAction>());
+  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
+  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
+  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
+  EXPECT_FALSE(manager.execute<CounterAction>(noOp));
 }
 
 } // namespace
-
-#endif