Port mlir-cuda-runner to use dialect conversion framework.
authorStephan Herhut <herhut@google.com>
Wed, 28 Aug 2019 08:50:28 +0000 (01:50 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Aug 2019 08:50:57 +0000 (01:50 -0700)
Instead of lowering the program in two steps (Standard->LLVM followed
by GPU->NVVM), leading to invalid IR inbetween, the runner now uses
one pattern based rewrite step to go directly from Standard+GPU to
LLVM+NVVM.

PiperOrigin-RevId: 265861934

mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp

index f1c8601..35f2314 100644 (file)
 #include <memory>
 
 namespace mlir {
-struct FunctionPassBase;
+class LLVMTypeConverter;
+class ModulePassBase;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to convert from the GPU dialect to NVVM.
+void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
+                                         OwningRewritePatternList &patterns);
 
 /// Creates a pass that lowers GPU dialect operations to NVVM counterparts.
-std::unique_ptr<FunctionPassBase> createLowerGpuOpsToNVVMOpsPass();
+std::unique_ptr<ModulePassBase> createLowerGpuOpsToNVVMOpsPass();
 
 } // namespace mlir
 
index 4c39794..0328cf4 100644 (file)
@@ -35,6 +35,8 @@ namespace NVVM {
 class NVVMDialect : public Dialect {
 public:
   explicit NVVMDialect(MLIRContext *context);
+
+  static StringRef getDialectNamespace() { return "nvvm"; }
 };
 
 } // namespace NVVM
index 3ba3e43..ed7ebfb 100644 (file)
@@ -20,6 +20,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 #include "llvm/ADT/StringSwitch.h"
 
