[mlir][complex] Add pow/sqrt/tanh ops and lowering to libm
authorBenjamin Kramer <benny.kra@googlemail.com>
Fri, 13 May 2022 15:12:11 +0000 (17:12 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Wed, 18 May 2022 12:03:14 +0000 (14:03 +0200)
Lowering through libm gives us a baseline version, even though it's not
going to be particularly fast. This is similar to what we do for some
math dialect ops.

Differential Revision: https://reviews.llvm.org/D125550

mlir/include/mlir/Conversion/ComplexToLibm/ComplexToLibm.h [new file with mode: 0644]
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/ComplexToLibm/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp [new file with mode: 0644]
mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir [new file with mode: 0644]
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

diff --git a/mlir/include/mlir/Conversion/ComplexToLibm/ComplexToLibm.h b/mlir/include/mlir/Conversion/ComplexToLibm/ComplexToLibm.h
new file mode 100644 (file)
index 0000000..e86d8e2
--- /dev/null
@@ -0,0 +1,27 @@
+//===- ComplexToLibm.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_
+#define MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+template <typename T>
+class OperationPass;
+
+/// Populate the given list with patterns that convert from Complex to Libm
+/// calls.
+void populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit);
+
+/// Create a pass to convert Complex operations to libm calls.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertComplexToLibmPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_
index ba6b75e..5ffdc2d 100644 (file)
@@ -17,6 +17,7 @@
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
index d0122f9..41e7b29 100644 (file)
@@ -197,6 +197,21 @@ def ConvertComplexToLLVM : Pass<"convert-complex-to-llvm"> {
 }
 
 //===----------------------------------------------------------------------===//
+// ComplexToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
+  let summary = "Convert Complex dialect to libm calls";
+  let description = [{
+    This pass converts supported Complex ops to libm calls.
+  }];
+  let constructor = "mlir::createConvertComplexToLibmPass()";
+  let dependentDialects = [
+    "func::FuncDialect",
+  ];
+}
+
+//===----------------------------------------------------------------------===//
 // ComplexToStandard
 //===----------------------------------------------------------------------===//
 
index b215d0d..a87af6f 100644 (file)
@@ -347,6 +347,24 @@ def NotEqualOp : Complex_Op<"neq",
 }
 
 //===----------------------------------------------------------------------===//
