1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
21 #include "ngraph/attribute_visitor.hpp"
22 #include "ngraph/op/add.hpp"
23 #include "ngraph/op/clamp.hpp"
24 #include "ngraph/op/multiply.hpp"
25 #include "ngraph/op/subtract.hpp"
26 #include "ngraph/op/util/rnn_cell_base.hpp"
27 #include "ngraph/util.hpp"
30 using namespace ngraph;
32 // Modify input vector in-place and return reference to modified vector.
33 static vector<string> to_lower_case(const vector<string>& vs)
35 vector<string> res(vs);
36 transform(begin(res), end(res), begin(res), [](string& s) { return to_lower(s); });
40 op::util::RNNCellBase::RNNCellBase()
46 op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
48 const vector<string>& activations,
49 const vector<float>& activations_alpha,
50 const vector<float>& activations_beta)
51 : m_hidden_size(hidden_size)
53 , m_activations(to_lower_case(activations))
54 , m_activations_alpha(activations_alpha)
55 , m_activations_beta(activations_beta)
59 bool ngraph::op::util::RNNCellBase::visit_attributes(AttributeVisitor& visitor)
61 visitor.on_attribute("hidden_size", m_hidden_size);
62 visitor.on_attribute("activations", m_activations);
63 visitor.on_attribute("activations_alpha", m_activations_alpha);
64 visitor.on_attribute("activations_beta", m_activations_beta);
65 visitor.on_attribute("clip", m_clip);
69 op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const
71 // Normalize activation function case.
72 std::string func_name = m_activations.at(idx);
74 std::transform(func_name.begin(), func_name.end(), func_name.begin(), [&loc](char c) {
75 return std::tolower(c, loc);
78 op::util::ActivationFunction afunc = get_activation_func_by_name(func_name);
80 // Set activation functions parameters (if any)
81 if (m_activations_alpha.size() > idx)
83 afunc.set_alpha(m_activations_alpha.at(idx));
85 if (m_activations_beta.size() > idx)
87 afunc.set_beta(m_activations_beta.at(idx));
93 shared_ptr<Node> op::util::RNNCellBase::add(const Output<Node>& lhs, const Output<Node>& rhs)
95 return {make_shared<op::v1::Add>(lhs, rhs)};
98 shared_ptr<Node> op::util::RNNCellBase::sub(const Output<Node>& lhs, const Output<Node>& rhs)
100 return {make_shared<op::v1::Subtract>(lhs, rhs)};
103 shared_ptr<Node> op::util::RNNCellBase::mul(const Output<Node>& lhs, const Output<Node>& rhs)
105 return {make_shared<op::v1::Multiply>(lhs, rhs)};
108 shared_ptr<Node> op::util::RNNCellBase::clip(const Output<Node>& data) const
112 return data.get_node_shared_ptr();
115 return make_shared<op::Clamp>(data, -m_clip, m_clip);