[ROCm] Adding pass to lower GPU Dialect to ROCDL Dialect.
authorDeven Desai <36858332+deven-amd@users.noreply.github.com>
Wed, 2 Oct 2019 08:50:03 +0000 (01:50 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 2 Oct 2019 08:50:30 +0000 (01:50 -0700)
This is a follow-up to the PRtensorflow/mlir#146 which introduced the ROCDL Dialect. This PR introduces a pass to lower GPU Dialect to the ROCDL Dialect. As with the previous PR, this one builds on the work done by @whchung, and addresses most of the review comments in the original PR.

Closes tensorflow/mlir#154

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/154 from deven-amd:deven-lower-gpu-to-rocdl 809893e08236da5ab6a38e3459692fa04247773d
PiperOrigin-RevId: 272390729

mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h [new file with mode: 0644]
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp [new file with mode: 0644]
mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt

diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
new file mode 100644 (file)
index 0000000..54cda41
--- /dev/null
@@ -0,0 +1,32 @@
+//===- GPUToROCDLPass.h - Convert GPU kernel to ROCDL dialect ---*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
+#define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
+
+#include <memory>
+
+namespace mlir {
+
+class ModuleOp;
+template <typename OpT> class OpPassBase;
+
+/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
+std::unique_ptr<OpPassBase<ModuleOp>> createLowerGpuOpsToROCDLOpsPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
index 6c14f54..d907335 100644 (file)
@@ -2,6 +2,7 @@ add_subdirectory(LoopsToGPU)
 add_subdirectory(ControlFlowToCFG)
 add_subdirectory(GPUToCUDA)
 add_subdirectory(GPUToNVVM)
+add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
new file mode 100644 (file)
index 0000000..3c97e5c
--- /dev/null
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRGPUtoROCDLTransforms
+  LowerGpuOpsToROCDLOps.cpp
+  )
+target_link_libraries(MLIRGPUtoROCDLTransforms
+  LLVMSupport
+  MLIRGPU
+  MLIRLLVMIR
+  MLIRROCDLIR
+  MLIRPass
+  )
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
new file mode 100644 (file)
index 0000000..1cff75d
--- /dev/null
@@ -0,0 +1,148 @@
+//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to generate ROCDLIR operations for higher-level
+// GPU operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/ADT/StringSwitch.h"
+
+using namespace mlir;
+
+namespace {
+
+// Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
+// that Op operates on.  Op is assumed to return an `std.index` value and
+// XOp, YOp and ZOp are assumed to return an `llvm.i32` value.  Depending on
+// `indexBitwidth`, sign-extend or truncate the resulting value to match the
+// bitwidth expected by the consumers of the value.
+template <typename Op, typename XOp, typename YOp, typename ZOp>
+struct GPUIndexIntrinsicOpLowering : public LLVMOpLowering {
+private:
+  enum dimension { X = 0, Y = 1, Z = 2, invalid };
+  unsigned indexBitwidth;
+
+  static dimension dimensionToIndex(Op op) {
+    return llvm::StringSwitch<dimension>(op.dimension())
+        .Case("x", X)
+        .Case("y", Y)
+        .Case("z", Z)
+        .Default(invalid);
+  }
+
+  static unsigned getIndexBitWidth(LLVMTypeConverter &type_converter) {
+    auto dialect = type_converter.getDialect();
+    return dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
+  }
+
+public:
+  explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(Op::getOperationName(),
+                       lowering_.getDialect()->getContext(), lowering_),
+        indexBitwidth(getIndexBitWidth(lowering_)) {}
+
+  // Convert the kernel arguments to an LLVM type, preserve the rest.
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto dialect = lowering.getDialect();
+    Value *newOp;
+    switch (dimensionToIndex(cast<Op>(op))) {
+    case X:
+      newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      break;
+    case Y:
+      newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      break;
+    case Z:
+      newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+      break;
+    default:
+      return matchFailure();
+    }
+
+    if (indexBitwidth > 32) {
+      newOp = rewriter.create<LLVM::SExtOp>(
+          loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+    } else if (indexBitwidth < 32) {
+      newOp = rewriter.create<LLVM::TruncOp>(
+          loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+    }
+
+    rewriter.replaceOp(op, {newOp});
+    return matchSuccess();
+  }
+};
+
+// A pass that replaces all occurences of GPU device operations with their
+// corresponding ROCDL equivalent.
+//
+// This pass only handles device code and is not meant to be run on GPU host
+// code.
+class LowerGpuOpsToROCDLOpsPass : public ModulePass<LowerGpuOpsToROCDLOpsPass> {
+public:
+  void runOnModule() override {
+    ModuleOp m = getModule();
+    if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName()))
+      return;
+
+    OwningRewritePatternList patterns;
+    LLVMTypeConverter converter(m.getContext());
+    populateStdToLLVMConversionPatterns(converter, patterns);
+    patterns.insert<
+        GPUIndexIntrinsicOpLowering<gpu::ThreadId, ROCDL::ThreadIdXOp,
+                                    ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>,
+        GPUIndexIntrinsicOpLowering<gpu::BlockDim, ROCDL::BlockDimXOp,
+                                    ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
+        GPUIndexIntrinsicOpLowering<gpu::BlockId, ROCDL::BlockIdXOp,
+                                    ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>,
+        GPUIndexIntrinsicOpLowering<gpu::GridDim, ROCDL::GridDimXOp,
+                                    ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
+        converter);
+
+    ConversionTarget target(getContext());
+    target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
+    target.addDynamicallyLegalOp<FuncOp>(
+        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+    if (failed(applyPartialConversion(m, target, patterns, &converter)))
+      signalPassFailure();
+  }
+};
+
+} // anonymous namespace
+
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToROCDLOpsPass() {
+  return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
+}
+
+static PassRegistration<LowerGpuOpsToROCDLOpsPass>
+    pass("lower-gpu-ops-to-rocdl-ops",
+         "Generate ROCDL operations for gpu operations");
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
new file mode 100644 (file)
index 0000000..9857325
--- /dev/null
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -lower-gpu-ops-to-rocdl-ops | FileCheck %s
+
+module attributes {gpu.kernel_module} {
+  // CHECK-LABEL: func @gpu_index_ops()
+  func @gpu_index_ops()
+      attributes { gpu.kernel } {
+    // CHECK: rocdl.workitem.id.x : !llvm.i32
+    %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
+    // CHECK: rocdl.workitem.id.y : !llvm.i32
+    %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
+    // CHECK: rocdl.workitem.id.z : !llvm.i32
+    %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
+
+    // CHECK: rocdl.workgroup.dim.x : !llvm.i32
+    %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
+    // CHECK: rocdl.workgroup.dim.y : !llvm.i32
+    %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
+    // CHECK: rocdl.workgroup.dim.z : !llvm.i32
+    %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
+
+    // CHECK: rocdl.workgroup.id.x : !llvm.i32
+    %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
+    // CHECK: rocdl.workgroup.id.y : !llvm.i32
+    %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
+    // CHECK: rocdl.workgroup.id.z : !llvm.i32
+    %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
+
+    // CHECK: rocdl.grid.dim.x : !llvm.i32
+    %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
+    // CHECK: rocdl.grid.dim.y : !llvm.i32
+    %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
+    // CHECK: rocdl.grid.dim.z : !llvm.i32
+    %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
+
+    std.return
+  }
+}
index 196edd8..75f5cf7 100644 (file)
@@ -24,6 +24,7 @@ set(LIBS
   MLIRFxpMathOps
   MLIRGPU
   MLIRGPUtoNVVMTransforms
+  MLIRGPUtoROCDLTransforms
   MLIRGPUtoSPIRVTransforms
   MLIRLinalg
   MLIRLLVMIR