Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / lstm_dynamic_input_inst.h
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "api_extension/lstm_dynamic_input.hpp"
20 #include "primitive_inst.h"
21 #include "error_handler.h"
22 #include <memory>
23 #include <string>
24
25 namespace cldnn {
26
27 template <>
28 struct typed_program_node<lstm_dynamic_input> : public typed_program_node_base<lstm_dynamic_input> {
29     using parent = typed_program_node_base<lstm_dynamic_input>;
30
31 public:
32     typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog) : parent(prim, prog) {}
33
34     program_node& input() const { return get_dependency(0); }
35     program_node& dyn_length() const { return get_dependency(1); }
36     program_node& weights() const { return get_dependency(2); }
37
38     program_node& bias() const {
39         CLDNN_ERROR_BOOL(id(), "Bias term", !bias_term(), "Trying to get non existing bias.");
40         return get_dependency(3);
41     }
42
43     int32_t direction() const { return weights().get_output_layout().size.feature[0]; }
44     bool dyn_length_term() const { return !get_primitive()->dyn_length.empty(); }
45     bool bias_term() const { return !get_primitive()->bias.empty(); }
46     bool weights_term() const { return !get_primitive()->weights.empty(); }
47 };
48
49 using lstm_dynamic_input_node = typed_program_node<lstm_dynamic_input>;
50
51 template <>
52 class typed_primitive_inst<lstm_dynamic_input> : public typed_primitive_inst_base<lstm_dynamic_input> {
53     using parent = typed_primitive_inst_base<lstm_dynamic_input>;
54
55 public:
56     static layout calc_output_layout(lstm_dynamic_input_node const& node);
57     static std::string to_string(lstm_dynamic_input_node const& node);
58
59 public:
60     typed_primitive_inst(network_impl& network, lstm_dynamic_input_node const& node);
61
62     memory_impl& dyn_length_memory() const { return dep_memory(1); }
63     memory_impl& weights_memory() const { return dep_memory(2); }
64     memory_impl& bias_memory() const {
65         CLDNN_ERROR_BOOL(id(), "Bias term", !bias_term(), "Trying to get non existing bias memory.");
66         return dep_memory(3);
67     }
68     int32_t direction() const { return node.direction(); }
69     bool bias_term() const { return node.bias_term(); }
70 };
71
72 using lstm_dynamic_input_inst = typed_primitive_inst<lstm_dynamic_input>;
73
74 }  // namespace cldnn