-namespace mlir {
+using namespace mlir;
+
 namespace {
 
-// A pass that replaces all occurences of GPU operations with their
-// corresponding NVVM equivalent.
-//
-// This pass does not handle launching of kernels. Instead, it is meant to be
-// used on the body region of a launch or the body region of a kernel
-// function.
-class LowerGpuOpsToNVVMOpsPass : public FunctionPass<LowerGpuOpsToNVVMOpsPass> {
+// Rewriting that replaces the types of a LaunchFunc operation with their
+// LLVM counterparts.
+struct GPULaunchFuncOpLowering : public LLVMOpLowering {
+public:
+  explicit GPULaunchFuncOpLowering(LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(gpu::LaunchFuncOp::getOperationName(),
+                       lowering_.getDialect()->getContext(), lowering_) {}
+
+  // Convert the kernel arguments to an LLVM type, preserve the rest.
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.clone(*op)->setOperands(operands);
+    return rewriter.replaceOp(op, llvm::None), matchSuccess();
+  }
+};
+
+// Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
+// that Op operates on.  Op is assumed to return an `std.index` value and
+// XOp, YOp and ZOp are assumed to return an `llvm.i32` value.  Depending on
+// `indexBitwidth`, sign-extend or truncate the resulting value to match the
+// bitwidth expected by the consumers of the value.
+template <typename Op, typename XOp, typename YOp, typename ZOp>
+struct GPUIndexIntrinsicOpLowering : public LLVMOpLowering {
 private:
   enum dimension { X = 0, Y = 1, Z = 2, invalid };
+  unsigned indexBitwidth;
 
-  template <typename T> dimension dimensionToIndex(T op) {
+  static dimension dimensionToIndex(Op op) {
     return llvm::StringSwitch<dimension>(op.dimension())
         .Case("x", X)
         .Case("y", Y)
@@ -51,89 +74,98 @@ private:
         .Default(invalid);
   }
 
-  // Helper that replaces Op with XOp, YOp, or ZOp dependeing on the dimension
-  // that Op operates on.  Op is assumed to return an `std.index` value and
-  // XOp, YOp and ZOp are assumed to return an `llvm.i32` value.  Depending on
-  // `indexBitwidth`, sign-extend or truncate the resulting value to match the
-  // bitwidth expected by the consumers of the value.
-  template <typename XOp, typename YOp, typename ZOp, class Op>
-  void replaceWithIntrinsic(Op operation, LLVM::LLVMDialect *dialect,
-                            unsigned indexBitwidth) {
-    assert(operation.getType().isIndex() &&
-           "expected an operation returning index");
-    OpBuilder builder(operation);
-    auto loc = operation.getLoc();
+  static unsigned getIndexBitWidth(LLVMTypeConverter &lowering) {
+    auto dialect = lowering.getDialect();
+    return dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
+  }
+
+public:
+  explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(Op::getOperationName(),
+                       lowering_.getDialect()->getContext(), lowering_),
+        indexBitwidth(getIndexBitWidth(lowering_)) {}
+
+  // Convert the kernel arguments to an LLVM type, preserve the rest.
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto dialect = lowering.getDialect();
     Value *newOp;
-    switch (dimensionToIndex(operation)) {
+    switch (dimensionToIndex(cast<Op>(op))) {
     case X:
-      newOp = builder.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
       break;
     case Y:
-      newOp = builder.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
       break;
     case Z:
-      newOp = builder.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
       break;
     default:
-      operation.emitError("Illegal dimension: " + operation.dimension());
-      signalPassFailure();
-      return;
+      return matchFailure();
     }
 
     if (indexBitwidth > 32) {
-      newOp = builder.create<LLVM::SExtOp>(
+      newOp = rewriter.create<LLVM::SExtOp>(
           loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
     } else if (indexBitwidth < 32) {
-      newOp = builder.create<LLVM::TruncOp>(
+      newOp = rewriter.create<LLVM::TruncOp>(
           loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
     }
-    operation.replaceAllUsesWith(newOp);
-    operation.erase();
+
+    rewriter.replaceOp(op, {newOp});
+    return matchSuccess();
   }
+};
 
+// A pass that replaces all occurences of GPU operations with their
+// corresponding NVVM equivalent.
+//
+// This pass does not handle launching of kernels. Instead, it is meant to be
+// used on the body region of a launch or the body region of a kernel
+// function.
+class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> {
 public:
-  void runOnFunction() {
-    LLVM::LLVMDialect *llvmDialect =
-        getContext().getRegisteredDialect<LLVM::LLVMDialect>();
-    unsigned indexBitwidth =
-        llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
-    getFunction().walk([&](Operation *opInst) {
-      if (auto threadId = dyn_cast<gpu::ThreadId>(opInst)) {
-        replaceWithIntrinsic<NVVM::ThreadIdXOp, NVVM::ThreadIdYOp,
-                             NVVM::ThreadIdZOp>(threadId, llvmDialect,
-                                                indexBitwidth);
-        return;
-      }
-      if (auto blockDim = dyn_cast<gpu::BlockDim>(opInst)) {
-        replaceWithIntrinsic<NVVM::BlockDimXOp, NVVM::BlockDimYOp,
-                             NVVM::BlockDimZOp>(blockDim, llvmDialect,
-                                                indexBitwidth);
-        return;
-      }
-      if (auto blockId = dyn_cast<gpu::BlockId>(opInst)) {
-        replaceWithIntrinsic<NVVM::BlockIdXOp, NVVM::BlockIdYOp,
-                             NVVM::BlockIdZOp>(blockId, llvmDialect,
-                                               indexBitwidth);
-        return;
-      }
-      if (auto gridDim = dyn_cast<gpu::GridDim>(opInst)) {
-        replaceWithIntrinsic<NVVM::GridDimXOp, NVVM::GridDimYOp,
-                             NVVM::GridDimZOp>(gridDim, llvmDialect,
-                                               indexBitwidth);
-        return;
-      }
-    });
+  void runOnModule() override {
+    ModuleOp m = getModule();
+
+    OwningRewritePatternList patterns;
+    LLVMTypeConverter converter(m.getContext());
+    populateGpuToNVVMConversionPatterns(converter, patterns);
+
+    ConversionTarget target(getContext());
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    target.addLegalDialect<NVVM::NVVMDialect>();
+    target.addDynamicallyLegalOp<FuncOp>(
+        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+    if (failed(applyPartialConversion(m, target, patterns, &converter)))
+      signalPassFailure();
   }
 };
 
 } // anonymous namespace
 
-std::unique_ptr<FunctionPassBase> createLowerGpuOpsToNVVMOpsPass() {
+/// Collect a set of patterns to convert from the GPU dialect to NVVM.
+void mlir::populateGpuToNVVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  patterns
+      .insert<GPULaunchFuncOpLowering,
+              GPUIndexIntrinsicOpLowering<gpu::ThreadId, NVVM::ThreadIdXOp,
+                                          NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
+              GPUIndexIntrinsicOpLowering<gpu::BlockDim, NVVM::BlockDimXOp,
+                                          NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
+              GPUIndexIntrinsicOpLowering<gpu::BlockId, NVVM::BlockIdXOp,
+                                          NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
+              GPUIndexIntrinsicOpLowering<gpu::GridDim, NVVM::GridDimXOp,
+                                          NVVM::GridDimYOp, NVVM::GridDimZOp>>(
+          converter);
+}
+
+std::unique_ptr<ModulePassBase> mlir::createLowerGpuOpsToNVVMOpsPass() {
   return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
 }
 
 static PassRegistration<LowerGpuOpsToNVVMOpsPass>
     pass("lower-gpu-ops-to-nvvm-ops",
          "Generate NVVM operations for gpu operations");
-
-} // namespace mlir
index 9bd9222..797b7bb 100644 (file)
@@ -30,6 +30,7 @@
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Pass/Pass.h"
@@ -108,34 +109,39 @@ OwnedCubin compilePtxToCubin(const std::string ptx, FuncOp &function) {
 }
 
 namespace {
-struct GPULaunchFuncOpLowering : public LLVMOpLowering {
+// A pass that lowers all Standard and Gpu operations to LLVM dialect. It does
+// not lower the GPULaunch operation to actual code but dows translate the
+// signature of its kernel argument.
+class LowerStandardAndGpuToLLVMAndNVVM
+    : public ModulePass<LowerStandardAndGpuToLLVMAndNVVM> {
 public:
-  explicit GPULaunchFuncOpLowering(LLVMTypeConverter &lowering_)
-      : LLVMOpLowering(gpu::LaunchFuncOp::getOperationName(),
-                       lowering_.getDialect()->getContext(), lowering_) {}
-
-  // Convert the kernel arguments to an LLVM type, preserve the rest.
-  PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.clone(*op)->setOperands(operands);
-    return rewriter.replaceOp(op, llvm::None), matchSuccess();
+  void runOnModule() override {
+    ModuleOp m = getModule();
+
+    OwningRewritePatternList patterns;
+    LLVMTypeConverter converter(m.getContext());
+    populateStdToLLVMConversionPatterns(converter, patterns);
+    populateGpuToNVVMConversionPatterns(converter, patterns);
+
+    ConversionTarget target(getContext());
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    target.addLegalDialect<NVVM::NVVMDialect>();
+    target.addLegalOp<ModuleOp>();
+    target.addLegalOp<ModuleTerminatorOp>();
+    target.addDynamicallyLegalOp<FuncOp>(
+        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+    if (failed(applyFullConversion(m, target, patterns, &converter)))
+      signalPassFailure();
   }
 };
 } // end anonymous namespace
 
 static LogicalResult runMLIRPasses(ModuleOp m) {
-  // As we gradually lower, the IR is inconsistent between passes. So do not
-  // verify inbetween.
-  PassManager pm(/*verifyPasses=*/false);
+  PassManager pm;
 
   pm.addPass(createGpuKernelOutliningPass());
-  pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
-                                          OwningRewritePatternList &patterns) {
-    populateStdToLLVMConversionPatterns(converter, patterns);
-    patterns.insert<GPULaunchFuncOpLowering>(converter);
-  }));
-  pm.addPass(createLowerGpuOpsToNVVMOpsPass());
+  pm.addPass(static_cast<std::unique_ptr<ModulePassBase>>(
+      std::make_unique<LowerStandardAndGpuToLLVMAndNVVM>()));
   pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
   pm.addPass(createGenerateCubinAccessorPass());
   pm.addPass(createConvertGpuLaunchFuncToCudaCallsPass());
@@ -143,9 +149,6 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
   if (failed(pm.run(m)))
     return failure();
 
-  if (failed(m.verify()))
-    return failure();
-
   return success();
 }