1d4c620535f6735299f1376ddf8ea0954376eedb
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / batch_normalization.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include "mkldnn.h"
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "utils.hpp"
23
24 using namespace mkldnn::impl;
25 using namespace mkldnn::impl::utils;
26 using namespace mkldnn::impl::status;
27 using namespace mkldnn::impl::prop_kind;
28 using namespace mkldnn::impl::alg_kind;
29 using namespace mkldnn::impl::types;
30
31 namespace {
32 status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
33         prop_kind_t prop_kind, const memory_desc_t *data_desc,
34         const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) {
35     bool args_ok = true
36         && !any_null(bnrm_desc, data_desc)
37         && one_of(prop_kind, forward_training, forward_inference,
38                 backward_data, backward)
39         && implication(prop_kind & backward, diff_data_desc != nullptr);
40     if (!args_ok) return invalid_arguments;
41
42     auto bd = batch_normalization_desc_t();
43     bd.primitive_kind = primitive_kind::batch_normalization;
44     bd.prop_kind = prop_kind;
45
46     bd.data_desc = *data_desc;
47     bd.diff_data_desc = zero_md();
48     if ( one_of(bd.prop_kind,backward_data, backward) )
49         bd.diff_data_desc = *diff_data_desc;
50
51     dims_t scaleshift_dims = { 2, data_desc->dims[1] };
52     mkldnn_memory_desc_init(&bd.data_scaleshift_desc, 2, scaleshift_dims,
53             data_desc->data_type, mkldnn_nc);
54     bd.diff_data_scaleshift_desc = zero_md();
55     if (bd.prop_kind == backward) {
56         mkldnn_memory_desc_init(&bd.diff_data_scaleshift_desc, 2,
57                 scaleshift_dims, data_desc->data_type, mkldnn_nc);
58     }
59
60     dims_t stats_dims = { data_desc->dims[1] };
61     mkldnn_memory_desc_init(&bd.mean_desc, 1, stats_dims,
62             data_desc->data_type, mkldnn_x);
63     mkldnn_memory_desc_init(&bd.variance_desc, 1, stats_dims,
64             data_desc->data_type, mkldnn_x);
65
66     bd.batch_norm_epsilon = epsilon;
67
68     unsigned bnorm_flags =
69         mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu;
70     if ((~bnorm_flags & flags) != 0) return invalid_arguments;
71
72     bd.flags = flags;
73
74     bool consistency = true
75         && utils::one_of(bd.data_desc.ndims, 2, 4, 5);
76     if (bd.prop_kind == backward_data)
77         consistency = consistency
78             && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5)
79             && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims,
80                     bd.diff_data_desc.ndims);
81     if (!consistency) return invalid_arguments;
82
83     *bnrm_desc = bd;
84     return success;
85 }
86 }
87
88 status_t mkldnn_batch_normalization_forward_desc_init(
89         batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
90         const memory_desc_t *data_desc, float epsilon, unsigned flags) {
91     if (!one_of(prop_kind, forward_training, forward_inference))
92         return invalid_arguments;
93     return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr,
94             epsilon, flags);
95 }
96
97 status_t mkldnn_batch_normalization_backward_desc_init(
98         batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
99         const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
100         float epsilon, unsigned flags) {
101     if (!one_of(prop_kind, backward, backward_data))
102         return invalid_arguments;
103     return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc,
104             epsilon, flags);
105 }
106
107 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s