ff55de1993f61a40c554f5fe3c4ed9c1c97ee9b0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / feature / nchw / View.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_EXEC_FEATURE_NCHW_VIEW_H__
18 #define __ONERT_EXEC_FEATURE_NCHW_VIEW_H__
19
20 #include "../Reader.h"
21
22 #include "backend/ITensor.h"
23 #include "ir/Shape.h"
24 #include "util/logging.h"
25
26 #include <cassert>
27
28 namespace onert
29 {
30 namespace exec
31 {
32 namespace feature
33 {
34 namespace nchw
35 {
36
37 template <typename T> class View final : public feature::Reader<T>
38 {
39 public:
40   // Construct for buffer of model inputs
41   View(const ir::FeatureShape &shape, T *ptr, size_t len)
42       : _shape{shape}, _ptr{reinterpret_cast<uint8_t *>(ptr)}, _len{len}
43   {
44     assert(shape.N * shape.C * shape.H * shape.W * sizeof(T) == len);
45
46     _strides.W = sizeof(T);
47     _strides.H = shape.W * sizeof(T);
48     _strides.C = shape.W * shape.H * sizeof(T);
49     _strides.N = shape.W * shape.H * shape.C * sizeof(T);
50   }
51
52   // Construct for backend tensor
53   View(::onert::backend::ITensor *tensor)
54       : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()}
55   {
56     assert(tensor->layout() == ir::Layout::NCHW);
57
58     const auto start_offset = tensor->calcOffset({0, 0, 0, 0});
59     _strides.W = tensor->dimension(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset;
60     _strides.H = tensor->dimension(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset;
61     _strides.C = tensor->dimension(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset;
62     _strides.N = tensor->dimension(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset;
63
64     _shape.W = tensor->dimension(3);
65     _shape.H = tensor->dimension(2);
66     _shape.C = tensor->dimension(1);
67     _shape.N = tensor->dimension(0);
68   }
69
70 public:
71   T at(uint32_t ch, uint32_t row, uint32_t col) const override
72   {
73     const auto offset = feature_index_to_byte_offset(0, ch, row, col);
74
75     T *ptr = reinterpret_cast<T *>(_ptr + offset);
76
77     return *ptr;
78   }
79   T at(uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) const override
80   {
81     const auto offset = feature_index_to_byte_offset(batch, ch, row, col);
82
83     T *ptr = reinterpret_cast<T *>(_ptr + offset);
84
85     return *ptr;
86   }
87
88 public:
89   T &at(uint32_t ch, uint32_t row, uint32_t col)
90   {
91     const auto offset = feature_index_to_byte_offset(0, ch, row, col);
92
93     T *ptr = reinterpret_cast<T *>(_ptr + offset);
94
95     return *ptr;
96   }
97   T &at(uint32_t batch, uint32_t ch, uint32_t row, uint32_t col)
98   {
99     const auto offset = feature_index_to_byte_offset(batch, ch, row, col);
100
101     T *ptr = reinterpret_cast<T *>(_ptr + offset);
102
103     return *ptr;
104   }
105
106 private:
107   size_t feature_index_to_byte_offset(uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) const
108   {
109     assert(1u * _shape.N > batch); // shape.N > batch
110     assert(1u * _shape.C > ch);    // shape.C > ch
111     assert(1u * _shape.H > row);   // shape.H > row
112     assert(1u * _shape.W > col);   // shape.W > col
113
114     uint32_t res = 0;
115     res += batch * _strides.N;
116     res += ch * _strides.C;
117     res += row * _strides.H;
118     res += col * _strides.W;
119
120     return res;
121   }
122
123 private:
124   // TODO Remove _shape
125   ir::FeatureShape _shape;
126   using Strides = ir::FeatureShape;
127   Strides _strides;
128   uint8_t *_ptr;
129   size_t _len;
130 };
131
132 } // namespace nchw
133 } // namespace feature
134 } // namespace exec
135 } // namespace onert
136
137 #endif // __ONERT_EXEC_FEATURE_NCHW_VIEW_H__