[MLIR][LLVMDialect] Added volatile and nontemporal attributes to load/store
authorGeorge Mitenkov <georgemitenk0v@gmail.com>
Mon, 27 Jul 2020 07:19:48 +0000 (10:19 +0300)
committerGeorge Mitenkov <georgemitenk0v@gmail.com>
Mon, 27 Jul 2020 07:55:56 +0000 (10:55 +0300)
This patch introduces 2 new optional attributes to `llvm.load`
and `llvm.store` ops: `volatile` and `nontemporal`. These attributes
are translated into proper LLVM as a `volatile` marker and a metadata node
respectively. They are also helpful with SPIR-V to LLVM dialect conversion
since they are the mappings for `Volatile` and `NonTemporal` Memory Operands.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D84396

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Target/llvmir.mlir

index 5322e24..29d7fd9 100644 (file)
@@ -235,6 +235,19 @@ class MemoryOpWithAlignmentBase {
   }];
 }
 
+// Code definition that is used for nontemporal metadata creation.
+class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
+  code setNonTemporalMetadataCode = [{
+    if ($nontemporal) {
+      llvm::Module *module = builder.GetInsertBlock()->getModule();
+      llvm::MDNode *metadata = llvm::MDNode::get(
+          inst->getContext(), llvm::ConstantAsMetadata::get(
+              builder.getInt32(1)));
+      inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
+    }
+  }];
+}
+
 // Memory-related operations.
 def LLVM_AllocaOp :
     MemoryOpWithAlignmentBase,
@@ -266,52 +279,49 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
   }];
 }
 def LLVM_LoadOp :
-    MemoryOpWithAlignmentBase,
+    MemoryOpWithAlignmentAndAttributes,
     LLVM_OneResultOp<"load">,
-    Arguments<(ins LLVM_Type:$addr, OptionalAttr<I64Attr>:$alignment)> {
+    Arguments<(ins LLVM_Type:$addr,
+                   OptionalAttr<I64Attr>:$alignment,
+                   UnitAttr:$volatile_,
+                   UnitAttr:$nontemporal)> {
   string llvmBuilder = [{
-    auto *inst = builder.CreateLoad($addr);
-  }] # setAlignmentCode # [{
+    auto *inst = builder.CreateLoad($addr, $volatile_);
+  }] # setAlignmentCode # setNonTemporalMetadataCode # [{
     $res = inst;
   }];
   let builders = [OpBuilder<
-    "OpBuilder &b, OperationState &result, Value addr, unsigned alignment = 0",
+    "OpBuilder &b, OperationState &result, Value addr, "
+    "unsigned alignment = 0, bool isVolatile = false, "
+    "bool isNonTemporal = false",
     [{
       auto type = addr.getType().cast<LLVM::LLVMType>().getPointerElementTy();
-      build(b, result, type, addr, alignment);
+      build(b, result, type, addr, alignment, isVolatile, isNonTemporal);
     }]>,
     OpBuilder<
     "OpBuilder &b, OperationState &result, Type t, Value addr, "
-    "unsigned alignment = 0",
-    [{
-      if (alignment == 0)
-        return build(b, result, t, addr, IntegerAttr());
-      build(b, result, t, addr, b.getI64IntegerAttr(alignment));
-    }]>];
+    "unsigned alignment = 0, bool isVolatile = false, "
+    "bool isNonTemporal = false">];
   let parser = [{ return parseLoadOp(parser, result); }];
   let printer = [{ printLoadOp(p, *this); }];
   let verifier = alignmentVerifierCode;
 }
 def LLVM_StoreOp :
-    MemoryOpWithAlignmentBase,
+    MemoryOpWithAlignmentAndAttributes,
     LLVM_ZeroResultOp<"store">,
     Arguments<(ins LLVM_Type:$value,
                    LLVM_Type:$addr,
-                   OptionalAttr<I64Attr>:$alignment)> {
+                   OptionalAttr<I64Attr>:$alignment,
+                   UnitAttr:$volatile_,
+                   UnitAttr:$nontemporal)> {
   string llvmBuilder = [{
-    auto *inst = builder.CreateStore($value, $addr);
-  }] # setAlignmentCode;
-  let builders = [
-    OpBuilder<
+    auto *inst = builder.CreateStore($value, $addr, $volatile_);
+  }] # setAlignmentCode # setNonTemporalMetadataCode;
+  let builders = [OpBuilder<
     "OpBuilder &b, OperationState &result, Value value, Value addr, "
-    "unsigned alignment = 0",
-    [{
-      if (alignment == 0)
-        return build(b, result, ArrayRef<Type>{}, value, addr, IntegerAttr());
-      build(b, result, ArrayRef<Type>{}, value, addr, 
-            b.getI64IntegerAttr(alignment));
-    }]
-  >];
+    "unsigned alignment = 0, bool isVolatile = false, "
+    "bool isNonTemporal = false">
+    ];
   let parser = [{ return parseStoreOp(parser, result); }];
   let printer = [{ printStoreOp(p, *this); }];
   let verifier = alignmentVerifierCode;
