Minor fixes in fastrnns benchmarks
authorJunjie Bai <jbai@fb.com>
Fri, 29 Mar 2019 08:16:52 +0000 (01:16 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 08:22:28 +0000 (01:22 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18613

Reviewed By: wanchaol

Differential Revision: D14681838

fbshipit-source-id: 60bd5c9b09398c74335f003cd21ea32dd1c45876

benchmarks/fastrnns/__init__.py
benchmarks/fastrnns/profile.py

index f32d4a0..f66d75a 100644 (file)
@@ -1,6 +1,5 @@
 from .cells import *
 from .factory import *
-from .test import *
 
 # (output, next_state) = cell(input, state)
 seqLength = 100
index 3b0ec2b..049287c 100644 (file)
@@ -6,7 +6,7 @@ import time
 import torch
 import datetime
 
-from .runner import get_rnn_runners
+from .runner import get_nn_runners
 
 PY3 = sys.version_info >= (3, 0)
 
@@ -48,7 +48,7 @@ def profile(rnns, sleep_between_seconds=1, nloops=5,
     params = dict(seqLength=seqLength, numLayers=numLayers,
                   inputSize=inputSize, hiddenSize=hiddenSize,
                   miniBatch=miniBatch, device=device, seed=seed)
-    for name, creator, context in get_rnn_runners(*rnns):
+    for name, creator, context in get_nn_runners(*rnns):
         with context():
             run_rnn(name, creator, nloops, **params)
             time.sleep(sleep_between_seconds)
@@ -94,11 +94,11 @@ def nvprof(cmd, outpath):
 
 
 def full_profile(rnns, **args):
-    args['internal_run'] = True
     profile_args = []
     for k, v in args.items():
         profile_args.append('--{}={}'.format(k, v))
     profile_args.append('--rnns {}'.format(' '.join(rnns)))
+    profile_args.append('--internal_run')
 
     outpath = nvprof_output_filename(rnns, **args)
 
@@ -125,7 +125,7 @@ if __name__ == '__main__':
 
     # if internal_run, we actually run the rnns.
     # if not internal_run, we shell out to nvprof with internal_run=T
-    parser.add_argument('--internal_run', default=False, type=bool,
+    parser.add_argument('--internal_run', default=False, action='store_true',
                         help='Don\'t use this')
     args = parser.parse_args()
     if args.rnns is None: