[mlir][llvm] Store memory op metadata using op attributes.
authorTobias Gysi <tobias.gysi@nextsilicon.com>
Fri, 10 Feb 2023 14:22:54 +0000 (15:22 +0100)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Fri, 10 Feb 2023 14:27:25 +0000 (15:27 +0100)
The revision introduces operation attributes to store tbaa metadata on
load and store operations rather than relying using dialect attributes.
At the same time, the change also ensures the provided getters and
setters instead are used instead of a string based lookup. The latter
is done for the tbaa, access groups, and alias scope attributes.

The goal of this change is to ensure the metadata attributes are only
placed on operations that have the corresponding operation attributes.
This is imported since only these operations later on translate these
attributes to LLVM IR. Dialect attributes placed on other operations
are lost during the translation.

Reviewed By: vzakhari, Dinistro

Differential Revision: https://reviews.llvm.org/D143654

12 files changed:
flang/lib/Optimizer/CodeGen/TBAABuilder.cpp
flang/test/Fir/tbaa.fir
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir
mlir/test/Target/LLVMIR/Import/import-failure.ll
mlir/test/Target/LLVMIR/tbaa.mlir

index 2d206ed..c420818 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "TBAABuilder.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
@@ -159,9 +160,13 @@ void TBAABuilder::attachTBAATag(Operation *op, Type baseFIRType,
   else
     tbaaTagSym = getDataAccessTag(baseFIRType, accessFIRType, gep);
 
-  if (tbaaTagSym)
-    op->setAttr(LLVMDialect::getTBAAAttrName(),
-                ArrayAttr::get(op->getContext(), tbaaTagSym));
+  if (!tbaaTagSym)
+    return;
+
+  auto tbaaAttr = ArrayAttr::get(op->getContext(), tbaaTagSym);
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<LoadOp, StoreOp>([&](auto memOp) { memOp.setTbaaAttr(tbaaAttr); })
+      .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
 }
 
 } // namespace fir
index ac94ea9..66261e6 100644 (file)
@@ -28,10 +28,10 @@ module {
 // CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_5:.*]] = llvm.mlir.constant(10 : i32) : i32
 // CHECK:           %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<struct<()>>>
-// CHECK:           %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<ptr<struct<()>>>
+// CHECK:           %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<ptr<struct<()>>>
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_9:.*]] = llvm.getelementptr %[[VAL_0]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_10:.*]] = llvm.load %[[VAL_9]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_10:.*]] = llvm.load %[[VAL_9]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_11:.*]] = llvm.mul %[[VAL_4]], %[[VAL_10]]  : i64
 // CHECK:           %[[VAL_12:.*]] = llvm.add %[[VAL_11]], %[[VAL_8]]  : i64
 // CHECK:           %[[VAL_13:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
