[mlir][gpu] Add support for f16 when lowering to nvvm intrinsics
authorStephan Herhut <herhut@google.com>
Tue, 9 Jun 2020 15:20:53 +0000 (17:20 +0200)
committerStephan Herhut <herhut@google.com>
Tue, 9 Jun 2020 17:33:45 +0000 (19:33 +0200)
Summary:
The NVVM target only provides implementations for tanh etc. on f32 and
f64 operands. To also support f16, we now insert operations to extend to f32
and truncate back to f16 around the intrinsic call.

Differential Revision: https://reviews.llvm.org/D81473

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

index c7bbb6d..58b5f1d 100644 (file)
@@ -20,6 +20,9 @@ namespace mlir {
 /// depending on the element type that Op operates upon. The function
 /// declaration is added in case it was not added before.
 ///
+/// If the input values are of f16 type, the value is first casted to f32, the
+/// function called and then the result casted back.
+///
 /// Example with NVVM:
 ///   %exp_f32 = std.exp %arg_f32 : f32
 ///
@@ -44,21 +47,48 @@ public:
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
 
-    LLVMType resultType = typeConverter.convertType(op->getResult(0).getType())
-                              .template cast<LLVM::LLVMType>();
-    LLVMType funcType = getFunctionType(resultType, operands);
-    StringRef funcName = getFunctionName(resultType);
+    static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+                                  SourceOp>::value,
+                  "expected op with same operand and result types");
+
+    SmallVector<Value, 1> castedOperands;
+    for (Value operand : operands)
+      castedOperands.push_back(maybeCast(operand, rewriter));
+
+    LLVMType resultType =
+        castedOperands.front().getType().cast<LLVM::LLVMType>();
+    LLVMType funcType = getFunctionType(resultType, castedOperands);
+    StringRef funcName = getFunctionName(funcType.getFunctionResultType());
     if (funcName.empty())
       return failure();
 
     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
     auto callOp = rewriter.create<LLVM::CallOp>(
-        op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
-    rewriter.replaceOp(op, {callOp.getResult(0)});
+        op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
+        castedOperands);
+
+    if (resultType == operands.front().getType()) {
+      rewriter.replaceOp(op, {callOp.getResult(0)});
+      return success();
+    }
+
+    Value truncated = rewriter.create<LLVM::FPTruncOp>(
+        op->getLoc(), operands.front().getType(), callOp.getResult(0));
+    rewriter.replaceOp(op, {truncated});
     return success();
   }
 
 private:
+  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
+    LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
+    if (!type.isHalfTy())
+      return operand;
+
+    return rewriter.create<LLVM::FPExtOp>(
+        operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
+        operand);
+  }
+
   LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
                                  ArrayRef<Value> operands) const {
     using LLVM::LLVMType;
index f05c9af..925615c 100644 (file)
@@ -219,12 +219,16 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_tanh
-  func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = std.tanh %arg_f16 : f16
+    // CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
+    // CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
+    // CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
     %result32 = std.tanh %arg_f32 : f32
     // CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.tanh %arg_f64 : f64
     // CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return %result32, %result64 : f32, f64
+    std.return %result16, %result32, %result64 : f16, f32, f64
   }
 }