[mlir][LLVM] Handle access groups during inlining
authorMarkus Böck <markus.bock+llvm@nextsilicon.com>
Thu, 20 Jul 2023 07:53:38 +0000 (09:53 +0200)
committerMarkus Böck <markus.bock+llvm@nextsilicon.com>
Thu, 20 Jul 2023 08:45:15 +0000 (10:45 +0200)
Handling access groups is luckily rather trivial: Any access groups from the call instruction are simply appended to any memory operations.
This is similar to one of the steps when handling alias scopes.
This patch nevertheless implements it as a separate function purely for readability purposes as it uses a different interface than alias scopes.

Differential Revision: https://reviews.llvm.org/D155795

mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
mlir/test/Dialect/LLVMIR/inlining.mlir

index 61f13b3..c2e319e 100644 (file)
@@ -238,6 +238,27 @@ static void handleAliasScopes(Operation *call,
   appendCallOpAliasScopes(call, inlinedBlocks);
 }
 
+/// Appends any access groups of the call operation to any inlined memory
+/// operation.
+static void handleAccessGroups(Operation *call,
+                               iterator_range<Region::iterator> inlinedBlocks) {
+  auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
+  if (!callAccessGroupInterface)
+    return;
+
+  auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
+  if (!accessGroups)
+    return;
+
+  // Simply append the call op's access groups to any operation implementing
+  // AccessGroupOpInterface.
+  for (Block &block : inlinedBlocks)
+    for (auto accessGroupOpInterface :
+         block.getOps<LLVM::AccessGroupOpInterface>())
+      accessGroupOpInterface.setAccessGroups(concatArrayAttr(
+          accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
+}
+
 /// If `requestedAlignment` is higher than the alignment specified on `alloca`,
 /// realigns `alloca` if this does not exceed the natural stack alignment.
 /// Returns the post-alignment of `alloca`, whether it was realigned or not.
@@ -433,16 +454,6 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
   bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final {
     if (isPure(op))
       return true;
-    // Some attributes on memory operations require handling during
-    // inlining. Since this is not yet implemented, refuse to inline memory
-    // operations that have any of these attributes.
-    if (auto iface = dyn_cast<LLVM::AccessGroupOpInterface>(op)) {
-      if (iface.getAccessGroupsOrNull()) {
-        LLVM_DEBUG(llvm::dbgs()
-                   << "Cannot inline: unhandled access group metadata\n");
-        return false;
-      }
-    }
     // clang-format off
     if (isa<LLVM::AllocaOp,
             LLVM::AssumeOp,
@@ -525,6 +536,7 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
       iterator_range<Region::iterator> inlinedBlocks) const override {
     handleInlinedAllocas(call, inlinedBlocks);
     handleAliasScopes(call, inlinedBlocks);
+    handleAccessGroups(call, inlinedBlocks);
   }
 
   // Keeping this (immutable) state on the interface allows us to look up
index 7ad92ad..b22bfb4 100644 (file)
@@ -53,42 +53,6 @@ func.func @test_inline(%ptr : !llvm.ptr) -> i32 {
 }
 
 // -----
-
-#group = #llvm.access_group<id = distinct[0]<>>
-
-llvm.func @inlinee(%ptr : !llvm.ptr) -> i32 {
-  %0 = llvm.load %ptr { access_groups = [#group] } : !llvm.ptr -> i32
-  llvm.return %0 : i32
-}
-
-// CHECK-LABEL: func @test_not_inline
-llvm.func @test_not_inline(%ptr : !llvm.ptr) -> i32 {
-  // CHECK-NEXT: llvm.call @inlinee
-  %0 = llvm.call @inlinee(%ptr) : (!llvm.ptr) -> (i32)
-  llvm.return %0 : i32
-}
-
-// -----
-
-#group = #llvm.access_group<id = distinct[0]<>>
-
-func.func private @with_mem_attr(%ptr : !llvm.ptr) {
-  %0 = llvm.mlir.constant(42 : i32) : i32
-  // Do not inline load/store operations that carry attributes requiring
-  // handling while inlining, until this is supported by the inliner.
-  llvm.store %0, %ptr { access_groups = [#group] }: i32, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: func.func @test_not_inline
-// CHECK-NEXT: call @with_mem_attr
-// CHECK-NEXT: return
-func.func @test_not_inline(%ptr : !llvm.ptr) {
-  call @with_mem_attr(%ptr) : (!llvm.ptr) -> ()
-  return
-}
-
-// -----
 // Check that llvm.return is correctly handled
 
 func.func @func(%arg0 : i32) -> i32  {
@@ -584,3 +548,47 @@ llvm.func @test_disallow_arg_attr(%ptr : !llvm.ptr) {
   llvm.call @disallowed_arg_attr(%ptr) : (!llvm.ptr) -> ()
   llvm.return
 }
+
+// -----
+
+#callee = #llvm.access_group<id = distinct[0]<>>
+#caller = #llvm.access_group<id = distinct[1]<>>
+
+llvm.func @inlinee(%ptr : !llvm.ptr) -> i32 {
+  %0 = llvm.load %ptr { access_groups = [#callee] } : !llvm.ptr -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-DAG: #[[$CALLEE:.*]] = #llvm.access_group<id = {{.*}}>
+// CHECK-DAG: #[[$CALLER:.*]] = #llvm.access_group<id = {{.*}}>
+
+// CHECK-LABEL: func @caller
+// CHECK: llvm.load
+// CHECK-SAME: access_groups = [#[[$CALLEE]], #[[$CALLER]]]
+llvm.func @caller(%ptr : !llvm.ptr) -> i32 {
+  %0 = llvm.call @inlinee(%ptr) { access_groups = [#caller] } : (!llvm.ptr) -> (i32)
+  llvm.return %0 : i32
+}
+
+// -----
+
+#caller = #llvm.access_group<id = distinct[1]<>>
+
+llvm.func @inlinee(%ptr : !llvm.ptr) -> i32 {
+  %0 = llvm.load %ptr : !llvm.ptr -> i32
+  llvm.return %0 : i32
+}
+
+// CHECK-DAG: #[[$CALLER:.*]] = #llvm.access_group<id = {{.*}}>
+
+// CHECK-LABEL: func @caller
+// CHECK: llvm.load
+// CHECK-SAME: access_groups = [#[[$CALLER]]]
+// CHECK: llvm.store
+// CHECK-SAME: access_groups = [#[[$CALLER]]]
+llvm.func @caller(%ptr : !llvm.ptr) -> i32 {
+  %c5 = llvm.mlir.constant(5 : i32) : i32
+  %0 = llvm.call @inlinee(%ptr) { access_groups = [#caller] } : (!llvm.ptr) -> (i32)
+  llvm.store %c5, %ptr { access_groups = [#caller] } : i32, !llvm.ptr
+  llvm.return %0 : i32
+}