Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / frontend / onnx_import / src / op / gru.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <string>
18 #include <vector>
19
20 #include "gru.hpp"
21 #include "ngraph/builder/split.hpp"
22 #include "ngraph/shape.hpp"
23 #include "onnx_import/core/null_node.hpp"
24 #include "onnx_import/default_opset.hpp"
25 #include "onnx_import/utils/recurrent.hpp"
26
27 namespace ngraph
28 {
29     namespace onnx_import
30     {
31         namespace op
32         {
33             namespace set_1
34             {
35                 namespace
36                 {
37                     struct GRUInputMap : public recurrent::OpInputMap
38                     {
39                         GRUInputMap(const Node& node, std::size_t gates_count)
40                             : OpInputMap(node, gates_count)
41                         {
42                             bool linear_before_reset = static_cast<bool>(
43                                 node.get_attribute_value<std::int64_t>("linear_before_reset", 0));
44
45                             // Override bias, since we need separated W and R biases for `h` gate.
46                             if (linear_before_reset)
47                             {
48                                 const auto& ng_inputs = node.get_ng_inputs();
49                                 const auto el_type = ng_inputs.at(0).get_element_type();
50
51                                 if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
52                                 {
53                                     NGRAPH_SUPPRESS_DEPRECATED_START
54
55                                     auto bias = ng_inputs.at(3);
56                                     // gates_count * 2 since B is: [Wb, Rb]
57                                     const int split_parts = 2 * 3;
58                                     const auto split_bias =
59                                         builder::opset1::split(bias, split_parts, 1);
60                                     const auto wr_z_bias = split_bias.at(0) + split_bias.at(3);
61                                     const auto wr_r_bias = split_bias.at(1) + split_bias.at(4);
62                                     // The result has shape: [num_directions, 4 * hidden_size]
63                                     // and data layout:
64                                     //       [
65                                     //          [Wb_z + Rb_z],
66                                     //          [Wb_r + Rb_r],
67                                     //          [Wb_h],
68                                     //          [Rb_h],
69                                     //          // num_directions times
70                                     //       ]
71                                     m_map[recurrent::OpInput::B] =
72                                         std::make_shared<default_opset::Concat>(
73                                             OutputVector{wr_z_bias,
74                                                          wr_r_bias,
75                                                          split_bias.at(2),
76                                                          split_bias.at(5)},
77                                             1);
78                                     NGRAPH_SUPPRESS_DEPRECATED_END
79                                 }
80                                 else
81                                 {
82                                     const std::size_t hidden_size =
83                                         m_map[recurrent::OpInput::R].get_shape().back();
84                                     const std::size_t num_directions =
85                                         m_map[recurrent::OpInput::W].get_shape().front();
86
87                                     m_map[recurrent::OpInput::B] =
88                                         std::make_shared<default_opset::Constant>(
89                                             el_type,
90                                             Shape{num_directions, (gates_count + 1) * hidden_size},
91                                             0.f);
92                                 }
93                             }
94                         }
95
96                         virtual ~GRUInputMap() = default;
97                     };
98
99                     struct GRUAttributes : public recurrent::OpAttributes
100                     {
101                         GRUAttributes(const Node& node)
102                             : OpAttributes(node)
103                             , m_linear_before_reset{static_cast<bool>(
104                                   node.get_attribute_value<std::int64_t>("linear_before_reset", 0))}
105                         {
106                             m_activations = node.get_attribute_value<std::vector<std::string>>(
107                                 "activations", {"sigmoid", "tanh"});
108                         }
109
110                         virtual ~GRUAttributes() = default;
111
112                         bool m_linear_before_reset;
113                     };
114                 }
115
116                 OutputVector gru(const Node& node)
117                 {
118                     constexpr std::size_t gates_count = 3;
119                     GRUInputMap input_map{node, gates_count};
120                     GRUAttributes attributes{node};
121
122                     recurrent::RecurrentSequence sequence_op(input_map, attributes.m_direction);
123                     auto results =
124                         sequence_op.run_sequence([&attributes](const recurrent::OpInputMap& args,
125                                                                const Output<ngraph::Node>& in_Xt,
126                                                                const Output<ngraph::Node> H_t) {
127
128                             const GRUInputMap& gru_args = dynamic_cast<const GRUInputMap&>(args);
129
130                             return std::make_shared<default_opset::GRUCell>(
131                                 in_Xt,
132                                 H_t,
133                                 gru_args.at(recurrent::OpInput::W),
134                                 gru_args.at(recurrent::OpInput::R),
135                                 gru_args.at(recurrent::OpInput::B),
136                                 attributes.m_hidden_size,
137                                 attributes.m_activations,
138                                 attributes.m_activations_alpha,
139                                 attributes.m_activations_beta,
140                                 attributes.m_clip_threshold,
141                                 attributes.m_linear_before_reset);
142                         });
143                     return results;
144                 }
145
146             } // namespace set_1
147
148         } // namespace op
149
150     } // namespace onnx_import
151
152 } // namespace ngraph