[mlir] Make `LocalAliasAnalysis` extesible
authorIvan Butygin <ivan.butygin@gmail.com>
Mon, 19 Dec 2022 21:26:07 +0000 (22:26 +0100)
committerIvan Butygin <ivan.butygin@gmail.com>
Wed, 21 Dec 2022 13:15:35 +0000 (14:15 +0100)
This is an alternative to https://reviews.llvm.org/D138761 . Instead of adding ad-hoc attributes to existing `LocalAliasAnalysis`, expose `aliasImpl` method so user can override it.

Differential Revision: https://reviews.llvm.org/D140348

mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/test/Analysis/test-alias-analysis-extending.mlir [new file with mode: 0644]
mlir/test/lib/Analysis/TestAliasAnalysis.cpp

index 9669a3c..0601240 100644 (file)
@@ -28,6 +28,10 @@ public:
 
   /// Return the modify-reference behavior of `op` on `location`.
   ModRefResult getModRef(Operation *op, Value location);
+
+protected:
+  /// Given the two values, return their aliasing behavior.
+  virtual AliasResult aliasImpl(Value lhs, Value rhs);
 };
 } // namespace mlir
 
index 1b7dee9..73ddd81 100644 (file)
@@ -246,7 +246,7 @@ getAllocEffectFor(Value value,
 }
 
 /// Given the two values, return their aliasing behavior.
-static AliasResult aliasImpl(Value lhs, Value rhs) {
+AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
   if (lhs == rhs)
     return AliasResult::MustAlias;
   Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr;
diff --git a/mlir/test/Analysis/test-alias-analysis-extending.mlir b/mlir/test/Analysis/test-alias-analysis-extending.mlir
new file mode 100644 (file)
index 0000000..20f2a7a
--- /dev/null
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-alias-analysis-extending))' -split-input-file -allow-unregistered-dialect 2>&1 | FileCheck %s
+
+// CHECK-LABEL: Testing : "restrict"
+// CHECK-DAG: func.region0#0 <-> func.region0#1: NoAlias
+
+// CHECK-DAG: view1#0 <-> view2#0: NoAlias
+// CHECK-DAG: view1#0 <-> func.region0#0: MustAlias
+// CHECK-DAG: view1#0 <-> func.region0#1: NoAlias
+// CHECK-DAG: view2#0 <-> func.region0#0: NoAlias
+// CHECK-DAG: view2#0 <-> func.region0#1: MustAlias
+func.func @restrict(%arg: memref<?xf32>, %arg1: memref<?xf32> {local_alias_analysis.restrict}) attributes {test.ptr = "func"} {
+  %0 = memref.subview %arg[0][2][1] {test.ptr = "view1"} : memref<?xf32> to memref<2xf32>
+  %1 = memref.subview %arg1[0][2][1] {test.ptr = "view2"} : memref<?xf32> to memref<2xf32>
+  return
+}
index 04b2bc3..b563be4 100644 (file)
@@ -13,6 +13,8 @@
 
 #include "TestAliasAnalysis.h"
 #include "mlir/Analysis/AliasAnalysis.h"
+#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
+#include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 
 using namespace mlir;
@@ -149,14 +151,76 @@ struct TestAliasAnalysisModRefPass
 } // namespace
 
 //===----------------------------------------------------------------------===//
+// Testing LocalAliasAnalysis extending
+//===----------------------------------------------------------------------===//
+
+/// Check if value is function argument.
+static bool isFuncArg(Value val) {
+  auto blockArg = val.dyn_cast<BlockArgument>();
+  if (!blockArg)
+    return false;
+
+  return mlir::isa_and_nonnull<FunctionOpInterface>(
+      blockArg.getOwner()->getParentOp());
+}
+
+/// Check if value has "restrict" attribute. Value must be a function argument.
+static bool isRestrict(Value val) {
+  auto blockArg = val.cast<BlockArgument>();
+  auto func =
+      mlir::cast<FunctionOpInterface>(blockArg.getOwner()->getParentOp());
+  return !!func.getArgAttr(blockArg.getArgNumber(),
+                           "local_alias_analysis.restrict");
+}
+
+namespace {
+/// LocalAliasAnalysis extended to support "restrict" attreibute.
+class LocalAliasAnalysisRestrict : public LocalAliasAnalysis {
+protected:
+  AliasResult aliasImpl(Value lhs, Value rhs) override {
+    if (lhs == rhs)
+      return AliasResult::MustAlias;
+
+    // Assume no aliasing if both values are function arguments and any of them
+    // have restrict attr.
+    if (isFuncArg(lhs) && isFuncArg(rhs))
+      if (isRestrict(lhs) || isRestrict(rhs))
+        return AliasResult::NoAlias;
+
+    return LocalAliasAnalysis::aliasImpl(lhs, rhs);
+  }
+};
+
+/// This pass tests adding additional analysis impls to the AliasAnalysis.
+struct TestAliasAnalysisExtendingPass
+    : public test::TestAliasAnalysisBase,
+      PassWrapper<TestAliasAnalysisExtendingPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisExtendingPass)
+
+  StringRef getArgument() const final {
+    return "test-alias-analysis-extending";
+  }
+  StringRef getDescription() const final {
+    return "Test alias analysis extending.";
+  }
+  void runOnOperation() override {
+    AliasAnalysis aliasAnalysis(getOperation());
+    aliasAnalysis.addAnalysisImplementation(LocalAliasAnalysisRestrict());
+    runAliasAnalysisOnOperation(getOperation(), aliasAnalysis);
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
 // Pass Registration
 //===----------------------------------------------------------------------===//
 
 namespace mlir {
 namespace test {
 void registerTestAliasAnalysisPass() {
-  PassRegistration<TestAliasAnalysisPass>();
+  PassRegistration<TestAliasAnalysisExtendingPass>();
   PassRegistration<TestAliasAnalysisModRefPass>();
+  PassRegistration<TestAliasAnalysisPass>();
 }
 } // namespace test
 } // namespace mlir