Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / include / backend / basic / train / TrainableTensor.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_TENSOR_H__
18 #define __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_TENSOR_H__
19
20 #include "backend/train/ITrainableTensor.h"
21
22 #include "backend/basic/Tensor.h"
23
24 namespace onert
25 {
26 namespace backend
27 {
28 namespace basic
29 {
30 namespace train
31 {
32
33 class TrainableTensor : public backend::train::ITrainableTensor
34 {
35 public:
36   TrainableTensor() = delete;
37   virtual ~TrainableTensor() = default;
38
39 public:
40   TrainableTensor(const ir::OperandInfo &info, const ir::Layout layout)
41     : ITrainableTensor{info}, _tensor{info, layout, nullptr}, _opt_vars{}
42   {
43     // DO NOTHING
44   }
45
46 public:
47   /**
48    * @brief Set the Buffer object. This method is called for static and non-const tensor
49    */
50   void setBuffer(uint8_t *buffer) { _tensor.setBuffer(buffer); }
51
52 public:
53   uint8_t *buffer() const override { return _tensor.buffer(); }
54   /**
55    * @brief Get dimension by index
56    *
57    * @param index Index to get diemension
58    * @return size_t Dimension at index
59    * @note N : dimension(0)
60    *       H : dimension(1)
61    *       W : dimension(2)
62    *       C : dimension(3)
63    */
64   size_t total_size() const override { return _tensor.total_size(); }
65   size_t calcOffset(const ir::Coordinates &coords) const override
66   {
67     return _tensor.calcOffset(coords);
68   }
69   ir::Layout layout() const override { return _tensor.layout(); }
70   ir::DataType data_type() const override { return _tensor.data_type(); }
71   bool is_constant() const override { return _tensor.is_constant(); }
72   bool is_dynamic() const override { return _tensor.is_dynamic(); }
73   ir::Shape getShape() const override { return _tensor.getShape(); };
74   const ir::OperandInfo &get_info() { return _tensor.get_info(); }
75
76 public:
77   std::vector<ITensor *> optVars() override;
78   void appendOptVar(std::unique_ptr<Tensor> opt_var) { _opt_vars.emplace_back(std::move(opt_var)); }
79
80 public:
81   void fillBuffer(const std::shared_ptr<ir::Data> &data);
82
83 private:
84   using ITensor::setShape;
85   using ITensor::set_dynamic;
86   using ITensor::applyShape;
87
88 protected:
89   Tensor _tensor;
90   std::vector<std::unique_ptr<Tensor>> _opt_vars; //< Optimizer variables
91 };
92
93 } // namespace train
94 } // namespace basic
95 } // namespace backend
96 } // namespace onert
97
98 #endif // __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_TENSOR_H__