[mlir] LLVM lowering: don't use calling convention in op returns
authorAlex Zinenko <zinenko@google.com>
Wed, 12 Apr 2023 09:18:25 +0000 (09:18 +0000)
committerAlex Zinenko <zinenko@google.com>
Thu, 13 Apr 2023 10:56:56 +0000 (10:56 +0000)
Conversions to the LLVM dialect have an option to use the "bare pointer"
calling convention that converts memref types differently than the
default convention. It has crept into the conversion of operations that
are not related to calls but do require multiresult-to-struct packing.
Use a similar mechanism for the latter without using the calling
convention.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D148086

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

index 6005751..1eb5661 100644 (file)
@@ -56,13 +56,19 @@ public:
                                 bool useBarePtrCallConv,
                                 SignatureConversion &result);
 
-  /// Convert a non-empty list of types to be returned from a function into a
-  /// supported LLVM IR type.  In particular, if more than one value is
-  /// returned, create an LLVM IR structure type with elements that correspond
-  /// to each of the MLIR types converted with `convertType`.
+  /// Convert a non-empty list of types to be returned from a function into an
+  /// LLVM-compatible type. In particular, if more than one value is returned,
+  /// create an LLVM dialect structure type with elements that correspond to
+  /// each of the types converted with `convertCallingConventionType`.
   Type packFunctionResults(TypeRange types,
                            bool useBarePointerCallConv = false);
 
+  /// Convert a non-empty list of types of values produced by an operation into
+  /// an LLVM-compatible type. In particular, if more than one value is
+  /// produced, create a literal structure with elements that correspond to each
+  /// of the LLVM-compatible types converted with `convertType`.
+  Type packOperationResults(TypeRange types);
+
   /// Convert a type in the context of the default or bare pointer calling
   /// convention. Calling convention sensitive types, such as MemRefType and
   /// UnrankedMemRefType, are converted following the specific rules for the
index d3983a3..e2dae40 100644 (file)
@@ -329,7 +329,7 @@ LogicalResult LLVM::detail::oneToOneRewrite(
   SmallVector<Type> resultTypes;
   if (numResults != 0) {
     resultTypes.push_back(
-        typeConverter.packFunctionResults(op->getResultTypes()));
+        typeConverter.packOperationResults(op->getResultTypes()));
     if (!resultTypes.back())
       return failure();
   }
index 833ea36..b8ed763 100644 (file)
@@ -496,10 +496,31 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors(
                                                     memrefTy, values[i]);
 }
 
-/// Convert a non-empty list of types to be returned from a function into a
-/// supported LLVM IR type.  In particular, if more than one value is returned,
-/// create an LLVM IR structure type with elements that correspond to each of
-/// the MLIR types converted with `convertType`.
+/// Convert a non-empty list of types of values produced by an operation into an
+/// LLVM-compatible type. In particular, if more than one value is
+/// produced, create a literal structure with elements that correspond to each
+/// of the types converted with `convertType`.
+Type LLVMTypeConverter::packOperationResults(TypeRange types) {
+  assert(!types.empty() && "expected non-empty list of type");
+  if (types.size() == 1)
+    return convertType(types[0]);
+
+  SmallVector<Type> resultTypes;
+  resultTypes.reserve(types.size());
+  for (Type type : types) {
+    Type converted = convertType(type);
+    if (!converted || !LLVM::isCompatibleType(converted))
+      return {};
+    resultTypes.push_back(converted);
+  }
+
+  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
+}
+
+/// Convert a non-empty list of types to be returned from a function into an
+/// LLVM-compatible type. In particular, if more than one value is returned,
+/// create an LLVM dialect structure type with elements that correspond to each
+/// of the types converted with `convertCallingConventionType`.
 Type LLVMTypeConverter::packFunctionResults(TypeRange types,
                                             bool useBarePtrCallConv) {
   assert(!types.empty() && "expected non-empty list of type");
@@ -508,7 +529,7 @@ Type LLVMTypeConverter::packFunctionResults(TypeRange types,
   if (types.size() == 1)
     return convertCallingConventionType(types.front(), useBarePtrCallConv);
 
-  SmallVector<Type, 8> resultTypes;
+  SmallVector<Type> resultTypes;
   resultTypes.reserve(types.size());
   for (auto t : types) {
     auto converted = convertCallingConventionType(t, useBarePtrCallConv);