Fix #58322: Handlers for debug actions with equal parameter types must not override...
authorTomás Longeri <tlongeri@google.com>
Sat, 22 Oct 2022 21:17:49 +0000 (14:17 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Sat, 22 Oct 2022 21:18:00 +0000 (14:18 -0700)
Also clean up redundant public access specifiers.

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D135924

mlir/docs/DebugActions.md
mlir/include/mlir/Support/DebugAction.h
mlir/unittests/Support/DebugActionTest.cpp
mlir/unittests/Support/DebugCounterTest.cpp

index 777be8e..a40459c 100644 (file)
@@ -54,10 +54,12 @@ rewrite patterns.
 /// * The Tag is specified via a static `StringRef getTag()` method.
 /// * The Description is specified via a static `StringRef getDescription()`
 ///   method.
-/// * The parameters for the action are provided via template parameters when
-///   inheriting from `DebugAction`.
+/// * `DebugAction` is a CRTP class, so the first template parameter is the
+///   action type class itself.
+/// * The parameters for the action are provided via additional template
+///   parameters when inheriting from `DebugAction`.
 struct ApplyPatternAction
-    : public DebugAction<Operation *, const Pattern &> {
+    : public DebugAction<ApplyPatternAction, Operation *, const Pattern &> {
   static StringRef getTag() { return "apply-pattern"; }
   static StringRef getDescription() {
     return "Control the application of rewrite patterns";
@@ -95,7 +97,7 @@ usage of the `shouldExecute` query is shown below:
 ```c++
 /// A debug action that allows for controlling the application of patterns.
 struct ApplyPatternAction
-    : public DebugAction<Operation *, const Pattern &> {
+    : public DebugAction<ApplyPatternAction, Operation *, const Pattern &> {
   static StringRef getTag() { return "apply-pattern"; }
   static StringRef getDescription() {
     return "Control the application of rewrite patterns";
index 41ec8b1..e1dc25e 100644 (file)
@@ -190,14 +190,12 @@ private:
 /// This class provides a handler class that can be derived from to handle
 /// instances of this action. The parameters to its query methods map 1-1 to the
 /// types on the action type.
-template <typename... ParameterTs>
+template <typename Derived, typename... ParameterTs>
 class DebugAction {
 public:
   class Handler : public DebugActionManager::HandlerBase {
   public:
-    Handler()
-        : HandlerBase(
-              TypeID::get<typename DebugAction<ParameterTs...>::Handler>()) {}
+    Handler() : HandlerBase(TypeID::get<Derived>()) {}
 
     /// This hook allows for controlling whether an action should execute or
     /// not. `parameters` correspond to the set of values provided by the
@@ -209,8 +207,7 @@ public:
 
     /// Provide classof to allow casting between handler types.
     static bool classof(const DebugActionManager::HandlerBase *handler) {
-      return handler->getHandlerID() ==
-             TypeID::get<typename DebugAction<ParameterTs...>::Handler>();
+      return handler->getHandlerID() == TypeID::get<Derived>();
     }
   };
 
index 4ecaf12..be2ca2e 100644 (file)
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Support/DebugAction.h"
+#include "mlir/Support/TypeID.h"
 #include "gmock/gmock.h"
 
 // DebugActionManager is only enabled in DEBUG mode.
 using namespace mlir;
 
 namespace {
-struct SimpleAction : public DebugAction<> {
+struct SimpleAction : DebugAction<SimpleAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleAction)
   static StringRef getTag() { return "simple-action"; }
   static StringRef getDescription() { return "simple-action-description"; }
 };
-struct ParametricAction : public DebugAction<bool> {
+struct OtherSimpleAction : DebugAction<OtherSimpleAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherSimpleAction)
+  static StringRef getTag() { return "other-simple-action"; }
+  static StringRef getDescription() {
+    return "other-simple-action-description";
+  }
+};
+struct ParametricAction : DebugAction<ParametricAction, bool> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction)
   static StringRef getTag() { return "param-action"; }
   static StringRef getDescription() { return "param-action-description"; }
 };
@@ -29,7 +39,7 @@ TEST(DebugActionTest, GenericHandler) {
 
   // A generic handler that always executes the simple action, but not the
   // parametric action.
-  struct GenericHandler : public DebugActionManager::GenericHandler {
+  struct GenericHandler : DebugActionManager::GenericHandler {
     FailureOr<bool> shouldExecute(StringRef tag, StringRef desc) final {
       if (tag == SimpleAction::getTag()) {
         EXPECT_EQ(desc, SimpleAction::getDescription());
@@ -51,7 +61,7 @@ TEST(DebugActionTest, ActionSpecificHandler) {
   DebugActionManager manager;
 
   // Handler that simply uses the input as the decider.
-  struct ActionSpecificHandler : public ParametricAction::Handler {
+  struct ActionSpecificHandler : ParametricAction::Handler {
     FailureOr<bool> shouldExecute(bool shouldExecuteParam) final {
       return shouldExecuteParam;
     }
@@ -69,7 +79,7 @@ TEST(DebugActionTest, DebugCounterHandler) {
   DebugActionManager manager;
 
   // Handler that uses the number of action executions as the decider.
-  struct DebugCounterHandler : public SimpleAction::Handler {
+  struct DebugCounterHandler : SimpleAction::Handler {
     FailureOr<bool> shouldExecute() final { return numExecutions++ < 3; }
     unsigned numExecutions = 0;
   };
@@ -83,6 +93,22 @@ TEST(DebugActionTest, DebugCounterHandler) {
   EXPECT_FALSE(manager.shouldExecute<SimpleAction>());
 }
 
+TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) {
+  DebugActionManager manager;
+
+  // One handler returns true and another returns false
+  struct SimpleActionHandler : SimpleAction::Handler {
+    FailureOr<bool> shouldExecute() final { return true; }
+  };
+  struct OtherSimpleActionHandler : OtherSimpleAction::Handler {
+    FailureOr<bool> shouldExecute() final { return false; }
+  };
+  manager.registerActionHandler<SimpleActionHandler>();
+  manager.registerActionHandler<OtherSimpleActionHandler>();
+  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
+  EXPECT_FALSE(manager.shouldExecute<OtherSimpleAction>());
+}
+
 } // namespace
 
 #endif
index 7ca48cd..bf8d027 100644 (file)
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Support/DebugCounter.h"
+#include "mlir/Support/TypeID.h"
 #include "gmock/gmock.h"
 
 using namespace mlir;
@@ -16,7 +17,8 @@ using namespace mlir;
 
 namespace {
 
-struct CounterAction : public DebugAction<> {
+struct CounterAction : public DebugAction<CounterAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CounterAction)
   static StringRef getTag() { return "counter-action"; }
   static StringRef getDescription() { return "Test action for debug counters"; }
 };