Track peak memory usage (#65157)
authordriazati <driazati@users.noreply.github.com>
Tue, 21 Sep 2021 00:03:53 +0000 (17:03 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 21 Sep 2021 00:25:16 +0000 (17:25 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65157

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D31029049

Pulled By: driazati

fbshipit-source-id: 3e87e94e4872d118ad191aef2b77b8cefe90aeb6

16 files changed:
.github/templates/linux_ci_workflow.yml.j2
.github/workflows/generated-libtorch-linux-xenial-cuda10.2-py3.6-gcc7.yml
.github/workflows/generated-libtorch-linux-xenial-cuda11.3-py3.6-gcc7.yml
.github/workflows/generated-linux-bionic-cuda10.2-py3.9-gcc7.yml
.github/workflows/generated-linux-bionic-py3.6-clang9.yml
.github/workflows/generated-linux-bionic-py3.8-gcc9-coverage.yml
.github/workflows/generated-linux-xenial-cuda10.2-py3.6-gcc7.yml
.github/workflows/generated-linux-xenial-cuda11.3-py3.6-gcc7.yml
.github/workflows/generated-linux-xenial-py3.6-gcc5.4.yml
.github/workflows/generated-parallelnative-linux-xenial-py3.6-gcc5.4.yml
.github/workflows/generated-periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7.yml
.github/workflows/generated-periodic-linux-xenial-cuda11.1-py3.6-gcc7.yml
.github/workflows/generated-puretorch-linux-xenial-py3.6-gcc5.4.yml
test/test_import_stats.py [new file with mode: 0644]
test/test_import_time.py [deleted file]
tools/stats/scribe.py

index 9b6eb0e..635ad07 100644 (file)
@@ -171,7 +171,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -290,9 +289,15 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      !{{ common.parse_ref() }}
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -313,6 +318,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -351,7 +361,6 @@ jobs:
           python3 -mcodecov
       {%- endif %}
       !{{ common.upload_test_reports(name='linux') }}
-      !{{ common.parse_ref() }}
       !{{ common.upload_test_statistics(build_environment) }}
       !{{ common.teardown_ec2_linux() }}
 {% endblock %}
index 6da8ad5..78fe21d 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 6cf49bb..21b3f3d 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 36d0624..c69a349 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -421,9 +420,17 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      - name: Parse ref
+        id: parse-ref
+        run: .github/scripts/parse_ref.py
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -444,6 +451,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -512,9 +524,6 @@ jobs:
           if-no-files-found: error
           path:
             test-reports-*.zip
-      - name: Parse ref
-        id: parse-ref
-        run: .github/scripts/parse_ref.py
       - name: Display and upload test statistics (Click Me)
         if: always()
         # temporary hack: set CIRCLE_* vars, until we update
index 41620c3..f88a206 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -421,9 +420,17 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      - name: Parse ref
+        id: parse-ref
+        run: .github/scripts/parse_ref.py
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -444,6 +451,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -512,9 +524,6 @@ jobs:
           if-no-files-found: error
           path:
             test-reports-*.zip
-      - name: Parse ref
-        id: parse-ref
-        run: .github/scripts/parse_ref.py
       - name: Display and upload test statistics (Click Me)
         if: always()
         # temporary hack: set CIRCLE_* vars, until we update
index 10ee301..4d1cb20 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -421,9 +420,17 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      - name: Parse ref
+        id: parse-ref
+        run: .github/scripts/parse_ref.py
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -444,6 +451,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -516,9 +528,6 @@ jobs:
           if-no-files-found: error
           path:
             test-reports-*.zip
-      - name: Parse ref
-        id: parse-ref
-        run: .github/scripts/parse_ref.py
       - name: Display and upload test statistics (Click Me)
         if: always()
         # temporary hack: set CIRCLE_* vars, until we update
index ea34015..be9df9b 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -421,9 +420,17 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      - name: Parse ref
+        id: parse-ref
+        run: .github/scripts/parse_ref.py
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -444,6 +451,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -512,9 +524,6 @@ jobs:
           if-no-files-found: error
           path:
             test-reports-*.zip
-      - name: Parse ref
-        id: parse-ref
-        run: .github/scripts/parse_ref.py
       - name: Display and upload test statistics (Click Me)
         if: always()
         # temporary hack: set CIRCLE_* vars, until we update
index d6538f9..523cd73 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -421,9 +420,17 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      - name: Parse ref
+        id: parse-ref
+        run: .github/scripts/parse_ref.py
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -444,6 +451,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -512,9 +524,6 @@ jobs:
           if-no-files-found: error
           path:
             test-reports-*.zip
-      - name: Parse ref
-        id: parse-ref
-        run: .github/scripts/parse_ref.py
       - name: Display and upload test statistics (Click Me)
         if: always()
         # temporary hack: set CIRCLE_* vars, until we update
index 24a663c..8ec8061 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -421,9 +420,17 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      - name: Parse ref
+        id: parse-ref
+        run: .github/scripts/parse_ref.py
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -444,6 +451,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -512,9 +524,6 @@ jobs:
           if-no-files-found: error
           path:
             test-reports-*.zip
-      - name: Parse ref
-        id: parse-ref
-        run: .github/scripts/parse_ref.py
       - name: Display and upload test statistics (Click Me)
         if: always()
         # temporary hack: set CIRCLE_* vars, until we update
index 88728d5..6a9c44d 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -421,9 +420,17 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      - name: Parse ref
+        id: parse-ref
+        run: .github/scripts/parse_ref.py
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -444,6 +451,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -512,9 +524,6 @@ jobs:
           if-no-files-found: error
           path:
             test-reports-*.zip
-      - name: Parse ref
-        id: parse-ref
-        run: .github/scripts/parse_ref.py
       - name: Display and upload test statistics (Click Me)
         if: always()
         # temporary hack: set CIRCLE_* vars, until we update
index 3ad46ef..4691fe5 100644 (file)
@@ -242,7 +242,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
index 70dcd4d..aa9c87b 100644 (file)
@@ -242,7 +242,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
@@ -419,9 +418,17 @@ jobs:
       - name: Output disk space left
         run: |
           sudo df -H
+      - name: Parse ref
+        id: parse-ref
+        run: .github/scripts/parse_ref.py
       - name: Test
         env:
           PR_NUMBER: ${{ github.event.pull_request.number }}
+          IS_GHA: 1
+          CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
+          CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
+          CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
+          AWS_DEFAULT_REGION: us-east-1
         run: |
           if [[ $TEST_CONFIG == 'multigpu' ]]; then
             TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
@@ -442,6 +449,11 @@ jobs:
             -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
             -e GITHUB_ACTIONS \
             -e IN_CI \
+            -e IS_GHA \
+            -e CIRCLE_BRANCH \
+            -e CIRCLE_SHA1 \
+            -e CIRCLE_PR_NUMBER \
+            -e AWS_DEFAULT_REGION \
             -e IN_WHEEL_TEST \
             -e SHARD_NUMBER \
             -e JOB_BASE_NAME \
@@ -510,9 +522,6 @@ jobs:
           if-no-files-found: error
           path:
             test-reports-*.zip
-      - name: Parse ref
-        id: parse-ref
-        run: .github/scripts/parse_ref.py
       - name: Display and upload test statistics (Click Me)
         if: always()
         # temporary hack: set CIRCLE_* vars, until we update
index 353cd73..489f1e2 100644 (file)
@@ -244,7 +244,6 @@ jobs:
           CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
           CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
           CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
-          GITHUB_WORKFLOW_RUN_ID: '${{ github.run_id }}'
         run: |
           COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
           export COMMIT_TIME
diff --git a/test/test_import_stats.py b/test/test_import_stats.py
new file mode 100644 (file)
index 0000000..9c3ee51
--- /dev/null
@@ -0,0 +1,67 @@
+import subprocess
+import sys
+import unittest
+import pathlib
+
+from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX, IS_IN_CI
+
+
+REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
+
+try:
+    # Just in case PyTorch was not built in 'develop' mode
+    sys.path.append(str(REPO_ROOT))
+    from tools.stats.scribe import rds_write, register_rds_schema
+except ImportError:
+    register_rds_schema = None
+    rds_write = None
+
+
+# these tests could eventually be changed to fail if the import/init
+# time is greater than a certain threshold, but for now we just use them
+# as a way to track the duration of `import torch` in our ossci-metrics
+# S3 bucket (see tools/stats/print_test_stats.py)
+class TestImportTime(TestCase):
+    def test_time_import_torch(self):
+        TestCase.runWithPytorchAPIUsageStderr("import torch")
+
+    def test_time_cuda_device_count(self):
+        TestCase.runWithPytorchAPIUsageStderr(
+            "import torch; torch.cuda.device_count()",
+        )
+
+    @unittest.skipIf(not IS_LINUX, "Memory test is only implemented for Linux")
+    @unittest.skipIf(not IS_IN_CI, "Memory test only runs in CI")
+    def test_peak_memory(self):
+        def profile(module, name):
+            command = f"import {module}; import resource; print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)"
+            result = subprocess.run(
+                [sys.executable, "-c", command],
+                stdout=subprocess.PIPE,
+            )
+            max_rss = int(result.stdout.decode().strip())
+
+            return {
+                "test_name": name,
+                "peak_memory_bytes": max_rss,
+            }
+
+        data = profile("torch", "pytorch")
+        baseline = profile("sys", "baseline")
+        rds_write(
+            "import_stats", [data, baseline]
+        )
+
+
+if __name__ == "__main__":
+    if register_rds_schema and IS_IN_CI:
+        register_rds_schema(
+            "import_stats",
+            {
+                "test_name": "string",
+                "peak_memory_bytes": "int",
+                "time_ms": "int",
+            },
+        )
+
+    run_tests()
diff --git a/test/test_import_time.py b/test/test_import_time.py
deleted file mode 100644 (file)
index 38ce685..0000000
+++ /dev/null
@@ -1,19 +0,0 @@
-from torch.testing._internal.common_utils import TestCase, run_tests
-
-
-# these tests could eventually be changed to fail if the import/init
-# time is greater than a certain threshold, but for now we just use them
-# as a way to track the duration of `import torch` in our ossci-metrics
-# S3 bucket (see tools/stats/print_test_stats.py)
-class TestImportTime(TestCase):
-    def test_time_import_torch(self):
-        TestCase.runWithPytorchAPIUsageStderr('import torch')
-
-    def test_time_cuda_device_count(self):
-        TestCase.runWithPytorchAPIUsageStderr(
-            'import torch; torch.cuda.device_count()',
-        )
-
-
-if __name__ == '__main__':
-    run_tests()
index b67920a..3409430 100644 (file)
@@ -2,7 +2,7 @@ import base64
 import bz2
 import os
 import json
-from typing import Dict, Any, List, Union
+from typing import Dict, Any, List, Union, Optional
 
 
 _lambda_client = None
@@ -73,6 +73,14 @@ def invoke_rds(events: List[Dict[str, Any]]) -> Any:
 
 
 def register_rds_schema(table_name: str, schema: Dict[str, str]) -> None:
+    """
+    Register a table in RDS so it can be written to later on with 'rds_write'.
+    'schema' should be a mapping of field names -> types, where supported types
+    are 'int' and 'string'.
+
+    Metadata fields such as pr, ref, branch, workflow_id, and build_environment
+    will be added automatically.
+    """
     base = {
         "pr": "string",
         "ref": "string",
@@ -87,6 +95,9 @@ def register_rds_schema(table_name: str, schema: Dict[str, str]) -> None:
 
 
 def schema_from_sample(data: Dict[str, Any]) -> Dict[str, str]:
+    """
+    Extract a schema compatible with 'register_rds_schema' from data.
+    """
     schema = {}
     for key, value in data.items():
         if isinstance(value, str):
@@ -102,6 +113,24 @@ Query = Dict[str, Any]
 
 
 def rds_query(queries: Union[Query, List[Query]]) -> Any:
+    """
+    Execute a simple read query on RDS. Queries should be of the form below,
+    where everything except 'table_name' and 'fields' is optional.
+
+    {
+        "table_name": "my_table",
+        "fields": ["something", "something_else"],
+        "where": [
+            {
+                "field": "something",
+                "value": 10
+            }
+        ],
+        "group_by": ["something"],
+        "order_by": ["something"],
+        "limit": 5,
+    }
+    """
     if not isinstance(queries, list):
         queries = [queries]
 
@@ -113,6 +142,11 @@ def rds_query(queries: Union[Query, List[Query]]) -> Any:
 
 
 def rds_saved_query(query_names: Union[str, List[str]]) -> Any:
+    """
+    Execute a hardcoded RDS query by name. See
+    https://github.com/pytorch/test-infra/blob/main/aws/lambda/rds-proxy/lambda_function.py#L52
+    for available queries or submit a PR there to add a new one.
+    """
     if not isinstance(query_names, list):
         query_names = [query_names]
 
@@ -124,8 +158,19 @@ def rds_saved_query(query_names: Union[str, List[str]]) -> Any:
 
 
 def rds_write(
-    table_name: str, values_list: List[Dict[str, Any]], only_on_master: bool = True
+    table_name: str,
+    values_list: List[Dict[str, Any]],
+    only_on_master: bool = True,
+    only_on_jobs: Optional[List[str]] = None,
 ) -> None:
+    """
+    Note: Only works from GitHub Actions CI runners
+
+    Write a set of entries to a particular RDS table. 'table_name' should be
+    a table registered via 'register_rds_schema' prior to calling rds_write.
+    'values_list' should be a list of dictionaries that map field names to
+    values.
+    """
     sprint("Writing for", os.getenv("CIRCLE_PR_NUMBER"))
     is_master = os.getenv("CIRCLE_PR_NUMBER", "").strip() == ""
     if only_on_master and not is_master:
@@ -136,12 +181,17 @@ def rds_write(
     if pr is not None and pr.strip() == "":
         pr = None
 
+    build_environment = os.environ.get("BUILD_ENVIRONMENT", "").split()[0]
+    if only_on_jobs is not None and build_environment not in only_on_jobs:
+        sprint(f"Skipping write since {build_environment} is not in {only_on_jobs}")
+        return
+
     base = {
         "pr": pr,
         "ref": os.getenv("CIRCLE_SHA1"),
         "branch": os.getenv("CIRCLE_BRANCH"),
-        "workflow_id": os.getenv("GITHUB_WORKFLOW_RUN_ID"),
-        "build_environment": os.environ.get("BUILD_ENVIRONMENT", "").split()[0],
+        "workflow_id": os.getenv("GITHUB_RUN_ID"),
+        "build_environment": build_environment,
     }
 
     events = []