# 1. Does not reuse the build artifact in other CI workflows
# 2. CI jobs are serialized because there is only one worker
import os
+import git # type: ignore[import]
import pathlib
import argparse
import subprocess
PYTHON_VERSION = "3.7"
TORCHBENCH_CONFIG_NAME = "config.yaml"
MAGIC_PREFIX = "RUN_TORCHBENCH:"
+MAGIC_TORCHBENCH_PREFIX = "TORCHBENCH_BRANCH:"
ABTEST_CONFIG_TEMPLATE = """# This config is automatically generated by run_torchbench.py
start: {control}
end: {treatment}
lines = map(lambda x: x.strip(), pf.read().splitlines())
magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines))
if magic_lines:
- # Only the first magic line will be respected.
+ # Only the first magic line will be recognized.
model_list = list(map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX):].split(",")))
# Shortcut: if model_list is ["ALL"], run all the tests
if model_list == ["ALL"]:
return []
return model_list
+def identify_torchbench_branch(torchbench_path: str, prbody_file: str) -> None:
+ branch_name: str
+ with open(prbody_file, "r") as pf:
+ lines = map(lambda x: x.strip(), pf.read().splitlines())
+ magic_lines = list(filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines))
+ if magic_lines:
+ # Only the first magic line will be recognized.
+ branch_name = magic_lines[0][len(MAGIC_TORCHBENCH_PREFIX):].strip()
+ # If not specified, directly return without the branch checkout
+ if not branch_name:
+ return
+ try:
+ print(f"Checking out the TorchBench branch: {branch_name} ...")
+ repo = git.Repo(torchbench_path)
+ origin = repo.remotes.origin
+ origin.fetch(branch_name)
+ repo.create_head(branch_name, origin.refs[branch_name]).checkout()
+ except git.exc.GitCommandError:
+ raise RuntimeError(f'{branch_name} doesn\'t exist in the pytorch/benchmark repository. Please double check.')
+
def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str) -> None:
# Copy system environment so that we will not override
env = dict(os.environ)
if not models:
print("Can't parse the model filter from the pr body. Currently we only support allow-list.")
exit(1)
+ # Identify the specified TorchBench branch, verify the branch exists, and checkout the branch
+ try:
+ identify_torchbench_branch(args.torchbench_path, args.pr_body)
+ except RuntimeError as e:
+ print(f"Identify TorchBench branch failed: {str(e)}")
+ exit(1)
print(f"Ready to run TorchBench with benchmark. Result will be saved in the directory: {output_dir}.")
# Run TorchBench with the generated config
torchbench_config = gen_abtest_config(args.pr_base_sha, args.pr_head_sha, models)
# shellcheck disable=SC1091
. "${HOME}"/anaconda3/etc/profile.d/conda.sh
conda activate pr-ci
- conda install -y numpy=1.17 requests=2.22 ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six dataclasses pillow pytest tabulate
+ conda install -y numpy=1.17 requests=2.22 ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six dataclasses pillow pytest tabulate gitpython
+ conda install -y -c pytorch-nightly torchtext torchvision
- name: Update self-hosted PyTorch
run: |
pushd "${HOME}"/pytorch
+ git remote prune origin
git fetch
popd
+ - name: Install TorchBench dependencies
+ run: |
+ # shellcheck disable=SC1091
+ . "${HOME}"/anaconda3/etc/profile.d/conda.sh
+ conda activate pr-ci
+ pushd "${PWD}"/benchmark
+ python install.py
- name: Run TorchBench
run: |
pushd "${HOME}"/pytorch