add Fast-RNN to AI-PEP
authorWanchao Liang <wanchaol@fb.com>
Fri, 5 Apr 2019 00:00:46 +0000 (17:00 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 00:04:21 +0000 (17:04 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18885

Reviewed By: hl475

Differential Revision: D14728854

fbshipit-source-id: 7e7a2946929551963f7c938e3d82a260a9efdfbd

benchmarks/fastrnns/bench.py

index 71cad4a..cdda25b 100644 (file)
@@ -110,6 +110,36 @@ def print_stderr(*args, **kwargs):
     return print(*args, **kwargs)
 
 
+def print_json_oss_format(results):
+    oss_results = {}
+    for group_name, group_val in results.items():
+        oss_results[group_name] = {}
+        for model_name, run_time in group_val.items():
+            # Output for OSS
+            oss_results[group_name][model_name] = run_time['avg']
+
+    print(json.dumps(oss_results))
+
+
+def print_json_pep_format(num_iters, results):
+    # print the AI-PEP format json string for each model
+    for group_name, group_val in results.items():
+        for model_name, run_time in group_val.items():
+            # Output for AI-PEP
+            print("Caffe2Observer " + json.dumps(
+                {
+                    "type": "NET",
+                    "metric": group_name + "-" + model_name,
+                    "unit": "ms",
+                    "num_runs": num_iters,
+                    "summary": {
+                        "mean": run_time['avg'],
+                        "stdev": run_time['std']
+                    }
+                }
+            ))
+
+
 def bench(rnn_runners, group_name, print_json=False, sep=' ', **params):
     print_stderr(print_header(sep=sep))
     results = {}
@@ -124,8 +154,8 @@ def bench(rnn_runners, group_name, print_json=False, sep=' ', **params):
                     raise
 
     return {
-        group_name: {k: v.avg_fwd for k, v in results.items()},
-        group_name + '-backward': {k: v.avg_bwd for k, v in results.items()},
+        group_name: {k: {"avg": v.avg_fwd, "std": v.std_fwd} for k, v in results.items()},
+        group_name + '-backward': {k: {"avg": v.avg_bwd, "std": v.std_bwd} for k, v in results.items()},
     }
 
 
@@ -157,7 +187,7 @@ if __name__ == '__main__':
                         'Note that some of these run really slowly '
                         'and that the `seqLength` flag will be ignored.')
     parser.add_argument('--sep', default=' ', type=str)
-    parser.add_argument('--print-json', action='store_true')
+    parser.add_argument('--print-json', nargs='?', default=None, const='oss')
     parser.add_argument('--rnns', nargs='*',
                         help='What to run. cudnn, aten, jit, etc')
     parser.add_argument('--cnns', nargs='*',
@@ -197,5 +227,7 @@ if __name__ == '__main__':
     if 'cnns' in args.group:
         results.update(bench_group(cnns, 'ResNet', 'resnet', bench_args))
 
-    if args.print_json:
-        print(json.dumps(results))
+    if args.print_json == 'oss':
+        print_json_oss_format(results)
+    elif args.print_json == 'pep':
+        print_json_pep_format(args.nloops, results)