[mlir] Add exp2 conversion to llvm.intr.exp2
authorAaron Smith <aaron.smith@microsoft.com>
Sun, 29 Mar 2020 08:23:08 +0000 (01:23 -0700)
committerAaron Smith <aaron.smith@microsoft.com>
Sun, 29 Mar 2020 08:23:08 +0000 (01:23 -0700)
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/Target/llvmir-intrinsics.mlir

index 53dea5b..954683b 100644 (file)
@@ -771,6 +771,7 @@ class LLVM_TernarySameArgsIntrinsicOp<string func, list<OpTrait> traits = []> :
 def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
 def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
 def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">;
+def LLVM_Exp2Op : LLVM_UnaryIntrinsicOp<"exp2">;
 def LLVM_FAbsOp : LLVM_UnaryIntrinsicOp<"fabs">;
 def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">;
 def LLVM_FMAOp : LLVM_TernarySameArgsIntrinsicOp<"fma">;
index 9ca5465..df32e89 100644 (file)
@@ -924,6 +924,14 @@ def ExpOp : FloatUnaryOp<"exp"> {
 }
 
 //===----------------------------------------------------------------------===//
+// ExpOp
+//===----------------------------------------------------------------------===//
+
+def Exp2Op : FloatUnaryOp<"exp2"> {
+  let summary = "base-2 exponential of the specified value";
+}
+
+//===----------------------------------------------------------------------===//
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
 
index a6c0d32..38e4854 100644 (file)
@@ -1222,6 +1222,7 @@ using CopySignOpLowering =
 using CosOpLowering = VectorConvertToLLVMPattern<CosOp, LLVM::CosOp>;
 using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
 using ExpOpLowering = VectorConvertToLLVMPattern<ExpOp, LLVM::ExpOp>;
+using Exp2OpLowering = VectorConvertToLLVMPattern<Exp2Op, LLVM::Exp2Op>;
 using Log10OpLowering = VectorConvertToLLVMPattern<Log10Op, LLVM::Log10Op>;
 using Log2OpLowering = VectorConvertToLLVMPattern<Log2Op, LLVM::Log2Op>;
 using LogOpLowering = VectorConvertToLLVMPattern<LogOp, LLVM::LogOp>;
@@ -2649,6 +2650,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       DialectCastOpLowering,
       DivFOpLowering,
       ExpOpLowering,
+      Exp2OpLowering,
       LogOpLowering,
       Log10OpLowering,
       Log2OpLowering,
index 2b22571..9c8d47d 100644 (file)
@@ -485,18 +485,20 @@ func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
   %12 = xor %arg2, %arg3 : i32
 // CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
   %13 = std.exp %arg0 : f32
-// CHECK-NEXT: %14 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
-  %14 = constant 7.9e-01 : f64
-// CHECK-NEXT: %15 = llvm.shl %arg2, %arg3 : !llvm.i32
-  %15 = shift_left %arg2, %arg3 : i32
-// CHECK-NEXT: %16 = llvm.ashr %arg2, %arg3 : !llvm.i32
-  %16 = shift_right_signed %arg2, %arg3 : i32
-// CHECK-NEXT: %17 = llvm.lshr %arg2, %arg3 : !llvm.i32
-  %17 = shift_right_unsigned %arg2, %arg3 : i32
+// CHECK-NEXT: %14 = "llvm.intr.exp2"(%arg0) : (!llvm.float) -> !llvm.float
+  %14 = std.exp2 %arg0 : f32
+// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
+  %15 = constant 7.9e-01 : f64
+// CHECK-NEXT: %16 = llvm.shl %arg2, %arg3 : !llvm.i32
+  %16 = shift_left %arg2, %arg3 : i32
+// CHECK-NEXT: %17 = llvm.ashr %arg2, %arg3 : !llvm.i32
+  %17 = shift_right_signed %arg2, %arg3 : i32
+// CHECK-NEXT: %18 = llvm.lshr %arg2, %arg3 : !llvm.i32
+  %18 = shift_right_unsigned %arg2, %arg3 : i32
 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
-  %18 = std.sqrt %arg0 : f32
+  %19 = std.sqrt %arg0 : f32
 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
-  %19 = std.sqrt %arg4 : f64
+  %20 = std.sqrt %arg4 : f64
   return %0, %4 : f32, i32
 }
 
index 7be5e5f..c332bc2 100644 (file)
@@ -27,6 +27,15 @@ llvm.func @exp_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
   llvm.return
 }
 
+// CHECK-LABEL: @exp2_test
+llvm.func @exp2_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
+  // CHECK: call float @llvm.exp2.f32
+  "llvm.intr.exp2"(%arg0) : (!llvm.float) -> !llvm.float
+  // CHECK: call <8 x float> @llvm.exp2.v8f32
+  "llvm.intr.exp2"(%arg1) : (!llvm<"<8 x float>">) -> !llvm<"<8 x float>">
+  llvm.return
+}
+
 // CHECK-LABEL: @log_test
 llvm.func @log_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
   // CHECK: call float @llvm.log.f32