[mlir][math] Fix lowering of AbsIOp
authorJeff Niu <jeff@modular.com>
Fri, 12 Aug 2022 15:30:55 +0000 (11:30 -0400)
committerJeff Niu <jeff@modular.com>
Fri, 12 Aug 2022 16:10:15 +0000 (12:10 -0400)
The LLVM intrinsic has a bool flag `is_int_min_poison` that needs to be
set.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D131785

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

index 937a2fa..ca989bf 100644 (file)
@@ -37,10 +37,13 @@ class LLVM_TernarySameArgsIntrinsicOp<string func, list<Trait> traits = []> :
 class LLVM_CountZerosIntrinsicOp<string func, list<Trait> traits = []> :
     LLVM_OneResultIntrOp<func, [], [0],
            !listconcat([NoSideEffect], traits)> {
-  let arguments = (ins LLVM_Type:$in, I<1>:$zero_undefined);
+  let arguments = (ins LLVM_Type:$in, I1:$zero_undefined);
+}
+
+def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [NoSideEffect]> {
+  let arguments = (ins LLVM_Type:$in, I1:$is_int_min_poison);
 }
 
-def LLVM_AbsOp : LLVM_UnaryIntrinsicOp<"abs">;
 def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
 def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
 def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">;
index 1cd24cd..cb34982 100644 (file)
@@ -19,7 +19,6 @@ using namespace mlir;
 
 namespace {
 using AbsFOpLowering = VectorConvertToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
-using AbsIOpLowering = VectorConvertToLLVMPattern<math::AbsIOp, LLVM::AbsOp>;
 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
 using CopySignOpLowering =
     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
@@ -41,11 +40,11 @@ using RoundOpLowering =
 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
 
-// A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
+// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
 template <typename MathOp, typename LLVMOp>
-struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
+struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
   using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
-  using Super = CountOpLowering<MathOp, LLVMOp>;
+  using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
 
   LogicalResult
   matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
@@ -83,9 +82,10 @@ struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
 };
 
 using CountLeadingZerosOpLowering =
-    CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
+    IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
 using CountTrailingZerosOpLowering =
-    CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
+    IntOpWithFlagLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
+using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
 
 // A `expm1` is converted into `exp - 1`.
 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
index fc91a55..64e2018 100644 (file)
@@ -10,13 +10,20 @@ func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
   %2 = math.sqrt %arg0 : f32
   // CHECK: = "llvm.intr.sqrt"(%{{.*}}) : (f64) -> f64
   %3 = math.sqrt %arg4 : f64
-  // CHECK: = "llvm.intr.abs"(%{{.*}}) : (i32) -> i32
-  %4 = math.absi %arg2 : i32
   func.return
 }
 
 // -----
 
+func.func @absi(%arg0: i32) -> i32 {
+  // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false
+  // CHECK: = "llvm.intr.abs"(%{{.*}}, %[[FALSE]]) : (i32, i1) -> i32
+  %0 = math.absi %arg0 : i32
+  return %0 : i32
+}
+
+// -----
+
 // CHECK-LABEL: func @log1p(
 // CHECK-SAME: f32
 func.func @log1p(%arg0 : f32) {