Add Conversion to lower loop::ForOp to spirv::LoopOp.
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 12 Nov 2019 19:32:54 +0000 (11:32 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 12 Nov 2019 19:33:27 +0000 (11:33 -0800)
loop::ForOp can be lowered to the structured control flow represented
by spirv::LoopOp by making the continue block of the spirv::LoopOp the
loop latch and the merge block the exit block. The resulting
spirv::LoopOp has a single back edge from the continue to header
block, and a single exit from header to merge.
PiperOrigin-RevId: 280015614

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/test/Conversion/GPUToSPIRV/loop.mlir [new file with mode: 0644]

index 4f73b9b..ece4269 100644 (file)
@@ -21,6 +21,7 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Pass/Pass.h"
@@ -29,6 +30,16 @@ using namespace mlir;
 
 namespace {
 
+/// Pattern to convert a loop::ForOp within kernel functions into spirv::LoopOp.
+class ForOpConversion final : public SPIRVOpLowering<loop::ForOp> {
+public:
+  using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
 /// builin variables.
 template <typename OpTy, spirv::BuiltIn builtin>
@@ -51,8 +62,79 @@ public:
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
+
 } // namespace
 
+PatternMatchResult
+ForOpConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                 ConversionPatternRewriter &rewriter) const {
+  // loop::ForOp can be lowered to the structured control flow represented by
+  // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
+  // latch and the merge block the exit block. The resulting spirv::LoopOp has a
+  // single back edge from the continue to header block, and a single exit from
+  // header to merge.
+  auto forOp = cast<loop::ForOp>(op);
+  loop::ForOpOperandAdaptor forOperands(operands);
+  auto loc = op->getLoc();
+  auto loopControl = rewriter.getI32IntegerAttr(
+      static_cast<uint32_t>(spirv::LoopControl::None));
+  auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
+  loopOp.addEntryAndMergeBlock();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  // Create the block for the header.
+  auto header = new Block();
+  // Insert the header.
+  loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
+
+  // Create the new induction variable to use.
+  BlockArgument *newIndVar =
+      header->addArgument(forOperands.lowerBound()->getType());
+  Block *body = forOp.getBody();
+
+  // Apply signature conversion to the body of the forOp. It has a single block,
+  // with argument which is the induction variable. That has to be replaced with
+  // the new induction variable.
+  TypeConverter::SignatureConversion signatureConverter(
+      body->getNumArguments());
+  signatureConverter.remapInput(0, newIndVar);
+  rewriter.applySignatureConversion(&forOp.getOperation()->getRegion(0),
+                                    signatureConverter);
+
+  // Delete the loop terminator.
+  rewriter.eraseOp(body->getTerminator());
+
+  // Move the blocks from the forOp into the loopOp. This is the body of the
+  // loopOp.
+  rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
+                              std::next(loopOp.body().begin(), 2));
+
+  // Branch into it from the entry.
+  rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
+  rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
+
+  // Generate the rest of the loop header.
+  rewriter.setInsertionPointToEnd(header);
+  auto mergeBlock = loopOp.getMergeBlock();
+  auto cmpOp = rewriter.create<spirv::SLessThanOp>(
+      loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
+  rewriter.create<spirv::BranchConditionalOp>(
+      loc, cmpOp, body, ArrayRef<Value *>(), mergeBlock, ArrayRef<Value *>());
+
+  // Generate instructions to increment the step of the induction variable and
+  // branch to the header.
+  Block *continueBlock = loopOp.getContinueBlock();
+  rewriter.setInsertionPointToEnd(continueBlock);
+
+  // Add the step to the induction variable and branch to the header.
+  Value *updatedIndVar = rewriter.create<spirv::IAddOp>(
+      loc, newIndVar->getType(), newIndVar, forOperands.step());
+  rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
+
+  rewriter.eraseOp(forOp);
+  return matchSuccess();
+}
+
 template <typename OpTy, spirv::BuiltIn builtin>
 PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
     Operation *op, ArrayRef<Value *> operands,
@@ -148,7 +230,7 @@ void GPUToSPIRVPass::runOnModule() {
   SPIRVTypeConverter typeConverter(&basicTypeConverter);
   OwningRewritePatternList patterns;
   patterns.insert<
-      KernelFnConversion,
+      ForOpConversion, KernelFnConversion,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
diff --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
new file mode 100644 (file)
index 0000000..870f4e3
--- /dev/null
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
+
+module attributes {gpu.container_module} {
+  func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) {
+    %c0 = constant 1 : index
+    "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "loop_kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, memref<10xf32>) -> ()
+    return
+  }
+
+  module @kernels attributes {gpu.kernel_module} {
+    func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>)
+    attributes {gpu.kernel} {
+      // CHECK: [[LB:%.*]] = spv.constant 4 : i32
+      %lb = constant 4 : index
+      // CHECK: [[UB:%.*]] = spv.constant 42 : i32
+      %ub = constant 42 : index
+      // CHECK: [[STEP:%.*]] = spv.constant 2 : i32
+      %step = constant 2 : index
+      // CHECK:      spv.loop {
+      // CHECK-NEXT:   spv.Branch [[HEADER:\^.*]]([[LB]] : i32)
+      // CHECK:      [[HEADER]]([[INDVAR:%.*]]: i32):
+      // CHECK:        [[CMP:%.*]] = spv.SLessThan [[INDVAR]], [[UB]] : i32
+      // CHECK:        spv.BranchConditional [[CMP]], [[BODY:\^.*]], [[MERGE:\^.*]]
+      // CHECK:      [[BODY]]:
+      // CHECK:        spv.AccessChain {{%.*}}{{\[}}[[INDVAR]]{{\]}} : {{.*}}
+      // CHECK:        spv.AccessChain {{%.*}}{{\[}}[[INDVAR]]{{\]}} : {{.*}}
+      // CHECK:        [[INCREMENT:%.*]] = spv.IAdd [[INDVAR]], [[STEP]] : i32
+      // CHECK:        spv.Branch [[HEADER]]([[INCREMENT]] : i32)
+      // CHECK:      [[MERGE]]
+      // CHECK:        spv._merge
+      // CHECK:      }
+      loop.for %arg4 = %lb to %ub step %step {
+        %1 = load %arg2[%arg4] : memref<10xf32>
+        store %1, %arg3[%arg4] : memref<10xf32>
+      }
+      return
+    }
+  }
+}
\ No newline at end of file