d649f7e74df0319eb42883823775848666b23543
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / mvn.hpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "../C/mvn.h"
20 #include "primitive.hpp"
21
22 namespace cldnn {
23 /// @addtogroup cpp_api C++ API
24 /// @{
25 /// @addtogroup cpp_topology Network Topology
26 /// @{
27 /// @addtogroup cpp_primitives Primitives
28 /// @{
29
30 /// @brief Mean Variance Normalization primitive.
31 /// @details Normalizes the input to have 0-mean and/or unit (1) variance.
32 struct mvn : public primitive_base<mvn, CLDNN_PRIMITIVE_DESC(mvn)> {
33     CLDNN_DECLARE_PRIMITIVE(mvn)
34
35     /// @brief Constructs mvn primitive.
36     /// @param id This primitive id.
37     /// @param input Input primitive id.
38     /// @param across_channels Determines if the normalization is done across or within channels. Default is within channels.'
39     /// @param normalize_variance Determines if normalize variance is applied. Default is true.
40     /// @param epsilon Epsilon for not dividing by zero while normalizing.
41     mvn(const primitive_id& id,
42         const primitive_id& input,
43         const bool across_channels = false,
44         const bool normalize_variance = true,
45         const float epsilon = 1e-10f,
46         const padding& output_padding = padding())
47         : primitive_base(id, {input}, output_padding),
48           across_channels(across_channels),
49           normalize_variance(normalize_variance),
50           epsilon(epsilon) {}
51
52     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{mvn}
53     mvn(const dto* dto)
54         : primitive_base(dto),
55           across_channels(dto->across_channels != 0),
56           normalize_variance(dto->normalize_variance != 0),
57           epsilon(dto->epsilon) {}
58
59     /// @brief Determines if the normalization is done across or within channels.
60     bool across_channels;
61     /// @brief Determines if normalize variance is applied.
62     bool normalize_variance;
63     /// @brief Epsilon for not dividing by zero while normalizing.
64     float epsilon;
65
66 protected:
67     void update_dto(dto& dto) const override {
68         dto.across_channels = across_channels;
69         dto.normalize_variance = normalize_variance;
70         dto.epsilon = epsilon;
71     }
72 };
73 /// @}
74 /// @}
75 /// @}
76 }  // namespace cldnn