llvm::StructType *convertMemRefType(MemRefType type);
/// \}
+ /// Convert a list of types to an LLVM type suitable for being returned from a
+ /// function. If the list is empty, return VoidTy. If it
+ /// contains one element, return the converted element. Otherwise, create an
+ /// LLVM StructType containing all the given types in order.
+ llvm::Type *getPackedResultType(ArrayRef<Type> types);
+
/// Get an a constant value of `indexType`.
inline llvm::Constant *getIndexConstant(int64_t value);
/// instruction) on success and nullptr on error.
llvm::Value *emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp);
+ /// Create a single LLVM value of struct type that includes the list of
+ /// given MLIR values. The `values` list must contain at least 2 elements.
+ llvm::Value *packValues(ArrayRef<const SSAValue *> values);
+ /// Extract a list of `num` LLVM values from a `value` of struct type.
+ SmallVector<llvm::Value *, 4> unpackValues(llvm::Value *value, unsigned num);
+
llvm::DenseMap<const Function *, llvm::Function *> functionMapping;
llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping;
llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping;
}
}
+// Helper function for lambdas below.
+static bool isTypeNull(llvm::Type *type) { return type == nullptr; }
+
+// If `types` has more than one type, pack them into an LLVM StructType,
+// otherwise just convert the type.
+llvm::Type *ModuleLowerer::getPackedResultType(ArrayRef<Type> types) {
+ // Convert result types one by one and check for errors.
+ auto resultTypes =
+ functional::map([this](Type t) { return convertType(t); }, types);
+ if (llvm::any_of(resultTypes, isTypeNull))
+ return nullptr;
+
+ // LLVM does not support tuple returns. If there are more than 2 results,
+ // pack them into an LLVM struct type.
+ if (resultTypes.empty())
+ return llvm::Type::getVoidTy(llvmContext);
+ if (resultTypes.size() == 1)
+ return resultTypes.front();
+ return llvm::StructType::get(llvmContext, resultTypes);
+}
+
+// Function types are converted to LLVM Function types by recursively converting
+// argument and result types. If MLIR Function has zero results, the LLVM
+// Function has one VoidType result. If MLIR Function has more than one result,
+// they are into an LLVM StructType in their order of appearance.
llvm::FunctionType *ModuleLowerer::convertFunctionType(FunctionType type) {
- // TODO(zinenko): convert tuple to LLVM structure types
- assert(type.getNumResults() <= 1 && "NYI: tuple returns");
- auto resultType = type.getNumResults() == 0
- ? llvm::Type::getVoidTy(llvmContext)
- : convertType(type.getResult(0));
+ llvm::Type *resultType = getPackedResultType(type.getResults());
if (!resultType)
return nullptr;
- auto argTypes =
- functional::map([this](Type inputType) { return convertType(inputType); },
- type.getInputs());
- if (std::any_of(argTypes.begin(), argTypes.end(),
- [](const llvm::Type *t) { return t == nullptr; }))
+ // Convert argument types one by one and check for errors.
+ auto argTypes = functional::map([this](Type t) { return convertType(t); },
+ type.getInputs());
+ if (llvm::any_of(argTypes, isTypeNull))
return nullptr;
return llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false);
return builder.CreateCall(freeFunc, data);
}
+// Create an undef struct value and insert individual values into it.
+llvm::Value *ModuleLowerer::packValues(ArrayRef<const SSAValue *> values) {
+ assert(values.size() > 1 && "cannot pack less than 2 values");
+
+ auto types =
+ functional::map([](const SSAValue *v) { return v->getType(); }, values);
+ llvm::Type *packedType = getPackedResultType(types);
+
+ llvm::Value *packed = llvm::UndefValue::get(packedType);
+ for (auto indexedValue : llvm::enumerate(values)) {
+ packed = builder.CreateInsertValue(
+ packed, valueMapping.lookup(indexedValue.value()),
+ indexedValue.index());
+ }
+ return packed;
+}
+
+// Emit extract value instructions to unpack the struct.
+SmallVector<llvm::Value *, 4> ModuleLowerer::unpackValues(llvm::Value *value,
+ unsigned num) {
+ SmallVector<llvm::Value *, 4> unpacked;
+ unpacked.reserve(num);
+ for (unsigned i = 0; i < num; ++i)
+ unpacked.push_back(builder.CreateExtractValue(value, i));
+ return unpacked;
+}
+
static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) {
switch (p) {
case CmpIPredicate::EQ:
[this](const SSAValue *value) { return valueMapping.lookup(value); },
callOp->getOperands());
auto numResults = callOp->getNumResults();
- // TODO(zinenko): support tuple returns
- assert(numResults <= 1 && "NYI: tuple returns");
-
llvm::Value *result =
builder.CreateCall(functionMapping[callOp->getCallee()], operands);
- if (numResults == 1)
+ if (numResults == 1) {
valueMapping[callOp->getResult(0)] = result;
+ } else if (numResults > 1) {
+ auto unpacked = unpackValues(result, numResults);
+ for (auto indexedValue : llvm::enumerate(unpacked)) {
+ valueMapping[callOp->getResult(indexedValue.index())] =
+ indexedValue.value();
+ }
+ }
return false;
}
// Terminators.
if (auto returnInst = inst.dyn_cast<ReturnOp>()) {
unsigned numOperands = returnInst->getNumOperands();
- // TODO(zinenko): support tuple returns
- assert(numOperands <= 1u && "NYI: tuple returns");
-
- if (numOperands == 0)
+ if (numOperands == 0) {
builder.CreateRetVoid();
- else
+ } else if (numOperands == 1) {
builder.CreateRet(valueMapping[returnInst->getOperand(0)]);
+ } else {
+ llvm::Value *packed =
+ packValues(llvm::to_vector<4>(returnInst->getOperands()));
+ if (!packed)
+ return true;
+ builder.CreateRet(packed);
+ }
return false;
}
// CHECK-NEXT: ret i64 %6
return %d0123 : index
}
+
+extfunc @get_i64() -> (i64)
+extfunc @get_f32() -> (f32)
+extfunc @get_memref() -> (memref<42x?x10x?xf32>)
+
+// CHECK-LABEL: define { i64, float, { float*, i64, i64 } } @multireturn() {
+cfgfunc @multireturn() -> (i64, f32, memref<42x?x10x?xf32>) {
+bb0:
+ %0 = call @get_i64() : () -> (i64)
+ %1 = call @get_f32() : () -> (f32)
+ %2 = call @get_memref() : () -> (memref<42x?x10x?xf32>)
+// CHECK: %{{[0-9]+}} = insertvalue { i64, float, { float*, i64, i64 } } undef, i64 %{{[0-9]+}}, 0
+// CHECK-NEXT: %{{[0-9]+}} = insertvalue { i64, float, { float*, i64, i64 } } %{{[0-9]+}}, float %{{[0-9]+}}, 1
+// CHECK-NEXT: %{{[0-9]+}} = insertvalue { i64, float, { float*, i64, i64 } } %{{[0-9]+}}, { float*, i64, i64 } %{{[0-9]+}}, 2
+// CHECK-NEXT: ret { i64, float, { float*, i64, i64 } } %{{[0-9]+}}
+ return %0, %1, %2 : i64, f32, memref<42x?x10x?xf32>
+}
+
+
+// CHECK-LABEL: define void @multireturn_caller() {
+cfgfunc @multireturn_caller() {
+bb0:
+// CHECK-NEXT: %1 = call { i64, float, { float*, i64, i64 } } @multireturn()
+// CHECK-NEXT: [[ret0:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 0
+// CHECK-NEXT: [[ret1:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 1
+// CHECK-NEXT: [[ret2:%[0-9]+]] = extractvalue { i64, float, { float*, i64, i64 } } %1, 2
+ %0 = call @multireturn() : () -> (i64, f32, memref<42x?x10x?xf32>)
+ %1 = constant 42 : i64
+// CHECK: add i64 [[ret0]], 42
+ %2 = addi %0#0, %1 : i64
+ %3 = constant 42.0 : f32
+// CHECK: fadd float [[ret1]], 4.200000e+01
+ %4 = addf %0#1, %3 : f32
+ %5 = constant 0 : index
+// CHECK: extractvalue { float*, i64, i64 } [[ret2]], 0
+ %6 = load %0#2 [%5, %5, %5, %5] : memref<42x?x10x?xf32>
+ return
+}