loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
}
+/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
+static Value createFPConstant(Location loc, Type srcType, Type dstType,
+ PatternRewriter &rewriter, double value) {
+ if (auto vecType = srcType.dyn_cast<VectorType>()) {
+ auto floatType = vecType.getElementType().cast<FloatType>();
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, dstType,
+ SplatElementsAttr::get(vecType,
+ rewriter.getFloatAttr(floatType, value)));
+ }
+ auto floatType = srcType.cast<FloatType>();
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, dstType, rewriter.getFloatAttr(floatType, value));
+}
+
/// Utility function for bitfiled ops:
/// - `BitFieldInsert`
/// - `BitFieldSExtract`
}
};
+class InverseSqrtPattern
+ : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = op.getType();
+ auto dstType = typeConverter.convertType(srcType);
+ if (!dstType)
+ return failure();
+
+ Location loc = op.getLoc();
+ Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
+ Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
+ rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
+ return success();
+ }
+};
+
/// Converts `spv.Load` and `spv.Store` to LLVM dialect.
template <typename SPIRVop>
class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
}
};
+/// Convert `spv.Tanh` to
+///
+/// exp(2x) - 1
+/// -----------
+/// exp(2x) + 1
+///
+class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = tanhOp.getType();
+ auto dstType = typeConverter.convertType(srcType);
+ if (!dstType)
+ return failure();
+
+ Location loc = tanhOp.getLoc();
+ Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
+ Value multiplied =
+ rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
+ Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
+ Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
+ Value numerator =
+ rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
+ Value denominator =
+ rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
+ rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
+ denominator);
+ return success();
+ }
+};
+
class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
public:
using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
- DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, TanPattern,
+ DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
+ InverseSqrtPattern, TanPattern, TanhPattern,
// Logical ops
DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
%0 = spv.GLSL.Tan %arg0 : f32
return
}
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.Tanh
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @tanh
+func @tanh(%arg0: f32) {
+ // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float
+ // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : !llvm.float
+ // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%[[X2]]) : (!llvm.float) -> !llvm.float
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+ // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : !llvm.float
+ // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : !llvm.float
+ // CHECK: llvm.fdiv %[[T0]], %[[T1]] : !llvm.float
+ %0 = spv.GLSL.Tanh %arg0 : f32
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.InverseSqrt
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inverse_sqrt
+func @inverse_sqrt(%arg0: f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%{{.*}}) : (!llvm.float) -> !llvm.float
+ // CHECK: llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float
+ %0 = spv.GLSL.InverseSqrt %arg0 : f32
+ return
+}