Make timeout in resnet50_trainer configurable (#17058)
authorJunjie Bai <bai@in.tum.de>
Thu, 14 Feb 2019 00:57:30 +0000 (16:57 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 14 Feb 2019 01:03:48 +0000 (17:03 -0800)
Summary:
xw285cornell petrex dagamayank
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17058

Differential Revision: D14068458

Pulled By: bddppq

fbshipit-source-id: 15df4007859067a22df4c6c407df4121e19aaf97

caffe2/python/examples/resnet50_trainer.py

index 6f67135..c22ce5f 100644 (file)
@@ -149,13 +149,13 @@ def LoadModel(path, model, use_ideep):
     predict_init_net = core.Net(pred_utils.GetNet(
         meta_net_def, predictor_constants.PREDICT_INIT_NET_TYPE))
 
-    if use_ideep: 
+    if use_ideep:
         predict_init_net.RunAllOnIDEEP()
-    else: 
+    else:
         predict_init_net.RunAllOnGPU()
-    if use_ideep: 
+    if use_ideep:
         init_net.RunAllOnIDEEP()
-    else: 
+    else:
         init_net.RunAllOnGPU()
 
     assert workspace.RunNetOnce(predict_init_net)
@@ -191,7 +191,7 @@ def RunEpoch(
     for i in range(epoch_iters):
         # This timeout is required (temporarily) since CUDA-NCCL
         # operators might deadlock when synchronizing between GPUs.
-        timeout = 600.0 if i == 0 else 60.0
+        timeout = args.first_iter_timeout if i == 0 else args.timeout
         with timeout_guard.CompleteInTimeOrDie(timeout):
             t1 = time.time()
             workspace.RunNet(train_model.net.Proto().name)
@@ -673,6 +673,13 @@ def main():
     parser.add_argument("--distributed_interfaces", type=str, default="",
                         help="Network interfaces to use for distributed run")
 
+    parser.add_argument("--first_iter_timeout", type=int, default=600,
+                        help="Timeout (secs) of the first iteration "
+                        "(default: %(default)s)")
+    parser.add_argument("--timeout", type=int, default=60,
+                        help="Timeout (secs) of each (except the first) iteration "
+                        "(default: %(default)s)")
+
     args = parser.parse_args()
 
     Train(args)