From d2ce435dba36ca2575073d9d51e34b77ffad2e27 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 24 Oct 2019 01:41:25 -0700 Subject: [PATCH] Add custom lowering of ExpOp for NVVM and ROCM. PiperOrigin-RevId: 276440911 --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 1 + .../Conversion/GPUCommon/OpToFuncCallLowering.h | 103 +++++++++++++++++++++ .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 5 +- .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 4 + 4 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 9b8df74..67fccec 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -66,6 +66,7 @@ public: /// Utilities to identify types. bool isFloatTy() { return getUnderlyingType()->isFloatTy(); } + bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); } bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); } bool isIntegerTy(unsigned bitwidth) { return getUnderlyingType()->isIntegerTy(bitwidth); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h new file mode 100644 index 0000000..0622dc6 --- /dev/null +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -0,0 +1,103 @@ +//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ +#define THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Module.h" + +namespace mlir { + +template +struct OpToFuncCallLowering : public LLVMOpLowering { +public: + explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, + StringRef f64Func) + : LLVMOpLowering(SourceOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), + f32Func(f32Func), f64Func(f64Func) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + using LLVM::LLVMFuncOp; + using LLVM::LLVMType; + + static_assert( + std::is_base_of, SourceOp>::value, + "expected single result op"); + + LLVMType resultType = lowering.convertType(op->getResult(0)->getType()) + .template cast(); + LLVMType funcType = getFunctionType(resultType, operands); + const std::string funcName = getFunctionName(resultType); + if (funcName.empty()) { + return matchFailure(); + } + + LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); + auto callOp = rewriter.create( + op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands); + rewriter.replaceOp(op, {callOp.getResult(0)}); + return matchSuccess(); + } + +private: + LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType, + ArrayRef operands) const { + using LLVM::LLVMType; + SmallVector operandTypes; + for (Value *operand : operands) { + operandTypes.push_back(operand->getType().cast()); + } + return LLVMType::getFunctionTy(resultType, operandTypes, + /*isVarArg=*/false); + } + + StringRef getFunctionName(LLVM::LLVMType type) const { + if (type.isFloatTy()) + return f32Func; + if (type.isDoubleTy()) + return f64Func; + return ""; + } + + LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, + LLVM::LLVMType funcType, + Operation *op) const { + using LLVM::LLVMFuncOp; + + LLVMFuncOp funcOp = + op->getParentOfType().lookupSymbol(funcName); + if (funcOp) + return funcOp; + + mlir::OpBuilder b(op->getParentOfType()); + return b.create(op->getLoc(), funcName, funcType, llvm::None); + } + + const std::string f32Func; + const std::string f64Func; +}; + +} // namespace mlir + +#endif // THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index b051514..e053ab4 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -30,6 +30,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "../GPUCommon/IndexIntrinsicsOpLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" using namespace mlir; @@ -479,9 +480,11 @@ public: GPUIndexIntrinsicOpLowering, GPUAllReduceOpLowering>(converter); - + patterns.insert>(converter, "__nv_expf", + "__nv_exp"); ConversionTarget target(getContext()); target.addIllegalDialect(); + target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); // TODO(csigg): Remove once we support replacing non-root ops. diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 2ea587e..ea421f5 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -29,6 +29,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "../GPUCommon/IndexIntrinsicsOpLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" using namespace mlir; @@ -59,9 +60,12 @@ public: GPUIndexIntrinsicOpLowering>( converter); + patterns.insert>(converter, "_ocml_exp_f32", + "_ocml_exp_f64"); ConversionTarget target(getContext()); target.addLegalDialect(); + target.addIllegalOp(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed(applyPartialConversion(m, target, patterns, &converter))) -- 2.7.4