From e22a2b9015efbfdec9b7b720a0773234a9b74b95 Mon Sep 17 00:00:00 2001 From: Junjie Bai Date: Fri, 29 Mar 2019 01:16:52 -0700 Subject: [PATCH] Minor fixes in fastrnns benchmarks Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18613 Reviewed By: wanchaol Differential Revision: D14681838 fbshipit-source-id: 60bd5c9b09398c74335f003cd21ea32dd1c45876 --- benchmarks/fastrnns/__init__.py | 1 - benchmarks/fastrnns/profile.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/benchmarks/fastrnns/__init__.py b/benchmarks/fastrnns/__init__.py index f32d4a0..f66d75a 100644 --- a/benchmarks/fastrnns/__init__.py +++ b/benchmarks/fastrnns/__init__.py @@ -1,6 +1,5 @@ from .cells import * from .factory import * -from .test import * # (output, next_state) = cell(input, state) seqLength = 100 diff --git a/benchmarks/fastrnns/profile.py b/benchmarks/fastrnns/profile.py index 3b0ec2b..049287c 100644 --- a/benchmarks/fastrnns/profile.py +++ b/benchmarks/fastrnns/profile.py @@ -6,7 +6,7 @@ import time import torch import datetime -from .runner import get_rnn_runners +from .runner import get_nn_runners PY3 = sys.version_info >= (3, 0) @@ -48,7 +48,7 @@ def profile(rnns, sleep_between_seconds=1, nloops=5, params = dict(seqLength=seqLength, numLayers=numLayers, inputSize=inputSize, hiddenSize=hiddenSize, miniBatch=miniBatch, device=device, seed=seed) - for name, creator, context in get_rnn_runners(*rnns): + for name, creator, context in get_nn_runners(*rnns): with context(): run_rnn(name, creator, nloops, **params) time.sleep(sleep_between_seconds) @@ -94,11 +94,11 @@ def nvprof(cmd, outpath): def full_profile(rnns, **args): - args['internal_run'] = True profile_args = [] for k, v in args.items(): profile_args.append('--{}={}'.format(k, v)) profile_args.append('--rnns {}'.format(' '.join(rnns))) + profile_args.append('--internal_run') outpath = nvprof_output_filename(rnns, **args) @@ -125,7 +125,7 @@ if __name__ == '__main__': # if internal_run, we actually run the rnns. # if not internal_run, we shell out to nvprof with internal_run=T - parser.add_argument('--internal_run', default=False, type=bool, + parser.add_argument('--internal_run', default=False, action='store_true', help='Don\'t use this') args = parser.parse_args() if args.rnns is None: -- 2.7.4