[spirv] Allow return ops to be in control flow ops
authorLei Zhang <antiagainst@google.com>
Sat, 5 Oct 2019 03:08:05 +0000 (20:08 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 5 Oct 2019 03:08:52 +0000 (20:08 -0700)
Use `getParentOfType<FunctionOp>()` instead of `cast<FuncOp>(getParentOp())`
to avoid crash when return ops are used inside spv.selection/spv.loop.

PiperOrigin-RevId: 273006041

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/control-flow-ops.mlir

index 800771c..0b558bf 100644 (file)
@@ -1745,7 +1745,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(spirv::ReturnOp returnOp) {
-  auto funcOp = cast<FuncOp>(returnOp.getParentOp());
+  auto funcOp = returnOp.getParentOfType<FuncOp>();
   auto numOutputs = funcOp.getType().getNumResults();
   if (numOutputs != 0)
     return returnOp.emitOpError("cannot be used in functions returning value")
@@ -1774,7 +1774,7 @@ static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
 }
 
 static LogicalResult verify(spirv::ReturnValueOp retValOp) {
-  auto funcOp = cast<FuncOp>(retValOp.getParentOp());
+  auto funcOp = retValOp.getParentOfType<FuncOp>();
   auto numFnResults = funcOp.getType().getNumResults();
   if (numFnResults != 1)
     return retValOp.emitOpError(
index f8dfa47..8c5b4c4 100644 (file)
@@ -459,6 +459,38 @@ func @only_allowed_in_last_block() -> () {
 // spv.Return
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: func @in_selection
+func @in_selection(%cond : i1) -> () {
+  spv.selection {
+    spv.BranchConditional %cond, ^then, ^merge
+  ^then:
+    // CHECK: spv.Return
+    spv.Return
+  ^merge:
+    spv._merge
+  }
+  spv.Return
+}
+
+// CHECK-LABEL: func @in_loop
+func @in_loop(%cond : i1) -> () {
+  spv.loop {
+    spv.Branch ^header
+  ^header:
+    spv.BranchConditional %cond, ^body, ^merge
+  ^body:
+    // CHECK: spv.Return
+    spv.Return
+  ^continue:
+    spv.Branch ^header
+  ^merge:
+    spv._merge
+  }
+  spv.Return
+}
+
+// -----
+
 "foo.function"() ({
   // expected-error @+1 {{op must appear in a 'func' block}}
   spv.Return
@@ -486,6 +518,40 @@ func @ret_val() -> (i32) {
   spv.ReturnValue %0 : i32
 }
 
+// CHECK-LABEL: func @in_selection
+func @in_selection(%cond : i1) -> (i32) {
+  spv.selection {
+    spv.BranchConditional %cond, ^then, ^merge
+  ^then:
+    %zero = spv.constant 0 : i32
+    // CHECK: spv.ReturnValue
+    spv.ReturnValue %zero : i32
+  ^merge:
+    spv._merge
+  }
+  %one = spv.constant 1 : i32
+  spv.ReturnValue %one : i32
+}
+
+// CHECK-LABEL: func @in_loop
+func @in_loop(%cond : i1) -> (i32) {
+  spv.loop {
+    spv.Branch ^header
+  ^header:
+    spv.BranchConditional %cond, ^body, ^merge
+  ^body:
+    %zero = spv.constant 0 : i32
+    // CHECK: spv.ReturnValue
+    spv.ReturnValue %zero : i32
+  ^continue:
+    spv.Branch ^header
+  ^merge:
+    spv._merge
+  }
+  %one = spv.constant 1 : i32
+  spv.ReturnValue %one : i32
+}
+
 // -----
 
 "foo.function"() ({