Fix logic to determine master vs PR (#65155)
authordriazati <driazati@users.noreply.github.com>
Tue, 21 Sep 2021 00:03:53 +0000 (17:03 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 21 Sep 2021 00:25:14 +0000 (17:25 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65155

This was bugged before on empty strings which caused the hook to write on any job, not just `master` regardless of the `only_on_master` flag.

Test Plan: see `[scribe] Skipping RDS write on PR` in the logs for `linux-xenial-cuda11.3-py3.6-gcc7`

Reviewed By: malfet

Differential Revision: D31029048

Pulled By: driazati

fbshipit-source-id: 77c4a60e443d8fc19990755a3a346576afee86d8

15 files changed:
.github/templates/linux_ci_workflow.yml.j2
.github/workflows/generated-libtorch-linux-xenial-cuda10.2-py3.6-gcc7.yml
.github/workflows/generated-libtorch-linux-xenial-cuda11.3-py3.6-gcc7.yml
.github/workflows/generated-linux-bionic-cuda10.2-py3.9-gcc7.yml
.github/workflows/generated-linux-bionic-py3.6-clang9.yml
.github/workflows/generated-linux-bionic-py3.8-gcc9-coverage.yml
.github/workflows/generated-linux-xenial-cuda10.2-py3.6-gcc7.yml
.github/workflows/generated-linux-xenial-cuda11.3-py3.6-gcc7.yml
.github/workflows/generated-linux-xenial-py3.6-gcc5.4.yml
.github/workflows/generated-parallelnative-linux-xenial-py3.6-gcc5.4.yml
.github/workflows/generated-periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7.yml
.github/workflows/generated-periodic-linux-xenial-cuda11.1-py3.6-gcc7.yml
.github/workflows/generated-puretorch-linux-xenial-py3.6-gcc5.4.yml
tools/stats/scribe.py
tools/stats/upload_binary_size_to_scuba.py

index 97af532..9b6eb0e 100644 (file)
@@ -171,6 +171,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 78fe21d..6da8ad5 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 21b3f3d..6cf49bb 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 243132f..36d0624 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 9c25413..41620c3 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index a70e81b..10ee301 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 68a3631..ea34015 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index dbbb4a5..d6538f9 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index ce3d774..24a663c 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 899249d..88728d5 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 4691fe5..3ad46ef 100644 (file)
@@ -242,6 +242,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 608b957..70dcd4d 100644 (file)
@@ -242,6 +242,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 489f1e2..353cd73 100644 (file)
@@ -244,6 +244,7 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
+          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 0ea30ce..b67920a 100644 (file)
@@ -8,6 +8,9 @@ from typing import Dict, Any, List, Union
 _lambda_client = None
 
 
+IS_GHA = os.getenv("IS_GHA", "0") == "1"
+
+
 def sprint(*args: Any) -> None:
     print("[scribe]", *args)
 
@@ -24,9 +27,7 @@ def aws_lambda() -> Any:
 
 
 def invoke_lambda(name: str, payload: Any) -> Any:
-    res = aws_lambda().invoke(
-        FunctionName=name, Payload=json.dumps(payload).encode()
-    )
+    res = aws_lambda().invoke(FunctionName=name, Payload=json.dumps(payload).encode())
     payload = str(res["Payload"].read().decode())
     if res.get("FunctionError"):
         raise Exception(payload)
@@ -64,6 +65,10 @@ def _send_to_scribe_via_http(access_token: str, logs: str) -> str:
 
 
 def invoke_rds(events: List[Dict[str, Any]]) -> Any:
+    if not IS_GHA:
+        sprint(f"Not invoking RDS lambda outside GitHub Actions:\n{events}")
+        return
+
     return invoke_lambda("rds-proxy", events)
 
 
@@ -73,11 +78,10 @@ def register_rds_schema(table_name: str, schema: Dict[str, str]) -> None:
         "ref": "string",
         "branch": "string",
         "workflow_id": "string",
+        "build_environment": "string",
     }
 
-    event = [
-        {"create_table": {"table_name": table_name, "fields": {**schema, **base}}}
-    ]
+    event = [{"create_table": {"table_name": table_name, "fields": {**schema, **base}}}]
 
     invoke_rds(event)
 
@@ -122,16 +126,22 @@ def rds_saved_query(query_names: Union[str, List[str]]) -> Any:
 def rds_write(
     table_name: str, values_list: List[Dict[str, Any]], only_on_master: bool = True
 ) -> None:
-    sprint("Writing for ", os.getenv("CIRCLE_PR_NUMBER"))
-    if not only_on_master and os.getenv("CIRCLE_PR_NUMBER"):
+    sprint("Writing for", os.getenv("CIRCLE_PR_NUMBER"))
+    is_master = os.getenv("CIRCLE_PR_NUMBER", "").strip() == ""
+    if only_on_master and not is_master:
         sprint("Skipping RDS write on PR")
         return
 
+    pr = os.getenv("CIRCLE_PR_NUMBER", None)
+    if pr is not None and pr.strip() == "":
+        pr = None
+
     base = {
-        "pr": os.getenv("CIRCLE_PR_NUMBER"),
+        "pr": pr,
         "ref": os.getenv("CIRCLE_SHA1"),
         "branch": os.getenv("CIRCLE_BRANCH"),
-        "workflow_id": os.getenv("CIRCLE_WORKFLOW_ID"),
+        "workflow_id": os.getenv("GITHUB_WORKFLOW_RUN_ID"),
+        "build_environment": os.environ.get("BUILD_ENVIRONMENT", "").split()[0],
     }
 
     events = []
@@ -140,5 +150,5 @@ def rds_write(
             {"write": {"table_name": table_name, "values": {**values, **base}}}
         )
 
-    print("Wrote stats for", table_name)
+    sprint("Wrote stats for", table_name)
     invoke_rds(events)
index 51d6555..0b53bdf 100644 (file)
@@ -29,11 +29,7 @@ def get_size(file_dir: str) -> int:
 
 
 def base_data() -> Dict[str, Any]:
-    build_env_split = os.environ.get("BUILD_ENVIRONMENT", "").split()
-    build_environment = build_env_split[0]
-
     return {
-        "build_environment": build_environment,
         "run_duration_seconds": int(
             time.time() - os.path.getmtime(os.path.realpath(__file__))
         ),
@@ -176,6 +172,7 @@ if __name__ == "__main__":
                 }
                 data.append({**base_data(), **library_data})
             rds_write("binary_size", data)
+            print(json.dumps(data, indent=2))
         else:
             print("checking dir: " + file_dir)
             size = get_size(file_dir)