[spirv] Allow block arguments on spv.Branch(Conditional)
authorLei Zhang <antiagainst@google.com>
Tue, 22 Oct 2019 00:31:32 +0000 (17:31 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 22 Oct 2019 00:32:00 +0000 (17:32 -0700)
We will use block arguments as the way to model SPIR-V OpPhi in
the SPIR-V dialect.

This CL also adds a few useful helper methods to both ops to
get the block arguments.

Also added tests for branch weight (de)serialization.

PiperOrigin-RevId: 275960797

mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/test/Dialect/SPIRV/Serialization/loop.mlir
mlir/test/Dialect/SPIRV/Serialization/selection.mlir
mlir/test/Dialect/SPIRV/control-flow-ops.mlir

index 9ab12e8..e4dd5b7 100644 (file)
@@ -46,23 +46,29 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
 
     ``` {.ebnf}
     branch-op ::= `spv.Branch` successor
+    successor ::= bb-id branch-use-list?
+    branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
     ```
 
     For example:
 
     ```
     spv.Branch ^target
+    spv.Branch ^target(%0, %1: i32, f32)
     ```
   }];
 
-  let arguments = (ins);
+  let arguments = (ins
+    Variadic<AnyType>:$block_arguments
+  );
 
   let results = (outs);
 
   let builders = [
     OpBuilder<
-      "Builder *, OperationState &state, Block *successor", [{
-        state.addSuccessor(successor, {});
+      "Builder *, OperationState &state, "
+      "Block *successor, ArrayRef<Value *> arguments = {}", [{
+        state.addSuccessor(successor, arguments);
       }]
     >
   ];
@@ -70,7 +76,13 @@ def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
   let skipDefaultBuilders = 1;
 
   let extraClassDeclaration = [{
+    /// Returns the branch target block.
     Block *getTarget() { return getOperation()->getSuccessor(0); }
+
+    /// Returns the block arguments.
+    operand_range getBlockArguments() {
+      return getOperation()->getSuccessorOperands(0);
+    }
   }];
 
   let autogenSerialization = 0;
@@ -105,17 +117,21 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
     branch-conditional-op ::= `spv.BranchConditional` ssa-use
                               (`[` integer-literal, integer-literal `]`)?
                               `,` successor `,` successor
+    successor ::= bb-id branch-use-list?
+    branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
     ```
 
     For example:
 
     ```
     spv.BranchConditional %condition, ^true_branch, ^false_branch
+    spv.BranchConditional %condition, ^true_branch(%0: i32), ^false_branch(%1: i32)
     ```
   }];
 
   let arguments = (ins
     SPV_Bool:$condition,
+    Variadic<AnyType>:$branch_arguments,
     OptionalAttr<I32ArrayAttr>:$branch_weights
   );
 
@@ -124,12 +140,13 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
   let builders = [
     OpBuilder<
       "Builder *builder, OperationState &state, Value *condition, "
-      "Block *trueBranch, Block *falseBranch, "
-      "Optional<std::pair<uint32_t, uint32_t>> weights",
+      "Block *trueBlock, ArrayRef<Value *> trueArguments, "
+      "Block *falseBlock, ArrayRef<Value *> falseArguments, "
+      "Optional<std::pair<uint32_t, uint32_t>> weights = {}",
       [{
         state.addOperands(condition);
-        state.addSuccessor(trueBranch, {});
-        state.addSuccessor(falseBranch, {});
+        state.addSuccessor(trueBlock, trueArguments);
+        state.addSuccessor(falseBlock, falseArguments);
         if (weights) {
           auto attr =
               builder->getI32ArrayAttr({static_cast<int32_t>(weights->first),
@@ -145,12 +162,57 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
   let autogenSerialization = 0;
 
   let extraClassDeclaration = [{
-    // Branch indices into the successor list.
+    /// Branch indices into the successor list.
     enum { kTrueIndex = 0, kFalseIndex = 1 };
 
+    /// Returns the target block for the true branch.
     Block *getTrueBlock() { return getOperation()->getSuccessor(kTrueIndex); }
 
+    /// Returns the target block for the false branch.
     Block *getFalseBlock() { return getOperation()->getSuccessor(kFalseIndex); }
+
+    /// Returns the number of arguments to the true target block.
+    unsigned getNumTrueBlockArguments() {
+      return getNumSuccessorOperands(kTrueIndex);
+    }
+
+    /// Returns the number of arguments to the false target block.
+    unsigned getNumFalseBlockArguments() {
+      return getNumSuccessorOperands(kFalseIndex);
+    }
+
+    // Iterator and range support for true target block arguments.
+    operand_iterator true_block_argument_begin() {
+      return operand_begin() + getTrueBlockArgumentIndex();
+    }
+    operand_iterator true_block_argument_end() {
+      return true_block_argument_begin() + getNumTrueBlockArguments();
+    }
+    operand_range getTrueBlockArguments() {
+      return {true_block_argument_begin(), true_block_argument_end()};
+    }
+
+    // Iterator and range support for false target block arguments.
+    operand_iterator false_block_argument_begin() {
+      return true_block_argument_end();
+    }
+    operand_iterator false_block_argument_end() {
+      return false_block_argument_begin() + getNumFalseBlockArguments();
+    }
+    operand_range getFalseBlockArguments() {
+      return {false_block_argument_begin(), false_block_argument_end()};
+    }
+
+  private:
+    /// Gets the index of the first true block argument in the operand list.
+    unsigned getTrueBlockArgumentIndex() {
+      return 1; // Omit the first argument, which is the condition.
+    }
+
+    /// Gets the index of the first false block argument in the operand list.
+    unsigned getFalseBlockArgumentIndex() {
+      return getTrueBlockArgumentIndex() + getNumTrueBlockArguments();
+    }
   }];
 }
 
index 8e7673b..ac9e469 100644 (file)
@@ -1489,8 +1489,10 @@ Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
     weights = std::make_pair(operands[3], operands[4]);
   }
 
-  opBuilder.create<spirv::BranchConditionalOp>(unknownLoc, condition, trueBlock,
-                                               falseBlock, weights);
+  opBuilder.create<spirv::BranchConditionalOp>(
+      unknownLoc, condition, trueBlock,
+      /*trueArguments=*/ArrayRef<Value *>(), falseBlock,
+      /*falseArguments=*/ArrayRef<Value *>(), weights);
 
   return success();
 }
index a23500d..88d9a14 100644 (file)
@@ -22,8 +22,8 @@ spv.module "Logical" "GLSL450" {
       %val0 = spv.Load "Function" %var : i32
 // CHECK-NEXT:     spv.SLessThan
       %cmp = spv.SLessThan %val0, %count : i32
-// CHECK-NEXT:     spv.BranchConditional %{{.*}}, ^bb2, ^bb4
-      spv.BranchConditional %cmp, ^body, ^merge
+// CHECK-NEXT:     spv.BranchConditional %{{.*}} [1, 1], ^bb2, ^bb4
+      spv.BranchConditional %cmp [1, 1], ^body, ^merge
 
 // CHECK-NEXT:   ^bb2:
     ^body:
index 31797e7..676cd17 100644 (file)
@@ -15,8 +15,8 @@ spv.module "Logical" "GLSL450" {
 // CHECK-NEXT:     spv.constant 0
 // CHECK-NEXT:     spv.Variable
     spv.selection {
-// CHECK-NEXT: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
-      spv.BranchConditional %cond, ^then, ^else
+// CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2
+      spv.BranchConditional %cond [5, 10], ^then, ^else
 
 // CHECK-NEXT:   ^bb1:
     ^then:
index c9b1d22..11377ed 100644 (file)
@@ -13,6 +13,16 @@ func @branch() -> () {
 
 // -----
 
+func @branch_argument() -> () {
+  %zero = spv.constant 0 : i32
+  // CHECK: spv.Branch ^bb1(%{{.*}}, %{{.*}} : i32, i32)
+  spv.Branch ^next(%zero, %zero: i32, i32)
+^next(%arg0: i32, %arg1: i32):
+  spv.Return
+}
+
+// -----
+
 func @missing_accessor() -> () {
   spv.Branch
   // expected-error @+1 {{expected block name}}
@@ -32,16 +42,6 @@ func @wrong_accessor_count() -> () {
 
 // -----
 
-func @accessor_argument_disallowed() -> () {
-  %zero = spv.constant 0 : i32
-  // expected-error @+1 {{requires zero operands}}
-  "spv.Branch"()[^next(%zero : i32)] : () -> ()
-^next(%arg: i32):
-  spv.Return
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // spv.BranchConditional
 //===----------------------------------------------------------------------===//
@@ -60,6 +60,24 @@ func @cond_branch() -> () {
 
 // -----
 
+func @cond_branch_argument() -> () {
+  %true = spv.constant true
+  %zero = spv.constant 0 : i32
+  // CHECK: spv.BranchConditional %{{.*}}, ^bb1(%{{.*}}, %{{.*}} : i32, i32), ^bb2
+  spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
+^true1(%arg0: i32, %arg1: i32):
+  // CHECK: spv.BranchConditional %{{.*}}, ^bb3, ^bb4(%{{.*}}, %{{.*}} : i32, i32)
+  spv.BranchConditional %true, ^true2, ^false2(%zero, %zero: i32, i32)
+^false1:
+  spv.Return
+^true2:
+  spv.Return
+^false2(%arg3: i32, %arg4: i32):
+  spv.Return
+}
+
+// -----
+
 func @cond_branch_with_weights() -> () {
   %true = spv.constant true
   // CHECK: spv.BranchConditional %{{.*}} [5, 10]
@@ -108,18 +126,6 @@ func @wrong_accessor_count() -> () {
 
 // -----
 
-func @accessor_argument_disallowed() -> () {
-  %true = spv.constant true
-  // expected-error @+1 {{requires a single operand}}
-  "spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> ()
-^one(%arg : i1):
-  spv.Return
-^two:
-  spv.Return
-}
-
-// -----
-
 func @wrong_number_of_weights() -> () {
   %true = spv.constant true
   // expected-error @+1 {{must have exactly two branch weights}}