index f448427..cf7a5d9 100644 (file)
@@ -31,6 +31,9 @@
 using namespace mlir;
 using namespace mlir::LLVM;
 
+static constexpr const char kVolatileAttrName[] = "volatile_";
+static constexpr const char kNonTemporalAttrName[] = "nontemporal";
+
 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
 
 //===----------------------------------------------------------------------===//
@@ -178,12 +181,28 @@ CondBrOp::getMutableSuccessorOperands(unsigned index) {
 }
 
 //===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::LoadOp.
+// Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
 
+void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
+                   Value addr, unsigned alignment, bool isVolatile,
+                   bool isNonTemporal) {
+  result.addOperands(addr);
+  result.addTypes(t);
+  if (isVolatile)
+    result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
+  if (isNonTemporal)
+    result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
+  if (alignment != 0)
+    result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
+}
+
 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
-  p << op.getOperationName() << ' ' << op.addr();
-  p.printOptionalAttrDict(op.getAttrs());
+  p << op.getOperationName() << ' ';
+  if (op.volatile_())
+    p << "volatile ";
+  p << op.addr();
+  p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName});
   p << " : " << op.addr().getType();
 }
 
@@ -201,12 +220,15 @@ static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
   return llvmTy.getPointerElementTy();
 }
 
-// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
+// <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::OperandType addr;
   Type type;
   llvm::SMLoc trailingTypeLoc;
 
+  if (succeeded(parser.parseOptionalKeyword("volatile")))
+    result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
+
   if (parser.parseOperand(addr) ||
       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
@@ -220,21 +242,41 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
 }
 
 //===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::StoreOp.
+// Builder, printer and parser for LLVM::StoreOp.
 //===----------------------------------------------------------------------===//
 
+void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
+                    Value addr, unsigned alignment, bool isVolatile,
+                    bool isNonTemporal) {
+  result.addOperands({value, addr});
+  result.addTypes(ArrayRef<Type>{});
+  if (isVolatile)
+    result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
+  if (isNonTemporal)
+    result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
+  if (alignment != 0)
+    result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
+}
+
 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
-  p << op.getOperationName() << ' ' << op.value() << ", " << op.addr();
-  p.printOptionalAttrDict(op.getAttrs());
+  p << op.getOperationName() << ' ';
+  if (op.volatile_())
+    p << "volatile ";
+  p << op.value() << ", " << op.addr();
+  p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName});
   p << " : " << op.addr().getType();
 }
 
-// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
+// <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
+//                 attribute-dict? `:` type
 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::OperandType addr, value;
   Type type;
   llvm::SMLoc trailingTypeLoc;
 
+  if (succeeded(parser.parseOptionalKeyword("volatile")))
+    result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
+
   if (parser.parseOperand(value) || parser.parseComma() ||
       parser.parseOperand(addr) ||
       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
index 954b5b1..d6180cb 100644 (file)
@@ -1266,3 +1266,32 @@ llvm.func @cond_br_weights(%cond : !llvm.i1, %arg0 : !llvm.i32,  %arg1 : !llvm.i
 }
 
 // CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10}
+
+// -----
+
+llvm.func @volatile_store_and_load() {
+  %val = llvm.mlir.constant(5 : i32) : !llvm.i32
+  %size = llvm.mlir.constant(1 : i64) : !llvm.i64
+  %0 = llvm.alloca %size x !llvm.i32 : (!llvm.i64) -> (!llvm<"i32*">)
+  // CHECK: store volatile i32 5, i32* %{{.*}}
+  llvm.store volatile %val, %0 : !llvm<"i32*">
+  // CHECK: %{{.*}} = load volatile i32, i32* %{{.*}}
+  %1 = llvm.load volatile %0: !llvm<"i32*">
+  llvm.return
+}
+
+// -----
+
+// Check that nontemporal attribute is exported as metadata node.
+llvm.func @nontemoral_store_and_load() {
+  %val = llvm.mlir.constant(5 : i32) : !llvm.i32
+  %size = llvm.mlir.constant(1 : i64) : !llvm.i64
+  %0 = llvm.alloca %size x !llvm.i32 : (!llvm.i64) -> (!llvm<"i32*">)
+  // CHECK: !nontemporal ![[NODE:[0-9]+]]
+  llvm.store %val, %0 {nontemporal} : !llvm<"i32*">
+  // CHECK: !nontemporal ![[NODE]]
+  %1 = llvm.load %0 {nontemporal} : !llvm<"i32*">
+  llvm.return
+}
+
+// CHECK: ![[NODE]] = !{i32 1}