#include "TBAABuilder.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
else
tbaaTagSym = getDataAccessTag(baseFIRType, accessFIRType, gep);
- if (tbaaTagSym)
- op->setAttr(LLVMDialect::getTBAAAttrName(),
- ArrayAttr::get(op->getContext(), tbaaTagSym));
+ if (!tbaaTagSym)
+ return;
+
+ auto tbaaAttr = ArrayAttr::get(op->getContext(), tbaaTagSym);
+ llvm::TypeSwitch<Operation *>(op)
+ .Case<LoadOp, StoreOp>([&](auto memOp) { memOp.setTbaaAttr(tbaaAttr); })
+ .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
}
} // namespace fir
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(10 : i32) : i32
// CHECK: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<struct<()>>>
-// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<ptr<struct<()>>>
+// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<ptr<struct<()>>>
// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VAL_9:.*]] = llvm.getelementptr %[[VAL_0]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_10:.*]] = llvm.load %[[VAL_9]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_10:.*]] = llvm.load %[[VAL_9]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_11:.*]] = llvm.mul %[[VAL_4]], %[[VAL_10]] : i64
// CHECK: %[[VAL_12:.*]] = llvm.add %[[VAL_11]], %[[VAL_8]] : i64
// CHECK: %[[VAL_13:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
// CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_0]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_20:.*]] = llvm.getelementptr %[[VAL_0]][0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_21:.*]] = llvm.load %[[VAL_20]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_21:.*]] = llvm.load %[[VAL_20]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_0]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i32>
-// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
+// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
// CHECK: %[[VAL_24:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_21]], %[[VAL_24]][1] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(20180515 : i32) : i32
// CHECK: %[[VAL_40:.*]] = llvm.insertvalue %[[VAL_39]], %[[VAL_38]][7] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
// CHECK: %[[VAL_41:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<struct<()>> to !llvm.ptr<struct<()>>
// CHECK: %[[VAL_42:.*]] = llvm.insertvalue %[[VAL_41]], %[[VAL_40]][0] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
-// CHECK: llvm.store %[[VAL_42]], %[[VAL_2]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
+// CHECK: llvm.store %[[VAL_42]], %[[VAL_2]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
// CHECK: %[[VAL_43:.*]] = llvm.getelementptr %[[VAL_2]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_44:.*]] = llvm.load %[[VAL_43]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i8>
+// CHECK: %[[VAL_44:.*]] = llvm.load %[[VAL_43]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i8>
// CHECK: %[[VAL_45:.*]] = llvm.icmp "eq" %[[VAL_44]], %[[VAL_3]] : i8
// CHECK: llvm.cond_br %[[VAL_45]], ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: %[[VAL_46:.*]] = llvm.getelementptr %[[VAL_2]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i32>>
-// CHECK: %[[VAL_47:.*]] = llvm.load %[[VAL_46]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
-// CHECK: llvm.store %[[VAL_5]], %[[VAL_47]] {llvm.tbaa = [@__flang_tbaa::@[[DATAT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK: %[[VAL_47:.*]] = llvm.load %[[VAL_46]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
+// CHECK: llvm.store %[[VAL_5]], %[[VAL_47]] {tbaa = [@__flang_tbaa::@[[DATAT:tag_[0-9]*]]]} : !llvm.ptr<i32>
// CHECK: llvm.br ^bb2
// CHECK: ^bb2:
// CHECK: llvm.return
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQcl.2E2F64756D6D792E66393000 : !llvm.ptr<array<12 x i8>>
// CHECK: %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr<array<12 x i8>> to !llvm.ptr<i8>
// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_9]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr<i8>, i32) -> !llvm.ptr<i8>
-// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_7]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
-// CHECK: llvm.store %[[VAL_11]], %[[VAL_3]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_7]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK: llvm.store %[[VAL_11]], %[[VAL_3]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_16:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_17:.*]] = llvm.load %[[VAL_16]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_17:.*]] = llvm.load %[[VAL_16]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_3]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK: %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VAL_21:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_24:.*]] = llvm.getelementptr %[[VAL_3]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i32>
-// CHECK: %[[VAL_25:.*]] = llvm.load %[[VAL_24]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
+// CHECK: %[[VAL_25:.*]] = llvm.load %[[VAL_24]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
// CHECK: %[[VAL_26:.*]] = llvm.getelementptr %[[VAL_3]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK: %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK: %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_28:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
// CHECK: %[[VAL_29:.*]] = llvm.insertvalue %[[VAL_23]], %[[VAL_28]][1] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
// CHECK: %[[VAL_30:.*]] = llvm.mlir.constant(20180515 : i32) : i32
// CHECK: %[[VAL_43:.*]] = llvm.bitcast %[[VAL_27]] : !llvm.ptr<i8> to !llvm.ptr<i8>
// CHECK: %[[VAL_44:.*]] = llvm.insertvalue %[[VAL_43]], %[[VAL_42]][8] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
// CHECK: %[[VAL_45:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_46:.*]] = llvm.load %[[VAL_45]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_46:.*]] = llvm.load %[[VAL_45]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_47:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_48:.*]] = llvm.load %[[VAL_47]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_48:.*]] = llvm.load %[[VAL_47]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_49:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_50:.*]] = llvm.load %[[VAL_49]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_50:.*]] = llvm.load %[[VAL_49]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_51:.*]] = llvm.getelementptr %[[VAL_3]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<struct<()>>>
-// CHECK: %[[VAL_52:.*]] = llvm.load %[[VAL_51]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<struct<()>>>
+// CHECK: %[[VAL_52:.*]] = llvm.load %[[VAL_51]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<struct<()>>>
// CHECK: %[[VAL_53:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VAL_54:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_55:.*]] = llvm.icmp "eq" %[[VAL_48]], %[[VAL_53]] : i64
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
// CHECK: %[[VAL_60:.*]] = llvm.bitcast %[[VAL_52]] : !llvm.ptr<struct<()>> to !llvm.ptr<struct<()>>
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_60]], %[[VAL_59]][0] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
-// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
// CHECK: %[[VAL_62:.*]] = llvm.bitcast %[[VAL_1]] : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>> to !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_62]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr<i8>, !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> i1
// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr<i8>) -> i32
// CHECK-LABEL: llvm.func @tbaa(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i32 {
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
// CHECK: llvm.return %[[VAL_2]] : i32
// CHECK: }
// CHECK-LABEL: llvm.func @tbaa(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i1 {
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i32
// CHECK: llvm.return %[[VAL_4]] : i1
// CHECK-LABEL: llvm.func @tbaa(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f32>, i64, i32, i8, i8, i8, i8)>>) -> i32 {
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 1] : (!llvm.ptr<struct<(ptr<f32>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
// CHECK: llvm.return %[[VAL_2]] : i32
// CHECK: }
// CHECK-LABEL: llvm.func @tbaa(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i1 {
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]] : i32
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_4:.*]] = llvm.sub %[[VAL_1]], %[[VAL_2]] : i64
// CHECK: %[[VAL_5:.*]] = llvm.mul %[[VAL_4]], %[[VAL_2]] : i64
// CHECK: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>>) -> !llvm.ptr<i64>
-// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i64>
+// CHECK: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i64>
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_5]], %[[VAL_7]] : i64
// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_8]], %[[VAL_3]] : i64
// CHECK: %[[VAL_10:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>>) -> !llvm.ptr<ptr<i32>>
-// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
+// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
// CHECK: %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr<i32> to !llvm.ptr<i8>
// CHECK: %[[VAL_13:.*]] = llvm.getelementptr %[[VAL_12]]{{\[}}%[[VAL_9]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
// CHECK: %[[VAL_14:.*]] = llvm.bitcast %[[VAL_13]] : !llvm.ptr<i8> to !llvm.ptr<i32>
let extraClassDeclaration = [{
/// Name of the data layout attributes.
static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }
- static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; }
- static StringRef getAliasScopesAttrName() { return "alias_scopes"; }
static StringRef getLoopAttrName() { return "llvm.loop"; }
- static StringRef getAccessGroupsAttrName() { return "access_groups"; }
- static StringRef getTBAAAttrName() { return "llvm.tbaa"; }
/// Names of llvm parameter attributes.
static StringRef getAlignAttrName() { return "llvm.align"; }
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
+ OptionalAttr<SymbolRefArrayAttr>:$tbaa,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
let results = (outs LLVM_LoadableType:$res);
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
+ OptionalAttr<SymbolRefArrayAttr>:$tbaa,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
string llvmInstName = "Store";
/// in these blocks.
void forgetMapping(Region ®ion);
- /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
- /// dialect access group operation.
- llvm::MDNode *getAccessGroup(Operation &opInst,
+ /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+ /// LLVM dialect access group operation.
+ llvm::MDNode *getAccessGroup(Operation *op,
SymbolRefAttr accessGroupRef) const;
- /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
- /// dialect alias scope operation
- llvm::MDNode *getAliasScope(Operation &opInst,
- SymbolRefAttr aliasScopeRef) const;
+ /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+ /// LLVM dialect alias scope operation
+ llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const;
// Sets LLVM metadata for memory operations that are in a parallel loop.
void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
/// metadata nodes for them and their domains.
LogicalResult createAliasScopeMetadata();
- /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
- /// dialect TBAATagOp operation.
- llvm::MDNode *getTBAANode(Operation &memOp, SymbolRefAttr tagRef) const;
+ /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+ /// LLVM dialect TBAATagOp operation.
+ llvm::MDNode *getTBAANode(Operation *op, SymbolRefAttr tagRef) const;
/// Process tbaa LLVM Metadata operations and create LLVM
/// metadata nodes for them.
// Builder, printer and parser for for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
-LogicalResult verifySymbolAttribute(
- Operation *op, StringRef attributeName,
+/// Verifies the given array attribute contains symbol references and checks the
+/// referenced symbol types using the provided verification function.
+LogicalResult verifyMemOpSymbolRefs(
+ Operation *op, StringRef name, ArrayAttr symbolRefs,
llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)>
verifySymbolType) {
- if (Attribute attribute = op->getAttr(attributeName)) {
- // Verify that the attribute is a symbol ref array attribute,
- // because this constraint is not verified for all attribute
- // names processed here (e.g. 'tbaa'). This verification
- // is redundant in some cases.
- if (!(attribute.isa<ArrayAttr>() &&
- llvm::all_of(attribute.cast<ArrayAttr>(), [&](Attribute attr) {
- return attr && attr.isa<SymbolRefAttr>();
- })))
- return op->emitOpError("attribute '")
- << attributeName
- << "' failed to satisfy constraint: symbol ref array attribute";
-
- for (SymbolRefAttr symbolRef :
- attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
- StringAttr metadataName = symbolRef.getRootReference();
- StringAttr symbolName = symbolRef.getLeafReference();
- // We want @metadata::@symbol, not just @symbol
- if (metadataName == symbolName) {
- return op->emitOpError() << "expected '" << symbolRef
- << "' to specify a fully qualified reference";
- }
- auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
- op->getParentOp(), metadataName);
- if (!metadataOp)
- return op->emitOpError()
- << "expected '" << symbolRef << "' to reference a metadata op";
- Operation *symbolOp =
- SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
- if (!symbolOp)
- return op->emitOpError()
- << "expected '" << symbolRef << "' to be a valid reference";
- if (failed(verifySymbolType(symbolOp, symbolRef))) {
- return failure();
- }
+ assert(symbolRefs && "expected a non-null attribute");
+
+ // Verify that the attribute is a symbol ref array attribute,
+ // because this constraint is not verified for all attribute
+ // names processed here (e.g. 'tbaa'). This verification
+ // is redundant in some cases.
+ if (!llvm::all_of(symbolRefs, [](Attribute attr) {
+ return attr && attr.isa<SymbolRefAttr>();
+ }))
+ return op->emitOpError("attribute '")
+ << name
+ << "' failed to satisfy constraint: symbol ref array attribute";
+
+ for (SymbolRefAttr symbolRef : symbolRefs.getAsRange<SymbolRefAttr>()) {
+ StringAttr metadataName = symbolRef.getRootReference();
+ StringAttr symbolName = symbolRef.getLeafReference();
+ // We want @metadata::@symbol, not just @symbol
+ if (metadataName == symbolName) {
+ return op->emitOpError() << "expected '" << symbolRef
+ << "' to specify a fully qualified reference";
+ }
+ auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+ op->getParentOp(), metadataName);
+ if (!metadataOp)
+ return op->emitOpError()
+ << "expected '" << symbolRef << "' to reference a metadata op";
+ Operation *symbolOp =
+ SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
+ if (!symbolOp)
+ return op->emitOpError()
+ << "expected '" << symbolRef << "' to be a valid reference";
+ if (failed(verifySymbolType(symbolOp, symbolRef))) {
+ return failure();
}
}
+
return success();
}
-// Verifies that metadata ops are wired up properly.
+/// Verifies the given array attribute contains symbol references that point to
+/// metadata operations of the given type.
template <typename OpTy>
-static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
+static LogicalResult
+verifyMemOpSymbolRefsPointTo(Operation *op, StringRef name,
+ std::optional<ArrayAttr> symbolRefs) {
+ if (!symbolRefs)
+ return success();
+
auto verifySymbolType = [op](Operation *symbolOp,
SymbolRefAttr symbolRef) -> LogicalResult {
if (!isa<OpTy>(symbolOp)) {
}
return success();
};
-
- return verifySymbolAttribute(op, attributeName, verifySymbolType);
+ return verifyMemOpSymbolRefs(op, name, *symbolRefs, verifySymbolType);
}
-static LogicalResult verifyMemoryOpMetadata(Operation *op) {
- // access_groups
- if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
- op, LLVMDialect::getAccessGroupsAttrName())))
+/// Verifies the types of the metadata operations referenced by aliasing and
+/// access group metadata.
+template <typename OpTy>
+LogicalResult verifyMemOpMetadata(OpTy memOp) {
+ if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AccessGroupMetadataOp>(
+ memOp, memOp.getAccessGroupsAttrName(), memOp.getAccessGroups())))
return failure();
- // alias_scopes
- if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
- op, LLVMDialect::getAliasScopesAttrName())))
+ if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AliasScopeMetadataOp>(
+ memOp, memOp.getAliasScopesAttrName(), memOp.getAliasScopes())))
return failure();
- // noalias_scopes
- if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
- op, LLVMDialect::getNoAliasScopesAttrName())))
+ if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AliasScopeMetadataOp>(
+ memOp, memOp.getNoaliasScopesAttrName(), memOp.getNoaliasScopes())))
return failure();
- // tbaa
- if (failed(verifyOpMetadata<LLVM::TBAATagOp>(op,
- LLVMDialect::getTBAAAttrName())))
+ if (failed(verifyMemOpSymbolRefsPointTo<LLVM::TBAATagOp>(
+ memOp, memOp.getTbaaAttrName(), memOp.getTbaa())))
return failure();
return success();
}
-LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
+LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); }
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
Value addr, unsigned alignment, bool isVolatile,
// Builder, printer and parser for LLVM::StoreOp.
//===----------------------------------------------------------------------===//
-LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
+LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); }
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
Value addr, unsigned alignment, bool isVolatile,
return convertibleMetadata;
}
+namespace {
+/// Helper class to attach metadata attributes to specific operation types. It
+/// specializes TypeSwitch to take an Operation and return a LogicalResult.
+template <typename... OpTys>
+struct AttributeSetter {
+ AttributeSetter(Operation *op) : op(op) {}
+
+ /// Calls `attachFn` on the provided Operation if it has one of
+ /// the given operation types. Returns failure otherwise.
+ template <typename CallableT>
+ LogicalResult apply(CallableT &&attachFn) {
+ return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<OpTys...>([&attachFn](auto concreteOp) {
+ attachFn(concreteOp);
+ return success();
+ })
+ .Default([&](auto) { return failure(); });
+ }
+
+private:
+ Operation *op;
+};
+} // namespace
+
/// Converts the given profiling metadata `node` to an MLIR profiling attribute
/// and attaches it to the imported operation if the translation succeeds.
/// Returns failure otherwise.
branchWeights.push_back(branchWeight->getZExtValue());
}
- // Attach the branch weights to the operations that support it.
- return llvm::TypeSwitch<Operation *, LogicalResult>(op)
- .Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
+ return AttributeSetter<CondBrOp, SwitchOp, CallOp, InvokeOp>(op).apply(
+ [&](auto branchWeightOp) {
branchWeightOp.setBranchWeightsAttr(
builder.getI32VectorAttr(branchWeights));
- return success();
- })
- .Default([op](auto) {
- return op->emitWarning()
- << op->getName() << " does not support branch weights";
});
}
if (!tbaaTagSym)
return failure();
- op->setAttr(LLVMDialect::getTBAAAttrName(),
- ArrayAttr::get(op->getContext(), tbaaTagSym));
- return success();
+ return AttributeSetter<LoadOp, StoreOp>(op).apply([&](auto memOp) {
+ memOp.setTbaaAttr(ArrayAttr::get(memOp.getContext(), tbaaTagSym));
+ });
}
/// Looks up all the symbol references pointing to the access group operations
SmallVector<Attribute> accessGroupAttrs(accessGroups->begin(),
accessGroups->end());
- op->setAttr(LLVMDialect::getAccessGroupsAttrName(),
- ArrayAttr::get(op->getContext(), accessGroupAttrs));
- return success();
+ return AttributeSetter<LoadOp, StoreOp>(op).apply([&](auto memOp) {
+ memOp.setAccessGroupsAttr(
+ ArrayAttr::get(memOp.getContext(), accessGroupAttrs));
+ });
}
/// Converts the given loop metadata node to an MLIR loop annotation attribute
llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
for (SymbolRefAttr accessGroupRef : parallelAccessGroups)
parallelAccess.push_back(
- moduleTranslation.getAccessGroup(*op, accessGroupRef));
+ moduleTranslation.getAccessGroup(op, accessGroupRef));
metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
}
}
llvm::MDNode *
-ModuleTranslation::getAccessGroup(Operation &opInst,
+ModuleTranslation::getAccessGroup(Operation *op,
SymbolRefAttr accessGroupRef) const {
auto metadataName = accessGroupRef.getRootReference();
auto accessGroupName = accessGroupRef.getLeafReference();
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
- opInst.getParentOp(), metadataName);
+ op->getParentOp(), metadataName);
auto *accessGroupOp =
SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
return accessGroupMetadataMapping.lookup(accessGroupOp);
void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
llvm::Instruction *inst) {
- auto accessGroups =
- op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
- if (accessGroups && !accessGroups.empty()) {
+ auto populateGroupsMetadata = [&](std::optional<ArrayAttr> groupRefs) {
+ if (!groupRefs || groupRefs->empty())
+ return;
+
llvm::Module *module = inst->getModule();
- SmallVector<llvm::Metadata *> metadatas;
- for (SymbolRefAttr accessGroupRef :
- accessGroups.getAsRange<SymbolRefAttr>())
- metadatas.push_back(getAccessGroup(*op, accessGroupRef));
-
- llvm::MDNode *unionMD = nullptr;
- if (metadatas.size() == 1)
- unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
- else if (metadatas.size() >= 2)
- unionMD = llvm::MDNode::get(module->getContext(), metadatas);
-
- inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
- }
+ SmallVector<llvm::Metadata *> groupMDs;
+ for (SymbolRefAttr groupRef : groupRefs->getAsRange<SymbolRefAttr>())
+ groupMDs.push_back(getAccessGroup(op, groupRef));
+
+ llvm::MDNode *node = nullptr;
+ if (groupMDs.size() == 1)
+ node = llvm::cast<llvm::MDNode>(groupMDs.front());
+ else if (groupMDs.size() >= 2)
+ node = llvm::MDNode::get(module->getContext(), groupMDs);
+
+ inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
+ };
+
+ llvm::TypeSwitch<Operation *>(op)
+ .Case<LoadOp, StoreOp>(
+ [&](auto memOp) { populateGroupsMetadata(memOp.getAccessGroups()); })
+ .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
}
LogicalResult ModuleTranslation::createAliasScopeMetadata() {
}
llvm::MDNode *
-ModuleTranslation::getAliasScope(Operation &opInst,
+ModuleTranslation::getAliasScope(Operation *op,
SymbolRefAttr aliasScopeRef) const {
StringAttr metadataName = aliasScopeRef.getRootReference();
StringAttr scopeName = aliasScopeRef.getLeafReference();
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
- opInst.getParentOp(), metadataName);
+ op->getParentOp(), metadataName);
Operation *aliasScopeOp =
SymbolTable::lookupNearestSymbolFrom(metadataOp, scopeName);
return aliasScopeMetadataMapping.lookup(aliasScopeOp);
void ModuleTranslation::setAliasScopeMetadata(Operation *op,
llvm::Instruction *inst) {
- auto populateScopeMetadata = [this, op, inst](StringRef attrName,
- StringRef llvmMetadataName) {
- auto scopes = op->getAttrOfType<ArrayAttr>(attrName);
- if (!scopes || scopes.empty())
+ auto populateScopeMetadata = [&](std::optional<ArrayAttr> scopeRefs,
+ unsigned kind) {
+ if (!scopeRefs || scopeRefs->empty())
return;
llvm::Module *module = inst->getModule();
SmallVector<llvm::Metadata *> scopeMDs;
- for (SymbolRefAttr scopeRef : scopes.getAsRange<SymbolRefAttr>())
- scopeMDs.push_back(getAliasScope(*op, scopeRef));
- llvm::MDNode *unionMD = llvm::MDNode::get(module->getContext(), scopeMDs);
- inst->setMetadata(module->getMDKindID(llvmMetadataName), unionMD);
+ for (SymbolRefAttr scopeRef : scopeRefs->getAsRange<SymbolRefAttr>())
+ scopeMDs.push_back(getAliasScope(op, scopeRef));
+ llvm::MDNode *node = llvm::MDNode::get(module->getContext(), scopeMDs);
+ inst->setMetadata(kind, node);
};
- populateScopeMetadata(LLVMDialect::getAliasScopesAttrName(), "alias.scope");
- populateScopeMetadata(LLVMDialect::getNoAliasScopesAttrName(), "noalias");
+ llvm::TypeSwitch<Operation *>(op)
+ .Case<LoadOp, StoreOp>([&](auto memOp) {
+ populateScopeMetadata(memOp.getAliasScopes(),
+ llvm::LLVMContext::MD_alias_scope);
+ populateScopeMetadata(memOp.getNoaliasScopes(),
+ llvm::LLVMContext::MD_noalias);
+ })
+ .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
}
-llvm::MDNode *ModuleTranslation::getTBAANode(Operation &memOp,
+llvm::MDNode *ModuleTranslation::getTBAANode(Operation *op,
SymbolRefAttr tagRef) const {
StringAttr metadataName = tagRef.getRootReference();
StringAttr tagName = tagRef.getLeafReference();
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
- memOp.getParentOp(), metadataName);
+ op->getParentOp(), metadataName);
Operation *tagOp = SymbolTable::lookupNearestSymbolFrom(metadataOp, tagName);
return tbaaMetadataMapping.lookup(tagOp);
}
void ModuleTranslation::setTBAAMetadata(Operation *op,
llvm::Instruction *inst) {
- auto tbaa = op->getAttrOfType<ArrayAttr>(LLVMDialect::getTBAAAttrName());
- if (!tbaa || tbaa.empty())
- return;
- // LLVM IR currently does not support attaching more than one
- // TBAA access tag to a memory accessing instruction.
- // It may be useful to support this in future, but for the time being
- // just ignore the metadata if MLIR operation has multiple access tags.
- if (tbaa.size() > 1) {
- op->emitWarning() << "TBAA access tags were not translated, because LLVM "
- "IR only supports a single tag per instruction";
- return;
- }
- SymbolRefAttr tagRef = tbaa[0].cast<SymbolRefAttr>();
- llvm::MDNode *tagNode = getTBAANode(*op, tagRef);
- inst->setMetadata(llvm::LLVMContext::MD_tbaa, tagNode);
+ auto populateTBAAMetadata = [&](std::optional<ArrayAttr> tagRefs) {
+ if (!tagRefs || tagRefs->empty())
+ return;
+
+ // LLVM IR currently does not support attaching more than one
+ // TBAA access tag to a memory accessing instruction.
+ // It may be useful to support this in future, but for the time being
+ // just ignore the metadata if MLIR operation has multiple access tags.
+ if (tagRefs->size() > 1) {
+ op->emitWarning() << "TBAA access tags were not translated, because LLVM "
+ "IR only supports a single tag per instruction";
+ return;
+ }
+
+ SymbolRefAttr tagRef = (*tagRefs)[0].cast<SymbolRefAttr>();
+ llvm::MDNode *node = getTBAANode(op, tagRef);
+ inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
+ };
+
+ llvm::TypeSwitch<Operation *>(op)
+ .Case<LoadOp, StoreOp>(
+ [&](auto memOp) { populateTBAAMetadata(memOp.getTbaa()); })
+ .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
}
LogicalResult ModuleTranslation::createTBAAMetadata() {
llvm.func @tbaa(%arg0: !llvm.ptr) {
%0 = llvm.mlir.constant(1 : i8) : i8
// expected-error@below {{expected '@tbaa_tag_1' to specify a fully qualified reference}}
- llvm.store %0, %arg0 {llvm.tbaa = [@tbaa_tag_1]} : i8, !llvm.ptr
+ llvm.store %0, %arg0 {tbaa = [@tbaa_tag_1]} : i8, !llvm.ptr
llvm.return
}
}
llvm.func @tbaa(%arg0: !llvm.ptr) {
%0 = llvm.mlir.constant(1 : i8) : i8
- // expected-error@below {{attribute 'llvm.tbaa' failed to satisfy constraint: symbol ref array attribute}}
- llvm.store %0, %arg0 {llvm.tbaa = ["sym"]} : i8, !llvm.ptr
+ // expected-error@below {{attribute 'tbaa' failed to satisfy constraint: symbol ref array attribute}}
+ llvm.store %0, %arg0 {tbaa = ["sym"]} : i8, !llvm.ptr
llvm.return
}
llvm.func @tbaa(%arg0: !llvm.ptr) {
%0 = llvm.mlir.constant(1 : i8) : i8
// expected-error@below {{expected '@metadata::@group1' to resolve to a llvm.tbaa_tag}}
- llvm.store %0, %arg0 {llvm.tbaa = [@metadata::@group1]} : i8, !llvm.ptr
+ llvm.store %0, %arg0 {tbaa = [@metadata::@group1]} : i8, !llvm.ptr
llvm.return
}
llvm.metadata @metadata {
llvm.func @tbaa(%arg0: !llvm.ptr) {
%0 = llvm.mlir.constant(1 : i8) : i8
// expected-error@below {{expected '@metadata::@sym' to be a valid reference}}
- llvm.store %0, %arg0 {llvm.tbaa = [@metadata::@sym]} : i8, !llvm.ptr
+ llvm.store %0, %arg0 {tbaa = [@metadata::@sym]} : i8, !llvm.ptr
llvm.return
}
llvm.metadata @metadata {
llvm.func @tbaa(%arg0: !llvm.ptr) {
%0 = llvm.mlir.constant(1 : i8) : i8
// expected-error@below {{expected '@tbaa::@sym' to reference a metadata op}}
- llvm.store %0, %arg0 {llvm.tbaa = [@tbaa::@sym]} : i8, !llvm.ptr
+ llvm.store %0, %arg0 {tbaa = [@tbaa::@sym]} : i8, !llvm.ptr
llvm.return
}
; // -----
-; CHECK: import-failure.ll
-; CHECK-SAME: warning: llvm.func does not support branch weights
; CHECK: import-failure.ll:{{.*}} warning: unhandled function metadata: !0 = !{!"branch_weights", i32 64}
define void @cond_br(i1 %arg) !prof !0 {
ret void
%1 = llvm.mlir.constant(1 : i32) : i32
%2 = llvm.getelementptr inbounds %arg1[%0, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg2_t", (i64, i64)>
// CHECK: load i64, ptr %{{.*}},{{.*}}!tbaa ![[LTAG:[0-9]*]]
- %3 = llvm.load %2 {llvm.tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> i64
+ %3 = llvm.load %2 {tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> i64
%4 = llvm.trunc %3 : i64 to i32
%5 = llvm.getelementptr inbounds %arg0[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg1_t", (i32, i32)>
// CHECK: store i32 %{{.*}}, ptr %{{.*}},{{.*}}!tbaa ![[STAG:[0-9]*]]
- llvm.store %4, %5 {llvm.tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
+ llvm.store %4, %5 {tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
llvm.return
}
}
%1 = llvm.mlir.constant(1 : i32) : i32
%2 = llvm.getelementptr inbounds %arg1[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg2_t", (f32, f32)>
// CHECK: load float, ptr %{{.*}},{{.*}}!tbaa ![[LTAG:[0-9]*]]
- %3 = llvm.load %2 {llvm.tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> f32
+ %3 = llvm.load %2 {tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> f32
%4 = llvm.fptosi %3 : f32 to i32
%5 = llvm.getelementptr inbounds %arg0[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg1_t", (i32, i32)>
// CHECK: store i32 %{{.*}}, ptr %{{.*}},{{.*}}!tbaa ![[STAG:[0-9]*]]
- llvm.store %4, %5 {llvm.tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
+ llvm.store %4, %5 {tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
llvm.return
}
}