+// PowOp
+//===----------------------------------------------------------------------===//
+
+def PowOp : ComplexArithmeticOp<"pow"> {
+  let summary = "complex power function";
+  let description = [{
+    The `sqrt` operation takes a complex number raises it to the given complex
+    exponent.
+
+    Example:
+
+    ```mlir
+    %a = complex.pow %b, %c : complex<f32>
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // ReOp
 //===----------------------------------------------------------------------===//
 
@@ -410,6 +428,25 @@ def SinOp : ComplexUnaryOp<"sin", [SameOperandsAndResultType]> {
 }
 
 //===----------------------------------------------------------------------===//
+// SqrtOp
+//===----------------------------------------------------------------------===//
+
+def SqrtOp : ComplexUnaryOp<"sqrt", [SameOperandsAndResultType]> {
+  let summary = "complex square root";
+  let description = [{
+    The `sqrt` operation takes a complex number and returns its square root.
+
+    Example:
+
+    ```mlir
+    %a = complex.sqrt %b : complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
+
+//===----------------------------------------------------------------------===//
 // SubOp
 //===----------------------------------------------------------------------===//
 
@@ -426,4 +463,24 @@ def SubOp : ComplexArithmeticOp<"sub"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TanhOp
+//===----------------------------------------------------------------------===//
+
+def TanhOp : ComplexUnaryOp<"tanh", [SameOperandsAndResultType]> {
+  let summary = "complex hyperbolic tangent";
+  let description = [{
+    The `tanh` operation takes a complex number and returns its hyperbolic
+    tangent.
+
+    Example:
+
+    ```mlir
+    %a = complex.tanh %b : complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
+
 #endif // COMPLEX_OPS
index d0c0083..218c89a 100644 (file)
@@ -6,6 +6,7 @@ add_subdirectory(ArmNeon2dToIntr)
 add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexToLLVM)
+add_subdirectory(ComplexToLibm)
 add_subdirectory(ComplexToStandard)
 add_subdirectory(ControlFlowToLLVM)
 add_subdirectory(ControlFlowToSPIRV)
diff --git a/mlir/lib/Conversion/ComplexToLibm/CMakeLists.txt b/mlir/lib/Conversion/ComplexToLibm/CMakeLists.txt
new file mode 100644 (file)
index 0000000..1941e21
--- /dev/null
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToLibm
+  ComplexToLibm.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToLibm
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRFunc
+  MLIRComplex
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
new file mode 100644 (file)
index 0000000..456c74d
--- /dev/null
@@ -0,0 +1,101 @@
+//===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+// Pattern to convert scalar complex operations to calls to libm functions.
+// Additionally the libm function signatures are declared.
+template <typename Op>
+struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
+public:
+  using OpRewritePattern<Op>::OpRewritePattern;
+  ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
+                         StringRef doubleFunc, PatternBenefit benefit)
+      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+        doubleFunc(doubleFunc){};
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+
+private:
+  std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+template <typename Op>
+LogicalResult
+ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
+                                        PatternRewriter &rewriter) const {
+  auto module = SymbolTable::getNearestSymbolTable(op);
+  auto type = op.getType().template cast<ComplexType>();
+  Type elementType = type.getElementType();
+  if (!elementType.isa<Float32Type, Float64Type>())
+    return failure();
+
+  auto name =
+      elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
+  auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+      SymbolTable::lookupSymbolIn(module, name));
+  // Forward declare function if it hasn't already been
+  if (!opFunc) {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+    auto opFunctionTy = FunctionType::get(
+        rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+    opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
+                                           opFunctionTy);
+    opFunc.setPrivate();
+  }
+  assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
+
+  rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands());
+
+  return success();
+}
+
+void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit) {
+  patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(),
+                                                   "cpowf", "cpow", benefit);
+  patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(),
+                                                    "csqrtf", "csqrt", benefit);
+  patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(),
+                                                    "ctanhf", "ctanh", benefit);
+}
+
+namespace {
+struct ConvertComplexToLibmPass
+    : public ConvertComplexToLibmBase<ConvertComplexToLibmPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToLibmPass::runOnOperation() {
+  auto module = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<func::FuncDialect>();
+  target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertComplexToLibmPass() {
+  return std::make_unique<ConvertComplexToLibmPass>();
+}
diff --git a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir
new file mode 100644 (file)
index 0000000..ca21f50
--- /dev/null
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s -convert-complex-to-libm -canonicalize | FileCheck %s
+
+// CHECK-DAG: @cpowf(complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-DAG: @cpow(complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-DAG: @csqrtf(complex<f32>) -> complex<f32>
+// CHECK-DAG: @csqrt(complex<f64>) -> complex<f64>
+// CHECK-DAG: @ctanhf(complex<f32>) -> complex<f32>
+// CHECK-DAG: @ctanh(complex<f64>) -> complex<f64>
+
+// CHECK-LABEL: func @cpow_caller
+// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
+// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
+func.func @cpow_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<f32>, complex<f64>)  {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cpowf(%[[FLOAT]], %[[FLOAT]]) : (complex<f32>, complex<f32>) -> complex<f32>
+  %float_result = complex.pow %float, %float : complex<f32>
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cpow(%[[DOUBLE]], %[[DOUBLE]]) : (complex<f64>, complex<f64>) -> complex<f64>
+  %double_result = complex.pow %double, %double : complex<f64>
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : complex<f32>, complex<f64>
+}
+
+// CHECK-LABEL: func @csqrt_caller
+// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
+// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
+func.func @csqrt_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<f32>, complex<f64>)  {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @csqrtf(%[[FLOAT]]) : (complex<f32>) -> complex<f32>
+  %float_result = complex.sqrt %float : complex<f32>
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @csqrt(%[[DOUBLE]]) : (complex<f64>) -> complex<f64>
+  %double_result = complex.sqrt %double : complex<f64>
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : complex<f32>, complex<f64>
+}
+
+// CHECK-LABEL: func @ctanh_caller
+// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
+// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
+func.func @ctanh_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<f32>, complex<f64>)  {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @ctanhf(%[[FLOAT]]) : (complex<f32>) -> complex<f32>
+  %float_result = complex.tanh %float : complex<f32>
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @ctanh(%[[DOUBLE]]) : (complex<f64>) -> complex<f64>
+  %double_result = complex.tanh %double : complex<f64>
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : complex<f32>, complex<f64>
+}
index 93b2187..c9f659b 100644 (file)
@@ -2457,6 +2457,7 @@ cc_library(
         ":AsyncToLLVM",
         ":BufferizationToMemRef",
         ":ComplexToLLVM",
+        ":ComplexToLibm",
         ":ComplexToStandard",
         ":ControlFlowToLLVM",
         ":ControlFlowToSPIRV",
@@ -4708,7 +4709,6 @@ cc_library(
     ],
 )
 
-
 cc_library(
     name = "TensorToSPIRV",
     srcs = glob([
@@ -6189,6 +6189,7 @@ cc_library(
         ":BufferizationTransforms",
         ":ComplexDialect",
         ":ComplexToLLVM",
+        ":ComplexToLibm",
         ":ControlFlowOps",
         ":ConversionPasses",
         ":DLTIDialect",
@@ -8062,6 +8063,30 @@ cc_library(
 )
 
 cc_library(
+    name = "ComplexToLibm",
+    srcs = glob([
+        "lib/Conversion/ComplexToLibm/*.cpp",
+        "lib/Conversion/ComplexToLibm/*.h",
+    ]) + [":ConversionPassDetail"],
+    hdrs = glob([
+        "include/mlir/Conversion/ComplexToLibm/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":ComplexDialect",
+        ":ConversionPassIncGen",
+        ":DialectUtils",
+        ":FuncDialect",
+        ":IR",
+        ":Pass",
+        ":Support",
+        ":Transforms",
+        "//llvm:Core",
+        "//llvm:Support",
+    ],
+)
+
+cc_library(
     name = "ComplexToStandard",
     srcs = glob([
         "lib/Conversion/ComplexToStandard/*.cpp",