[mlir][llvm] Improve LoadOp and StoreOp import.
authorTobias Gysi <tobias.gysi@nextsilicon.com>
Mon, 13 Feb 2023 07:12:09 +0000 (08:12 +0100)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Mon, 13 Feb 2023 07:42:43 +0000 (08:42 +0100)
The revision supports importing the volatile keyword and nontemporal
metadata for the LoadOp and StoreOp. Additionally, it updates the
builders and uses an assembly format for printing and parsing.

The operation type still requires custom parse and print methods
due to the current handling of typed and opaque pointers.

Reviewed By: Dinistro

Differential Revision: https://reviews.llvm.org/D143714

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Target/LLVMIR/Import/instructions.ll
mlir/test/Target/LLVMIR/Import/metadata-loop.ll

index 8e4b834..d79e511 100644 (file)
@@ -189,11 +189,10 @@ class MemoryOpBase {
   }];
   code setNonTemporalMetadataCode = [{
     if ($nontemporal) {
-      llvm::Module *module = builder.GetInsertBlock()->getModule();
       llvm::MDNode *metadata = llvm::MDNode::get(
           inst->getContext(), llvm::ConstantAsMetadata::get(
               builder.getInt32(1)));
-      inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
+      inst->setMetadata(llvm::LLVMContext::MD_nontemporal, metadata);
     }
   }];
   code setAccessGroupsMetadataCode = [{
@@ -355,6 +354,10 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
                    UnitAttr:$nontemporal);
   let results = (outs LLVM_LoadableType:$res);
   string llvmInstName = "Load";
+  let assemblyFormat = [{
+    (`volatile` $volatile_^)? $addr attr-dict `:`
+    custom<LoadType>(type($addr), type($res))
+  }];
   string llvmBuilder = [{
     auto *inst = builder.CreateLoad($_resultType, $addr, $volatile_);
   }] # setAlignmentCode
@@ -365,9 +368,12 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
      # [{
     $res = inst;
   }];
-  // FIXME: Import attributes.
   string mlirBuilder = [{
-    $res = $_builder.create<LLVM::LoadOp>($_location, $_resultType, $addr);
+    auto *loadInst = cast<llvm::LoadInst>(inst);
+    unsigned alignment = loadInst->getAlign().value();
+    $res = $_builder.create<LLVM::LoadOp>($_location, $_resultType, $addr,
+        alignment, loadInst->isVolatile(),
+        loadInst->hasMetadata(llvm::LLVMContext::MD_nontemporal));
   }];
   let builders = [
     OpBuilder<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment,
@@ -378,9 +384,10 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
                      "when the pointer type is opaque");
       build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal);
     }]>,
-    OpBuilder<(ins "Type":$t, "Value":$addr,
+    OpBuilder<(ins "Type":$type, "Value":$addr,
       CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
-      CArg<"bool", "false">:$isNonTemporal)>,];
+      CArg<"bool", "false">:$isNonTemporal)>
+  ];
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
@@ -395,6 +402,10 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpBase {
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
   string llvmInstName = "Store";
+  let assemblyFormat = [{
+    (`volatile` $volatile_^)? $value `,` $addr attr-dict `:`
+    custom<StoreType>(type($value), type($addr))
+  }];
   string llvmBuilder = [{
     auto *inst = builder.CreateStore($value, $addr, $volatile_);
   }] # setAlignmentCode
@@ -402,16 +413,18 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpBase {
      # setAccessGroupsMetadataCode
      # setAliasScopeMetadataCode
      # setTBAAMetadataCode;
-  // FIXME: Import attributes.
   string mlirBuilder = [{
-    $_op = $_builder.create<LLVM::StoreOp>($_location, $value, $addr);
+    auto *storeInst = cast<llvm::StoreInst>(inst);
+    unsigned alignment = storeInst->getAlign().value();
+    $_op = $_builder.create<LLVM::StoreOp>($_location, $value, $addr,
+        alignment, storeInst->isVolatile(),
+        storeInst->hasMetadata(llvm::LLVMContext::MD_nontemporal));
   }];
   let builders = [
     OpBuilder<(ins "Value":$value, "Value":$addr,
       CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
       CArg<"bool", "false">:$isNonTemporal)>
-    ];
-  let hasCustomAssemblyFormat = 1;
+  ];
   let hasVerifier = 1;
 }
 
index 1dc501f..c19d268 100644 (file)
@@ -44,8 +44,6 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
 
-static constexpr const char kVolatileAttrName[] = "volatile_";
-static constexpr const char kNonTemporalAttrName[] = "nontemporal";
 static constexpr const char kElemTypeAttrName[] = "elem_type";
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
@@ -665,7 +663,7 @@ Type LLVM::GEPOp::getSourceElementType() {
 }
 
 //===----------------------------------------------------------------------===//
-// Builder, printer and parser for for LLVM::LoadOp.
+// LoadOp
 //===----------------------------------------------------------------------===//
 
 /// Verifies the given array attribute contains symbol references and checks the
