From 7b5132dae8e0d0645aacfb41c09cbc3ef33a330e Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Tue, 9 Nov 2021 14:15:07 +0100 Subject: [PATCH] [fir] Add complex operations conversion from FIR LLVM IR This patch add conversion for primitive operations on complex types. - fir.addc - fir.subc - fir.mulc - fir.divc - fir.negc This adds also the type conversion for !fir.complex type. This patch is part of the upstreaming effort from fir-dev branch. This patch was updated to avoid failure on windows buildbot. Flang codegen does not support windows target so we force the test to use a known target instead. Reviewed By: kiranchandramohan, rovka Differential Revision: https://reviews.llvm.org/D113434 Co-authored-by: Jean Perier Co-authored-by: Eric Schweitz --- flang/include/flang/Optimizer/CodeGen/CGPasses.td | 4 + flang/lib/Optimizer/CodeGen/CodeGen.cpp | 188 +++++++++++++++++++++- flang/lib/Optimizer/CodeGen/Target.cpp | 7 + flang/lib/Optimizer/CodeGen/Target.h | 3 + flang/lib/Optimizer/CodeGen/TypeConverter.h | 53 +++++- flang/test/Fir/convert-to-llvm.fir | 140 +++++++++++++++- flang/test/Fir/types-to-llvm.fir | 28 ++++ 7 files changed, 415 insertions(+), 8 deletions(-) diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td index 51b1218..2474409 100644 --- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td +++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td @@ -24,6 +24,10 @@ def FIRToLLVMLowering : Pass<"fir-to-llvm-ir", "mlir::ModuleOp"> { }]; let constructor = "::fir::createFIRToLLVMPass()"; let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let options = [ + Option<"forcedTargetTriple", "target", "std::string", /*default=*/"", + "Override module's target triple."> + ]; } def CodeGenRewrite : Pass<"cg-rewrite"> { diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 4a0a726..db95017 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -14,6 +14,7 @@ #include "PassDetail.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/FIRContext.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -487,6 +488,175 @@ struct InsertOnRangeOpConversion return success(); } }; + +static mlir::Type getComplexEleTy(mlir::Type complex) { + if (auto cc = complex.dyn_cast()) + return cc.getElementType(); + return complex.cast().getElementType(); +} + +// +// Primitive operations on Complex types +// + +/// Generate inline code for complex addition/subtraction +template +mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds, + mlir::ConversionPatternRewriter &rewriter, + fir::LLVMTypeConverter &lowering) { + mlir::Value a = opnds[0]; + mlir::Value b = opnds[1]; + auto loc = sumop.getLoc(); + auto ctx = sumop.getContext(); + auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); + auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); + mlir::Type eleTy = lowering.convertType(getComplexEleTy(sumop.getType())); + mlir::Type ty = lowering.convertType(sumop.getType()); + auto x0 = rewriter.create(loc, eleTy, a, c0); + auto y0 = rewriter.create(loc, eleTy, a, c1); + auto x1 = rewriter.create(loc, eleTy, b, c0); + auto y1 = rewriter.create(loc, eleTy, b, c1); + auto rx = rewriter.create(loc, eleTy, x0, x1); + auto ry = rewriter.create(loc, eleTy, y0, y1); + auto r0 = rewriter.create(loc, ty); + auto r1 = rewriter.create(loc, ty, r0, rx, c0); + return rewriter.create(loc, ty, r1, ry, c1); +} + +struct AddcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AddcOp addc, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // given: (x + iy) + (x' + iy') + // result: (x + x') + i(y + y') + auto r = complexSum(addc, adaptor.getOperands(), + rewriter, lowerTy()); + rewriter.replaceOp(addc, r.getResult()); + return success(); + } +}; + +struct SubcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SubcOp subc, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // given: (x + iy) - (x' + iy') + // result: (x - x') + i(y - y') + auto r = complexSum(subc, adaptor.getOperands(), + rewriter, lowerTy()); + rewriter.replaceOp(subc, r.getResult()); + return success(); + } +}; + +/// Inlined complex multiply +struct MulcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::MulcOp mulc, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // TODO: Can we use a call to __muldc3 ? + // given: (x + iy) * (x' + iy') + // result: (xx'-yy')+i(xy'+yx') + mlir::Value a = adaptor.getOperands()[0]; + mlir::Value b = adaptor.getOperands()[1]; + auto loc = mulc.getLoc(); + auto *ctx = mulc.getContext(); + auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); + auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); + mlir::Type eleTy = convertType(getComplexEleTy(mulc.getType())); + mlir::Type ty = convertType(mulc.getType()); + auto x0 = rewriter.create(loc, eleTy, a, c0); + auto y0 = rewriter.create(loc, eleTy, a, c1); + auto x1 = rewriter.create(loc, eleTy, b, c0); + auto y1 = rewriter.create(loc, eleTy, b, c1); + auto xx = rewriter.create(loc, eleTy, x0, x1); + auto yx = rewriter.create(loc, eleTy, y0, x1); + auto xy = rewriter.create(loc, eleTy, x0, y1); + auto ri = rewriter.create(loc, eleTy, xy, yx); + auto yy = rewriter.create(loc, eleTy, y0, y1); + auto rr = rewriter.create(loc, eleTy, xx, yy); + auto ra = rewriter.create(loc, ty); + auto r1 = rewriter.create(loc, ty, ra, rr, c0); + auto r0 = rewriter.create(loc, ty, r1, ri, c1); + rewriter.replaceOp(mulc, r0.getResult()); + return success(); + } +}; + +/// Inlined complex division +struct DivcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // TODO: Can we use a call to __divdc3 instead? + // Just generate inline code for now. + // given: (x + iy) / (x' + iy') + // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y' + mlir::Value a = adaptor.getOperands()[0]; + mlir::Value b = adaptor.getOperands()[1]; + auto loc = divc.getLoc(); + auto *ctx = divc.getContext(); + auto c0 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0)); + auto c1 = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1)); + mlir::Type eleTy = convertType(getComplexEleTy(divc.getType())); + mlir::Type ty = convertType(divc.getType()); + auto x0 = rewriter.create(loc, eleTy, a, c0); + auto y0 = rewriter.create(loc, eleTy, a, c1); + auto x1 = rewriter.create(loc, eleTy, b, c0); + auto y1 = rewriter.create(loc, eleTy, b, c1); + auto xx = rewriter.create(loc, eleTy, x0, x1); + auto x1x1 = rewriter.create(loc, eleTy, x1, x1); + auto yx = rewriter.create(loc, eleTy, y0, x1); + auto xy = rewriter.create(loc, eleTy, x0, y1); + auto yy = rewriter.create(loc, eleTy, y0, y1); + auto y1y1 = rewriter.create(loc, eleTy, y1, y1); + auto d = rewriter.create(loc, eleTy, x1x1, y1y1); + auto rrn = rewriter.create(loc, eleTy, xx, yy); + auto rin = rewriter.create(loc, eleTy, yx, xy); + auto rr = rewriter.create(loc, eleTy, rrn, d); + auto ri = rewriter.create(loc, eleTy, rin, d); + auto ra = rewriter.create(loc, ty); + auto r1 = rewriter.create(loc, ty, ra, rr, c0); + auto r0 = rewriter.create(loc, ty, r1, ri, c1); + rewriter.replaceOp(divc, r0.getResult()); + return success(); + } +}; + +/// Inlined complex negation +struct NegcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::NegcOp neg, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // given: -(x + iy) + // result: -x - iy + auto *ctxt = neg.getContext(); + auto eleTy = convertType(getComplexEleTy(neg.getType())); + auto ty = convertType(neg.getType()); + auto loc = neg.getLoc(); + mlir::Value o0 = adaptor.getOperands()[0]; + auto c0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0)); + auto c1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1)); + auto rp = rewriter.create(loc, eleTy, o0, c0); + auto ip = rewriter.create(loc, eleTy, o0, c1); + auto nrp = rewriter.create(loc, eleTy, rp); + auto nip = rewriter.create(loc, eleTy, ip); + auto r = rewriter.create(loc, ty, o0, nrp, c0); + rewriter.replaceOpWithNewOp(neg, ty, r, nip, c1); + return success(); + } +}; + } // namespace namespace { @@ -501,15 +671,21 @@ public: mlir::ModuleOp getModule() { return getOperation(); } void runOnOperation() override final { + auto mod = getModule(); + if (!forcedTargetTriple.empty()) { + fir::setTargetTriple(mod, forcedTargetTriple); + } + auto *context = getModule().getContext(); fir::LLVMTypeConverter typeConverter{getModule()}; mlir::OwningRewritePatternList pattern(context); - pattern.insert< - AddrOfOpConversion, CallOpConversion, ExtractValueOpConversion, - HasValueOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, - InsertValueOpConversion, SelectOpConversion, SelectRankOpConversion, - UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>( - typeConverter); + pattern.insert(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index 78e52b9..0541de2 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -35,6 +35,13 @@ struct GenericTarget : public CodeGenSpecifics { using CodeGenSpecifics::CodeGenSpecifics; using AT = CodeGenSpecifics::Attributes; + mlir::Type complexMemoryType(mlir::Type eleTy) const override { + assert(fir::isa_real(eleTy)); + // { t, t } struct of 2 eleTy + mlir::TypeRange range = {eleTy, eleTy}; + return mlir::TupleType::get(eleTy.getContext(), range); + } + Marshalling boxcharArgumentType(mlir::Type eleTy, bool sret) const override { CodeGenSpecifics::Marshalling marshal; auto idxTy = mlir::IntegerType::get(eleTy.getContext(), S::defaultWidth); diff --git a/flang/lib/Optimizer/CodeGen/Target.h b/flang/lib/Optimizer/CodeGen/Target.h index eb9c93d..af4004c 100644 --- a/flang/lib/Optimizer/CodeGen/Target.h +++ b/flang/lib/Optimizer/CodeGen/Target.h @@ -65,6 +65,9 @@ public: CodeGenSpecifics() = delete; virtual ~CodeGenSpecifics() {} + /// Type presentation of a `complex` type value in memory. + virtual mlir::Type complexMemoryType(mlir::Type eleTy) const = 0; + /// Type representation of a `complex` type argument when passed by /// value. An argument value may need to be passed as a (safe) reference /// argument. diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.h b/flang/lib/Optimizer/CodeGen/TypeConverter.h index f8d65180..f4d252df 100644 --- a/flang/lib/Optimizer/CodeGen/TypeConverter.h +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.h @@ -14,6 +14,7 @@ #define FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H #include "DescriptorModel.h" +#include "Target.h" #include "flang/Lower/Todo.h" // remove when TODO's are done #include "flang/Optimizer/Support/FIRContext.h" #include "flang/Optimizer/Support/KindMapping.h" @@ -28,7 +29,10 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter { public: LLVMTypeConverter(mlir::ModuleOp module) : mlir::LLVMTypeConverter(module.getContext()), - kindMapping(getKindMapping(module)) { + kindMapping(getKindMapping(module)), + specifics(CodeGenSpecifics::get(module.getContext(), + getTargetTriple(module), + getKindMapping(module))) { LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n"); // Each conversion should return a value of type mlir::Type. @@ -40,6 +44,10 @@ public: addConversion( [&](fir::RecordType derived) { return convertRecordType(derived); }); addConversion( + [&](fir::ComplexType cmplx) { return convertComplexType(cmplx); }); + addConversion( + [&](fir::RealType real) { return convertRealType(real.getFKind()); }); + addConversion( [&](fir::ReferenceType ref) { return convertPointerLike(ref); }); addConversion( [&](SequenceType sequence) { return convertSequenceType(sequence); }); @@ -140,6 +148,24 @@ public: /*isPacked=*/false)); } + // Use the target specifics to figure out how to map complex to LLVM IR. The + // use of complex values in function signatures is handled before conversion + // to LLVM IR dialect here. + // + // fir.complex | std.complex --> llvm<"{t,t}"> + template + mlir::Type convertComplexType(C cmplx) { + LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n'); + auto eleTy = cmplx.getElementType(); + return convertType(specifics->complexMemoryType(eleTy)); + } + + // convert a front-end kind value to either a std or LLVM IR dialect type + // fir.real --> llvm.anyfloat where anyfloat is a kind mapping + mlir::Type convertRealType(fir::KindTy kind) { + return fromRealTypeID(kindMapping.getRealTypeID(kind), kind); + } + template mlir::Type convertPointerLike(A &ty) { mlir::Type eleTy = ty.getEleTy(); @@ -187,8 +213,33 @@ public: return mlir::LLVM::LLVMPointerType::get(baseTy); } + /// Convert llvm::Type::TypeID to mlir::Type + mlir::Type fromRealTypeID(llvm::Type::TypeID typeID, fir::KindTy kind) { + switch (typeID) { + case llvm::Type::TypeID::HalfTyID: + return mlir::FloatType::getF16(&getContext()); + case llvm::Type::TypeID::BFloatTyID: + return mlir::FloatType::getBF16(&getContext()); + case llvm::Type::TypeID::FloatTyID: + return mlir::FloatType::getF32(&getContext()); + case llvm::Type::TypeID::DoubleTyID: + return mlir::FloatType::getF64(&getContext()); + case llvm::Type::TypeID::X86_FP80TyID: + return mlir::FloatType::getF80(&getContext()); + case llvm::Type::TypeID::FP128TyID: + return mlir::FloatType::getF128(&getContext()); + default: + emitError(UnknownLoc::get(&getContext())) + << "unsupported type: !fir.real<" << kind << ">"; + return {}; + } + } + + KindMapping &getKindMap() { return kindMapping; } + private: KindMapping kindMapping; + std::unique_ptr specifics; }; } // namespace fir diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir index 9d8b9bb..cafac1d 100644 --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -1,4 +1,8 @@ -// RUN: fir-opt --split-input-file --fir-to-llvm-ir %s | FileCheck %s +// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s | FileCheck %s +// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=aarch64-unknown-linux-gnu" %s | FileCheck %s +// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=i386-unknown-linux-gnu" %s | FileCheck %s +// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=powerpc64le-unknown-linux-gn" %s | FileCheck %s + // Test simple global LLVM conversion @@ -376,3 +380,137 @@ func @test_call_return_val() -> i32 { // CHECK-NEXT: %0 = llvm.call @dummy_return_val() : () -> i32 // CHECK-NEXT: llvm.return %0 : i32 // CHECK-NEXT: } + +// ----- + +// Test FIR complex addition conversion +// given: (x + iy) + (x' + iy') +// result: (x + x') + i(y + y') + +func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> { + %c = fir.addc %a, %b : !fir.complex<16> + return %c : !fir.complex<16> +} + +// CHECK-LABEL: llvm.func @fir_complex_add( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f128, f128)>, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f128, f128)> { +// CHECK: %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]] : f128 +// CHECK: %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]] : f128 +// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[ADD_X0_X1]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[ADD_Y0_Y1]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)> + +// ----- + +// Test FIR complex substraction conversion +// given: (x + iy) - (x' + iy') +// result: (x - x') + i(y - y') + +func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> { + %c = fir.subc %a, %b : !fir.complex<16> + return %c : !fir.complex<16> +} + +// CHECK-LABEL: llvm.func @fir_complex_sub( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f128, f128)>, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f128, f128)> { +// CHECK: %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]] : f128 +// CHECK: %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]] : f128 +// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[SUB_X0_X1]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[SUB_Y0_Y1]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)> + +// ----- + +// Test FIR complex multiply conversion +// given: (x + iy) * (x' + iy') +// result: (xx'-yy')+i(xy'+yx') + +func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> { + %c = fir.mulc %a, %b : !fir.complex<16> + return %c : !fir.complex<16> +} + +// CHECK-LABEL: llvm.func @fir_complex_mul( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f128, f128)>, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f128, f128)> { +// CHECK: %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128 +// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128 +// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128 +// CHECK: %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]] : f128 +// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128 +// CHECK: %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128 +// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[SUB]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[ADD]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)> + +// ----- + +// Test FIR complex division conversion +// given: (x + iy) / (x' + iy') +// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y' + +func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> { + %c = fir.divc %a, %b : !fir.complex<16> + return %c : !fir.complex<16> +} + +// CHECK-LABEL: llvm.func @fir_complex_div( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f128, f128)>, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f128, f128)> { +// CHECK: %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128 +// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : f128 +// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128 +// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128 +// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128 +// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : f128 +// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : f128 +// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128 +// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : f128 +// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : f128 +// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : f128 +// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)> + +// ----- + +// Test FIR complex negation conversion +// given: -(x + iy) +// result: -x - iy + +func @fir_complex_neg(%a: !fir.complex<16>) -> !fir.complex<16> { + %c = fir.negc %a : !fir.complex<16> + return %c : !fir.complex<16> +} + +// CHECK-LABEL: llvm.func @fir_complex_neg( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f128, f128)> { +// CHECK: %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[NEGX:.*]] = llvm.fneg %[[X]] : f128 +// CHECK: %[[NEGY:.*]] = llvm.fneg %[[Y]] : f128 +// CHECK: %{{.*}} = llvm.insertvalue %[[NEGX]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %{{.*}} = llvm.insertvalue %[[NEGY]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)> diff --git a/flang/test/Fir/types-to-llvm.fir b/flang/test/Fir/types-to-llvm.fir index 409e6da..8bd007e 100644 --- a/flang/test/Fir/types-to-llvm.fir +++ b/flang/test/Fir/types-to-llvm.fir @@ -72,3 +72,31 @@ func private @foo3(%arg0: !fir.logical<8>) func private @foo4(%arg0: !fir.logical<16>) // CHECK-LABEL: foo4 // CHECK-SAME: i128 + +// ----- + +// Test `!fir.complex` conversion. + +func private @foo0(%arg0: !fir.complex<2>) +// CHECK-LABEL: foo0 +// CHECK-SAME: !llvm.struct<(f16, f16)>) + +func private @foo1(%arg0: !fir.complex<3>) +// CHECK-LABEL: foo1 +// CHECK-SAME: !llvm.struct<(bf16, bf16)>) + +func private @foo2(%arg0: !fir.complex<4>) +// CHECK-LABEL: foo2 +// CHECK-SAME: !llvm.struct<(f32, f32)>) + +func private @foo3(%arg0: !fir.complex<8>) +// CHECK-LABEL: foo3 +// CHECK-SAME: !llvm.struct<(f64, f64)>) + +func private @foo4(%arg0: !fir.complex<10>) +// CHECK-LABEL: foo4 +// CHECK-SAME: !llvm.struct<(f80, f80)>) + +func private @foo5(%arg0: !fir.complex<16>) +// CHECK-LABEL: foo5 +// CHECK-SAME: !llvm.struct<(f128, f128)>) -- 2.7.4