(torch.distributed.elastic) properly format traceback on error (#65041)
authorKiuk Chung <kiuk@fb.com>
Wed, 15 Sep 2021 19:48:28 +0000 (12:48 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 19:50:21 +0000 (12:50 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65041

Fixes a bug introduced in https://github.com/pytorch/pytorch/pull/64036 where the traceback of the error handler is printed out rather than the traceback of the actual exception.

Fixes https://github.com/pytorch/pytorch/issues/60910
Closes https://github.com/pytorch/pytorch/issues/60910

BEFORE (note that the `py_callstack` is NOT the traceback of the RuntimeError):
```
**************************************************************************************************************************************************************************************************************************************************
                                                                                                              run_script_path FAILED
==================================================================================================================================================================================================================================================
Root Cause:
[0]:
  time: 2021-09-14_22:01:06
  rank: 0 (local_rank: 0)
  exitcode: 1 (pid: 1092727)
  error_file: /tmp/torchelastic_aeyvjbpe/none_8zuih7tj/attempt_0/0/error.json
  msg:
    {
      "message": "RuntimeError: rasing error since --throw was specified",
      "extraInfo": {
        "py_callstack": [
          "  File \"<string>\", line 1, in <module>\n",
          "  File \"/usr/local/fbcode/platform009/lib/python3.8/multiprocessing/spawn.py\", line 116, in spawn_main\n    exitcode = _main(fd, parent_sentinel)\n",
          "  File \"/usr/local/fbcode/platform009/lib/python3.8/multiprocessing/spawn.py\", line 129, in _main\n    return self._bootstrap(parent_sentinel)\n",
          "  File \"/usr/local/fbcode/platform009/lib/python3.8/multiprocessing/process.py\", line 315, in _bootstrap\n    self.run()\n",
          "  File \"/usr/local/fbcode/platform009/lib/python3.8/multiprocessing/process.py\", line 108, in run\n    self._target(*self._args, **self._kwargs)\n",
          "  File \"/data/users/kiuk/fbsource/fbcode/buck-out/dev/gen/caffe2/run#link-tree/torch/multiprocessing/spawn.py\", line 59, in _wrap\n    fn(i, *args)\n",
          "  File \"/data/users/kiuk/fbsource/fbcode/buck-out/dev/gen/caffe2/run#link-tree/torch/distributed/elastic/multiprocessing/api.py\", line 382, in _wrap\n    ret = record(fn)(*args_)\n",
          "  File \"/data/users/kiuk/fbsource/fbcode/buck-out/dev/gen/caffe2/run#link-tree/torch/distributed/elastic/multiprocessing/errors/__init__.py\", line 373, in wrapper\n    error_handler.record_exception(e)\n",
          "  File \"/data/users/kiuk/fbsource/fbcode/buck-out/dev/gen/caffe2/run#link-tree/torch/distributed/elastic/multiprocessing/errors/error_handler.py\", line 86, in record_exception\n    _write_error(e, self._get_error_file_path())\n",
          "  File \"/data/users/kiuk/fbsource/fbcode/buck-out/dev/gen/caffe2/run#link-tree/torch/distributed/elastic/multiprocessing/errors/error_handler.py\", line 26, in _write_error\n    \"py_callstack\": traceback.format_stack(),\n"
        ],
        "timestamp": "1631682066"
      }
    }

==================================================================================================================================================================================================================================================
Other Failures:
  <NO_OTHER_FAILURES>
**************************************************************************************************************************************************************************************************************************************************
```

AFTER (note the traceback is the traceback of the RuntimeError):
```
********************************************************************************
                             run_script_path FAILED
================================================================================
Root Cause:
[0]:
  time: 2021-09-14_21:49:25
  rank: 0 (local_rank: 0)
  exitcode: 1 (pid: 1014681)
  error_file: /tmp/torchelastic_q0zods2c/none_qwmz5dgj/attempt_0/0/error.json
  msg: Traceback (most recent call last):
    File "/data/users/kiuk/fbsource/fbcode/buck-out/dev/gen/caffe2/run#link-tree/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 361, in wrapper
      return f(*args, **kwargs)
    File "/data/users/kiuk/fbsource/fbcode/buck-out/dev/gen/caffe2/run#link-tree/torch/distributed/run.py", line 671, in run_script_path
      runpy.run_path(sys.argv[0], run_name="__main__")
    File "/usr/local/fbcode/platform009/lib/python3.8/runpy.py", line 265, in run_path
      return _run_module_code(code, init_globals, run_name,
    File "/usr/local/fbcode/platform009/lib/python3.8/runpy.py", line 97, in _run_module_code
      _run_code(code, mod_globals, init_globals,
    File "/usr/local/fbcode/platform009/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/kiuk/tmp/test.py", line 55, in <module>
      main()
    File "/data/users/kiuk/fbsource/fbcode/buck-out/dev/gen/caffe2/run#link-tree/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 361, in wrapper
      return f(*args, **kwargs)
    File "/home/kiuk/tmp/test.py", line 25, in main
      raise RuntimeError("rasing error since --throw was specified")
  RuntimeError: rasing error since --throw was specified

================================================================================
Other Failures:
  <NO_OTHER_FAILURES>
********************************************************************************
```

Test Plan:
(see summary for before and after)

`test.py` contents:
```
import argparse
import os
import sys

import torch
import torch.distributed as dist
import torch.nn.functional as F

from torch.distributed.elastic.multiprocessing.errors import record

def parse_args(argv):
    parser = argparse.ArgumentParser(description="test script")
    parser.add_argument("--init_method", type=str, default="env://")
    parser.add_argument("--backend", type=str, default="gloo")
    parser.add_argument("--throw", action="store_true", default=False)
    parser.add_argument("--exit", action="store_true", default=False)
    return parser.parse_args()

record
def main():
    args = parse_args(sys.argv[1:])

    if args.throw:
        raise RuntimeError("rasing error since --throw was specified")

    if args.exit:
        sys.exit(1)

    init_method=args.init_method
    backend=args.backend

    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])

    print(f"initializing `{backend}` process group with rank={rank}, world_size={world_size} at {init_method}")

    dist.init_process_group(
        backend=backend,
        init_method=init_method,
        world_size=world_size,
        rank=rank)

    print(f"successfully initialized process group with rank={dist.get_rank()}, world_size={dist.get_world_size()}")

    t = F.one_hot(torch.tensor(rank), num_classes=world_size)
    dist.all_reduce(t)
    derived_world_size = torch.sum(t).item()
    if derived_world_size != world_size:
        raise RuntimeError(f"derived world size: {derived_world_size} != actual world size: {world_size}")
    else:
        print(f"sucessfully derived world size: {derived_world_size} (expected: {world_size}). Exiting")

if __name__ == "__main__":
    main()
```

run it as:

```
$ python -m torch.distributed.run --nproc_per_node 2 test.py --throw
```

Reviewed By: cbalioglu

Differential Revision: D30953731

fbshipit-source-id: bbea04c59c2aec58969cf44d8e3723d5f8abe8a8

torch/distributed/elastic/multiprocessing/errors/__init__.py
torch/distributed/elastic/multiprocessing/errors/error_handler.py

index ab0e0f3..c1be1e0 100644 (file)
@@ -57,13 +57,14 @@ from dataclasses import dataclass, field
 from datetime import datetime
 from functools import wraps
 from string import Template
-from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Any
+from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
 
 from torch.distributed.elastic.utils.logging import get_logger
 
 from .error_handler import ErrorHandler  # noqa: F401
 from .handlers import get_error_handler  # noqa: F401
 
+
 log = get_logger()
 
 
@@ -245,7 +246,7 @@ class ChildFailedError(Exception):
                 other_failures_fmt.append(fmt)
 
         # upper boundary on width
-        width = min(width, 250)
+        width = min(width, 80)
 
         return Template(_MSG_FORMAT_TEMPLATE).substitute(
             boarder=boarder_delim * width,
@@ -258,18 +259,22 @@ class ChildFailedError(Exception):
     def _format_failure(
         self, idx: int, rank: int, failure: ProcessFailure
     ) -> Tuple[str, int]:
-        if isinstance(failure.message, str):
-            msg = '"' + failure.message + '"'
-        else:
-            try:
-                dmp = json.dumps(failure.message, indent=2)
-            except ValueError:
-                msg = failure.message
-            else:
-                msg = os.linesep
-                # Indent by 4 chars.
-                for l in dmp.splitlines():
-                    msg += f"    {l}{os.linesep}"
+
+        # failure.message is either a str (when the failure does not generate a traceback - e.g. signals)
+        # or a dict (json) of the form
+        # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}}
+        # so the display logic is:
+        # 1. if failure.message is not a dict (it is a str) just show it as is
+        # 2. else try to get the traceback (py_callstack)
+        # 3.      if the traceback is not there, use the message
+        # 4.      if the message  is not there show <N/A>
+        msg = failure.message
+        if isinstance(failure.message, dict):
+            msg = (
+                failure.message.get("extraInfo", {})
+                .get("py_callstack", failure.message.get("message", "<N/A>"))
+                .replace("\n", "\n  ")  # to properly indent the traceback
+            )
 
         fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
             idx=idx,
index 2974355..74586e9 100644 (file)
@@ -23,7 +23,7 @@ def _write_error(e: BaseException, error_file: Optional[str]):
         "message": {
             "message": f"{type(e).__name__}: {e}",
             "extraInfo": {
-                "py_callstack": traceback.format_stack(),
+                "py_callstack": traceback.format_exc(),
                 "timestamp": str(int(time.time())),
             },
         }