Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / bench_rnn.cpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22
23 #include "mkldnn.h"
24
25 #include "mkldnn_common.hpp"
26 #include "mkldnn_debug.hpp"
27 #include "mkldnn_memory.hpp"
28
29 #include "rnn/input_rnn.hpp"
30 #include "rnn/rnn.hpp"
31
32 namespace rnn {
33
34 int bench(int argc, char **argv) {
35     // !!?? TODO: check consistence of direction, dir ...
36     mkldnn_prop_kind_t direction = mkldnn_forward;
37     dir_t dir = FWD_D;
38     for (int arg = 0; arg < argc; ++arg) {
39         if (!strncmp("--dir=", argv[arg], 6)) {
40             dir = str2dir(argv[arg] + 6);
41             if (dir == FWD_D)
42                 direction = mkldnn_forward;
43             else if (dir == BWD_D)
44                 direction = mkldnn_backward;
45             else
46                 assert("unknown dir");
47         }
48     }
49     const int num_r = sizeof(rnns) / sizeof(rnns[0]);
50
51     for (int r = 0; r < num_r; ++r) {
52         const rnn_prb_t p(rnns[r], conf_f32, direction);
53         check(&p);
54     }
55
56     return OK;
57 }
58
59 void check(const rnn_prb_t *p) {
60     res_t res{};
61     char pstr[max_prb_len];
62     prb2str(p, &res, pstr);
63
64     int status = rnn::doit(p, &res);
65
66     prb2str(p, &res, pstr);
67     bool want_perf_report = false;
68
69     parse_result(res, want_perf_report, false, status, pstr);
70
71     if (bench_mode & PERF)
72         perf_report(p, &res, pstr);
73
74     benchdnn_stat.tests++;
75 }
76
77 } // namespace rnn