Add update_any_test_checks.py convenience utility
authorNicolai Hähnle <nicolai.haehnle@amd.com>
Thu, 1 Dec 2022 12:40:13 +0000 (13:40 +0100)
committerNicolai Hähnle <nicolai.haehnle@amd.com>
Thu, 1 Dec 2022 16:25:53 +0000 (17:25 +0100)
Given a list of test files, this utility will run (optionally in
parallel) the corresponding update_*_test_checks tool for all given
tests that have automatically generated assertions.

Differential Revision: https://reviews.llvm.org/D139100

llvm/utils/update_any_test_checks.py [new file with mode: 0755]

diff --git a/llvm/utils/update_any_test_checks.py b/llvm/utils/update_any_test_checks.py
new file mode 100755 (executable)
index 0000000..e1c7a8f
--- /dev/null
@@ -0,0 +1,115 @@
+#!/usr/bin/env python3
+
+"""Dispatch to update_*_test_checks.py scripts automatically in bulk
+
+Given a list of test files, this script will invoke the correct
+update_test_checks-style script, skipping any tests which have not previously
+had assertions autogenerated.
+"""
+
+from __future__ import print_function
+
+import argparse
+import os
+import re
+import subprocess
+import sys
+from concurrent.futures import ThreadPoolExecutor
+
+RE_ASSERTIONS = re.compile(
+    r'NOTE: Assertions have been autogenerated by ([^\s]+)( UTC_ARGS:.*)?$')
+
+def find_utc_tool(search_path, utc_name):
+  """
+  Return the path to the given UTC tool in the search path, or None if not
+  found.
+  """
+  for path in search_path:
+    candidate = os.path.join(path, utc_name)
+    if os.path.isfile(candidate):
+      return candidate
+  return None
+
+def run_utc_tool(utc_name, utc_tool, testname):
+  result = subprocess.run([utc_tool, testname], stdout=subprocess.PIPE,
+                          stderr=subprocess.PIPE)
+  return (result.returncode, result.stdout, result.stderr)
+
+def main():
+  from argparse import RawTextHelpFormatter
+  parser = argparse.ArgumentParser(description=__doc__,
+                                   formatter_class=RawTextHelpFormatter)
+  parser.add_argument(
+      '--jobs', '-j', default=1, type=int,
+      help='Run the given number of jobs in parallel')
+  parser.add_argument(
+      '--utc-dir', nargs='*',
+      help='Additional directories to scan for update_*_test_checks scripts')
+  parser.add_argument('tests', nargs='+')
+  config = parser.parse_args()
+
+  script_name = os.path.basename(__file__)
+  if config.utc_dir:
+    utc_search_path = config.utc_dir[:]
+  else:
+    utc_search_path = []
+  utc_search_path.append(os.path.join(os.path.dirname(script_name),
+                                      os.path.pardir))
+
+  not_autogenerated = []
+  utc_tools = {}
+  have_error = False
+
+  with ThreadPoolExecutor(max_workers=config.jobs) as executor:
+    jobs = []
+
+    for testname in config.tests:
+      with open(testname, 'r') as f:
+        header = f.readline().strip()
+        m = RE_ASSERTIONS.search(header)
+        if m is None:
+          not_autogenerated.append(testname)
+          continue
+
+        utc_name = m.group(1)
+        if utc_name not in utc_tools:
+          utc_tools[utc_name] = find_utc_tool(utc_search_path, utc_name)
+          if not utc_tools[utc_name]:
+            print(f"{utc_name}: not found (used in {testname})",
+                  file=sys.stderr)
+            have_error = True
+            continue
+
+        future = executor.submit(run_utc_tool, utc_name, utc_tools[utc_name],
+                                 testname)
+        jobs.append((testname, future))
+
+    for testname, future in jobs:
+      return_code, stdout, stderr = future.result()
+
+      print(f"Update {testname}")
+      stdout = stdout.decode(errors='replace')
+      if stdout:
+        print(stdout, end='')
+        if not stdout.endswith('\n'):
+          print()
+
+      stderr = stderr.decode(errors='replace')
+      if stderr:
+        print(stderr, end='')
+        if not stderr.endswith('\n'):
+          print()
+      if return_code != 0:
+        print(f"Return code: {return_code}")
+        have_error = True
+
+  if have_error:
+    sys.exit(1)
+
+  if not_autogenerated:
+    print("Tests without autogenerated assertions:")
+    for testname in not_autogenerated:
+      print(f"  {testname}")
+
+if __name__ == '__main__':
+  main()