@@ -40,11 +40,11 @@ module {
 // CHECK:           %[[VAL_16:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_17:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_0]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK:           %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK:           %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
 // CHECK:           %[[VAL_20:.*]] = llvm.getelementptr %[[VAL_0]][0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_21:.*]] = llvm.load %[[VAL_20]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_21:.*]] = llvm.load %[[VAL_20]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_0]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
 // CHECK:           %[[VAL_24:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_21]], %[[VAL_24]][1] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_26:.*]] = llvm.mlir.constant(20180515 : i32) : i32
@@ -64,15 +64,15 @@ module {
 // CHECK:           %[[VAL_40:.*]] = llvm.insertvalue %[[VAL_39]], %[[VAL_38]][7] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_41:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<struct<()>> to !llvm.ptr<struct<()>>
 // CHECK:           %[[VAL_42:.*]] = llvm.insertvalue %[[VAL_41]], %[[VAL_40]][0] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
-// CHECK:           llvm.store %[[VAL_42]], %[[VAL_2]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
+// CHECK:           llvm.store %[[VAL_42]], %[[VAL_2]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
 // CHECK:           %[[VAL_43:.*]] = llvm.getelementptr %[[VAL_2]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i8>
-// CHECK:           %[[VAL_44:.*]] = llvm.load %[[VAL_43]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i8>
+// CHECK:           %[[VAL_44:.*]] = llvm.load %[[VAL_43]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i8>
 // CHECK:           %[[VAL_45:.*]] = llvm.icmp "eq" %[[VAL_44]], %[[VAL_3]] : i8
 // CHECK:           llvm.cond_br %[[VAL_45]], ^bb1, ^bb2
 // CHECK:         ^bb1:
 // CHECK:           %[[VAL_46:.*]] = llvm.getelementptr %[[VAL_2]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i32>>
-// CHECK:           %[[VAL_47:.*]] = llvm.load %[[VAL_46]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
-// CHECK:           llvm.store %[[VAL_5]], %[[VAL_47]] {llvm.tbaa = [@__flang_tbaa::@[[DATAT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_47:.*]] = llvm.load %[[VAL_46]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
+// CHECK:           llvm.store %[[VAL_5]], %[[VAL_47]] {tbaa = [@__flang_tbaa::@[[DATAT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           llvm.br ^bb2
 // CHECK:         ^bb2:
 // CHECK:           llvm.return
@@ -133,24 +133,24 @@ module {
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.addressof @_QQcl.2E2F64756D6D792E66393000 : !llvm.ptr<array<12 x i8>>
 // CHECK:           %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr<array<12 x i8>> to !llvm.ptr<i8>
 // CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_9]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr<i8>, i32) -> !llvm.ptr<i8>
-// CHECK:           %[[VAL_11:.*]] = llvm.load %[[VAL_7]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
-// CHECK:           llvm.store %[[VAL_11]], %[[VAL_3]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK:           %[[VAL_11:.*]] = llvm.load %[[VAL_7]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK:           llvm.store %[[VAL_11]], %[[VAL_3]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
 // CHECK:           %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_13:.*]] = llvm.load %[[VAL_12]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_13:.*]] = llvm.load %[[VAL_12]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_15:.*]] = llvm.load %[[VAL_14]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_15:.*]] = llvm.load %[[VAL_14]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_16:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_17:.*]] = llvm.load %[[VAL_16]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_17:.*]] = llvm.load %[[VAL_16]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_3]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK:           %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK:           %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
 // CHECK:           %[[VAL_20:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_21:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_24:.*]] = llvm.getelementptr %[[VAL_3]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_25:.*]] = llvm.load %[[VAL_24]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_25:.*]] = llvm.load %[[VAL_24]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
 // CHECK:           %[[VAL_26:.*]] = llvm.getelementptr %[[VAL_3]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK:           %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK:           %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
 // CHECK:           %[[VAL_28:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_29:.*]] = llvm.insertvalue %[[VAL_23]], %[[VAL_28]][1] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_30:.*]] = llvm.mlir.constant(20180515 : i32) : i32
@@ -169,13 +169,13 @@ module {
 // CHECK:           %[[VAL_43:.*]] = llvm.bitcast %[[VAL_27]] : !llvm.ptr<i8> to !llvm.ptr<i8>
 // CHECK:           %[[VAL_44:.*]] = llvm.insertvalue %[[VAL_43]], %[[VAL_42]][8] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_45:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_46:.*]] = llvm.load %[[VAL_45]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_46:.*]] = llvm.load %[[VAL_45]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_47:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_48:.*]] = llvm.load %[[VAL_47]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_48:.*]] = llvm.load %[[VAL_47]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_49:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_50:.*]] = llvm.load %[[VAL_49]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_50:.*]] = llvm.load %[[VAL_49]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_51:.*]] = llvm.getelementptr %[[VAL_3]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<struct<()>>>
-// CHECK:           %[[VAL_52:.*]] = llvm.load %[[VAL_51]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<struct<()>>>
+// CHECK:           %[[VAL_52:.*]] = llvm.load %[[VAL_51]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<struct<()>>>
 // CHECK:           %[[VAL_53:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_54:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:           %[[VAL_55:.*]] = llvm.icmp "eq" %[[VAL_48]], %[[VAL_53]] : i64
@@ -185,7 +185,7 @@ module {
 // CHECK:           %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_60:.*]] = llvm.bitcast %[[VAL_52]] : !llvm.ptr<struct<()>> to !llvm.ptr<struct<()>>
 // CHECK:           %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_60]], %[[VAL_59]][0] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
-// CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
 // CHECK:           %[[VAL_62:.*]] = llvm.bitcast %[[VAL_1]] : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>> to !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
 // CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_62]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr<i8>, !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> i1
 // CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr<i8>) -> i32
@@ -253,7 +253,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           llvm.return %[[VAL_2]] : i32
 // CHECK:         }
 
@@ -275,7 +275,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i1 {
 // CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK:           %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i32
 // CHECK:           llvm.return %[[VAL_4]] : i1
@@ -299,7 +299,7 @@ func.func @tbaa(%arg0: !fir.box<f32>) -> i32 {
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                               %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f32>, i64, i32, i8, i8, i8, i8)>>) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 1] : (!llvm.ptr<struct<(ptr<f32>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           llvm.return %[[VAL_2]] : i32
 // CHECK:         }
 
@@ -321,7 +321,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i1 {
 // CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:           %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]]  : i32
 // CHECK:           %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -353,11 +353,11 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<?xi32>>) {
 // CHECK:           %[[VAL_4:.*]] = llvm.sub %[[VAL_1]], %[[VAL_2]]  : i64
 // CHECK:           %[[VAL_5:.*]] = llvm.mul %[[VAL_4]], %[[VAL_2]]  : i64
 // CHECK:           %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_8:.*]] = llvm.mul %[[VAL_5]], %[[VAL_7]]  : i64
 // CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_8]], %[[VAL_3]]  : i64
 // CHECK:           %[[VAL_10:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>>) -> !llvm.ptr<ptr<i32>>
-// CHECK:           %[[VAL_11:.*]] = llvm.load %[[VAL_10]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
+// CHECK:           %[[VAL_11:.*]] = llvm.load %[[VAL_10]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
 // CHECK:           %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr<i32> to !llvm.ptr<i8>
 // CHECK:           %[[VAL_13:.*]] = llvm.getelementptr %[[VAL_12]]{{\[}}%[[VAL_9]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
 // CHECK:           %[[VAL_14:.*]] = llvm.bitcast %[[VAL_13]] : !llvm.ptr<i8> to !llvm.ptr<i32>
index 4fd7586..9c9ebfc 100644 (file)
@@ -35,11 +35,7 @@ def LLVM_Dialect : Dialect {
   let extraClassDeclaration = [{
     /// Name of the data layout attributes.
     static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }
-    static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; }
-    static StringRef getAliasScopesAttrName() { return "alias_scopes"; }
     static StringRef getLoopAttrName() { return "llvm.loop"; }
-    static StringRef getAccessGroupsAttrName() { return "access_groups"; }
-    static StringRef getTBAAAttrName() { return "llvm.tbaa"; }
 
     /// Names of llvm parameter attributes.
     static StringRef getAlignAttrName() { return "llvm.align"; }
index 39b79f1..8e4b834 100644 (file)
@@ -350,6 +350,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
                    OptionalAttr<SymbolRefArrayAttr>:$access_groups,
                    OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
                    OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
+                   OptionalAttr<SymbolRefArrayAttr>:$tbaa,
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
   let results = (outs LLVM_LoadableType:$res);
@@ -390,6 +391,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpBase {
                    OptionalAttr<SymbolRefArrayAttr>:$access_groups,
                    OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
                    OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
+                   OptionalAttr<SymbolRefArrayAttr>:$tbaa,
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
   string llvmInstName = "Store";
index 59b29f5..2b08d96 100644 (file)
@@ -120,15 +120,14 @@ public:
   /// in these blocks.
   void forgetMapping(Region &region);
 
-  /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
-  /// dialect access group operation.
-  llvm::MDNode *getAccessGroup(Operation &opInst,
+  /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+  /// LLVM dialect access group operation.
+  llvm::MDNode *getAccessGroup(Operation *op,
                                SymbolRefAttr accessGroupRef) const;
 
-  /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
-  /// dialect alias scope operation
-  llvm::MDNode *getAliasScope(Operation &opInst,
-                              SymbolRefAttr aliasScopeRef) const;
+  /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+  /// LLVM dialect alias scope operation
+  llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const;
 
   // Sets LLVM metadata for memory operations that are in a parallel loop.
   void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
@@ -287,9 +286,9 @@ private:
   /// metadata nodes for them and their domains.
   LogicalResult createAliasScopeMetadata();
 
-  /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
-  /// dialect TBAATagOp operation.
-  llvm::MDNode *getTBAANode(Operation &memOp, SymbolRefAttr tagRef) const;
+  /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+  /// LLVM dialect TBAATagOp operation.
+  llvm::MDNode *getTBAANode(Operation *op, SymbolRefAttr tagRef) const;
 
   /// Process tbaa LLVM Metadata operations and create LLVM
   /// metadata nodes for them.
index 1ef24c6..1dc501f 100644 (file)
@@ -668,53 +668,60 @@ Type LLVM::GEPOp::getSourceElementType() {
 // Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
 
-LogicalResult verifySymbolAttribute(
-    Operation *op, StringRef attributeName,
+/// Verifies the given array attribute contains symbol references and checks the
+/// referenced symbol types using the provided verification function.
+LogicalResult verifyMemOpSymbolRefs(
+    Operation *op, StringRef name, ArrayAttr symbolRefs,
     llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)>
         verifySymbolType) {
-  if (Attribute attribute = op->getAttr(attributeName)) {
-    // Verify that the attribute is a symbol ref array attribute,
-    // because this constraint is not verified for all attribute
-    // names processed here (e.g. 'tbaa'). This verification
-    // is redundant in some cases.
-    if (!(attribute.isa<ArrayAttr>() &&
-          llvm::all_of(attribute.cast<ArrayAttr>(), [&](Attribute attr) {
-            return attr && attr.isa<SymbolRefAttr>();
-          })))
-      return op->emitOpError("attribute '")
-             << attributeName
-             << "' failed to satisfy constraint: symbol ref array attribute";
-
-    for (SymbolRefAttr symbolRef :
-         attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
-      StringAttr metadataName = symbolRef.getRootReference();
-      StringAttr symbolName = symbolRef.getLeafReference();
-      // We want @metadata::@symbol, not just @symbol
-      if (metadataName == symbolName) {
-        return op->emitOpError() << "expected '" << symbolRef
-                                 << "' to specify a fully qualified reference";
-      }
-      auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-          op->getParentOp(), metadataName);
-      if (!metadataOp)
-        return op->emitOpError()
-               << "expected '" << symbolRef << "' to reference a metadata op";
-      Operation *symbolOp =
-          SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
-      if (!symbolOp)
-        return op->emitOpError()
-               << "expected '" << symbolRef << "' to be a valid reference";
-      if (failed(verifySymbolType(symbolOp, symbolRef))) {
-        return failure();
-      }
+  assert(symbolRefs && "expected a non-null attribute");
+
+  // Verify that the attribute is a symbol ref array attribute,
+  // because this constraint is not verified for all attribute
+  // names processed here (e.g. 'tbaa'). This verification
+  // is redundant in some cases.
+  if (!llvm::all_of(symbolRefs, [](Attribute attr) {
+        return attr && attr.isa<SymbolRefAttr>();
+      }))
+    return op->emitOpError("attribute '")
+           << name
+           << "' failed to satisfy constraint: symbol ref array attribute";
+
+  for (SymbolRefAttr symbolRef : symbolRefs.getAsRange<SymbolRefAttr>()) {
+    StringAttr metadataName = symbolRef.getRootReference();
+    StringAttr symbolName = symbolRef.getLeafReference();
+    // We want @metadata::@symbol, not just @symbol
+    if (metadataName == symbolName) {
+      return op->emitOpError() << "expected '" << symbolRef
+                               << "' to specify a fully qualified reference";
+    }
+    auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+        op->getParentOp(), metadataName);
+    if (!metadataOp)
+      return op->emitOpError()
+             << "expected '" << symbolRef << "' to reference a metadata op";
+    Operation *symbolOp =
+        SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
+    if (!symbolOp)
+      return op->emitOpError()
+             << "expected '" << symbolRef << "' to be a valid reference";
+    if (failed(verifySymbolType(symbolOp, symbolRef))) {
+      return failure();
     }
   }
+
   return success();
 }
 
-// Verifies that metadata ops are wired up properly.
+/// Verifies the given array attribute contains symbol references that point to
+/// metadata operations of the given type.
 template <typename OpTy>
-static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
+static LogicalResult
+verifyMemOpSymbolRefsPointTo(Operation *op, StringRef name,
+                             std::optional<ArrayAttr> symbolRefs) {
+  if (!symbolRefs)
+    return success();
+
   auto verifySymbolType = [op](Operation *symbolOp,
                                SymbolRefAttr symbolRef) -> LogicalResult {
     if (!isa<OpTy>(symbolOp)) {
@@ -724,35 +731,33 @@ static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
     }
     return success();
   };
-
-  return verifySymbolAttribute(op, attributeName, verifySymbolType);
+  return verifyMemOpSymbolRefs(op, name, *symbolRefs, verifySymbolType);
 }
 
-static LogicalResult verifyMemoryOpMetadata(Operation *op) {
-  // access_groups
-  if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
-          op, LLVMDialect::getAccessGroupsAttrName())))
+/// Verifies the types of the metadata operations referenced by aliasing and
+/// access group metadata.
+template <typename OpTy>
+LogicalResult verifyMemOpMetadata(OpTy memOp) {
+  if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AccessGroupMetadataOp>(
+          memOp, memOp.getAccessGroupsAttrName(), memOp.getAccessGroups())))
     return failure();
 
-  // alias_scopes
-  if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
-          op, LLVMDialect::getAliasScopesAttrName())))
+  if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AliasScopeMetadataOp>(
+          memOp, memOp.getAliasScopesAttrName(), memOp.getAliasScopes())))
     return failure();
 
-  // noalias_scopes
-  if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
-          op, LLVMDialect::getNoAliasScopesAttrName())))
+  if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AliasScopeMetadataOp>(
+          memOp, memOp.getNoaliasScopesAttrName(), memOp.getNoaliasScopes())))
     return failure();
 
