import torch
import datetime
-from .runner import get_rnn_runners
+from .runner import get_nn_runners
PY3 = sys.version_info >= (3, 0)
params = dict(seqLength=seqLength, numLayers=numLayers,
inputSize=inputSize, hiddenSize=hiddenSize,
miniBatch=miniBatch, device=device, seed=seed)
- for name, creator, context in get_rnn_runners(*rnns):
+ for name, creator, context in get_nn_runners(*rnns):
with context():
run_rnn(name, creator, nloops, **params)
time.sleep(sleep_between_seconds)
def full_profile(rnns, **args):
- args['internal_run'] = True
profile_args = []
for k, v in args.items():
profile_args.append('--{}={}'.format(k, v))
profile_args.append('--rnns {}'.format(' '.join(rnns)))
+ profile_args.append('--internal_run')
outpath = nvprof_output_filename(rnns, **args)
# if internal_run, we actually run the rnns.
# if not internal_run, we shell out to nvprof with internal_run=T
- parser.add_argument('--internal_run', default=False, type=bool,
+ parser.add_argument('--internal_run', default=False, action='store_true',
help='Don\'t use this')
args = parser.parse_args()
if args.rnns is None: