}
};
-// A `tanh` is converted into a call to the `tanh` function.
-struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
- using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
-
- using LLVMFuncOpT = LLVM::LLVMFuncOp;
- using LLVMTypeT = LLVM::LLVMType;
-
- OperandAdaptor<TanhOp> transformed(operands);
- LLVMTypeT operandType =
- transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
-
- if (!operandType)
- return failure();
-
- std::string functionName;
- if (operandType.isFloatTy())
- functionName = "tanhf";
- else if (operandType.isDoubleTy())
- functionName = "tanh";
- else
- return failure();
-
- // Get a reference to the tanh function, inserting it if necessary.
- Operation *tanhFunc =
- SymbolTable::lookupNearestSymbolFrom(op, functionName);
-
- LLVMFuncOpT tanhLLVMFunc;
- if (tanhFunc) {
- tanhLLVMFunc = cast<LLVMFuncOpT>(tanhFunc);
- } else {
- PatternRewriter::InsertionGuard insertGuard(rewriter);
- auto module = op->getParentOfType<ModuleOp>();
- rewriter.setInsertionPointToStart(module.getBody());
- tanhLLVMFunc = rewriter.create<LLVMFuncOpT>(
- module.getLoc(), functionName,
- LLVMTypeT::getFunctionTy(operandType, operandType,
- /*isVarArg=*/false));
- }
-
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(
- op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
- transformed.operand());
- return success();
- }
-};
-
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
SqrtOpLowering,
SubFOpLowering,
SubIOpLowering,
- TanhOpLowering,
TruncateIOpLowering,
UnsignedDivIOpLowering,
UnsignedRemIOpLowering,
: ConversionTarget(ctx) {
this->addLegalDialect<LLVM::LLVMDialect>();
this->addIllegalOp<LLVM::DialectCastOp>();
+ this->addIllegalOp<TanhOp>();
}
std::unique_ptr<OpPassBase<ModuleOp>>
// CHECK-NEXT: %2 = llvm.icmp "slt" %arg2, %1 : !llvm.i32
%2 = cmpi "slt", %arg2, %1 : i32
// CHECK-NEXT: %3 = llvm.sdiv %arg2, %arg3 : !llvm.i32
- %4 = divi_signed %arg2, %arg3 : i32
+ %3 = divi_signed %arg2, %arg3 : i32
// CHECK-NEXT: %4 = llvm.udiv %arg2, %arg3 : !llvm.i32
- %5 = divi_unsigned %arg2, %arg3 : i32
+ %4 = divi_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %5 = llvm.srem %arg2, %arg3 : !llvm.i32
- %6 = remi_signed %arg2, %arg3 : i32
+ %5 = remi_signed %arg2, %arg3 : i32
// CHECK-NEXT: %6 = llvm.urem %arg2, %arg3 : !llvm.i32
- %7 = remi_unsigned %arg2, %arg3 : i32
+ %6 = remi_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %7 = llvm.select %2, %arg2, %arg3 : !llvm.i1, !llvm.i32
- %8 = select %2, %arg2, %arg3 : i32
+ %7 = select %2, %arg2, %arg3 : i32
// CHECK-NEXT: %8 = llvm.fdiv %arg0, %arg1 : !llvm.float
- %9 = divf %arg0, %arg1 : f32
+ %8 = divf %arg0, %arg1 : f32
// CHECK-NEXT: %9 = llvm.frem %arg0, %arg1 : !llvm.float
- %10 = remf %arg0, %arg1 : f32
+ %9 = remf %arg0, %arg1 : f32
// CHECK-NEXT: %10 = llvm.and %arg2, %arg3 : !llvm.i32
- %11 = and %arg2, %arg3 : i32
+ %10 = and %arg2, %arg3 : i32
// CHECK-NEXT: %11 = llvm.or %arg2, %arg3 : !llvm.i32
- %12 = or %arg2, %arg3 : i32
+ %11 = or %arg2, %arg3 : i32
// CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
- %13 = xor %arg2, %arg3 : i32
+ %12 = xor %arg2, %arg3 : i32
// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
- %14 = std.exp %arg0 : f32
-// CHECK-NEXT: %14 = llvm.call @tanhf(%arg0) : (!llvm.float) -> !llvm.float
- %15 = std.tanh %arg0 : f32
-// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
- %16 = constant 7.9e-01 : f64
-// CHECK-NEXT: %16 = llvm.call @tanh(%15) : (!llvm.double) -> !llvm.double
- %17 = std.tanh %16 : f64
-// CHECK-NEXT: %17 = llvm.shl %arg2, %arg3 : !llvm.i32
- %18 = shift_left %arg2, %arg3 : i32
-// CHECK-NEXT: %18 = llvm.ashr %arg2, %arg3 : !llvm.i32
- %19 = shift_right_signed %arg2, %arg3 : i32
-// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
- %20 = shift_right_unsigned %arg2, %arg3 : i32
+ %13 = std.exp %arg0 : f32
+// CHECK-NEXT: %14 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
+ %14 = constant 7.9e-01 : f64
+// CHECK-NEXT: %15 = llvm.shl %arg2, %arg3 : !llvm.i32
+ %15 = shift_left %arg2, %arg3 : i32
+// CHECK-NEXT: %16 = llvm.ashr %arg2, %arg3 : !llvm.i32
+ %16 = shift_right_signed %arg2, %arg3 : i32
+// CHECK-NEXT: %17 = llvm.lshr %arg2, %arg3 : !llvm.i32
+ %17 = shift_right_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
- %21 = std.sqrt %arg0 : f32
+ %18 = std.sqrt %arg0 : f32
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
- %22 = std.sqrt %arg4 : f64
+ %19 = std.sqrt %arg4 : f64
return %0, %4 : f32, i32
}
// -----
-module {
- func @check_tanh_func_added_only_once_to_symbol_table(%f: f32, %lf: f64) -> () {
- %f0 = std.tanh %f : f32
- %f1 = std.tanh %f0 : f32
- %lf0 = std.tanh %lf : f64
- %lf1 = std.tanh %lf0 : f64
- return
- }
-// CHECK: module {
-// CHECK: llvm.func @tanh(!llvm.double) -> !llvm.double
-// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
-// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
-}
-
-// -----
-
// CHECK-LABEL: func @atomic_rmw
func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32