-  // tbaa
-  if (failed(verifyOpMetadata<LLVM::TBAATagOp>(op,
-                                               LLVMDialect::getTBAAAttrName())))
+  if (failed(verifyMemOpSymbolRefsPointTo<LLVM::TBAATagOp>(
+          memOp, memOp.getTbaaAttrName(), memOp.getTbaa())))
     return failure();
 
   return success();
 }
 
-LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
+LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); }
 
 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
                    Value addr, unsigned alignment, bool isVolatile,
@@ -828,7 +833,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
 // Builder, printer and parser for LLVM::StoreOp.
 //===----------------------------------------------------------------------===//
 
-LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
+LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); }
 
 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
                     Value addr, unsigned alignment, bool isVolatile,
index 207ef8b..7d36d40 100644 (file)
@@ -76,6 +76,30 @@ static ArrayRef<unsigned> getSupportedMetadataImpl() {
   return convertibleMetadata;
 }
 
+namespace {
+/// Helper class to attach metadata attributes to specific operation types. It
+/// specializes TypeSwitch to take an Operation and return a LogicalResult.
+template <typename... OpTys>
+struct AttributeSetter {
+  AttributeSetter(Operation *op) : op(op) {}
+
+  /// Calls `attachFn` on the provided Operation if it has one of
+  /// the given operation types. Returns failure otherwise.
+  template <typename CallableT>
+  LogicalResult apply(CallableT &&attachFn) {
+    return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+        .Case<OpTys...>([&attachFn](auto concreteOp) {
+          attachFn(concreteOp);
+          return success();
+        })
+        .Default([&](auto) { return failure(); });
+  }
+
+private:
+  Operation *op;
+};
+} // namespace
+
 /// Converts the given profiling metadata `node` to an MLIR profiling attribute
 /// and attaches it to the imported operation if the translation succeeds.
 /// Returns failure otherwise.
