a09961a845328214a431a197cd91154267631d0a
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / feature / nhwc / View.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_EXEC_FEATURE_NHWC_VIEW_H__
18 #define __ONERT_EXEC_FEATURE_NHWC_VIEW_H__
19
20 #include "../Reader.h"
21
22 #include <cassert>
23 #include <cstddef>
24
25 #include "backend/ITensor.h"
26 #include "ir/Shape.h"
27 #include "util/Utils.h"
28
29 namespace onert
30 {
31 namespace exec
32 {
33 namespace feature
34 {
35 namespace nhwc
36 {
37
38 template <typename T> class View final : public feature::Reader<T>
39 {
40 public:
41   // Construct for buffer of model inputs
42   View(const ir::FeatureShape &shape, T *ptr, size_t len)
43       : _shape{shape}, _ptr{reinterpret_cast<uint8_t *>(ptr)}, _len{len}
44   {
45     UNUSED_RELEASE(len); // Workaround for unused variable in release mode
46     assert(shape.N * shape.H * shape.W * shape.C * sizeof(T) == len);
47
48     // No padding
49     _strides.C = sizeof(T);
50     _strides.W = shape.C * sizeof(T);
51     _strides.H = shape.C * shape.W * sizeof(T);
52     _strides.N = shape.C * shape.W * shape.H * sizeof(T);
53   }
54
55   // Construct for backend tensor
56   View(backend::ITensor *tensor)
57       : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()}
58   {
59     assert(tensor->layout() == ir::Layout::NHWC);
60
61     const auto start_offset = tensor->calcOffset({0, 0, 0, 0});
62     _strides.C = tensor->dimension(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset;
63     _strides.W = tensor->dimension(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset;
64     _strides.H = tensor->dimension(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset;
65     _strides.N = tensor->dimension(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset;
66
67     _shape.C = tensor->dimension(3);
68     _shape.W = tensor->dimension(2);
69     _shape.H = tensor->dimension(1);
70     _shape.N = tensor->dimension(0);
71   }
72
73 public:
74   T at(uint32_t row, uint32_t col, uint32_t ch) const override
75   {
76     const auto offset = feature_index_to_byte_offset(0, row, col, ch);
77
78     const T *ptr = reinterpret_cast<const T *>(_ptr + offset);
79
80     return *ptr;
81   }
82   T at(uint32_t batch, uint32_t row, uint32_t col, uint32_t ch) const override
83   {
84     const auto offset = feature_index_to_byte_offset(batch, row, col, ch);
85
86     const T *ptr = reinterpret_cast<const T *>(_ptr + offset);
87
88     return *ptr;
89   }
90
91   T &at(uint32_t row, uint32_t col, uint32_t ch)
92   {
93     const auto offset = feature_index_to_byte_offset(0, row, col, ch);
94
95     T *ptr = reinterpret_cast<T *>(_ptr + offset);
96
97     return *ptr;
98   }
99
100   T &at(uint32_t batch, uint32_t row, uint32_t col, uint32_t ch)
101   {
102     const auto offset = feature_index_to_byte_offset(batch, row, col, ch);
103
104     T *ptr = reinterpret_cast<T *>(_ptr + offset);
105
106     return *ptr;
107   }
108
109 private:
110   size_t feature_index_to_byte_offset(uint32_t batch, uint32_t row, uint32_t col, uint32_t ch) const
111   {
112     assert(1u * _shape.N > batch); // shape.N > batch
113     assert(1u * _shape.H > row);   // shape.H > row
114     assert(1u * _shape.W > col);   // shape.W > col
115     assert(1u * _shape.C > ch);    // shape.C > ch
116
117     uint32_t res = 0;
118     res += batch * _strides.N;
119     res += row * _strides.H;
120     res += col * _strides.W;
121     res += ch * _strides.C;
122
123     return res;
124   }
125
126 private:
127   // TODO Remove _shape
128   ir::FeatureShape _shape;
129   using Strides = ir::FeatureShape;
130   Strides _strides;
131   uint8_t *_ptr;
132   size_t _len;
133 };
134
135 } // namespace nhwc
136 } // namespace feature
137 } // namespace exec
138 } // namespace onert
139
140 #endif // __ONERT_EXEC_FEATURE_NHWC_VIEW_H__