From f7bffad5a7ccc820ba87de12c10a6a5c4dc81d6e Mon Sep 17 00:00:00 2001 From: Ehsan Toosi Date: Thu, 12 Dec 2019 09:24:43 -0800 Subject: [PATCH] Added lowering of `std.tanh` to llvm function call to `tanh` and `tanhf`. Closes tensorflow/mlir#312 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/312 from dfki-ehna:tanh 9e89b072ff91ff390ad739501745114feb3ac856 PiperOrigin-RevId: 285205674 --- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 51 ++++++++++++++++++++++ .../StandardToLLVM/convert-to-llvmir.mlir | 24 +++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 933c0cf..508868d 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1200,6 +1200,56 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { bool useAlloca; }; +// A `tanh` is converted into a call to the `tanh` function. +struct TanhOpLowering : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + using LLVMFuncOpT = LLVM::LLVMFuncOp; + using LLVMTypeT = LLVM::LLVMType; + + OperandAdaptor transformed(operands); + LLVMTypeT operandType = + transformed.operand()->getType().dyn_cast_or_null(); + + if (!operandType) + return matchFailure(); + + std::string functionName; + if (operandType.isFloatTy()) + functionName = "tanhf"; + else if (operandType.isDoubleTy()) + functionName = "tanh"; + else + return matchFailure(); + + // Get a reference to the tanh function, inserting it if necessary. + Operation *tanhFunc = + SymbolTable::lookupNearestSymbolFrom(op, functionName); + + LLVMFuncOpT tanhLLVMFunc; + if (tanhFunc) { + tanhLLVMFunc = cast(tanhFunc); + } else { + PatternRewriter::InsertionGuard insertGuard(rewriter); + auto module = op->getParentOfType(); + rewriter.setInsertionPointToStart(module.getBody()); + tanhLLVMFunc = rewriter.create( + module.getLoc(), functionName, + LLVMTypeT::getFunctionTy(operandType, operandType, + /*isVarArg=*/false)); + } + + rewriter.replaceOpWithNewOp( + op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc), + transformed.operand()); + return matchSuccess(); + } +}; + struct MemRefCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; @@ -2038,6 +2088,7 @@ void mlir::populateStdToLLVMConversionPatterns( SubFOpLowering, SubIOpLowering, SubViewOpLowering, + TanhOpLowering, TruncateIOpLowering, ViewOpLowering, XOrOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index a224480..9d8b047 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s // CHECK-LABEL: func @empty() { // CHECK-NEXT: llvm.return @@ -423,6 +423,12 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) { %13 = xor %arg2, %arg3 : i32 // CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float %14 = std.exp %arg0 : f32 +// CHECK-NEXT: %14 = llvm.call @tanhf(%arg0) : (!llvm.float) -> !llvm.float + %15 = std.tanh %arg0 : f32 +// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double + %16 = constant 7.9e-01 : f64 +// CHECK-NEXT: %16 = llvm.call @tanh(%15) : (!llvm.double) -> !llvm.double + %17 = std.tanh %16 : f64 return %0, %4 : f32, i32 } @@ -788,3 +794,19 @@ func @subview_const_stride(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %ar memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref (d0 * 4 + d1 * 2 + s0)> return } + +// ----- + +module { + func @check_tanh_func_added_only_once_to_symbol_table(%f: f32, %lf: f64) -> () { + %f0 = std.tanh %f : f32 + %f1 = std.tanh %f0 : f32 + %lf0 = std.tanh %lf : f64 + %lf1 = std.tanh %lf0 : f64 + return + } +// CHECK: module { +// CHECK: llvm.func @tanh(!llvm.double) -> !llvm.double +// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float +// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table +} -- 2.7.4