@@ -129,16 +153,10 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     branchWeights.push_back(branchWeight->getZExtValue());
   }
 
-  // Attach the branch weights to the operations that support it.
-  return llvm::TypeSwitch<Operation *, LogicalResult>(op)
-      .Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
+  return AttributeSetter<CondBrOp, SwitchOp, CallOp, InvokeOp>(op).apply(
+      [&](auto branchWeightOp) {
         branchWeightOp.setBranchWeightsAttr(
             builder.getI32VectorAttr(branchWeights));
-        return success();
-      })
-      .Default([op](auto) {
-        return op->emitWarning()
-               << op->getName() << " does not support branch weights";
       });
 }
 
@@ -151,9 +169,9 @@ static LogicalResult setTBAAAttr(const llvm::MDNode *node, Operation *op,
   if (!tbaaTagSym)
     return failure();
 
-  op->setAttr(LLVMDialect::getTBAAAttrName(),
-              ArrayAttr::get(op->getContext(), tbaaTagSym));
-  return success();
+  return AttributeSetter<LoadOp, StoreOp>(op).apply([&](auto memOp) {
+    memOp.setTbaaAttr(ArrayAttr::get(memOp.getContext(), tbaaTagSym));
+  });
 }
 
 /// Looks up all the symbol references pointing to the access group operations
