/// depending on the element type that Op operates upon. The function
/// declaration is added in case it was not added before.
///
+/// If the input values are of f16 type, the value is first casted to f32, the
+/// function called and then the result casted back.
+///
/// Example with NVVM:
/// %exp_f32 = std.exp %arg_f32 : f32
///
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
- LLVMType resultType = typeConverter.convertType(op->getResult(0).getType())
- .template cast<LLVM::LLVMType>();
- LLVMType funcType = getFunctionType(resultType, operands);
- StringRef funcName = getFunctionName(resultType);
+ static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+ SourceOp>::value,
+ "expected op with same operand and result types");
+
+ SmallVector<Value, 1> castedOperands;
+ for (Value operand : operands)
+ castedOperands.push_back(maybeCast(operand, rewriter));
+
+ LLVMType resultType =
+ castedOperands.front().getType().cast<LLVM::LLVMType>();
+ LLVMType funcType = getFunctionType(resultType, castedOperands);
+ StringRef funcName = getFunctionName(funcType.getFunctionResultType());
if (funcName.empty())
return failure();
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp = rewriter.create<LLVM::CallOp>(
- op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
- rewriter.replaceOp(op, {callOp.getResult(0)});
+ op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
+ castedOperands);
+
+ if (resultType == operands.front().getType()) {
+ rewriter.replaceOp(op, {callOp.getResult(0)});
+ return success();
+ }
+
+ Value truncated = rewriter.create<LLVM::FPTruncOp>(
+ op->getLoc(), operands.front().getType(), callOp.getResult(0));
+ rewriter.replaceOp(op, {truncated});
return success();
}
private:
+ Value maybeCast(Value operand, PatternRewriter &rewriter) const {
+ LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
+ if (!type.isHalfTy())
+ return operand;
+
+ return rewriter.create<LLVM::FPExtOp>(
+ operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
+ operand);
+ }
+
LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
ArrayRef<Value> operands) const {
using LLVM::LLVMType;
// CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
// CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
// CHECK-LABEL: func @gpu_tanh
- func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = std.tanh %arg_f16 : f16
+ // CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
+ // CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
+ // CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
%result32 = std.tanh %arg_f32 : f32
// CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
%result64 = std.tanh %arg_f64 : f64
// CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
- std.return %result32, %result64 : f32, f64
+ std.return %result16, %result32, %result64 : f16, f32, f64
}
}