LLVM IR Lowering: support multi-value returns.
authorAlex Zinenko <zinenko@google.com>
Tue, 4 Dec 2018 14:16:26 +0000 (06:16 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:15:56 +0000 (14:15 -0700)
Unlike MLIR, LLVM IR does not support functions that return multiple values.
Simulate this by packing values into the LLVM structure type in the same order
as they appear in the MLIR return.  If the function returns only a single
value, return it directly without packing.

PiperOrigin-RevId: 223964886

mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/test/Target/llvmir.mlir

index 2e3ea537c59213f0d876f40fbecebf8619c65cb8..673485ade99f0c616e33cc413e2f7472422fe748 100644 (file)
@@ -79,6 +79,12 @@ private:
   llvm::StructType *convertMemRefType(MemRefType type);
   /// \}
 
+  /// Convert a list of types to an LLVM type suitable for being returned from a
+  /// function.  If the list is empty, return VoidTy.  If it
+  /// contains one element, return the converted element. Otherwise, create an
+  /// LLVM StructType containing all the given types in order.
+  llvm::Type *getPackedResultType(ArrayRef<Type> types);
+
   /// Get an a constant value of `indexType`.
   inline llvm::Constant *getIndexConstant(int64_t value);
 
@@ -118,6 +124,12 @@ private:
   /// instruction) on success and nullptr on error.
   llvm::Value *emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp);
 
+  /// Create a single LLVM value of struct type that includes the list of
+  /// given MLIR values.  The `values` list must contain at least 2 elements.
+  llvm::Value *packValues(ArrayRef<const SSAValue *> values);
+  /// Extract a list of `num` LLVM values from a `value` of struct type.
+  SmallVector<llvm::Value *, 4> unpackValues(llvm::Value *value, unsigned num);
+
   llvm::DenseMap<const Function *, llvm::Function *> functionMapping;
   llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping;
   llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping;
@@ -157,20 +169,40 @@ llvm::Type *ModuleLowerer::convertFloatType(FloatType type) {
   }
 }
 
+// Helper function for lambdas below.
+static bool isTypeNull(llvm::Type *type) { return type == nullptr; }
+
+// If `types` has more than one type, pack them into an LLVM StructType,
+// otherwise just convert the type.
+llvm::Type *ModuleLowerer::getPackedResultType(ArrayRef<Type> types) {
+  // Convert result types one by one and check for errors.
+  auto resultTypes =
+      functional::map([this](Type t) { return convertType(t); }, types);
+  if (llvm::any_of(resultTypes, isTypeNull))
+    return nullptr;
+
+  // LLVM does not support tuple returns.  If there are more than 2 results,
+  // pack them into an LLVM struct type.
+  if (resultTypes.empty())
+    return llvm::Type::getVoidTy(llvmContext);
+  if (resultTypes.size() == 1)
+    return resultTypes.front();
+  return llvm::StructType::get(llvmContext, resultTypes);
+}
+
+// Function types are converted to LLVM Function types by recursively converting
+// argument and result types.  If MLIR Function has zero results, the LLVM
+// Function has one VoidType result.  If MLIR Function has more than one result,
+// they are into an LLVM StructType in their order of appearance.
 llvm::FunctionType *ModuleLowerer::convertFunctionType(FunctionType type) {
-  // TODO(zinenko): convert tuple to LLVM structure types
-  assert(type.getNumResults() <= 1 && "NYI: tuple returns");
-  auto resultType = type.getNumResults() == 0
-                        ? llvm::Type::getVoidTy(llvmContext)
-                        : convertType(type.getResult(0));
+  llvm::Type *resultType = getPackedResultType(type.getResults());
   if (!resultType)
     return nullptr;
 
-  auto argTypes =
-      functional::map([this](Type inputType) { return convertType(inputType); },
-                      type.getInputs());
-  if (std::any_of(argTypes.begin(), argTypes.end(),
-                  [](const llvm::Type *t) { return t == nullptr; }))
+  // Convert argument types one by one and check for errors.
+  auto argTypes = functional::map([this](Type t) { return convertType(t); },
+                                  type.getInputs());
+  if (llvm::any_of(argTypes, isTypeNull))
     return nullptr;
 
   return llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false);
@@ -364,6 +396,33 @@ ModuleLowerer::emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp) {
   return builder.CreateCall(freeFunc, data);
 }
 
