From: Lei Zhang Date: Tue, 22 Oct 2019 00:31:32 +0000 (-0700) Subject: [spirv] Allow block arguments on spv.Branch(Conditional) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d9fe892e4228909908cf9b0f6dd80d9b56c060b4;p=platform%2Fupstream%2Fllvm.git [spirv] Allow block arguments on spv.Branch(Conditional) We will use block arguments as the way to model SPIR-V OpPhi in the SPIR-V dialect. This CL also adds a few useful helper methods to both ops to get the block arguments. Also added tests for branch weight (de)serialization. PiperOrigin-RevId: 275960797 --- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 9ab12e8..e4dd5b7 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -46,23 +46,29 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> { ``` {.ebnf} branch-op ::= `spv.Branch` successor + successor ::= bb-id branch-use-list? + branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)` ``` For example: ``` spv.Branch ^target + spv.Branch ^target(%0, %1: i32, f32) ``` }]; - let arguments = (ins); + let arguments = (ins + Variadic:$block_arguments + ); let results = (outs); let builders = [ OpBuilder< - "Builder *, OperationState &state, Block *successor", [{ - state.addSuccessor(successor, {}); + "Builder *, OperationState &state, " + "Block *successor, ArrayRef arguments = {}", [{ + state.addSuccessor(successor, arguments); }] > ]; @@ -70,7 +76,13 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> { let skipDefaultBuilders = 1; let extraClassDeclaration = [{ + /// Returns the branch target block. Block *getTarget() { return getOperation()->getSuccessor(0); } + + /// Returns the block arguments. + operand_range getBlockArguments() { + return getOperation()->getSuccessorOperands(0); + } }]; let autogenSerialization = 0; @@ -105,17 +117,21 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> { branch-conditional-op ::= `spv.BranchConditional` ssa-use (`[` integer-literal, integer-literal `]`)? `,` successor `,` successor + successor ::= bb-id branch-use-list? + branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)` ``` For example: ``` spv.BranchConditional %condition, ^true_branch, ^false_branch + spv.BranchConditional %condition, ^true_branch(%0: i32), ^false_branch(%1: i32) ``` }]; let arguments = (ins SPV_Bool:$condition, + Variadic:$branch_arguments, OptionalAttr:$branch_weights ); @@ -124,12 +140,13 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> { let builders = [ OpBuilder< "Builder *builder, OperationState &state, Value *condition, " - "Block *trueBranch, Block *falseBranch, " - "Optional> weights", + "Block *trueBlock, ArrayRef trueArguments, " + "Block *falseBlock, ArrayRef falseArguments, " + "Optional> weights = {}", [{ state.addOperands(condition); - state.addSuccessor(trueBranch, {}); - state.addSuccessor(falseBranch, {}); + state.addSuccessor(trueBlock, trueArguments); + state.addSuccessor(falseBlock, falseArguments); if (weights) { auto attr = builder->getI32ArrayAttr({static_cast(weights->first), @@ -145,12 +162,57 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> { let autogenSerialization = 0; let extraClassDeclaration = [{ - // Branch indices into the successor list. + /// Branch indices into the successor list. enum { kTrueIndex = 0, kFalseIndex = 1 }; + /// Returns the target block for the true branch. Block *getTrueBlock() { return getOperation()->getSuccessor(kTrueIndex); } + /// Returns the target block for the false branch. Block *getFalseBlock() { return getOperation()->getSuccessor(kFalseIndex); } + + /// Returns the number of arguments to the true target block. + unsigned getNumTrueBlockArguments() { + return getNumSuccessorOperands(kTrueIndex); + } + + /// Returns the number of arguments to the false target block. + unsigned getNumFalseBlockArguments() { + return getNumSuccessorOperands(kFalseIndex); + } + + // Iterator and range support for true target block arguments. + operand_iterator true_block_argument_begin() { + return operand_begin() + getTrueBlockArgumentIndex(); + } + operand_iterator true_block_argument_end() { + return true_block_argument_begin() + getNumTrueBlockArguments(); + } + operand_range getTrueBlockArguments() { + return {true_block_argument_begin(), true_block_argument_end()}; + } + + // Iterator and range support for false target block arguments. + operand_iterator false_block_argument_begin() { + return true_block_argument_end(); + } + operand_iterator false_block_argument_end() { + return false_block_argument_begin() + getNumFalseBlockArguments(); + } + operand_range getFalseBlockArguments() { + return {false_block_argument_begin(), false_block_argument_end()}; + } + + private: + /// Gets the index of the first true block argument in the operand list. + unsigned getTrueBlockArgumentIndex() { + return 1; // Omit the first argument, which is the condition. + } + + /// Gets the index of the first false block argument in the operand list. + unsigned getFalseBlockArgumentIndex() { + return getTrueBlockArgumentIndex() + getNumTrueBlockArguments(); + } }]; } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 8e7673b..ac9e469 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1489,8 +1489,10 @@ Deserializer::processBranchConditional(ArrayRef operands) { weights = std::make_pair(operands[3], operands[4]); } - opBuilder.create(unknownLoc, condition, trueBlock, - falseBlock, weights); + opBuilder.create( + unknownLoc, condition, trueBlock, + /*trueArguments=*/ArrayRef(), falseBlock, + /*falseArguments=*/ArrayRef(), weights); return success(); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir index a23500d..88d9a14 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -22,8 +22,8 @@ spv.module "Logical" "GLSL450" { %val0 = spv.Load "Function" %var : i32 // CHECK-NEXT: spv.SLessThan %cmp = spv.SLessThan %val0, %count : i32 -// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb2, ^bb4 - spv.BranchConditional %cmp, ^body, ^merge +// CHECK-NEXT: spv.BranchConditional %{{.*}} [1, 1], ^bb2, ^bb4 + spv.BranchConditional %cmp [1, 1], ^body, ^merge // CHECK-NEXT: ^bb2: ^body: diff --git a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir index 31797e7..676cd17 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir @@ -15,8 +15,8 @@ spv.module "Logical" "GLSL450" { // CHECK-NEXT: spv.constant 0 // CHECK-NEXT: spv.Variable spv.selection { -// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2 - spv.BranchConditional %cond, ^then, ^else +// CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2 + spv.BranchConditional %cond [5, 10], ^then, ^else // CHECK-NEXT: ^bb1: ^then: diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir index c9b1d22..11377ed 100644 --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -13,6 +13,16 @@ func @branch() -> () { // ----- +func @branch_argument() -> () { + %zero = spv.constant 0 : i32 + // CHECK: spv.Branch ^bb1(%{{.*}}, %{{.*}} : i32, i32) + spv.Branch ^next(%zero, %zero: i32, i32) +^next(%arg0: i32, %arg1: i32): + spv.Return +} + +// ----- + func @missing_accessor() -> () { spv.Branch // expected-error @+1 {{expected block name}} @@ -32,16 +42,6 @@ func @wrong_accessor_count() -> () { // ----- -func @accessor_argument_disallowed() -> () { - %zero = spv.constant 0 : i32 - // expected-error @+1 {{requires zero operands}} - "spv.Branch"()[^next(%zero : i32)] : () -> () -^next(%arg: i32): - spv.Return -} - -// ----- - //===----------------------------------------------------------------------===// // spv.BranchConditional //===----------------------------------------------------------------------===// @@ -60,6 +60,24 @@ func @cond_branch() -> () { // ----- +func @cond_branch_argument() -> () { + %true = spv.constant true + %zero = spv.constant 0 : i32 + // CHECK: spv.BranchConditional %{{.*}}, ^bb1(%{{.*}}, %{{.*}} : i32, i32), ^bb2 + spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1 +^true1(%arg0: i32, %arg1: i32): + // CHECK: spv.BranchConditional %{{.*}}, ^bb3, ^bb4(%{{.*}}, %{{.*}} : i32, i32) + spv.BranchConditional %true, ^true2, ^false2(%zero, %zero: i32, i32) +^false1: + spv.Return +^true2: + spv.Return +^false2(%arg3: i32, %arg4: i32): + spv.Return +} + +// ----- + func @cond_branch_with_weights() -> () { %true = spv.constant true // CHECK: spv.BranchConditional %{{.*}} [5, 10] @@ -108,18 +126,6 @@ func @wrong_accessor_count() -> () { // ----- -func @accessor_argument_disallowed() -> () { - %true = spv.constant true - // expected-error @+1 {{requires a single operand}} - "spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> () -^one(%arg : i1): - spv.Return -^two: - spv.Return -} - -// ----- - func @wrong_number_of_weights() -> () { %true = spv.constant true // expected-error @+1 {{must have exactly two branch weights}}