From b259c26eb0717b6962c2f75ac23a44491b27810d Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 12 Nov 2019 11:59:34 -0800 Subject: [PATCH] Add support for OpPhi in loop header block During deserialization, the loop header block will be moved into the spv.loop's region. If the loop header block has block arguments, we need to make sure it is correctly carried over to the block where the new spv.loop resides. During serialization, we need to make sure block arguments from the spv.loop's entry block are not silently dropped. PiperOrigin-RevId: 280021777 --- .../mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 5 +++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 12 ++++++ .../Dialect/SPIRV/Serialization/Deserializer.cpp | 37 +++++++++++++---- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 11 ++++++ mlir/test/Dialect/SPIRV/Serialization/loop.mlir | 46 +++++++++++++++++++++- 5 files changed, 102 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 8de2aeb..1e41fa0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -300,7 +300,12 @@ def SPV_LoopOp : SPV_Op<"loop", [InFunctionScope]> { let regions = (region AnyRegion:$body); + let builders = [OpBuilder<"Builder *builder, OperationState &state">]; + let extraClassDeclaration = [{ + // Returns the entry block. + Block *getEntryBlock(); + // Returns the loop header block. Block *getHeaderBlock(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 3c1563e..9d76e56 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1442,6 +1442,13 @@ static LogicalResult verify(spirv::LoadOp loadOp) { // spv.loop //===----------------------------------------------------------------------===// +void spirv::LoopOp::build(Builder *builder, OperationState &state) { + state.addAttribute("loop_control", + builder->getI32IntegerAttr( + static_cast(spirv::LoopControl::None))); + state.addRegion(); +} + static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) { // TODO(antiagainst): support loop control properly Builder builder = parser.getBuilder(); @@ -1557,6 +1564,11 @@ static LogicalResult verify(spirv::LoopOp loopOp) { return success(); } +Block *spirv::LoopOp::getEntryBlock() { + assert(!body().empty() && "op region should not be empty!"); + return &body().front(); +} + Block *spirv::LoopOp::getHeaderBlock() { assert(!body().empty() && "op region should not be empty!"); // The second block is the loop header block. diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 40b5318..11509bb 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1700,9 +1700,8 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp() { // merge block so that the newly created LoopOp will be inserted there. OpBuilder builder(&mergeBlock->front()); - auto control = builder.getI32IntegerAttr( - static_cast(spirv::LoopControl::None)); - auto loopOp = builder.create(location, control); + // TODO(antiagainst): handle loop control properly + auto loopOp = builder.create(location); loopOp.addEntryAndMergeBlock(); return loopOp; @@ -1810,10 +1809,25 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { headerBlock->replaceAllUsesWith(mergeBlock); if (isLoop) { + // The loop selection/loop header block may have block arguments. Since now + // we place the selection/loop op inside the old merge block, we need to + // make sure the old merge block has the same block argument list. + assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); + for (BlockArgument *blockArg : headerBlock->getArguments()) { + mergeBlock->addArgument(blockArg->getType()); + } + + // If the loop header block has block arguments, make sure the spv.branch op + // matches. + SmallVector blockArgs; + if (!headerBlock->args_empty()) + blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; + // The loop entry block should have a unconditional branch jumping to the // loop header block. builder.setInsertionPointToEnd(&body.front()); - builder.create(location, mapper.lookupOrNull(headerBlock)); + builder.create(location, mapper.lookupOrNull(headerBlock), + ArrayRef(blockArgs)); } // All the blocks cloned into the SelectionOp/LoopOp's region can now be @@ -1901,16 +1915,23 @@ LogicalResult Deserializer::structurizeControlFlow() { for (const auto &info : blockMergeInfo) { auto *headerBlock = info.first; - LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n"); + LLVM_DEBUG(headerBlock->print(llvm::dbgs())); const auto &mergeInfo = info.second; + auto *mergeBlock = mergeInfo.mergeBlock; - auto *continueBlock = mergeInfo.continueBlock; assert(mergeBlock && "merge block cannot be nullptr"); - LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << "\n"); + if (!mergeBlock->args_empty()) + return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); + LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n"); + LLVM_DEBUG(mergeBlock->print(llvm::dbgs())); + + auto *continueBlock = mergeInfo.continueBlock; if (continueBlock) { LLVM_DEBUG(llvm::dbgs() - << "[cf] continue block " << continueBlock << "\n"); + << "[cf] continue block " << continueBlock << ":\n"); + LLVM_DEBUG(continueBlock->print(llvm::dbgs())); } if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 805a339..0213420 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1515,6 +1515,17 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // afterwards. encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID}); + // We omit the LoopOp's entry block and start serialization from the loop + // header block. The entry block should not contain any additional ops other + // than a single spv.Branch that jumps to the loop header block. However, + // the spv.Branch can contain additional block arguments. Those block + // arguments must come from out of the loop using implicit capture. We will + // need to query the for the value sent and the for the incoming + // parent block. For the latter, we need to make sure this block is + // registered. The value sent should come from the block this loop resides in. + blockIDMap[loopOp.getEntryBlock()] = + getBlockID(loopOp.getOperation()->getBlock()); + // Emit the loop header block, which dominates all other blocks, first. We // need to emit an OpLoopMerge instruction before the loop header block's // terminator. diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir index 88d9a14..e89708f 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s +// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s // Single loop @@ -61,6 +61,50 @@ spv.module "Logical" "GLSL450" { // ----- +spv.module "Logical" "GLSL450" { + spv.globalVariable @GV1 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @GV2 bind(0, 1) : !spv.ptr [0]>, StorageBuffer> + func @loop_kernel() { + %0 = spv._address_of @GV1 : !spv.ptr [0]>, StorageBuffer> + %1 = spv.constant 0 : i32 + %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer> + %3 = spv._address_of @GV2 : !spv.ptr [0]>, StorageBuffer> + %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer> + %6 = spv.constant 4 : i32 + %7 = spv.constant 42 : i32 + %8 = spv.constant 2 : i32 +// CHECK: spv.Branch ^bb1(%{{.*}} : i32) +// CHECK-NEXT: ^bb1(%[[OUTARG:.*]]: i32): +// CHECK-NEXT: spv.loop { + spv.loop { +// CHECK-NEXT: spv.Branch ^bb1(%[[OUTARG]] : i32) + spv.Branch ^header(%6 : i32) +// CHECK-NEXT: ^bb1(%[[HEADARG:.*]]: i32): + ^header(%9: i32): + %10 = spv.SLessThan %9, %7 : i32 +// CHECK: spv.BranchConditional %{{.*}}, ^bb2, ^bb3 + spv.BranchConditional %10, ^body, ^merge +// CHECK-NEXT: ^bb2: // pred: ^bb1 + ^body: + %11 = spv.AccessChain %2[%9] : !spv.ptr, StorageBuffer> + %12 = spv.Load "StorageBuffer" %11 : f32 + %13 = spv.AccessChain %5[%9] : !spv.ptr, StorageBuffer> + spv.Store "StorageBuffer" %13, %12 : f32 +// CHECK: %[[ADD:.*]] = spv.IAdd + %14 = spv.IAdd %9, %8 : i32 +// CHECK-NEXT: spv.Branch ^bb1(%[[ADD]] : i32) + spv.Branch ^header(%14 : i32) +// CHECK-NEXT: ^bb3: + ^merge: +// CHECK-NEXT: spv._merge + spv._merge + } + spv.Return + } + spv.EntryPoint "GLCompute" @loop_kernel + spv.ExecutionMode @loop_kernel "LocalSize", 1, 1, 1 +} attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]} + // TODO(antiagainst): re-enable this after fixing the assertion failure. // Nested loop -- 2.7.4