Allow analyses to provide a hook 'isInvalidated' to determine if they are truly inval...
authorRiver Riddle <riverriddle@google.com>
Tue, 3 Dec 2019 19:13:39 +0000 (11:13 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Dec 2019 19:14:20 +0000 (11:14 -0800)
The hook has the following form:
*   `bool isInvalidated(const AnalysisManager::PreservedAnalyses &)`

Given a preserved analysis set, the analysis returns true if it should truly be
invalidated. This allows for more fine-tuned invalidation in cases where an
analysis wasn't explicitly marked preserved, but may be preserved(or
invalidated) based upon other properties; such as analyses sets.

PiperOrigin-RevId: 283582889

mlir/g3doc/WritingAPass.md
mlir/include/mlir/Pass/AnalysisManager.h
mlir/unittests/Pass/AnalysisManagerTest.cpp

index df0d153..1e4564a 100644 (file)
@@ -116,12 +116,20 @@ the following:
 *   Provide a valid constructor taking an `Operation*`.
 *   Must not modify the given operation.
 
-The base `OperationPass` class provide utilities for querying and preserving
-analyses for the current operation being processed. Using the example passes
-defined above, let's see some examples:
+An analysis may provide additional hooks to control various behavior:
+
+*   `bool isInvalidated(const AnalysisManager::PreservedAnalyses &)`
+
+Given a preserved analysis set, the analysis returns true if it should truly be
+invalidated. This allows for more fine-tuned invalidation in cases where an
+analysis wasn't explicitly marked preserved, but may be preserved(or
+invalidated) based upon other properties such as analyses sets.
 
 ### Querying Analyses
 
+The base `OperationPass` class provide utilities for querying and preserving
+analyses for the current operation being processed.
+
 *   OperationPass automatically provides the following utilities for querying
     analyses:
     *   `getAnalysis<>`
@@ -137,7 +145,7 @@ defined above, let's see some examples:
         -   Get an analysis for a given child operation, constructing it if
             necessary.
 
-A few example usages are shown below:
+Using the example passes defined above, let's see some examples:
 
 ```c++
 /// An interesting analysis.
index 163ecf6..6c37223 100644 (file)
@@ -76,9 +76,36 @@ private:
   SmallPtrSet<const void *, 2> preservedIDs;
 };
 
+namespace analysis_impl {
+/// Trait to check if T provides a static 'isInvalidated' method.
+template <typename T, typename... Args>
+using has_is_invalidated = decltype(std::declval<T &>().isInvalidated(
+    std::declval<const PreservedAnalyses &>()));
+
+/// Implementation of 'isInvalidated' if the analysis provides a definition.
+template <typename AnalysisT>
+std::enable_if_t<is_detected<has_is_invalidated, AnalysisT>::value, bool>
+isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
+  return analysis.isInvalidated(pa);
+}
+/// Default implementation of 'isInvalidated'.
+template <typename AnalysisT>
+std::enable_if_t<!is_detected<has_is_invalidated, AnalysisT>::value, bool>
+isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
+  return !pa.isPreserved<AnalysisT>();
+}
+} // end namespace analysis_impl
+
 /// The abstract polymorphic base class representing an analysis.
 struct AnalysisConcept {
   virtual ~AnalysisConcept() = default;
+
+  /// A hook used to query analyses for invalidation. Given a preserved analysis
+  /// set, returns true if it should truly be invalidated. This allows for more
+  /// fine-tuned invalidation in cases where an analysis wasn't explicitly
+  /// marked preserved, but may be preserved(or invalidated) based upon other
+  /// properties such as analyses sets.
+  virtual bool isInvalidated(const PreservedAnalyses &pa) = 0;
 };
 
 /// A derived analysis model used to hold a specific analysis object.
@@ -87,6 +114,12 @@ template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
   explicit AnalysisModel(Args &&... args)
       : analysis(std::forward<Args>(args)...) {}
 
+  /// A hook used to query analyses for invalidation.
+  bool isInvalidated(const PreservedAnalyses &pa) final {
+    return analysis_impl::isInvalidated(analysis, pa);
+  }
+
+  /// The actual analysis object.
   AnalysisT analysis;
 };
 
@@ -147,11 +180,11 @@ public:
 
   /// Invalidate any cached analyses based upon the given set of preserved
   /// analyses.
-  void invalidate(const detail::PreservedAnalyses &pa) {
-    // Remove any analyses not marked as preserved.
+  void invalidate(const PreservedAnalyses &pa) {
+    // Remove any analyses that were invalidated.
     for (auto it = analyses.begin(), e = analyses.end(); it != e;) {
       auto curIt = it++;
-      if (!pa.isPreserved(curIt->first))
+      if (curIt->second->isInvalidated(pa))
         analyses.erase(curIt);
     }
   }
@@ -170,7 +203,7 @@ struct NestedAnalysisMap {
   Operation *getOperation() const { return analyses.getOperation(); }
 
   /// Invalidate any non preserved analyses.
-  void invalidate(const detail::PreservedAnalyses &pa);
+  void invalidate(const PreservedAnalyses &pa);
 
   /// The cached analyses for nested operations.
   llvm::DenseMap<Operation *, std::unique_ptr<NestedAnalysisMap>> childAnalyses;
@@ -195,6 +228,8 @@ class AnalysisManager {
                                             const AnalysisManager *>;
 
 public:
+  using PreservedAnalyses = detail::PreservedAnalyses;
+
   // Query for a cached analysis on the given parent operation. The analysis may
   // not exist and if it does it may be out-of-date.
   template <typename AnalysisT>
@@ -240,7 +275,7 @@ public:
   AnalysisManager slice(Operation *op);
 
   /// Invalidate any non preserved analyses,
-  void invalidate(const detail::PreservedAnalyses &pa) { impl->invalidate(pa); }
+  void invalidate(const PreservedAnalyses &pa) { impl->invalidate(pa); }
 
   /// Clear any held analyses.
   void clear() {
index d55c47d..790ad9c 100644 (file)
@@ -114,4 +114,37 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
   EXPECT_FALSE(am.getCachedChildAnalysis<OtherAnalysis>(func1).hasValue());
 }
 
+/// Test analyses with custom invalidation logic.
+struct TestAnalysisSet {};
+
+struct CustomInvalidatingAnalysis {
+  CustomInvalidatingAnalysis(Operation *) {}
+
+  bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
+    return !pa.isPreserved<TestAnalysisSet>();
+  }
+};
+
+TEST(AnalysisManagerTest, CustomInvalidation) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  // Create a function and a module.
+  OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+  ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+  AnalysisManager am = mam;
+
+  detail::PreservedAnalyses pa;
+
+  // Check that the analysis is invalidated properly.
+  am.getAnalysis<CustomInvalidatingAnalysis>();
+  am.invalidate(pa);
+  EXPECT_FALSE(am.getCachedAnalysis<CustomInvalidatingAnalysis>().hasValue());
+
+  // Check that the analysis is preserved properly.
+  am.getAnalysis<CustomInvalidatingAnalysis>();
+  pa.preserve<TestAnalysisSet>();
+  am.invalidate(pa);
+  EXPECT_TRUE(am.getCachedAnalysis<CustomInvalidatingAnalysis>().hasValue());
+}
 } // end namespace