[spirv] Fix gen_spirv_dialect.py and add spv.Unreachable
authorLei Zhang <antiagainst@google.com>
Wed, 30 Oct 2019 12:40:47 +0000 (05:40 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 30 Oct 2019 12:41:18 +0000 (05:41 -0700)
This CL fixed gen_spirv_dialect.py to support nested delimiters when
chunking existing ODS entries in .td files and to allow ops without
correspondence in the spec. This is needed to pull in the definition
of OpUnreachable.

PiperOrigin-RevId: 277486465

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/terminator.mlir
mlir/test/Dialect/SPIRV/control-flow-ops.mlir
mlir/utils/spirv/gen_spirv_dialect.py

index 77e457e..a6446e0 100644 (file)
@@ -174,6 +174,7 @@ def SPV_OC_OpBranch                 : I32EnumAttrCase<"OpBranch", 249>;
 def SPV_OC_OpBranchConditional      : I32EnumAttrCase<"OpBranchConditional", 250>;
 def SPV_OC_OpReturn                 : I32EnumAttrCase<"OpReturn", 253>;
 def SPV_OC_OpReturnValue            : I32EnumAttrCase<"OpReturnValue", 254>;
+def SPV_OC_OpUnreachable            : I32EnumAttrCase<"OpUnreachable", 255>;
 def SPV_OC_OpModuleProcessed        : I32EnumAttrCase<"OpModuleProcessed", 330>;
 
 def SPV_OpcodeAttr :
@@ -209,7 +210,7 @@ def SPV_OpcodeAttr :
       SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpPhi,
       SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
       SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
-      SPV_OC_OpModuleProcessed
+      SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed
       ]> {
     let returnType = "::mlir::spirv::Opcode";
     let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
index 92cb06f..5c9eab8 100644 (file)
@@ -373,6 +373,29 @@ def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
 
 // -----
 
+def SPV_UnreachableOp : SPV_Op<"Unreachable", [InFunctionScope, Terminator]> {
+  let summary = "Declares that this block is not reachable in the CFG.";
+
+  let description = [{
+    This instruction must be the last instruction in a block.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    unreachable-op ::= `spv.Unreachable`
+    ```
+  }];
+
+  let arguments = (ins);
+
+  let results = (outs);
+
+  let parser = [{ return parseNoIOOp(parser, result); }];
+  let printer = [{ printNoIOOp(getOperation(), p); }];
+}
+
+// -----
+
 def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> {
   let summary = "Return a value from a function.";
 
index 8d32daa..85e22a5 100644 (file)
@@ -2286,6 +2286,26 @@ static void print(spirv::UndefOp undefOp, OpAsmPrinter &printer) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.Unreachable
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
+  auto *op = unreachableOp.getOperation();
+  auto *block = op->getBlock();
+  // Fast track: if this is in entry block, its invalid. Otherwise, if no
+  // predecessors, it's valid.
+  if (block->isEntryBlock())
+    return unreachableOp.emitOpError("cannot be used in reachable block");
+  if (block->hasNoPredecessors())
+    return success();
+
+  // TODO(antiagainst): further verification needs to analyze reachablility from
+  // the entry block.
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.Variable
 //===----------------------------------------------------------------------===//
 
index ba30986..502926a 100644 (file)
@@ -14,4 +14,14 @@ spv.module "Logical" "GLSL450" {
     // CHECK: spv.ReturnValue {{.*}} : i32
     spv.ReturnValue %1 : i32
   }
+
+  // CHECK-LABEL: @unreachable
+  func @unreachable() {
+    spv.Return
+  // CHECK-NOT: ^bb
+  ^bb1:
+    // Unreachable blocks will be dropped during serialization.
+    // CHECK-NOT: spv.Unreachable
+    spv.Unreachable
+  }
 }
index 11377ed..63e214a 100644 (file)
@@ -676,3 +676,37 @@ func @missing_entry_block() -> () {
   }
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.Unreachable
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @unreachable_no_pred
+func @unreachable_no_pred() {
+    spv.Return
+
+  ^next:
+    // CHECK: spv.Unreachable
+    spv.Unreachable
+}
+
+// CHECK-LABEL: func @unreachable_with_pred
+func @unreachable_with_pred() {
+    spv.Return
+
+  ^parent:
+    spv.Branch ^unreachable
+
+  ^unreachable:
+    // CHECK: spv.Unreachable
+    spv.Unreachable
+}
+
+// -----
+
+func @unreachable() {
+  // expected-error @+1 {{cannot be used in reachable block}}
+  spv.Unreachable
+}
index 22fdd94..1e1af82 100755 (executable)
@@ -505,6 +505,42 @@ def get_string_between(base, start, end):
   return '', split[0]
 
 
+def get_string_between_nested(base, start, end):
+  """Extracts a substring with a nested start and end from a string.
+
+  Arguments:
+    - base: string to extract from.
+    - start: string to use as the start of the substring.
+    - end: string to use as the end of the substring.
+
+  Returns:
+    - The substring if found
+    - The part of the base after end of the substring. Is the base string itself
+      if the substring wasnt found.
+  """
+  split = base.split(start, 1)
+  if len(split) == 2:
+    # Handle nesting delimiters
+    rest = split[1]
+    unmatched_start = 1
+    index = 0
+    while unmatched_start > 0 and index < len(rest):
+      if rest[index:].startswith(end):
+        unmatched_start -= 1
+        index += len(end)
+      elif rest[index:].startswith(start):
+        unmatched_start += 1
+        index += len(start)
+      else:
+        index += 1
+
+    assert index < len(rest), \
+           'cannot find end "{end}" while extracting substring '\
+           'starting with "{start}"'.format(start=start, end=end)
+    return rest[:index - len(end)].rstrip(end), rest[index:]
+  return '', split[0]
+
+
 def extract_td_op_info(op_def):
   """Extracts potentially manually specified sections in op's definition.
 
@@ -528,7 +564,7 @@ def extract_td_op_info(op_def):
   inst_category = inst_category[0] if len(inst_category) == 1 else 'Op'
 
   # Get category_args
-  op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0]
+  op_tmpl_params = get_string_between_nested(op_def, '<', '>')[0]
   opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
   category_args = rest.split('[', 1)[0]
 
@@ -587,10 +623,12 @@ def update_td_op_definitions(path, instructions, docs, filter_list,
   # For each existing op, extract the manually-written sections out to retain
   # them when re-generating the ops. Also append the existing ops to filter
   # list.
+  name_op_map = {}  # Map from opname to its existing ODS definition
   op_info_dict = {}
   for op in ops:
     info_dict = extract_td_op_info(op)
     opname = info_dict['opname']
+    name_op_map[opname] = op
     op_info_dict[opname] = info_dict
     filter_list.append(opname)
   filter_list = sorted(list(set(filter_list)))
@@ -598,11 +636,15 @@ def update_td_op_definitions(path, instructions, docs, filter_list,
   op_defs = []
   for opname in filter_list:
     # Find the grammar spec for this op
-    instruction = next(
-        inst for inst in instructions if inst['opname'] == opname)
-    op_defs.append(
-        get_op_definition(instruction, docs[opname],
-                          op_info_dict.get(opname, {})))
+    try:
+      instruction = next(
+          inst for inst in instructions if inst['opname'] == opname)
+      op_defs.append(
+          get_op_definition(instruction, docs[opname],
+                            op_info_dict.get(opname, {})))
+    except StopIteration:
+      # This is an op added by us; use the existing ODS definition.
+      op_defs.append(name_op_map[opname])
 
   # Substitute the old op definitions
   op_defs = [header] + op_defs + [footer]