Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / bench_rnn.cpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22
23 #include "mkldnn.h"
24
25 #include "mkldnn_common.hpp"
26 #include "mkldnn_debug.hpp"
27 #include "mkldnn_memory.hpp"
28
29 #include "rnn/rnn.hpp"
30
31 namespace rnn {
32
33 /* global driver parameters */
34 mkldnn_prop_kind_t prop = mkldnn_forward;
35 alg_t alg = VANILLA_RNN;
36 mkldnn_rnn_direction_t direction = mkldnn_unidirectional_left2right;
37 activation_t activation = RELU;
38 const char *perf_template = "perf,%n,%d,,,%-t,,%0t,";
39 const dt_conf_t *cfg = conf_f32;
40 policy_t scale_policy = NONE;
41 attr_t attr;
42 bool allow_unimpl = false;
43 int mb = 0;
44
45 void reset_parameters() {
46     cfg = conf_f32;
47     attr = attr_t();
48     prop = mkldnn_forward;
49     alg = VANILLA_RNN;
50     direction = mkldnn_unidirectional_left2right;
51     activation = RELU;
52     scale_policy = NONE;
53     allow_unimpl = false;
54     mb = 0;
55 }
56
57 int bench(int argc, char **argv, bool main_bench) {
58     for (int arg = 0; arg < argc; ++arg) {
59         if (!strncmp("--batch=", argv[arg], 8))
60             SAFE(batch(argv[arg] + 8, bench), CRIT);
61         else if (!strncmp("--prop=", argv[arg], 7)) {
62             dir_t dir = str2dir(argv[arg] + 7);
63             if (dir == FWD_D)
64                 prop = mkldnn_forward;
65             else if (dir == BWD_DW)
66                 prop = mkldnn_backward;
67             else
68                 assert("unknown dir");
69         } else if (!strncmp("--alg=", argv[arg], 6))
70             alg = str2alg(argv[arg] + 6);
71         else if (!strncmp("--cfg=", argv[arg], 6))
72             cfg = str2cfg(argv[arg] + 6);
73         else if (!strncmp("--attr=", argv[arg], 7))
74             SAFE(str2attr(&attr, argv[arg] + 7), CRIT);
75         else if (!strncmp("--direction=", argv[arg], 12))
76             direction = str2direction(argv[arg] + 12);
77         else if (!strncmp("--activation=", argv[arg], 13))
78             activation = str2activation(argv[arg] + 13);
79         else if (!strncmp("--allow-unimpl=", argv[arg], 15))
80             allow_unimpl = str2bool(argv[arg] + 15);
81         else if (!strncmp("--scaling=", argv[arg], 10))
82             scale_policy = str2policy(argv[arg] + 10);
83         else if (!strncmp("--reset", argv[arg], 7))
84             reset_parameters();
85         else if (!strncmp("--perf-template=", argv[arg], 16))
86             perf_template = argv[arg] + 16;
87         else if (!strncmp("--mb=", argv[arg], 5))
88             mb = atoi(argv[arg] + 5);
89         else if (!strncmp("-v", argv[arg], 2))
90             verbose = atoi(argv[arg] + 2);
91         else if (!strncmp("--verbose=", argv[arg], 10))
92             verbose = atoi(argv[arg] + 10);
93         else {
94             rnn_desc_t d;
95             if (str2desc(&d, argv[arg]) == FAIL) {
96                 fprintf(stderr, "driver: unknown option: `%s`, exiting...\n",
97                         argv[arg]);
98                 exit(2);
99             }
100             if (cfg != conf_f32 && alg != VANILLA_LSTM) {
101                 fprintf(stderr,
102                         "driver: configuration ``%s` is supported for LSTM "
103                         "cell only, exiting...\n",
104                         cfg2str(cfg));
105                 exit(2);
106             }
107             if (cfg != conf_f32 && scale_policy == NONE) {
108                 fprintf(stderr,
109                         "driver: configuration ``%s` requires scale policy to "
110                         "be COMMON or PER_OC, exiting...\n",
111                         cfg2str(cfg));
112                 exit(2);
113             }
114             check(&d);
115         }
116     }
117     return OK;
118 }
119
120 void check(rnn_desc_t *d) {
121     const rnn_prb_t p(*d, cfg, prop, alg, direction, activation, attr,
122         scale_policy, mb);
123     res_t res{};
124     char pstr[max_prb_len];
125
126     int status = rnn::doit(&p, &res);
127
128     prb2str(&p, &res, pstr);
129     bool want_perf_report = false;
130
131     parse_result(res, want_perf_report, allow_unimpl, status, pstr);
132
133     if (bench_mode & PERF)
134         perf_report(&p, &res, pstr);
135
136     benchdnn_stat.tests++;
137 }
138
139 } // namespace rnn