Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / ade / ade / source / memory_descriptor_view.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "memory/memory_descriptor_view.hpp"
7
8 #include <vector>
9
10 #include <util/algorithm.hpp>
11 #include <util/range.hpp>
12 #include <util/zip_range.hpp>
13
14 #include "memory/memory_descriptor.hpp"
15
16 namespace ade
17 {
18
19 struct MemoryDescriptorView::Connector final
20 {
21     // We use linear searches here because number of listeners and views usually
22     // will be very small
23     struct ListenerDesc final
24     {
25         MemoryDescriptorView* owner;
26         std::vector<IMemoryDescriptorViewListener*> listeners;
27     };
28
29     std::vector<ListenerDesc> listeners;
30
31     struct OwnerComparator final
32     {
33         const MemoryDescriptorView* owner;
34         bool operator()(const ListenerDesc& desc) const
35         {
36             ASSERT(nullptr != owner);
37             ASSERT(nullptr != desc.owner);
38             return owner == desc.owner;
39         }
40     };
41
42     void addListener(MemoryDescriptorView* view, IMemoryDescriptorViewListener* listener)
43     {
44         ASSERT(nullptr != view);
45         ASSERT(nullptr != listener);
46         ASSERT(!contains(view, listener));
47         findDesc(view).listeners.push_back(listener);
48     }
49
50     void removeListener(MemoryDescriptorView* view, IMemoryDescriptorViewListener* listener)
51     {
52         ASSERT(nullptr != view);
53         ASSERT(nullptr != listener);
54         ASSERT(contains(view, listener));
55         auto& desc = findDesc(view);
56         util::unstable_erase(desc.listeners, util::find(desc.listeners, listener));
57         ASSERT(!contains(view, listener));
58     }
59
60     void onDestroy(MemoryDescriptorView* view)
61     {
62         ASSERT(nullptr != view);
63         auto it = util::find_if(listeners, OwnerComparator{view});
64         if (listeners.end() != it)
65         {
66             for (auto& listener: it->listeners)
67             {
68                 ASSERT(nullptr != listener);
69                 listener->destroy();
70             }
71             util::unstable_erase(listeners, it);
72         }
73     }
74
75     bool contains(const MemoryDescriptorView* view, const IMemoryDescriptorViewListener* listener) const
76     {
77         ASSERT(nullptr != view);
78         ASSERT(nullptr != listener);
79         auto it = util::find_if(listeners, OwnerComparator{view});
80         if (listeners.end() == it)
81         {
82             return false;
83         }
84         return it->listeners.end() != util::find(it->listeners, listener);
85     }
86
87     ListenerDesc& findDesc(MemoryDescriptorView* view)
88     {
89         ASSERT(nullptr != view);
90         auto it = util::find_if(listeners, OwnerComparator{view});
91         if(listeners.end() != it)
92         {
93             return *it;
94         }
95         listeners.push_back(ListenerDesc{view, {}});
96         return listeners.back();
97     }
98
99     std::size_t listenersCount() const
100     {
101         std::size_t ret = 0;
102         for (auto& desc: listeners)
103         {
104             ret += desc.listeners.size();
105         }
106         return ret;
107     }
108
109     ~Connector()
110     {
111         ASSERT(0 == listenersCount());
112     }
113 };
114
115
116 MemoryDescriptorView::MemoryDescriptorView()
117 {
118
119 }
120
121 MemoryDescriptorView::MemoryDescriptorView(MemoryDescriptor& descriptor,
122                                            const memory::DynMdSpan& span,
123                                            RetargetableState retargetable):
124     m_parent(&descriptor),
125     m_span(span),
126     m_retargetable(retargetable),
127     m_connector(std::make_shared<Connector>())
128 {
129     checkSpans(descriptor);
130 }
131
132 MemoryDescriptorView::MemoryDescriptorView(MemoryDescriptorView& parent,
133                                            const memory::DynMdSpan& span):
134     m_parent_view(&parent),
135     m_span(span),
136     m_retargetable(parent.retargetableState()),
137     m_connector(parent.m_connector)
138 {
139
140 }
141
142 MemoryDescriptorView::~MemoryDescriptorView()
143 {
144     if (nullptr != m_connector)
145     {
146         m_connector->onDestroy(this);
147     }
148 }
149
150 void MemoryDescriptorView::retarget(MemoryDescriptor& newParent,
151                                     const memory::DynMdSpan& newSpan)
152 {
153     ASSERT(isRetargetable());
154     ASSERT(nullptr != m_parent);
155     ASSERT(nullptr == m_parent_view);
156     ASSERT(nullptr != m_connector);
157     const auto size = m_span.size();
158     ASSERT(newSpan.size() == size);
159     checkSpans(newParent);
160
161     for (auto& desc: m_connector->listeners)
162     {
163         auto owner = desc.owner;
164         ASSERT(nullptr != owner);
165         for (auto listener: desc.listeners)
166         {
167             ASSERT(nullptr != listener);
168             const auto origSpan = owner->span();
169             const auto origin = origSpan.origin();
170             const auto updatedSpan = util::make_span(origin, origSpan.size());
171             if (owner == this)
172             {
173                 ASSERT(updatedSpan == newSpan);
174             }
175             listener->retarget(*m_parent, origSpan, newParent, updatedSpan);
176         }
177     }
178
179     m_span = newSpan;
180     m_parent = &newParent;
181
182     for (auto& desc: m_connector->listeners)
183     {
184         for (auto listener: desc.listeners)
185         {
186             ASSERT(nullptr != listener);
187             listener->retargetComplete();
188         }
189     }
190 }
191
192 MemoryDescriptorView::RetargetableState MemoryDescriptorView::retargetableState() const
193 {
194     return m_retargetable;
195 }
196
197 bool MemoryDescriptorView::isRetargetable() const
198 {
199     return m_retargetable == Retargetable;
200 }
201
202 void MemoryDescriptorView::addListener(IMemoryDescriptorViewListener* listener)
203 {
204     ASSERT(nullptr != listener);
205     ASSERT(nullptr != m_connector);
206     m_connector->addListener(this, listener);
207 }
208
209 void MemoryDescriptorView::removeListener(IMemoryDescriptorViewListener* listener)
210 {
211     ASSERT(nullptr != listener);
212     ASSERT(nullptr != m_connector);
213     m_connector->removeListener(this, listener);
214 }
215
216 memory::DynMdSpan MemoryDescriptorView::span() const
217 {
218     ASSERT(nullptr != *this);
219     if (nullptr != m_parent_view)
220     {
221         return m_span + m_parent_view->span().origin();
222     }
223     return m_span;
224 }
225
226 memory::DynMdSize MemoryDescriptorView::size() const
227 {
228     ASSERT(nullptr != *this);
229     return m_span.size();
230 }
231
232 std::size_t MemoryDescriptorView::elementSize() const
233 {
234     ASSERT(nullptr != getDescriptor());
235     return getDescriptor()->elementSize();
236 }
237
238 MemoryDescriptor* MemoryDescriptorView::getDescriptor()
239 {
240     if (nullptr != m_parent_view)
241     {
242         return m_parent_view->getDescriptor();
243     }
244     return m_parent;
245 }
246
247 const MemoryDescriptor* MemoryDescriptorView::getDescriptor() const
248 {
249     if (nullptr != m_parent_view)
250     {
251         return m_parent_view->getDescriptor();
252     }
253     return m_parent;
254 }
255
256 MemoryDescriptorView* MemoryDescriptorView::getParentView()
257 {
258     return m_parent_view;
259 }
260
261 const MemoryDescriptorView* MemoryDescriptorView::getParentView() const
262 {
263     return m_parent_view;
264 }
265
266 memory::DynMdView<void> MemoryDescriptorView::getExternalView() const
267 {
268     auto parent = getDescriptor();
269     ASSERT(nullptr != parent);
270     auto data = parent->getExternalView();
271     if (nullptr == data)
272     {
273         return nullptr;
274     }
275     return data.slice(span());
276 }
277
278 MemoryDescriptorView::AccessHandle MemoryDescriptorView::access(const memory::DynMdSpan& span, MemoryAccessType accessType)
279 {
280     ASSERT(nullptr != getDescriptor());
281     return getDescriptor()->access(span + this->span().origin(), accessType);
282 }
283
284 void MemoryDescriptorView::commit(MemoryDescriptorView::AccessHandle handle)
285 {
286     ASSERT(nullptr != getDescriptor());
287     getDescriptor()->commit(handle);
288 }
289
290 bool operator==(std::nullptr_t, const MemoryDescriptorView& ref)
291 {
292     return ref.getDescriptor() == nullptr;
293 }
294
295 bool operator==(const MemoryDescriptorView& ref, std::nullptr_t)
296 {
297     return ref.getDescriptor() == nullptr;
298 }
299
300 bool operator!=(std::nullptr_t, const MemoryDescriptorView& ref)
301 {
302     return ref.getDescriptor() != nullptr;
303 }
304
305 bool operator!=(const MemoryDescriptorView& ref, std::nullptr_t)
306 {
307     return ref.getDescriptor() != nullptr;
308 }
309
310 void MemoryDescriptorView::checkSpans(MemoryDescriptor& descriptor) const
311 {
312     ASSERT(descriptor.dimensions().dims_count() == m_span.dims_count());
313     for (auto i: util::iota(m_span.dims_count()))
314     {
315         auto& val = m_span[i];
316         ASSERT(val.begin >= 0);
317         ASSERT(val.end <= descriptor.dimensions()[i]);
318     }
319 }
320
321 void* getViewDataPtr(ade::MemoryDescriptorView& view, std::size_t offset)
322 {
323     ASSERT(nullptr != view);
324     auto data = view.getExternalView().mem;
325     ASSERT(nullptr != data);
326     const auto newSize = data.size - offset;
327     ASSERT(newSize > 0);
328     return data.Slice(offset, newSize).data;
329 }
330
331 void copyFromViewMemory(void* dst, ade::MemoryDescriptorView& view)
332 {
333     ASSERT(nullptr != dst);
334     ASSERT(nullptr != view);
335     copyFromViewMemory(dst, view.getExternalView());
336 }
337
338 void copyToViewMemory(const void* src, ade::MemoryDescriptorView& view)
339 {
340     ASSERT(nullptr != src);
341     ASSERT(nullptr != view);
342     copyToViewMemory(src, view.getExternalView());
343 }
344
345 void copyFromViewMemory(void* dst, ade::memory::DynMdView<void> view)
346 {
347     ASSERT(nullptr != dst);
348     ASSERT(nullptr != view);
349     const auto size = view.sizeInBytes();
350     util::raw_copy(view.mem, util::memory_range(dst, size));
351 }
352
353 void copyToViewMemory(const void* src, ade::memory::DynMdView<void> view)
354 {
355     ASSERT(nullptr != src);
356     ASSERT(nullptr != view);
357     const auto size = view.sizeInBytes();
358     util::raw_copy(util::memory_range(src, size), view.mem);
359 }
360
361 }