Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / runtime / contrib / pure_arm_compute / src / internal / op / Lstm.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * @file    Lstm.h
19  * @ingroup COM_AI_RUNTIME
20  * @brief   This file defines internal::tflite::op::LSTM::Param struct
21  *          and internal::tflite::op::LSTM::Node class
22  */
23 #ifndef __INTERNAL_OP_LSTM_H__
24 #define __INTERNAL_OP_LSTM_H__
25
26 #include "internal/op/Node.h"
27
28 #include <cstdint>
29
30 namespace internal
31 {
32 namespace tflite
33 {
34 namespace op
35 {
36 namespace LSTM
37 {
38
39 /**
40  * @brief Struct to have indexes for operation parameter
41  */
42 struct Param
43 {
44   int32_t scratch_buffer_index;   /**< Index of scartch buffer */
45   int32_t output_state_out_index; /**< Index of output state out */
46   int32_t cell_state_out_index;   /**< Index of cell state out */
47   int32_t output_index;           /**< Index of output */
48
49   int32_t input_index;                       /**< Index of input */
50   int32_t input_to_input_weights_index;      /**< Index of input to input weights */
51   int32_t input_to_forget_weights_index;     /**< Index of input to forget weights */
52   int32_t input_to_cell_weights_index;       /**< Index of input to cell weights */
53   int32_t input_to_output_weights_index;     /**< Index of input to output weights */
54   int32_t recurrent_to_input_weights_index;  /**< Index of recurrent to input weights */
55   int32_t recurrent_to_forget_weights_index; /**< Index of recurrent to forget weights */
56   int32_t recurrent_to_cell_weights_index;   /**< Index of recurrent to cell weights */
57   int32_t recurrent_to_output_weights_index; /**< Index of recurrent to output weights */
58   int32_t cell_to_input_weights_index;       /**< Index of cell to input weights */
59   int32_t cell_to_forget_weights_index;      /**< Index of cell to forget weights */
60   int32_t cell_to_output_weights_index;      /**< Index of cell to output weights */
61   int32_t input_gate_bias_index;             /**< Index of input gate bias */
62   int32_t forget_gate_bias_index;            /**< Index of forget gate bias */
63   int32_t cell_bias_index;                   /**< Index of cell bias */
64   int32_t output_gate_bias_index;            /**< Index of output gate bias */
65   int32_t projection_weights_index;          /**< Index of projection weights */
66   int32_t projection_bias_index;             /**< Index of projection bias */
67   int32_t output_state_in_index;             /**< Index of output state in */
68   int32_t cell_state_in_index;               /**< Index of cell state in */
69   int32_t activation_index;                  /**< Index of activation */
70   int32_t cell_threshold_index;              /**< Index of cell threshold */
71   int32_t projection_threshold_index;        /**< Index of projection threshold */
72
73   /**
74    * @brief Construct as default
75    */
76   Param() = default;
77   /**
78    * @brief     Construct a new Param object with params
79    * @param[in] inputCount  Count of inputs
80    * @param[in] inputs      Pointer of inputs
81    * @param[in] outputCount Count of outputs
82    * @param[in] outputs     Pointer of outputs
83    */
84   Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs);
85 };
86
87 /**
88  * @brief Class to represent an operation of data structure
89  */
90 class Node final : public op::Node
91 {
92 public:
93   /**
94    * @brief     Construct a new Node object with param
95    * @param[in] param Param object that makes up a Node
96    */
97   Node(const Param &param) : _param(param)
98   {
99     // DO NOTHING
100   }
101
102 public:
103   /**
104    * @brief Destruct as default
105    */
106   virtual ~Node() = default;
107
108 public:
109   /**
110    * @brief  Get a reference of Param object
111    * @return Reference of Param object
112    */
113   const Param &param(void) const { return _param; }
114
115 public:
116   /**
117    * @brief  Visit this Node by NodeVisitor
118    * @return N/A
119    */
120   void accept(NodeVisitor &&) const override;
121
122 private:
123   const Param _param;
124 };
125
126 } // namespace LSTM
127 } // namespace op
128 } // namespace tflite
129 } // namespace internal
130
131 #endif // __INTERNAL_OP_LSTM_H__