+// Create an undef struct value and insert individual values into it.
+llvm::Value *ModuleLowerer::packValues(ArrayRef<const SSAValue *> values) {
+  assert(values.size() > 1 && "cannot pack less than 2 values");
+
+  auto types =
+      functional::map([](const SSAValue *v) { return v->getType(); }, values);
+  llvm::Type *packedType = getPackedResultType(types);
+
+  llvm::Value *packed = llvm::UndefValue::get(packedType);
+  for (auto indexedValue : llvm::enumerate(values)) {
+    packed = builder.CreateInsertValue(
+        packed, valueMapping.lookup(indexedValue.value()),
+        indexedValue.index());
+  }
+  return packed;
+}
+
+// Emit extract value instructions to unpack the struct.
+SmallVector<llvm::Value *, 4> ModuleLowerer::unpackValues(llvm::Value *value,
+                                                          unsigned num) {
+  SmallVector<llvm::Value *, 4> unpacked;
+  unpacked.reserve(num);
+  for (unsigned i = 0; i < num; ++i)
+    unpacked.push_back(builder.CreateExtractValue(value, i));
+  return unpacked;
+}
+
 static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) {
   switch (p) {
   case CmpIPredicate::EQ:
@@ -541,26 +600,34 @@ bool ModuleLowerer::convertInstruction(const Instruction &inst) {
         [this](const SSAValue *value) { return valueMapping.lookup(value); },
         callOp->getOperands());
     auto numResults = callOp->getNumResults();
-    // TODO(zinenko): support tuple returns
-    assert(numResults <= 1 && "NYI: tuple returns");
-
     llvm::Value *result =
         builder.CreateCall(functionMapping[callOp->getCallee()], operands);
-    if (numResults == 1)
+    if (numResults == 1) {
       valueMapping[callOp->getResult(0)] = result;
+    } else if (numResults > 1) {
+      auto unpacked = unpackValues(result, numResults);
+      for (auto indexedValue : llvm::enumerate(unpacked)) {
+        valueMapping[callOp->getResult(indexedValue.index())] =
+            indexedValue.value();
+      }
+    }
     return false;
   }
 
   // Terminators.
   if (auto returnInst = inst.dyn_cast<ReturnOp>()) {
     unsigned numOperands = returnInst->getNumOperands();
-    // TODO(zinenko): support tuple returns
-    assert(numOperands <= 1u && "NYI: tuple returns");
-
-    if (numOperands == 0)
+    if (numOperands == 0) {
       builder.CreateRetVoid();
-    else
+    } else if (numOperands == 1) {
       builder.CreateRet(valueMapping[returnInst->getOperand(0)]);
+    } else {
+      llvm::Value *packed =
+          packValues(llvm::to_vector<4>(returnInst->getOperands()));
+      if (!packed)
+        return true;
+      builder.CreateRet(packed);
+    }
 
     return false;
   }
index 3c476ab2fa423ef2eb1256a8a510a90c1ed0a179..7bd51e26cbc617892bf744e0b43837caf5c89615 100644 (file)
@@ -559,3 +559,41 @@ bb0(%arg0: memref<42x?x10x?xf32>):
 // CHECK-NEXT: ret i64 %6
   return %d0123 : index
 }
+
+extfunc @get_i64() -> (i64)
+extfunc @get_f32() -> (f32)
+extfunc @get_memref() -> (memref<42x?x10x?xf32>)
+
+// CHECK-LABEL: define { i64, float, { float*, i64, i64 } } @multireturn() {
+cfgfunc @multireturn() -> (i64, f32, memref<42x?x10x?xf32>) {
+bb0:
+  %0 = call @get_i64() : () -> (i64)
+  %1 = call @get_f32() : () -> (f32)
+  %2 = call @get_memref() : () -> (memref<42x?x10x?xf32>)
+// CHECK:        %{{[0-9]+}} = insertvalue { i64, float, { float*, i64, i64 } } undef, i64 %{{[0-9]+}}, 0
+// CHECK-NEXT:   %{{[0-9]+}} = insertvalue { i64, float, { float*, i64, i64 } } %{{[0-9]+}}, float %{{[0-9]+}}, 1
+// CHECK-NEXT:   %{{[0-9]+}} = insertvalue { i64, float, { float*, i64, i64 } } %{{[0-9]+}}, { float*, i64, i64 } %{{[0-9]+}}, 2
+// CHECK-NEXT:   ret { i64, float, { float*, i64, i64 } } %{{[0-9]+}}
+  return %0, %1, %2 : i64, f32, memref<42x?x10x?xf32>
+}
+
+
+// CHECK-LABEL: define void @multireturn_caller() {
+cfgfunc @multireturn_caller() {
+bb0:
+// CHECK-NEXT:   %1 = call { i64, float, { float*, i64, i64 } } @multireturn()
+// CHECK-NEXT:   [[ret0:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 0
+// CHECK-NEXT:   [[ret1:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 1
+// CHECK-NEXT:   [[ret2:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 2
+  %0 = call @multireturn() : () -> (i64, f32, memref<42x?x10x?xf32>)
+  %1 = constant 42 : i64
+// CHECK:   add i64 [[ret0]], 42
+  %2 = addi %0#0, %1 : i64
+  %3 = constant 42.0 : f32
+// CHECK:   fadd float [[ret1]], 4.200000e+01
+  %4 = addf %0#1, %3 : f32
+  %5 = constant 0 : index
+// CHECK:   extractvalue { float*, i64, i64 } [[ret2]], 0
+  %6 = load %0#2 [%5, %5, %5, %5] : memref<42x?x10x?xf32>
+  return
+}