Add custom lowering of ExpOp for NVVM and ROCM.
authorAlexander Belyaev <pifon@google.com>
Thu, 24 Oct 2019 08:41:25 +0000 (01:41 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 24 Oct 2019 08:41:57 +0000 (01:41 -0700)
PiperOrigin-RevId: 276440911

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h [new file with mode: 0644]
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

index 9b8df74..67fccec 100644 (file)
@@ -66,6 +66,7 @@ public:
 
   /// Utilities to identify types.
   bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
+  bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); }
   bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
   bool isIntegerTy(unsigned bitwidth) {
     return getUnderlyingType()->isIntegerTy(bitwidth);
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
new file mode 100644 (file)
index 0000000..0622dc6
--- /dev/null
@@ -0,0 +1,103 @@
+//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
+#define THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+
+namespace mlir {
+
+template <typename SourceOp>
+struct OpToFuncCallLowering : public LLVMOpLowering {
+public:
+  explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func,
+                                StringRef f64Func)
+      : LLVMOpLowering(SourceOp::getOperationName(),
+                       lowering_.getDialect()->getContext(), lowering_),
+        f32Func(f32Func), f64Func(f64Func) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    using LLVM::LLVMFuncOp;
+    using LLVM::LLVMType;
+
+    static_assert(
+        std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
+        "expected single result op");
+
+    LLVMType resultType = lowering.convertType(op->getResult(0)->getType())
+                              .template cast<LLVM::LLVMType>();
+    LLVMType funcType = getFunctionType(resultType, operands);
+    const std::string funcName = getFunctionName(resultType);
+    if (funcName.empty()) {
+      return matchFailure();
+    }
+
+    LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
+    auto callOp = rewriter.create<LLVM::CallOp>(
+        op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
+    rewriter.replaceOp(op, {callOp.getResult(0)});
+    return matchSuccess();
+  }
+
+private:
+  LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
+                                 ArrayRef<Value *> operands) const {
+    using LLVM::LLVMType;
+    SmallVector<LLVMType, 1> operandTypes;
+    for (Value *operand : operands) {
+      operandTypes.push_back(operand->getType().cast<LLVMType>());
+    }
+    return LLVMType::getFunctionTy(resultType, operandTypes,
+                                   /*isVarArg=*/false);
+  }
+
+  StringRef getFunctionName(LLVM::LLVMType type) const {
+    if (type.isFloatTy())
+      return f32Func;
+    if (type.isDoubleTy())
+      return f64Func;
+    return "";
+  }
+
+  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName,
+                                     LLVM::LLVMType funcType,
+                                     Operation *op) const {
+    using LLVM::LLVMFuncOp;
+
+    LLVMFuncOp funcOp =
+        op->getParentOfType<ModuleOp>().lookupSymbol<LLVMFuncOp>(funcName);
+    if (funcOp)
+      return funcOp;
+
+    mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
+    return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType, llvm::None);
+  }
+
+  const std::string f32Func;
+  const std::string f64Func;
+};
+
+} // namespace mlir
+
+#endif // THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
index b051514..e053ab4 100644 (file)
@@ -30,6 +30,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 
 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
 
 using namespace mlir;
 
@@ -479,9 +480,11 @@ public:
         GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
                                     NVVM::GridDimYOp, NVVM::GridDimZOp>,
         GPUAllReduceOpLowering>(converter);
-
+    patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
+                                                 "__nv_exp");
     ConversionTarget target(getContext());
     target.addIllegalDialect<gpu::GPUDialect>();
+    target.addIllegalOp<LLVM::ExpOp>();
     target.addLegalDialect<LLVM::LLVMDialect>();
     target.addLegalDialect<NVVM::NVVMDialect>();
     // TODO(csigg): Remove once we support replacing non-root ops.
index 2ea587e..ea421f5 100644 (file)
@@ -29,6 +29,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 
 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
 
 using namespace mlir;
 
@@ -59,9 +60,12 @@ public:
         GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp,
                                     ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
         converter);
+    patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "_ocml_exp_f32",
+                                                 "_ocml_exp_f64");
 
     ConversionTarget target(getContext());
     target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
+    target.addIllegalOp<LLVM::ExpOp>();
     target.addDynamicallyLegalOp<FuncOp>(
         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
     if (failed(applyPartialConversion(m, target, patterns, &converter)))