fc180998a864bf45b23fb8f0deed1db3adba0f01
[platform/upstream/dldt.git] / ngraph / core / src / op / util / rnn_cell_base.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <algorithm>
18 #include <iterator>
19 #include <locale>
20
21 #include "ngraph/attribute_visitor.hpp"
22 #include "ngraph/op/add.hpp"
23 #include "ngraph/op/clamp.hpp"
24 #include "ngraph/op/multiply.hpp"
25 #include "ngraph/op/subtract.hpp"
26 #include "ngraph/op/util/rnn_cell_base.hpp"
27 #include "ngraph/util.hpp"
28
29 using namespace std;
30 using namespace ngraph;
31
32 // Modify input vector in-place and return reference to modified vector.
33 static vector<string> to_lower_case(const vector<string>& vs)
34 {
35     vector<string> res(vs);
36     transform(begin(res), end(res), begin(res), [](string& s) { return to_lower(s); });
37     return res;
38 }
39
40 op::util::RNNCellBase::RNNCellBase()
41     : m_clip(0.f)
42     , m_hidden_size(0)
43 {
44 }
45
46 op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
47                                    float clip,
48                                    const vector<string>& activations,
49                                    const vector<float>& activations_alpha,
50                                    const vector<float>& activations_beta)
51     : m_hidden_size(hidden_size)
52     , m_clip(clip)
53     , m_activations(to_lower_case(activations))
54     , m_activations_alpha(activations_alpha)
55     , m_activations_beta(activations_beta)
56 {
57 }
58
59 bool ngraph::op::util::RNNCellBase::visit_attributes(AttributeVisitor& visitor)
60 {
61     visitor.on_attribute("hidden_size", m_hidden_size);
62     visitor.on_attribute("activations", m_activations);
63     visitor.on_attribute("activations_alpha", m_activations_alpha);
64     visitor.on_attribute("activations_beta", m_activations_beta);
65     visitor.on_attribute("clip", m_clip);
66     return true;
67 }
68
69 op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const
70 {
71     // Normalize activation function case.
72     std::string func_name = m_activations.at(idx);
73     std::locale loc;
74     std::transform(func_name.begin(), func_name.end(), func_name.begin(), [&loc](char c) {
75         return std::tolower(c, loc);
76     });
77
78     op::util::ActivationFunction afunc = get_activation_func_by_name(func_name);
79
80     // Set activation functions parameters (if any)
81     if (m_activations_alpha.size() > idx)
82     {
83         afunc.set_alpha(m_activations_alpha.at(idx));
84     }
85     if (m_activations_beta.size() > idx)
86     {
87         afunc.set_beta(m_activations_beta.at(idx));
88     }
89
90     return afunc;
91 }
92
93 shared_ptr<Node> op::util::RNNCellBase::add(const Output<Node>& lhs, const Output<Node>& rhs)
94 {
95     return {make_shared<op::v1::Add>(lhs, rhs)};
96 }
97
98 shared_ptr<Node> op::util::RNNCellBase::sub(const Output<Node>& lhs, const Output<Node>& rhs)
99 {
100     return {make_shared<op::v1::Subtract>(lhs, rhs)};
101 }
102
103 shared_ptr<Node> op::util::RNNCellBase::mul(const Output<Node>& lhs, const Output<Node>& rhs)
104 {
105     return {make_shared<op::v1::Multiply>(lhs, rhs)};
106 }
107
108 shared_ptr<Node> op::util::RNNCellBase::clip(const Output<Node>& data) const
109 {
110     if (m_clip == 0.f)
111     {
112         return data.get_node_shared_ptr();
113     }
114
115     return make_shared<op::Clamp>(data, -m_clip, m_clip);
116 }