@@ -169,9 +187,10 @@ static LogicalResult setAccessGroupAttr(const llvm::MDNode *node, Operation *op,
 
   SmallVector<Attribute> accessGroupAttrs(accessGroups->begin(),
                                           accessGroups->end());
-  op->setAttr(LLVMDialect::getAccessGroupsAttrName(),
-              ArrayAttr::get(op->getContext(), accessGroupAttrs));
-  return success();
+  return AttributeSetter<LoadOp, StoreOp>(op).apply([&](auto memOp) {
+    memOp.setAccessGroupsAttr(
+        ArrayAttr::get(memOp.getContext(), accessGroupAttrs));
+  });
 }
 
 /// Converts the given loop metadata node to an MLIR loop annotation attribute
index 5864b48..b02433b 100644 (file)
@@ -210,7 +210,7 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
         llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
     for (SymbolRefAttr accessGroupRef : parallelAccessGroups)
       parallelAccess.push_back(
-          moduleTranslation.getAccessGroup(*op, accessGroupRef));
+          moduleTranslation.getAccessGroup(op, accessGroupRef));
     metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
   }
 
index aa5627e..3834bf0 100644 (file)
@@ -986,12 +986,12 @@ LogicalResult ModuleTranslation::convertFunctions() {
 }
 
 llvm::MDNode *
