[mlir][SPIRV] Add lowering for math.log1p operation to SPIR-V dialect.
authorMaheshRavishankar <ravishankarm@google.com>
Thu, 3 Jun 2021 23:25:56 +0000 (16:25 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Thu, 3 Jun 2021 23:27:19 +0000 (16:27 -0700)
Differential Revision: https://reviews.llvm.org/D103635

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

index 3851bac..9a43214 100644 (file)
@@ -317,6 +317,28 @@ public:
   }
 };
 
+/// Converts math.log1p to SPIR-V ops.
+///
+/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
+/// these operations.
+class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
+public:
+  using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(math::Log1pOp operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(operands.size() == 1);
+    Location loc = operation.getLoc();
+    auto type =
+        this->getTypeConverter()->convertType(operation.operand().getType());
+    auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
+    auto onePlus = rewriter.create<spirv::FAddOp>(loc, one, operands[0]);
+    rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
+    return success();
+  }
+};
+
 /// Converts std.remi_signed to SPIR-V ops.
 ///
 /// This cannot be merged into the template unary/binary pattern due to
@@ -1347,7 +1369,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
       UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
       UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
-      SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
+      Log1pOpPattern, SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
 
       // Comparison patterns
       BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
index c9ce74b..e20062d 100644 (file)
@@ -53,6 +53,10 @@ func @float32_unary_scalar(%arg0: f32) {
   %3 = math.exp %arg0 : f32
   // CHECK: spv.GLSL.Log %{{.*}}: f32
   %4 = math.log %arg0 : f32
+  // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
+  // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
+  // CHECK: spv.GLSL.Log %[[ADDONE]]
+  %40 = math.log1p %arg0 : f32
   // CHECK: spv.FNegate %{{.*}}: f32
   %5 = negf %arg0 : f32
   // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32