[MLIR][SPIRVToLLVM] Conversion for inverse sqrt and tanh
authorGeorge Mitenkov <georgemitenk0v@gmail.com>
Thu, 30 Jul 2020 07:24:10 +0000 (10:24 +0300)
committerGeorge Mitenkov <georgemitenk0v@gmail.com>
Thu, 30 Jul 2020 07:50:48 +0000 (10:50 +0300)
This is a second patch on conversion of GLSL ops to LLVM dialect.
It introduces patterns to convert `spv.InverseSqrt` and `spv.Tanh`.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D84633

mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir

index 803b05a..58d160d 100644 (file)
@@ -91,6 +91,21 @@ static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
 }
 
+/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
+static Value createFPConstant(Location loc, Type srcType, Type dstType,
+                              PatternRewriter &rewriter, double value) {
+  if (auto vecType = srcType.dyn_cast<VectorType>()) {
+    auto floatType = vecType.getElementType().cast<FloatType>();
+    return rewriter.create<LLVM::ConstantOp>(
+        loc, dstType,
+        SplatElementsAttr::get(vecType,
+                               rewriter.getFloatAttr(floatType, value)));
+  }
+  auto floatType = srcType.cast<FloatType>();
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, dstType, rewriter.getFloatAttr(floatType, value));
+}
+
 /// Utility function for bitfiled ops:
 ///   - `BitFieldInsert`
 ///   - `BitFieldSExtract`
@@ -590,6 +605,27 @@ public:
   }
 };
 
+class InverseSqrtPattern
+    : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = op.getType();
+    auto dstType = typeConverter.convertType(srcType);
+    if (!dstType)
+      return failure();
+
+    Location loc = op.getLoc();
+    Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
+    Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
+    rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
+    return success();
+  }
+};
+
 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
 template <typename SPIRVop>
 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
@@ -821,6 +857,40 @@ public:
   }
 };
 
+/// Convert `spv.Tanh` to
+///
+///   exp(2x) - 1
+///   -----------
+///   exp(2x) + 1
+///
+class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = tanhOp.getType();
+    auto dstType = typeConverter.convertType(srcType);
+    if (!dstType)
+      return failure();
+
+    Location loc = tanhOp.getLoc();
+    Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
+    Value multiplied =
+        rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
+    Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
+    Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
+    Value numerator =
+        rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
+    Value denominator =
+        rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
+    rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
+                                              denominator);
+    return success();
+  }
+};
+
 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
 public:
   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
@@ -1052,7 +1122,8 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
       DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
       DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
-      DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, TanPattern,
+      DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
+      InverseSqrtPattern, TanPattern, TanhPattern,
 
       // Logical ops
       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
index 1907619..ab501b8 100644 (file)
@@ -103,3 +103,33 @@ func @tan(%arg0: f32) {
        %0 = spv.GLSL.Tan %arg0 : f32
        return
 }
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.Tanh
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @tanh
+func @tanh(%arg0: f32) {
+       // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float
+  // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : !llvm.float
+  // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%[[X2]]) : (!llvm.float) -> !llvm.float
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+  // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : !llvm.float
+  // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : !llvm.float
+  // CHECK: llvm.fdiv %[[T0]], %[[T1]] : !llvm.float
+       %0 = spv.GLSL.Tanh %arg0 : f32
+       return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.InverseSqrt
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inverse_sqrt
+func @inverse_sqrt(%arg0: f32) {
+       // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+  // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%{{.*}}) : (!llvm.float) -> !llvm.float
+       // CHECK: llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float
+       %0 = spv.GLSL.InverseSqrt %arg0 : f32
+       return
+}