-ModuleTranslation::getAccessGroup(Operation &opInst,
+ModuleTranslation::getAccessGroup(Operation *op,
                                   SymbolRefAttr accessGroupRef) const {
   auto metadataName = accessGroupRef.getRootReference();
   auto accessGroupName = accessGroupRef.getLeafReference();
   auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-      opInst.getParentOp(), metadataName);
+      op->getParentOp(), metadataName);
   auto *accessGroupOp =
       SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
   return accessGroupMetadataMapping.lookup(accessGroupOp);
@@ -1010,23 +1010,28 @@ LogicalResult ModuleTranslation::createAccessGroupMetadata() {
 
 void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
                                                 llvm::Instruction *inst) {
-  auto accessGroups =
-      op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
-  if (accessGroups && !accessGroups.empty()) {
+  auto populateGroupsMetadata = [&](std::optional<ArrayAttr> groupRefs) {
+    if (!groupRefs || groupRefs->empty())
+      return;
+
     llvm::Module *module = inst->getModule();
-    SmallVector<llvm::Metadata *> metadatas;
-    for (SymbolRefAttr accessGroupRef :
-         accessGroups.getAsRange<SymbolRefAttr>())
-      metadatas.push_back(getAccessGroup(*op, accessGroupRef));
-
-    llvm::MDNode *unionMD = nullptr;
-    if (metadatas.size() == 1)
-      unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
-    else if (metadatas.size() >= 2)
-      unionMD = llvm::MDNode::get(module->getContext(), metadatas);
-
-    inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
-  }
+    SmallVector<llvm::Metadata *> groupMDs;
+    for (SymbolRefAttr groupRef : groupRefs->getAsRange<SymbolRefAttr>())
+      groupMDs.push_back(getAccessGroup(op, groupRef));
+
+    llvm::MDNode *node = nullptr;
+    if (groupMDs.size() == 1)
+      node = llvm::cast<llvm::MDNode>(groupMDs.front());
+    else if (groupMDs.size() >= 2)
+      node = llvm::MDNode::get(module->getContext(), groupMDs);
+
+    inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
+  };
+
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<LoadOp, StoreOp>(
+          [&](auto memOp) { populateGroupsMetadata(memOp.getAccessGroups()); })
+      .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
 }
 
 LogicalResult ModuleTranslation::createAliasScopeMetadata() {
@@ -1067,12 +1072,12 @@ LogicalResult ModuleTranslation::createAliasScopeMetadata() {
 }
 
 llvm::MDNode *
-ModuleTranslation::getAliasScope(Operation &opInst,
+ModuleTranslation::getAliasScope(Operation *op,
                                  SymbolRefAttr aliasScopeRef) const {
   StringAttr metadataName = aliasScopeRef.getRootReference();
   StringAttr scopeName = aliasScopeRef.getLeafReference();
   auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-      opInst.getParentOp(), metadataName);
+      op->getParentOp(), metadataName);
   Operation *aliasScopeOp =
       SymbolTable::lookupNearestSymbolFrom(metadataOp, scopeName);
   return aliasScopeMetadataMapping.lookup(aliasScopeOp);
@@ -1080,50 +1085,63 @@ ModuleTranslation::getAliasScope(Operation &opInst,
 
 void ModuleTranslation::setAliasScopeMetadata(Operation *op,
                                               llvm::Instruction *inst) {
-  auto populateScopeMetadata = [this, op, inst](StringRef attrName,
-                                                StringRef llvmMetadataName) {
-    auto scopes = op->getAttrOfType<ArrayAttr>(attrName);
-    if (!scopes || scopes.empty())
+  auto populateScopeMetadata = [&](std::optional<ArrayAttr> scopeRefs,
+                                   unsigned kind) {
+    if (!scopeRefs || scopeRefs->empty())
       return;
     llvm::Module *module = inst->getModule();
     SmallVector<llvm::Metadata *> scopeMDs;
-    for (SymbolRefAttr scopeRef : scopes.getAsRange<SymbolRefAttr>())
-      scopeMDs.push_back(getAliasScope(*op, scopeRef));
-    llvm::MDNode *unionMD = llvm::MDNode::get(module->getContext(), scopeMDs);
-    inst->setMetadata(module->getMDKindID(llvmMetadataName), unionMD);
+    for (SymbolRefAttr scopeRef : scopeRefs->getAsRange<SymbolRefAttr>())
+      scopeMDs.push_back(getAliasScope(op, scopeRef));
+    llvm::MDNode *node = llvm::MDNode::get(module->getContext(), scopeMDs);
+    inst->setMetadata(kind, node);
   };
 
-  populateScopeMetadata(LLVMDialect::getAliasScopesAttrName(), "alias.scope");
-  populateScopeMetadata(LLVMDialect::getNoAliasScopesAttrName(), "noalias");
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<LoadOp, StoreOp>([&](auto memOp) {
+        populateScopeMetadata(memOp.getAliasScopes(),
+                              llvm::LLVMContext::MD_alias_scope);
+        populateScopeMetadata(memOp.getNoaliasScopes(),
+                              llvm::LLVMContext::MD_noalias);
+      })
+      .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
 }
 
-llvm::MDNode *ModuleTranslation::getTBAANode(Operation &memOp,
+llvm::MDNode *ModuleTranslation::getTBAANode(Operation *op,
                                              SymbolRefAttr tagRef) const {
   StringAttr metadataName = tagRef.getRootReference();
   StringAttr tagName = tagRef.getLeafReference();
   auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-      memOp.getParentOp(), metadataName);
+      op->getParentOp(), metadataName);
   Operation *tagOp = SymbolTable::lookupNearestSymbolFrom(metadataOp, tagName);
   return tbaaMetadataMapping.lookup(tagOp);
 }
 
 void ModuleTranslation::setTBAAMetadata(Operation *op,
                                         llvm::Instruction *inst) {
-  auto tbaa = op->getAttrOfType<ArrayAttr>(LLVMDialect::getTBAAAttrName());
-  if (!tbaa || tbaa.empty())
-    return;
-  // LLVM IR currently does not support attaching more than one
-  // TBAA access tag to a memory accessing instruction.
-  // It may be useful to support this in future, but for the time being
-  // just ignore the metadata if MLIR operation has multiple access tags.
-  if (tbaa.size() > 1) {
-    op->emitWarning() << "TBAA access tags were not translated, because LLVM "
-                         "IR only supports a single tag per instruction";
-    return;
-  }
-  SymbolRefAttr tagRef = tbaa[0].cast<SymbolRefAttr>();
-  llvm::MDNode *tagNode = getTBAANode(*op, tagRef);
-  inst->setMetadata(llvm::LLVMContext::MD_tbaa, tagNode);
+  auto populateTBAAMetadata = [&](std::optional<ArrayAttr> tagRefs) {
+    if (!tagRefs || tagRefs->empty())
+      return;
+
+    // LLVM IR currently does not support attaching more than one
+    // TBAA access tag to a memory accessing instruction.
+    // It may be useful to support this in future, but for the time being
+    // just ignore the metadata if MLIR operation has multiple access tags.
+    if (tagRefs->size() > 1) {
+      op->emitWarning() << "TBAA access tags were not translated, because LLVM "
+                           "IR only supports a single tag per instruction";
+      return;
+    }
+
+    SymbolRefAttr tagRef = (*tagRefs)[0].cast<SymbolRefAttr>();
+    llvm::MDNode *node = getTBAANode(op, tagRef);
+    inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
+  };
+
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<LoadOp, StoreOp>(
+          [&](auto memOp) { populateTBAAMetadata(memOp.getTbaa()); })
+      .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
 }
 
 LogicalResult ModuleTranslation::createTBAAMetadata() {
index 0513dc7..a747d59 100644 (file)
@@ -8,7 +8,7 @@ module {
   llvm.func @tbaa(%arg0: !llvm.ptr) {
     %0 = llvm.mlir.constant(1 : i8) : i8
     // expected-error@below {{expected '@tbaa_tag_1' to specify a fully qualified reference}}
-    llvm.store %0, %arg0 {llvm.tbaa = [@tbaa_tag_1]} : i8, !llvm.ptr
+    llvm.store %0, %arg0 {tbaa = [@tbaa_tag_1]} : i8, !llvm.ptr
     llvm.return
   }
 }
@@ -17,8 +17,8 @@ module {
 
 llvm.func @tbaa(%arg0: !llvm.ptr) {
   %0 = llvm.mlir.constant(1 : i8) : i8
-  // expected-error@below {{attribute 'llvm.tbaa' failed to satisfy constraint: symbol ref array attribute}}
-  llvm.store %0, %arg0 {llvm.tbaa = ["sym"]} : i8, !llvm.ptr
+  // expected-error@below {{attribute 'tbaa' failed to satisfy constraint: symbol ref array attribute}}
+  llvm.store %0, %arg0 {tbaa = ["sym"]} : i8, !llvm.ptr
   llvm.return
 }
 
@@ -28,7 +28,7 @@ module {
   llvm.func @tbaa(%arg0: !llvm.ptr) {
     %0 = llvm.mlir.constant(1 : i8) : i8
     // expected-error@below {{expected '@metadata::@group1' to resolve to a llvm.tbaa_tag}}
-    llvm.store %0, %arg0 {llvm.tbaa = [@metadata::@group1]} : i8, !llvm.ptr
+    llvm.store %0, %arg0 {tbaa = [@metadata::@group1]} : i8, !llvm.ptr
     llvm.return
   }
   llvm.metadata @metadata {
@@ -42,7 +42,7 @@ module {
   llvm.func @tbaa(%arg0: !llvm.ptr) {
     %0 = llvm.mlir.constant(1 : i8) : i8
     // expected-error@below {{expected '@metadata::@sym' to be a valid reference}}
-    llvm.store %0, %arg0 {llvm.tbaa = [@metadata::@sym]} : i8, !llvm.ptr
+    llvm.store %0, %arg0 {tbaa = [@metadata::@sym]} : i8, !llvm.ptr
     llvm.return
   }
   llvm.metadata @metadata {
@@ -54,7 +54,7 @@ module {
 llvm.func @tbaa(%arg0: !llvm.ptr) {
   %0 = llvm.mlir.constant(1 : i8) : i8
   // expected-error@below {{expected '@tbaa::@sym' to reference a metadata op}}
-  llvm.store %0, %arg0 {llvm.tbaa = [@tbaa::@sym]} : i8, !llvm.ptr
+  llvm.store %0, %arg0 {tbaa = [@tbaa::@sym]} : i8, !llvm.ptr
   llvm.return
 }
 
index 12a605a..3286d3b 100644 (file)
@@ -566,8 +566,6 @@ bb2:
 
 ; // -----
 
-; CHECK:      import-failure.ll
-; CHECK-SAME: warning: llvm.func does not support branch weights
 ; CHECK:      import-failure.ll:{{.*}} warning: unhandled function metadata: !0 = !{!"branch_weights", i32 64}
 define void @cond_br(i1 %arg) !prof !0 {
   ret void
index 26a96e4..84f27be 100644 (file)
@@ -16,11 +16,11 @@ module {
     %1 = llvm.mlir.constant(1 : i32) : i32
     %2 = llvm.getelementptr inbounds %arg1[%0, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg2_t", (i64, i64)>
     // CHECK: load i64, ptr %{{.*}},{{.*}}!tbaa ![[LTAG:[0-9]*]]
-    %3 = llvm.load %2 {llvm.tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> i64
+    %3 = llvm.load %2 {tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> i64
     %4 = llvm.trunc %3 : i64 to i32
     %5 = llvm.getelementptr inbounds %arg0[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg1_t", (i32, i32)>
     // CHECK: store i32 %{{.*}}, ptr %{{.*}},{{.*}}!tbaa ![[STAG:[0-9]*]]
-    llvm.store %4, %5 {llvm.tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
+    llvm.store %4, %5 {tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
     llvm.return
   }
 }
@@ -60,11 +60,11 @@ module {
     %1 = llvm.mlir.constant(1 : i32) : i32
     %2 = llvm.getelementptr inbounds %arg1[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg2_t", (f32, f32)>
     // CHECK: load float, ptr %{{.*}},{{.*}}!tbaa ![[LTAG:[0-9]*]]
-    %3 = llvm.load %2 {llvm.tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> f32
+    %3 = llvm.load %2 {tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> f32
     %4 = llvm.fptosi %3 : f32 to i32
     %5 = llvm.getelementptr inbounds %arg0[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg1_t", (i32, i32)>
     // CHECK: store i32 %{{.*}}, ptr %{{.*}},{{.*}}!tbaa ![[STAG:[0-9]*]]
-    llvm.store %4, %5 {llvm.tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
+    llvm.store %4, %5 {tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
     llvm.return
   }
 }