[mlir] Removed TanHOp lowering from ConvertStandardToLLVM since there is no reasonabl...
authorMarcel Koester <marcel.koester@dfki.de>
Tue, 3 Mar 2020 10:07:14 +0000 (11:07 +0100)
committerMarcel Koester <marcel.koester@dfki.de>
Wed, 25 Mar 2020 15:43:45 +0000 (16:43 +0100)
Summary: The current ConvertStandardToLLVM phase lowers the standard TanHOp to function calls to external tanh symbols. However, this leads to misunderstandings since these external symbols are not defined anywhere. This commit removes the TanHOp lowering functionality from ConvertStandardToLLVM, adapts the LowerGpuOpsToNVVMOps and LowerGpuOpsToROCDLOps passes and adjusts the affected test cases.

Reviewers: mravishankar, herhut

Subscribers: jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, csigg, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75509

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

index eb5628d..c7bbb6d 100644 (file)
@@ -95,24 +95,6 @@ private:
   const std::string f64Func;
 };
 
-namespace gpu {
-/// Returns a predicate to be used with addDynamicallyLegalOp. The predicate
-/// returns false for calls to the provided intrinsics and true otherwise.
-inline std::function<bool(Operation *)>
-filterIllegalLLVMIntrinsics(ArrayRef<StringRef> intrinsics, MLIRContext *ctx) {
-  SmallVector<StringRef, 4> illegalIds(intrinsics.begin(), intrinsics.end());
-  return [illegalIds](Operation *op) -> bool {
-    LLVM::CallOp callOp = dyn_cast<LLVM::CallOp>(op);
-    if (!callOp || !callOp.callee())
-      return true;
-    StringRef callee = callOp.callee().getValue();
-    return !llvm::any_of(illegalIds, [callee](StringRef intrinsic) {
-      return callee.equals(intrinsic);
-    });
-  };
-}
-} // namespace gpu
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
index e929caa..18aeeb8 100644 (file)
@@ -279,8 +279,6 @@ public:
                         LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>();
     target.addIllegalOp<FuncOp>();
     target.addLegalDialect<NVVM::NVVMDialect>();
-    target.addDynamicallyLegalOp<mlir::LLVM::CallOp>(
-        gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
     // TODO(csigg): Remove once we support replacing non-root ops.
     target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
     if (failed(applyPartialConversion(m, target, patterns, &converter)))
index 238821e..79fb377 100644 (file)
@@ -71,8 +71,6 @@ public:
     target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
     target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp,
                         LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>();
-    target.addDynamicallyLegalOp<LLVM::CallOp>(
-        gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
     target.addIllegalOp<FuncOp>();
     if (failed(applyPartialConversion(m, target, patterns, &converter)))
       signalPassFailure();
index e353e93..d37a773 100644 (file)
@@ -1737,56 +1737,6 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
   }
 };
 
-// A `tanh` is converted into a call to the `tanh` function.
-struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
-  using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    using LLVMFuncOpT = LLVM::LLVMFuncOp;
-    using LLVMTypeT = LLVM::LLVMType;
-
-    OperandAdaptor<TanhOp> transformed(operands);
-    LLVMTypeT operandType =
-        transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
-
-    if (!operandType)
-      return failure();
-
-    std::string functionName;
-    if (operandType.isFloatTy())
-      functionName = "tanhf";
-    else if (operandType.isDoubleTy())
-      functionName = "tanh";
-    else
-      return failure();
-
-    // Get a reference to the tanh function, inserting it if necessary.
-    Operation *tanhFunc =
-        SymbolTable::lookupNearestSymbolFrom(op, functionName);
-
-    LLVMFuncOpT tanhLLVMFunc;
-    if (tanhFunc) {
-      tanhLLVMFunc = cast<LLVMFuncOpT>(tanhFunc);
-    } else {
-      PatternRewriter::InsertionGuard insertGuard(rewriter);
-      auto module = op->getParentOfType<ModuleOp>();
-      rewriter.setInsertionPointToStart(module.getBody());
-      tanhLLVMFunc = rewriter.create<LLVMFuncOpT>(
-          module.getLoc(), functionName,
-          LLVMTypeT::getFunctionTy(operandType, operandType,
-                                   /*isVarArg=*/false));
-    }
-
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
-        op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
-        transformed.operand());
-    return success();
-  }
-};
-
 struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
   using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
 
@@ -2833,7 +2783,6 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       SqrtOpLowering,
       SubFOpLowering,
       SubIOpLowering,
-      TanhOpLowering,
       TruncateIOpLowering,
       UnsignedDivIOpLowering,
       UnsignedRemIOpLowering,
@@ -3022,6 +2971,7 @@ mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
     : ConversionTarget(ctx) {
   this->addLegalDialect<LLVM::LLVMDialect>();
   this->addIllegalOp<LLVM::DialectCastOp>();
+  this->addIllegalOp<TanhOp>();
 }
 
 std::unique_ptr<OpPassBase<ModuleOp>>
index 699ea31..9c072a6 100644 (file)
@@ -407,43 +407,39 @@ func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
 // CHECK-NEXT:  %2 = llvm.icmp "slt" %arg2, %1 : !llvm.i32
   %2 = cmpi "slt", %arg2, %1 : i32
 // CHECK-NEXT:  %3 = llvm.sdiv %arg2, %arg3 : !llvm.i32
-  %4 = divi_signed %arg2, %arg3 : i32
+  %3 = divi_signed %arg2, %arg3 : i32
 // CHECK-NEXT:  %4 = llvm.udiv %arg2, %arg3 : !llvm.i32
-  %5 = divi_unsigned %arg2, %arg3 : i32
+  %4 = divi_unsigned %arg2, %arg3 : i32
 // CHECK-NEXT:  %5 = llvm.srem %arg2, %arg3 : !llvm.i32
-  %6 = remi_signed %arg2, %arg3 : i32
+  %5 = remi_signed %arg2, %arg3 : i32
 // CHECK-NEXT:  %6 = llvm.urem %arg2, %arg3 : !llvm.i32
-  %7 = remi_unsigned %arg2, %arg3 : i32
+  %6 = remi_unsigned %arg2, %arg3 : i32
 // CHECK-NEXT:  %7 = llvm.select %2, %arg2, %arg3 : !llvm.i1, !llvm.i32
-  %8 = select %2, %arg2, %arg3 : i32
+  %7 = select %2, %arg2, %arg3 : i32
 // CHECK-NEXT:  %8 = llvm.fdiv %arg0, %arg1 : !llvm.float
-  %9 = divf %arg0, %arg1 : f32
+  %8 = divf %arg0, %arg1 : f32
 // CHECK-NEXT:  %9 = llvm.frem %arg0, %arg1 : !llvm.float
-  %10 = remf %arg0, %arg1 : f32
+  %9 = remf %arg0, %arg1 : f32
 // CHECK-NEXT: %10 = llvm.and %arg2, %arg3 : !llvm.i32
-  %11 = and %arg2, %arg3 : i32
+  %10 = and %arg2, %arg3 : i32
 // CHECK-NEXT: %11 = llvm.or %arg2, %arg3 : !llvm.i32
-  %12 = or %arg2, %arg3 : i32
+  %11 = or %arg2, %arg3 : i32
 // CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
-  %13 = xor %arg2, %arg3 : i32
+  %12 = xor %arg2, %arg3 : i32
 // CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
-  %14 = std.exp %arg0 : f32
-// CHECK-NEXT: %14 = llvm.call @tanhf(%arg0) : (!llvm.float) -> !llvm.float
-  %15 = std.tanh %arg0 : f32
-// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
-  %16 = constant 7.9e-01 : f64
-// CHECK-NEXT: %16 = llvm.call @tanh(%15) : (!llvm.double) -> !llvm.double
-  %17 = std.tanh %16 : f64
-// CHECK-NEXT: %17 = llvm.shl %arg2, %arg3 : !llvm.i32
-  %18 = shift_left %arg2, %arg3 : i32
-// CHECK-NEXT: %18 = llvm.ashr %arg2, %arg3 : !llvm.i32
-  %19 = shift_right_signed %arg2, %arg3 : i32
-// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
-  %20 = shift_right_unsigned %arg2, %arg3 : i32
+  %13 = std.exp %arg0 : f32
+// CHECK-NEXT: %14 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
+  %14 = constant 7.9e-01 : f64
+// CHECK-NEXT: %15 = llvm.shl %arg2, %arg3 : !llvm.i32
+  %15 = shift_left %arg2, %arg3 : i32
+// CHECK-NEXT: %16 = llvm.ashr %arg2, %arg3 : !llvm.i32
+  %16 = shift_right_signed %arg2, %arg3 : i32
+// CHECK-NEXT: %17 = llvm.lshr %arg2, %arg3 : !llvm.i32
+  %17 = shift_right_unsigned %arg2, %arg3 : i32
 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
-  %21 = std.sqrt %arg0 : f32
+  %18 = std.sqrt %arg0 : f32
 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
-  %22 = std.sqrt %arg4 : f64
+  %19 = std.sqrt %arg4 : f64
   return %0, %4 : f32, i32
 }
 
@@ -853,22 +849,6 @@ func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1)
 
 // -----
 
-module {
-  func @check_tanh_func_added_only_once_to_symbol_table(%f: f32, %lf: f64) -> () {
-    %f0 = std.tanh %f : f32
-    %f1 = std.tanh %f0 : f32
-    %lf0 = std.tanh %lf : f64
-    %lf1 = std.tanh %lf0 : f64
-    return
-  }
-// CHECK: module {
-// CHECK: llvm.func @tanh(!llvm.double) -> !llvm.double
-// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
-// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
-}
-
-// -----
-
 // CHECK-LABEL: func @atomic_rmw
 func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
   atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32