Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / rnn_aux.hpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include "rnn/rnn.hpp"
18 #include <assert.h>
19 #include <stdlib.h>
20
21 namespace rnn {
22
23 typedef enum {
24     rnn_forward = 0,
25     rnn_backward,
26 } rnn_propagation_t;
27
28 typedef enum {
29     left2right = 0,
30     right2left,
31 } rnn_iter_direction_t;
32
33 typedef enum {
34     bottom2top = 0,
35     top2bottom,
36 } rnn_layer_direction_t;
37
38 typedef enum { action_copy = 0, action_sum, action_concat } rnn_action_t;
39
40 void init_buffer(float *buf, int size, float value);
41
42 float logistic(float x);
43 float dlogistic(float x);
44 float relu(float x);
45 float drelu(float x);
46 float dtanhf(float x);
47 float one_m_square(float x);
48 float x_m_square(float x);
49
50 int compare_dat(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
51         dnn_mem_t &mem_fp, res_t *r, bool final_compare);
52
53 int compare_input(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
54         res_t *r, bool final_compare);
55 int compare_states(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
56         res_t *r, bool final_compare);
57 int compare_weights_input(const rnn_prb_t *p, dnn_mem_t &mem_dt,
58         dnn_mem_t &mem_fp, res_t *r, bool final_compare);
59 int compare_weights_states(const rnn_prb_t *p, dnn_mem_t &mem_dt,
60         dnn_mem_t &mem_fp, res_t *r, bool final_compare);
61 int compare_bias(const rnn_prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
62         res_t *r, bool final_compare);
63 int compare_dst_last_layer(const rnn_prb_t *p, dnn_mem_t &mem_dt,
64         dnn_mem_t &mem_fp, res_t *r, bool final_compare);
65 int compare_dst_last_iteration(const rnn_prb_t *p, dnn_mem_t &mem_dt,
66         dnn_mem_t &mem_fp, res_t *r, bool final_compare);
67 };