[mlir] Adds argument attributes for using LLVM's sret and byval attributes
authorEric Schweitz <eschweitz@nvidia.com>
Thu, 7 Jan 2021 20:50:20 +0000 (12:50 -0800)
committerEric Schweitz <eschweitz@nvidia.com>
Thu, 7 Jan 2021 20:52:14 +0000 (12:52 -0800)
to the conversion of LLVM IR dialect. These attributes are used in FIR to
support the lowering of Fortran using target-specific calling conventions.

Add roundtrip tests.

Add changes per review comments/concerns.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D94052

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/func.mlir
mlir/test/Target/llvmir-invalid.mlir

index 7700867..492025b 100644 (file)
@@ -1102,6 +1102,22 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
           llvm::AttrBuilder().addAlignmentAttr(llvm::Align(attr.getInt())));
     }
 
+    if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.sret")) {
+      auto argTy = mlirArg.getType();
+      if (!argTy.isa<LLVM::LLVMPointerType>())
+        return func.emitError(
+            "llvm.sret attribute attached to LLVM non-pointer argument");
+      llvmArg.addAttr(llvm::Attribute::AttrKind::StructRet);
+    }
+
+    if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.byval")) {
+      auto argTy = mlirArg.getType();
+      if (!argTy.isa<LLVM::LLVMPointerType>())
+        return func.emitError(
+            "llvm.byval attribute attached to LLVM non-pointer argument");
+      llvmArg.addAttr(llvm::Attribute::AttrKind::ByVal);
+    }
+
     valueMapping[mlirArg] = &llvmArg;
     argIdx++;
   }
index 72e117d..d9d7e17 100644 (file)
@@ -87,6 +87,16 @@ module {
     llvm.return
   }
 
+  // CHECK: llvm.func @byvalattr(%{{.*}}: !llvm.ptr<i32> {llvm.byval})
+  llvm.func @byvalattr(%arg0: !llvm.ptr<i32> {llvm.byval}) {
+    llvm.return
+  }
+
+  // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr<i32> {llvm.sret})
+  llvm.func @sretattr(%arg0: !llvm.ptr<i32> {llvm.sret}) {
+    llvm.return
+  }
+
   // CHECK: llvm.func @variadic(...)
   llvm.func @variadic(...)
 
index 1411759..fcd98ef 100644 (file)
@@ -14,6 +14,19 @@ llvm.func @invalid_noalias(%arg0 : !llvm.float {llvm.noalias = true}) -> !llvm.f
 
 // -----
 
+// expected-error @+1 {{llvm.sret attribute attached to LLVM non-pointer argument}}
+llvm.func @invalid_noalias(%arg0 : !llvm.float {llvm.sret}) -> !llvm.float {
+  llvm.return %arg0 : !llvm.float
+}
+// -----
+
+// expected-error @+1 {{llvm.byval attribute attached to LLVM non-pointer argument}}
+llvm.func @invalid_noalias(%arg0 : !llvm.float {llvm.byval}) -> !llvm.float {
+  llvm.return %arg0 : !llvm.float
+}
+
+// -----
+
 // expected-error @+1 {{llvm.align attribute attached to LLVM non-pointer argument}}
 llvm.func @invalid_align(%arg0 : !llvm.float {llvm.align = 4}) -> !llvm.float {
   llvm.return %arg0 : !llvm.float