@@ -759,29 +757,13 @@ LogicalResult verifyMemOpMetadata(OpTy memOp) {
 
 LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); }
 
-void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
+void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
                    Value addr, unsigned alignment, bool isVolatile,
                    bool isNonTemporal) {
-  result.addOperands(addr);
-  result.addTypes(t);
-  if (isVolatile)
-    result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
-  if (isNonTemporal)
-    result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
-  if (alignment != 0)
-    result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
-}
-
-void LoadOp::print(OpAsmPrinter &p) {
-  p << ' ';
-  if (getVolatile_())
-    p << "volatile ";
-  p << getAddr();
-  p.printOptionalAttrDict((*this)->getAttrs(),
-                          {kVolatileAttrName, kElemTypeAttrName});
-  p << " : " << getAddr().getType();
-  if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
-    p << " -> " << getType();
+  build(builder, state, type, addr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+        alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
+        isNonTemporal);
 }
 
 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
@@ -797,105 +779,85 @@ getLoadStoreElementType(OpAsmParser &parser, Type type, SMLoc trailingTypeLoc) {
   return llvmTy.getElementType();
 }
 
-// <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
-//                 (`->` type)?
-ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand addr;
-  Type type;
+/// Parses the LoadOp type either using the typed or opaque pointer format.
+// TODO: Drop once the typed pointer assembly format is not needed anymore.
+static ParseResult parseLoadType(OpAsmParser &parser, Type &type,
+                                 Type &elementType) {
   SMLoc trailingTypeLoc;
-
-  if (succeeded(parser.parseOptionalKeyword("volatile")))
-    result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
-
-  if (parser.parseOperand(addr) ||
-      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
-      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
-      parser.resolveOperand(addr, type, result.operands))
+  if (parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
     return failure();
 
-  std::optional<Type> elemTy =
+  std::optional<Type> pointerElementType =
       getLoadStoreElementType(parser, type, trailingTypeLoc);
-  if (!elemTy)
+  if (!pointerElementType)
     return failure();
-  if (*elemTy) {
-    result.addTypes(*elemTy);
+  if (*pointerElementType) {
+    elementType = *pointerElementType;
     return success();
   }
 
-  Type trailingType;
-  if (parser.parseArrow() || parser.parseType(trailingType))
+  if (parser.parseArrow() || parser.parseType(elementType))
     return failure();
-  result.addTypes(trailingType);
   return success();
 }
 
+/// Prints the LoadOp type either using the typed or opaque pointer format.
+// TODO: Drop once the typed pointer assembly format is not needed anymore.
+static void printLoadType(OpAsmPrinter &printer, Operation *op, Type type,
+                          Type elementType) {
+  printer << type;
+  auto pointerType = cast<LLVMPointerType>(type);
+  if (pointerType.isOpaque())
+    printer << " -> " << elementType;
+}
+
 //===----------------------------------------------------------------------===//
-// Builder, printer and parser for LLVM::StoreOp.
+// StoreOp
 //===----------------------------------------------------------------------===//
 
 LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); }
 
-void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
+void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
                     Value addr, unsigned alignment, bool isVolatile,
                     bool isNonTemporal) {
-  result.addOperands({value, addr});
-  result.addTypes({});
-  if (isVolatile)
-    result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
-  if (isNonTemporal)
-    result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
-  if (alignment != 0)
-    result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
+  build(builder, state, value, addr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+        alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
+        isNonTemporal);
 }
 
