[MLIR] Rework generate-test-checks.py to attach CHECK lines to the source (test)...
authorTim Shen <timshen@google.com>
Tue, 16 Jun 2020 02:41:03 +0000 (19:41 -0700)
committerTim Shen <timshen@google.com>
Tue, 16 Jun 2020 18:15:46 +0000 (11:15 -0700)
Summary:
This patch adds --source flag to indicate the source file. Then it tries to find insert
points in the source file and insert corresponding checks at those places.

Example output from Tensorflow XLA:

// -----

// CHECK-LABEL:   func @main.3(
// CHECK-SAME:                 %[[VAL_0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 : index},
// CHECK-SAME:                 %[[VAL_1:.*]]: memref<16xi8> {xla_lhlo.alloc = 0 : index, xla_lhlo.liveout = true}) {
// CHECK:           %[[VAL_2:.*]] = constant 0 : index
// CHECK:           %[[VAL_3:.*]] = constant 0 : index
// CHECK:           %[[VAL_4:.*]] = std.view %[[VAL_1]]{{\[}}%[[VAL_3]]][] : memref<16xi8> to memref<2x2xf32>
// CHECK:           "xla_lhlo.tanh"(%[[VAL_0]], %[[VAL_4]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK:           return
// CHECK:         }
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
  %res = "xla_hlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
  return %res : tensor<2x2xf32>
}

Differential Revision: https://reviews.llvm.org/D81903

mlir/utils/generate-test-checks.py

index 5fac81b..e08f64e 100755 (executable)
@@ -56,6 +56,12 @@ class SSAVariableNamer:
   def pop_name_scope(self):
     self.scopes.pop()
 
+  def num_scopes(self):
+    return len(self.scopes)
+
+  def clear_counter(self):
+    self.name_counter = 0
+
 
 # Process a line of input that has been split at each SSA identifier '%'.
 def process_line(line_chunks, variable_namer):
@@ -87,6 +93,22 @@ def process_line(line_chunks, variable_namer):
   return output_line + '\n'
 
 
+def process_source_lines(source_lines, note, args):
+  source_split_re = re.compile(args.source_delim_regex)
+
+  source_segments = [[]]
+  for line in source_lines:
+    if line == note:
+      continue
+    if line.find(args.check_prefix) != -1:
+      continue
+    if source_split_re.search(line):
+      source_segments.append([])
+
+    source_segments[-1].append(line + '\n')
+  return source_segments
+
+
 # Pre-process a line of input to remove any character sequences that will be
 # problematic with FileCheck.
 def preprocess_line(line):
@@ -112,25 +134,51 @@ def main():
       '--output',
       nargs='?',
       type=argparse.FileType('w'),
-      default=sys.stdout)
+      default=None)
   parser.add_argument(
       'input',
       nargs='?',
       type=argparse.FileType('r'),
       default=sys.stdin)
+  parser.add_argument(
+      '--source', type=str,
+      help='Print each CHECK chunk before each delimeter line in the source'
+           'file, respectively. The delimeter lines are identified by '
+           '--source_delim_regex.')
+  parser.add_argument('--source_delim_regex', type=str, default='func @')
+  parser.add_argument(
+      '--starts_from_scope', type=int, default=1,
+      help='Omit the top specified level of content. For example, by default '
+           'it omits "module {"')
+  parser.add_argument('-i', '--inplace', action='store_true', default=False)
+
   args = parser.parse_args()
 
   # Open the given input file.
   input_lines = [l.rstrip() for l in args.input]
   args.input.close()
 
-  output_lines = []
-
   # Generate a note used for the generated check file.
   script_name = os.path.basename(__file__)
   autogenerated_note = (ADVERT + 'utils/' + script_name)
-  output_lines.append(autogenerated_note + '\n')
 
+  source_segments = None
+  if args.source:
+    source_segments = process_source_lines(
+        [l.rstrip() for l in open(args.source, 'r')],
+        autogenerated_note,
+        args
+    )
+
+  if args.inplace:
+    assert args.output is None
+    output = open(args.source, 'w')
+  elif args.output is None:
+    output = sys.stdout
+  else:
+    output = args.output
+
+  output_segments = [[]]
   # A map containing data used for naming SSA value names.
   variable_namer = SSAVariableNamer()
   for input_line in input_lines:
@@ -144,17 +192,25 @@ def main():
     if is_block:
       input_line = input_line.rsplit('//', 1)[0].rstrip()
 
-    # Top-level operations are heuristically the operations at nesting level 1.
-    is_toplevel_op = (not is_block and input_line.startswith('  ') and
-                      input_line[2] != ' ' and input_line[2] != '}')
+    cur_level = variable_namer.num_scopes()
 
     # If the line starts with a '}', pop the last name scope.
     if lstripped_input_line[0] == '}':
       variable_namer.pop_name_scope()
+      cur_level = variable_namer.num_scopes()
 
     # If the line ends with a '{', push a new name scope.
     if input_line[-1] == '{':
       variable_namer.push_name_scope()
+      if cur_level == args.starts_from_scope:
+        output_segments.append([])
+
+    # Omit lines at the near top level e.g. "module {".
+    if cur_level < args.starts_from_scope:
+      continue
+
+    if len(output_segments[-1]) == 0:
+      variable_namer.clear_counter()
 
     # Preprocess the input to remove any sequences that may be problematic with
     # FileCheck.
@@ -164,7 +220,7 @@ def main():
     ssa_split = input_line.split('%')
 
     # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
-    if not is_toplevel_op or not ssa_split[0]:
+    if len(output_segments[-1]) != 0 or not ssa_split[0]:
       output_line = '// ' + args.check_prefix + ': '
       # Pad to align with the 'LABEL' statements.
       output_line += (' ' * len('-LABEL'))
@@ -176,32 +232,40 @@ def main():
       output_line += process_line(ssa_split[1:], variable_namer)
 
     else:
-      # Append a newline to the output to separate the logical blocks.
-      output_lines.append('\n')
-      output_line = '// ' + args.check_prefix + '-LABEL: '
-
       # Output the first line chunk that does not contain an SSA name for the
       # label.
-      output_line += ssa_split[0] + '\n'
+      output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
 
-      # Process the rest of the input line on a separate check line.
-      if len(ssa_split) > 1:
+      # Process the rest of the input line on separate check lines.
+      for argument in ssa_split[1:]:
         output_line += '// ' + args.check_prefix + '-SAME:  '
 
         # Pad to align with the original position in the line.
         output_line += ' ' * len(ssa_split[0])
 
         # Process the rest of the line.
-        output_line += process_line(ssa_split[1:], variable_namer)
+        output_line += process_line([argument], variable_namer)
 
     # Append the output line.
-    output_lines.append(output_line)
+    output_segments[-1].append(output_line)
+
+  output.write(autogenerated_note + '\n')
 
   # Write the output.
-  for output_line in output_lines:
-    args.output.write(output_line)
-  args.output.write('\n')
-  args.output.close()
+  if source_segments:
+    assert len(output_segments) == len(source_segments)
+    for check_segment, source_segment in zip(output_segments, source_segments):
+      for line in check_segment:
+        output.write(line)
+      for line in source_segment:
+        output.write(line)
+  else:
+    for segment in output_segments:
+      output.write('\n')
+      for output_line in segment:
+        output.write(output_line)
+    output.write('\n')
+  output.close()
 
 
 if __name__ == '__main__':