From b4869f2fa71f977db94f0e7645711a169c845410 Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Wed, 3 Aug 2022 14:08:33 -0700 Subject: [PATCH] [mlir][LLVMIR] Fix incorrect result type from llvm.fcmp If any of the operands for FCmpOp is a vector, returns a vector, rather than an i1 type result. Differential Revision: https://reviews.llvm.org/D134449 --- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 14 +++++++++++++- mlir/test/Dialect/LLVMIR/roundtrip.mlir | 10 ++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index df1dd6a..9c47d36 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -907,8 +907,20 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { Value rhs = processValue(inst->getOperand(1)); if (!lhs || !rhs) return failure(); + + if (lhs.getType() != rhs.getType()) + return failure(); + + Type boolType = b.getI1Type(); + Type resType = boolType; + if (LLVM::isCompatibleVectorType(lhs.getType())) { + unsigned numElements = + LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); + resType = VectorType::get({numElements}, boolType); + } + instMap[inst] = b.create( - loc, b.getI1Type(), + loc, resType, getFCmpPredicate(cast(inst)->getPredicate()), lhs, rhs); return success(); } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 37a7905..ad26fd7 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -449,7 +449,7 @@ llvm.func @useInlineAsm(%arg0: i32) { } // CHECK-LABEL: @fastmathFlags -func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32) { +func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f32>, %arg4: vector<2 x f32>) { // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 @@ -461,8 +461,14 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32) { %3 = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 %4 = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 -// CHECK: {{.*}} = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[SCALAR_PRED0:.+]] = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 %5 = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %{{.*}} = llvm.add %[[SCALAR_PRED0]], %[[SCALAR_PRED0]] : i1 + %typecheck_5 = llvm.add %5, %5 : i1 +// CHECK: %[[VEC_PRED0:.+]] = llvm.fcmp "oeq" %arg3, %arg4 {fastmathFlags = #llvm.fastmath} : vector<2xf32> + %vcmp = llvm.fcmp "oeq" %arg3, %arg4 {fastmathFlags = #llvm.fastmath} : vector<2xf32> +// CHECK: %{{.*}} = llvm.add %[[VEC_PRED0]], %[[VEC_PRED0]] : vector<2xi1> + %typecheck_vcmp = llvm.add %vcmp, %vcmp : vector<2xi1> // CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 -- 2.7.4