-void StoreOp::print(OpAsmPrinter &p) {
-  p << ' ';
-  if (getVolatile_())
-    p << "volatile ";
-  p << getValue() << ", " << getAddr();
-  p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
-  p << " : ";
-  if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
-    p << getValue().getType() << ", ";
-  p << getAddr().getType();
-}
-
-// <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
-//                 attribute-dict? `:` type (`,` type)?
-ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand addr, value;
-  Type type;
+/// Parses the StoreOp type either using the typed or opaque pointer format.
+// TODO: Drop once the typed pointer assembly format is not needed anymore.
+static ParseResult parseStoreType(OpAsmParser &parser, Type &elementType,
+                                  Type &type) {
   SMLoc trailingTypeLoc;
-
-  if (succeeded(parser.parseOptionalKeyword("volatile")))
-    result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
-
-  if (parser.parseOperand(value) || parser.parseComma() ||
-      parser.parseOperand(addr) ||
-      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
-      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
+  if (parser.getCurrentLocation(&trailingTypeLoc) ||
+      parser.parseType(elementType))
     return failure();
 
-  Type operandType;
-  if (succeeded(parser.parseOptionalComma())) {
-    operandType = type;
-    if (parser.parseType(type))
-      return failure();
-  } else {
-    std::optional<Type> maybeOperandType =
-        getLoadStoreElementType(parser, type, trailingTypeLoc);
-    if (!maybeOperandType)
-      return failure();
-    operandType = *maybeOperandType;
-  }
+  if (succeeded(parser.parseOptionalComma()))
+    return parser.parseType(type);
 
-  if (parser.resolveOperand(value, operandType, result.operands) ||
-      parser.resolveOperand(addr, type, result.operands))
+  // Extract the element type from the pointer type.
+  type = elementType;
+  std::optional<Type> pointerElementType =
+      getLoadStoreElementType(parser, type, trailingTypeLoc);
+  if (!pointerElementType)
     return failure();
-
+  elementType = *pointerElementType;
   return success();
 }
 
+/// Prints the StoreOp type either using the typed or opaque pointer format.
+// TODO: Drop once the typed pointer assembly format is not needed anymore.
+static void printStoreType(OpAsmPrinter &printer, Operation *op,
+                           Type elementType, Type type) {
+  auto pointerType = cast<LLVMPointerType>(type);
+  if (pointerType.isOpaque())
+    printer << elementType << ", ";
+  printer << type;
+}
+
 //===----------------------------------------------------------------------===//
 // CallOp
 //===----------------------------------------------------------------------===//
index 14dbf07..cd54411 100644 (file)
@@ -251,7 +251,7 @@ define void @integer_arith(i32 %arg1, i32 %arg2, i64 %arg3, i64 %arg4) {
 ; CHECK-SAME:  %[[VEC:[a-zA-Z0-9]+]]
 ; CHECK-SAME:  %[[IDX:[a-zA-Z0-9]+]]
 define half @extract_element(ptr %vec, i32 %idx) {
-  ; CHECK:  %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr -> vector<4xf16>
+  ; CHECK:  %[[V1:.+]] = llvm.load %[[VEC]] {{.*}} : !llvm.ptr -> vector<4xf16>
   ; CHECK:  %[[V2:.+]] = llvm.extractelement %[[V1]][%[[IDX]] : i32] : vector<4xf16>
   ; CHECK:  llvm.return %[[V2]]
   %1 = load <4 x half>, ptr %vec
@@ -266,7 +266,7 @@ define half @extract_element(ptr %vec, i32 %idx) {
 ; CHECK-SAME:  %[[VAL:[a-zA-Z0-9]+]]
 ; CHECK-SAME:  %[[IDX:[a-zA-Z0-9]+]]
 define <4 x half> @insert_element(ptr %vec, half %val, i32 %idx) {
-  ; CHECK:  %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr -> vector<4xf16>
+  ; CHECK:  %[[V1:.+]] = llvm.load %[[VEC]] {{.*}} : !llvm.ptr -> vector<4xf16>
   ; CHECK:  %[[V2:.+]] = llvm.insertelement %[[VAL]], %[[V1]][%[[IDX]] : i32] : vector<4xf16>
   ; CHECK:  llvm.return %[[V2]]
   %1 = load <4 x half>, ptr %vec
@@ -352,13 +352,20 @@ define ptr @alloca(i64 %size) {
 ; CHECK-LABEL: @load_store
 ; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
 define void @load_store(ptr %ptr) {
-  ; CHECK:  %[[V1:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
-  ; CHECK:  llvm.store %[[V1]], %[[PTR]] : f64, !llvm.ptr
+  ; CHECK:  %[[V1:[0-9]+]] = llvm.load %[[PTR]] {alignment = 8 : i64} : !llvm.ptr -> f64
+  ; CHECK:  %[[V2:[0-9]+]] = llvm.load volatile %[[PTR]] {alignment = 16 : i64, nontemporal} : !llvm.ptr -> f64
   %1 = load double, ptr %ptr
+  %2 = load volatile double, ptr %ptr, align 16, !nontemporal !0
+
+  ; CHECK:  llvm.store %[[V1]], %[[PTR]] {alignment = 8 : i64} : f64, !llvm.ptr
+  ; CHECK:  llvm.store volatile %[[V2]], %[[PTR]] {alignment = 16 : i64, nontemporal} : f64, !llvm.ptr
   store double %1, ptr %ptr
+  store volatile double %2, ptr %ptr, align 16, !nontemporal !0
   ret void
 }
 
+!0 = !{i32 1}
+
 ; // -----
 
 ; CHECK-LABEL: @atomic_rmw
index 1ddd5e2..9aecb13 100644 (file)
@@ -8,13 +8,12 @@
 ; CHECK: }
 
 ; CHECK-LABEL: llvm.func @access_group
-; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
 define void @access_group(ptr %arg1) {
-  ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]]}
+  ; CHECK:  access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]]
   %1 = load i32, ptr %arg1, !llvm.access.group !0
-  ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]]}
+  ; CHECK:  access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]]
   %2 = load i32, ptr %arg1, !llvm.access.group !1
-  ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP3]]]}
+  ; CHECK:  access_groups = [@__llvm_global_metadata::@[[$GROUP3]]]
   %3 = load i32, ptr %arg1, !llvm.access.group !2
   ret void
 }