--- /dev/null
+//===- ComplexToLibm.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_
+#define MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+template <typename T>
+class OperationPass;
+
+/// Populate the given list with patterns that convert from Complex to Libm
+/// calls.
+void populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit);
+
+/// Create a pass to convert Complex operations to libm calls.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertComplexToLibmPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOLIBM_COMPLEXTOLIBM_H_
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
}
//===----------------------------------------------------------------------===//
+// ComplexToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
+ let summary = "Convert Complex dialect to libm calls";
+ let description = [{
+ This pass converts supported Complex ops to libm calls.
+ }];
+ let constructor = "mlir::createConvertComplexToLibmPass()";
+ let dependentDialects = [
+ "func::FuncDialect",
+ ];
+}
+
+//===----------------------------------------------------------------------===//
// ComplexToStandard
//===----------------------------------------------------------------------===//
}
//===----------------------------------------------------------------------===//
+// PowOp
+//===----------------------------------------------------------------------===//
+
+def PowOp : ComplexArithmeticOp<"pow"> {
+ let summary = "complex power function";
+ let description = [{
+ The `sqrt` operation takes a complex number raises it to the given complex
+ exponent.
+
+ Example:
+
+ ```mlir
+ %a = complex.pow %b, %c : complex<f32>
+ ```
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
}
//===----------------------------------------------------------------------===//
+// SqrtOp
+//===----------------------------------------------------------------------===//
+
+def SqrtOp : ComplexUnaryOp<"sqrt", [SameOperandsAndResultType]> {
+ let summary = "complex square root";
+ let description = [{
+ The `sqrt` operation takes a complex number and returns its square root.
+
+ Example:
+
+ ```mlir
+ %a = complex.sqrt %b : complex<f32>
+ ```
+ }];
+
+ let results = (outs Complex<AnyFloat>:$result);
+}
+
+//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
}];
}
+//===----------------------------------------------------------------------===//
+// TanhOp
+//===----------------------------------------------------------------------===//
+
+def TanhOp : ComplexUnaryOp<"tanh", [SameOperandsAndResultType]> {
+ let summary = "complex hyperbolic tangent";
+ let description = [{
+ The `tanh` operation takes a complex number and returns its hyperbolic
+ tangent.
+
+ Example:
+
+ ```mlir
+ %a = complex.tanh %b : complex<f32>
+ ```
+ }];
+
+ let results = (outs Complex<AnyFloat>:$result);
+}
+
#endif // COMPLEX_OPS
add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLLVM)
+add_subdirectory(ComplexToLibm)
add_subdirectory(ComplexToStandard)
add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSPIRV)
--- /dev/null
+add_mlir_conversion_library(MLIRComplexToLibm
+ ComplexToLibm.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToLibm
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFunc
+ MLIRComplex
+ MLIRTransformUtils
+ )
--- /dev/null
+//===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+// Pattern to convert scalar complex operations to calls to libm functions.
+// Additionally the libm function signatures are declared.
+template <typename Op>
+struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
+public:
+ using OpRewritePattern<Op>::OpRewritePattern;
+ ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
+ StringRef doubleFunc, PatternBenefit benefit)
+ : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+ doubleFunc(doubleFunc){};
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+
+private:
+ std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+template <typename Op>
+LogicalResult
+ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const {
+ auto module = SymbolTable::getNearestSymbolTable(op);
+ auto type = op.getType().template cast<ComplexType>();
+ Type elementType = type.getElementType();
+ if (!elementType.isa<Float32Type, Float64Type>())
+ return failure();
+
+ auto name =
+ elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
+ auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+ SymbolTable::lookupSymbolIn(module, name));
+ // Forward declare function if it hasn't already been
+ if (!opFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+ auto opFunctionTy = FunctionType::get(
+ rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+ opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
+ opFunctionTy);
+ opFunc.setPrivate();
+ }
+ assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
+
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands());
+
+ return success();
+}
+
+void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(),
+ "cpowf", "cpow", benefit);
+ patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(),
+ "csqrtf", "csqrt", benefit);
+ patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(),
+ "ctanhf", "ctanh", benefit);
+}
+
+namespace {
+struct ConvertComplexToLibmPass
+ : public ConvertComplexToLibmBase<ConvertComplexToLibmPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToLibmPass::runOnOperation() {
+ auto module = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<func::FuncDialect>();
+ target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>();
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertComplexToLibmPass() {
+ return std::make_unique<ConvertComplexToLibmPass>();
+}
--- /dev/null
+// RUN: mlir-opt %s -convert-complex-to-libm -canonicalize | FileCheck %s
+
+// CHECK-DAG: @cpowf(complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-DAG: @cpow(complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-DAG: @csqrtf(complex<f32>) -> complex<f32>
+// CHECK-DAG: @csqrt(complex<f64>) -> complex<f64>
+// CHECK-DAG: @ctanhf(complex<f32>) -> complex<f32>
+// CHECK-DAG: @ctanh(complex<f64>) -> complex<f64>
+
+// CHECK-LABEL: func @cpow_caller
+// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
+// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
+func.func @cpow_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cpowf(%[[FLOAT]], %[[FLOAT]]) : (complex<f32>, complex<f32>) -> complex<f32>
+ %float_result = complex.pow %float, %float : complex<f32>
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cpow(%[[DOUBLE]], %[[DOUBLE]]) : (complex<f64>, complex<f64>) -> complex<f64>
+ %double_result = complex.pow %double, %double : complex<f64>
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : complex<f32>, complex<f64>
+}
+
+// CHECK-LABEL: func @csqrt_caller
+// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
+// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
+func.func @csqrt_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @csqrtf(%[[FLOAT]]) : (complex<f32>) -> complex<f32>
+ %float_result = complex.sqrt %float : complex<f32>
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @csqrt(%[[DOUBLE]]) : (complex<f64>) -> complex<f64>
+ %double_result = complex.sqrt %double : complex<f64>
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : complex<f32>, complex<f64>
+}
+
+// CHECK-LABEL: func @ctanh_caller
+// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
+// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
+func.func @ctanh_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @ctanhf(%[[FLOAT]]) : (complex<f32>) -> complex<f32>
+ %float_result = complex.tanh %float : complex<f32>
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @ctanh(%[[DOUBLE]]) : (complex<f64>) -> complex<f64>
+ %double_result = complex.tanh %double : complex<f64>
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : complex<f32>, complex<f64>
+}
":AsyncToLLVM",
":BufferizationToMemRef",
":ComplexToLLVM",
+ ":ComplexToLibm",
":ComplexToStandard",
":ControlFlowToLLVM",
":ControlFlowToSPIRV",
],
)
-
cc_library(
name = "TensorToSPIRV",
srcs = glob([
":BufferizationTransforms",
":ComplexDialect",
":ComplexToLLVM",
+ ":ComplexToLibm",
":ControlFlowOps",
":ConversionPasses",
":DLTIDialect",
)
cc_library(
+ name = "ComplexToLibm",
+ srcs = glob([
+ "lib/Conversion/ComplexToLibm/*.cpp",
+ "lib/Conversion/ComplexToLibm/*.h",
+ ]) + [":ConversionPassDetail"],
+ hdrs = glob([
+ "include/mlir/Conversion/ComplexToLibm/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":ComplexDialect",
+ ":ConversionPassIncGen",
+ ":DialectUtils",
+ ":FuncDialect",
+ ":IR",
+ ":Pass",
+ ":Support",
+ ":Transforms",
+ "//llvm:Core",
+ "//llvm:Support",
+ ],
+)
+
+cc_library(
name = "ComplexToStandard",
srcs = glob([
"lib/Conversion/ComplexToStandard/*.cpp",