From: Wanchao Liang Date: Fri, 5 Apr 2019 00:00:46 +0000 (-0700) Subject: add Fast-RNN to AI-PEP X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~407 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=07efee395c2e3c6ba9bf54581fcce2aac91a54c1;p=platform%2Fupstream%2Fpytorch.git add Fast-RNN to AI-PEP Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18885 Reviewed By: hl475 Differential Revision: D14728854 fbshipit-source-id: 7e7a2946929551963f7c938e3d82a260a9efdfbd --- diff --git a/benchmarks/fastrnns/bench.py b/benchmarks/fastrnns/bench.py index 71cad4a..cdda25b 100644 --- a/benchmarks/fastrnns/bench.py +++ b/benchmarks/fastrnns/bench.py @@ -110,6 +110,36 @@ def print_stderr(*args, **kwargs): return print(*args, **kwargs) +def print_json_oss_format(results): + oss_results = {} + for group_name, group_val in results.items(): + oss_results[group_name] = {} + for model_name, run_time in group_val.items(): + # Output for OSS + oss_results[group_name][model_name] = run_time['avg'] + + print(json.dumps(oss_results)) + + +def print_json_pep_format(num_iters, results): + # print the AI-PEP format json string for each model + for group_name, group_val in results.items(): + for model_name, run_time in group_val.items(): + # Output for AI-PEP + print("Caffe2Observer " + json.dumps( + { + "type": "NET", + "metric": group_name + "-" + model_name, + "unit": "ms", + "num_runs": num_iters, + "summary": { + "mean": run_time['avg'], + "stdev": run_time['std'] + } + } + )) + + def bench(rnn_runners, group_name, print_json=False, sep=' ', **params): print_stderr(print_header(sep=sep)) results = {} @@ -124,8 +154,8 @@ def bench(rnn_runners, group_name, print_json=False, sep=' ', **params): raise return { - group_name: {k: v.avg_fwd for k, v in results.items()}, - group_name + '-backward': {k: v.avg_bwd for k, v in results.items()}, + group_name: {k: {"avg": v.avg_fwd, "std": v.std_fwd} for k, v in results.items()}, + group_name + '-backward': {k: {"avg": v.avg_bwd, "std": v.std_bwd} for k, v in results.items()}, } @@ -157,7 +187,7 @@ if __name__ == '__main__': 'Note that some of these run really slowly ' 'and that the `seqLength` flag will be ignored.') parser.add_argument('--sep', default=' ', type=str) - parser.add_argument('--print-json', action='store_true') + parser.add_argument('--print-json', nargs='?', default=None, const='oss') parser.add_argument('--rnns', nargs='*', help='What to run. cudnn, aten, jit, etc') parser.add_argument('--cnns', nargs='*', @@ -197,5 +227,7 @@ if __name__ == '__main__': if 'cnns' in args.group: results.update(bench_group(cnns, 'ResNet', 'resnet', bench_args)) - if args.print_json: - print(json.dumps(results)) + if args.print_json == 'oss': + print_json_oss_format(results) + elif args.print_json == 'pep': + print_json_pep_format(args.nloops, results)