Add capture of "IRUnits" as context for an MLIR Action
authorMehdi Amini <joker.eph@gmail.com>
Mon, 20 Mar 2023 12:40:37 +0000 (13:40 +0100)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 Mar 2023 12:40:55 +0000 (13:40 +0100)
IRUnit is defined as:

  using IRUnit = PointerUnion<Operation *, Region *, Block *, Value>;

The tracing::Action is extended to take an ArrayRef<IRUnit> as context to
describe an Action. It is demonstrated in the "ActionLogging" observer.

Reviewed By: rriddle, Mogball

Differential Revision: https://reviews.llvm.org/D144814

mlir/include/mlir/Debug/Observers/ActionLogging.h
mlir/include/mlir/IR/Action.h
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/IR/Unit.h [new file with mode: 0644]
mlir/lib/Debug/Observers/ActionLogging.cpp
mlir/lib/IR/CMakeLists.txt
mlir/lib/IR/Unit.cpp [new file with mode: 0644]
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/test/Pass/action-logging.mlir

index ff280c5..bd1d565 100644 (file)
@@ -22,9 +22,9 @@ namespace tracing {
 /// on the provided stream.
 struct ActionLogger : public ExecutionContext::Observer {
   ActionLogger(raw_ostream &os, bool printActions = true,
-               bool printBreakpoints = true)
-      : os(os), printActions(printActions), printBreakpoints(printBreakpoints) {
-  }
+               bool printBreakpoints = true, bool printIRUnits = true)
+      : os(os), printActions(printActions), printBreakpoints(printBreakpoints),
+        printIRUnits(printIRUnits) {}
 
   void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint,
                      bool willExecute) override;
@@ -34,6 +34,7 @@ private:
   raw_ostream &os;
   bool printActions;
   bool printBreakpoints;
+  bool printIRUnits;
 };
 
 } // namespace tracing
index 569d428..9359324 100644 (file)
@@ -15,6 +15,7 @@
 #ifndef MLIR_IR_ACTION_H
 #define MLIR_IR_ACTION_H
 
+#include "mlir/IR/Unit.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -51,11 +52,19 @@ public:
     os << "Action \"" << getTag() << "\"";
   }
 
+  /// Return the set of IR units that are associated with this action.
+  virtual ArrayRef<IRUnit> getContextIRUnits() const { return irUnits; }
+
 protected:
-  Action(TypeID actionID) : actionID(actionID) {}
+  Action(TypeID actionID, ArrayRef<IRUnit> irUnits)
+      : actionID(actionID), irUnits(irUnits) {}
 
   /// The type of the derived action class, used for `isa`/`dyn_cast`.
   TypeID actionID;
+
+  /// Set of IR units (operations, regions, blocks, values) that are associated
+  /// with this action.
+  ArrayRef<IRUnit> irUnits;
 };
 
 /// CRTP Implementation of an action. This class provides a base class for
@@ -67,7 +76,8 @@ protected:
 template <typename Derived>
 class ActionImpl : public Action {
 public:
-  ActionImpl() : Action(TypeID::get<Derived>()) {}
+  ActionImpl(ArrayRef<IRUnit> irUnits = {})
+      : Action(TypeID::get<Derived>(), irUnits) {}
 
   /// Provide classof to allow casting between action types.
   static bool classof(const Action *action) {
index cc13447..d9e140b 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/TypeID.h"
+#include "llvm/ADT/ArrayRef.h"
 #include <functional>
 #include <memory>
 #include <vector>
@@ -265,9 +266,10 @@ public:
 
   /// Dispatch the provided action to the handler if any, or just execute it.
   template <typename ActionTy, typename... Args>
-  void executeAction(function_ref<void()> actionFn, Args &&...args) {
+  void executeAction(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
+                     Args &&...args) {
     if (LLVM_UNLIKELY(hasActionHandler()))
-      executeActionInternal<ActionTy, Args...>(actionFn,
+      executeActionInternal<ActionTy, Args...>(actionFn, irUnits,
                                                std::forward<Args>(args)...);
     else
       actionFn();
@@ -286,8 +288,10 @@ private:
   /// avoid calling the ctor for the Action unnecessarily.
   template <typename ActionTy, typename... Args>
   LLVM_ATTRIBUTE_NOINLINE void
-  executeActionInternal(function_ref<void()> actionFn, Args &&...args) {
-    executeActionInternal(actionFn, ActionTy(std::forward<Args>(args)...));
+  executeActionInternal(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
+                        Args &&...args) {
+    executeActionInternal(actionFn,
+                          ActionTy(irUnits, std::forward<Args>(args)...));
   }
 
   const std::unique_ptr<MLIRContextImpl> impl;
diff --git a/mlir/include/mlir/IR/Unit.h b/mlir/include/mlir/IR/Unit.h
new file mode 100644 (file)
index 0000000..033dab5
--- /dev/null
@@ -0,0 +1,42 @@
+//===- Unit.h -  IR Unit definition--------------------*- C++ -*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_UNIT_H
+#define MLIR_IR_UNIT_H
+
+#include "mlir/IR/OperationSupport.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm {
+class raw_ostream;
+} // namespace llvm
+namespace mlir {
+class Operation;
+class Region;
+class Block;
+class Value;
+
+/// IRUnit is a union of the different types of IR objects that consistute the
+/// IR structure (other than Type and Attribute), that is Operation, Region, and
+/// Block.
+class IRUnit : public PointerUnion<Operation *, Region *, Block *, Value> {
+public:
+  using PointerUnion::PointerUnion;
+
+  /// Print the IRUnit to the given stream.
+  void print(raw_ostream &os,
+             OpPrintingFlags flags =
+                 OpPrintingFlags().skipRegions().useLocalScope()) const;
+};
+
+raw_ostream &operator<<(raw_ostream &os, const IRUnit &unit);
+
+} // end namespace mlir
+
+#endif // MLIR_IR_UNIT_H
index 9826adf..7e7c5ac 100644 (file)
@@ -7,9 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Debug/Observers/ActionLogging.h"
+#include "mlir/IR/Action.h"
 #include "llvm/Support/Threading.h"
-#include <sstream>
-#include <thread>
+#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
 using namespace mlir::tracing;
@@ -22,6 +22,10 @@ void ActionLogger::beforeExecute(const ActionActiveStack *action,
                                  Breakpoint *breakpoint, bool willExecute) {
   SmallVector<char> name;
   llvm::get_thread_name(name);
+  if (name.empty()) {
+    llvm::raw_svector_ostream os(name);
+    os << llvm::get_threadid();
+  }
   os << "[thread " << name << "] ";
   if (willExecute)
     os << "begins ";
@@ -29,21 +33,30 @@ void ActionLogger::beforeExecute(const ActionActiveStack *action,
     os << "skipping ";
   if (printBreakpoints) {
     if (breakpoint)
-      os << " (on breakpoint: " << *breakpoint << ") ";
+      os << "(on breakpoint: " << *breakpoint << ") ";
     else
-      os << " (no breakpoint) ";
+      os << "(no breakpoint) ";
   }
   os << "Action ";
   if (printActions)
     action->getAction().print(os);
   else
     os << action->getAction().getTag();
+  if (printIRUnits) {
+    os << " (";
+    interleaveComma(action->getAction().getContextIRUnits(), os);
+    os << ")";
+  }
   os << "`\n";
 }
 
 void ActionLogger::afterExecute(const ActionActiveStack *action) {
   SmallVector<char> name;
   llvm::get_thread_name(name);
+  if (name.empty()) {
+    llvm::raw_svector_ostream os(name);
+    os << llvm::get_threadid();
+  }
   os << "[thread " << name << "] completed `" << action->getAction().getTag()
      << "`\n";
 }
index 8b4fb42..4377ebe 100644 (file)
@@ -32,6 +32,7 @@ add_mlir_library(MLIRIR
   Types.cpp
   TypeRange.cpp
   TypeUtilities.cpp
+  Unit.cpp
   Value.cpp
   ValueRange.cpp
   Verifier.cpp
diff --git a/mlir/lib/IR/Unit.cpp b/mlir/lib/IR/Unit.cpp
new file mode 100644 (file)
index 0000000..7da714f
--- /dev/null
@@ -0,0 +1,63 @@
+//===- Unit.cpp - Support for manipulating IR Unit ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Unit.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Region.h"
+#include "llvm/Support/raw_ostream.h"
+#include <iterator>
+#include <sstream>
+
+using namespace mlir;
+
+static void printOp(llvm::raw_ostream &os, Operation *op,
+                    OpPrintingFlags &flags) {
+  if (!op) {
+    os << "<Operation:nullptr>";
+    return;
+  }
+  op->print(os, flags);
+}
+
+static void printRegion(llvm::raw_ostream &os, Region *region,
+                        OpPrintingFlags &flags) {
+  if (!region) {
+    os << "<Region:nullptr>";
+    return;
+  }
+  os << "Region #" << region->getRegionNumber() << " for op ";
+  printOp(os, region->getParentOp(), flags);
+}
+
+static void printBlock(llvm::raw_ostream &os, Block *block,
+                       OpPrintingFlags &flags) {
+  Region *region = block->getParent();
+  Block *entry = &region->front();
+  int blockId = std::distance(entry->getIterator(), block->getIterator());
+  os << "Block #" << blockId << " for ";
+  bool shouldSkipRegions = flags.shouldSkipRegions();
+  printRegion(os, region, flags.skipRegions());
+  if (!shouldSkipRegions)
+    block->print(os);
+}
+
+void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
+  if (auto *op = this->dyn_cast<Operation *>())
+    return printOp(os, op, flags);
+  if (auto *region = this->dyn_cast<Region *>())
+    return printRegion(os, region, flags);
+  if (auto *block = this->dyn_cast<Block *>())
+    return printBlock(os, block, flags);
+  llvm_unreachable("unknown IRUnit");
+}
+
+llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, const IRUnit &unit) {
+  unit.print(os);
+  return os;
+}
index 2b07898..e496a29 100644 (file)
@@ -482,7 +482,7 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
           pass->runOnOperation();
         passFailed = pass->passState->irAndPassFailed.getInt();
       },
-      *pass, op);
+      {op}, *pass);
 
   // Invalidate any non preserved analyses.
   am.invalidate(pass->passState->preservedAnalyses);
index 6fc46af..ca60cf2 100644 (file)
 #include "mlir/IR/Action.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
 /// Encapsulate the "action" of executing a single pass, used for the MLIR
 /// tracing infrastructure.
 struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
-  PassExecutionAction(const Pass &pass, Operation *op) : pass(pass), op(op) {}
+  using Base = tracing::ActionImpl<PassExecutionAction>;
+  PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass)
+      : Base(irUnits), pass(pass) {}
   static constexpr StringLiteral tag = "pass-execution-action";
   void print(raw_ostream &os) const override;
   const Pass &getPass() const { return pass; }
-  Operation *getOp() const { return op; }
+  Operation *getOp() const {
+    ArrayRef<IRUnit> irUnits = getContextIRUnits();
+    return irUnits.empty() ? nullptr : irUnits[0].dyn_cast<Operation *>();
+  }
 
 public:
   const Pass &pass;
index 943f05a..d10c64c 100644 (file)
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s --log-actions-to=- -canonicalize -test-module-pass | FileCheck %s
 
-// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `Canonicalizer` on Operation `builtin.module`
-// CHECK: [thread {{.*}}] completed `pass-execution-action`
-// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `(anonymous namespace)::TestModulePass` on Operation `builtin.module`
-// CHECK: [thread {{.*}}] completed `pass-execution-action`
+// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `Canonicalizer` on Operation `builtin.module` (module {...}
+// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action`
+// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action`  running `{{.*}}TestModulePass` on Operation `builtin.module` (module {...}
+// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action`
+// CHECK-NOT: Action