[ Mixed Tensor ] Enable FP32 unittest cases
[platform/core/ml/nntrainer.git] / nntrainer / tensor / memory_data.h
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
4  *
5  * @file   memory_data.h
6  * @date   14 Oct 2022
7  * @see    https://github.com/nnstreamer/nntrainer
8  * @author Jiho Chu <jiho.chu@samsung.com>
9  * @bug    No known bugs except for NYI items
10  * @brief  MemoryData class
11  *
12  */
13
14 #ifndef __MEMORY_DATA_H__
15 #define __MEMORY_DATA_H__
16
17 #include <functional>
18
19 namespace nntrainer {
20
21 using MemoryDataValidateCallback = std::function<void(unsigned int)>;
22
23 /**
24  * @brief  MemoryData Class
25  */
26 class MemoryData {
27 public:
28   /**
29    * @brief  Constructor of Memory Data
30    * @param[in] addr Memory data
31    */
32   explicit MemoryData(void *addr) :
33     valid(true),
34     id(0),
35     address(addr),
36     validate_cb([](unsigned int) {}),
37     invalidate_cb([](unsigned int) {}) {}
38
39   /**
40    * @brief  Constructor of Memory Data
41    * @param[in] mem_id validate callback.
42    * @param[in] v_cb validate callback.
43    * @param[in] i_cb invalidate callback.
44    */
45   explicit MemoryData(unsigned int mem_id, MemoryDataValidateCallback v_cb,
46                       MemoryDataValidateCallback i_cb) :
47     valid(false),
48     id(mem_id),
49     address(nullptr),
50     validate_cb(v_cb),
51     invalidate_cb(i_cb) {}
52
53   /**
54    * @brief  Deleted constructor of Memory Data
55    */
56   explicit MemoryData() = delete;
57
58   /**
59    * @brief  Constructor of MemoryData
60    */
61   explicit MemoryData(MemoryDataValidateCallback v_cb,
62                       MemoryDataValidateCallback i_cb) = delete;
63   /**
64    * @brief  Constructor of MemoryData
65    */
66   explicit MemoryData(void *addr, MemoryDataValidateCallback v_cb,
67                       MemoryDataValidateCallback i_cb) = delete;
68
69   /**
70    * @brief  Destructor of Memory Data
71    */
72   virtual ~MemoryData() = default;
73
74   /**
75    * @brief  Set address
76    */
77   void setAddr(void *addr) { address = addr; }
78
79   /**
80    * @brief  Get address
81    */
82   template <typename T = float> T *getAddr() const {
83     return static_cast<T *>(address);
84   }
85
86   /**
87    * @brief  Validate memory data
88    */
89   void validate() {
90     if (valid)
91       return;
92     validate_cb(id);
93   }
94
95   /**
96    * @brief  Invalidate memory data
97    */
98   void invalidate() {
99     if (!valid)
100       return;
101     invalidate_cb(id);
102   }
103
104   /**
105    * @brief  Set valid
106    */
107   void setValid(bool v) { valid = v; }
108
109 private:
110   bool valid;
111   unsigned int id;
112   void *address;
113   MemoryDataValidateCallback validate_cb;
114   MemoryDataValidateCallback invalidate_cb;
115 };
116
117 } // namespace nntrainer
118
119 #endif /* __MEMORY_DATA_H__ */