2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "coco/IR/Data.h"
19 #include <nncc/core/ADT/kernel/NCHWLayout.h>
20 #include <nncc/core/ADT/kernel/Overlay.h>
25 using namespace nncc::core::ADT;
27 using std::make_unique;
34 void allocate(const coco::Bag *b, uint32_t elemsize)
36 auto buffer = make_unique<std::vector<uint8_t>>();
37 buffer->resize(b->size() * elemsize);
39 _data[b] = std::move(buffer);
42 void release(const coco::Bag *b) { _data.erase(b); }
45 uint8_t *at(const coco::Bag *b)
47 auto it = _data.find(b);
49 if (it != _data.end())
51 return it->second->data();
58 uint32_t size(const coco::Bag *b) const
60 auto it = _data.find(b);
62 if (it != _data.end())
64 return it->second->size();
71 std::map<const coco::Bag *, std::unique_ptr<std::vector<uint8_t>>> _data;
78 template <typename T> class KernelOverlay : public kernel::Reader<T>, public kernel::Accessor<T>
81 KernelOverlay(T *base, const coco::KernelObject *object) : _base{base}, _object{object}
87 T at(uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) const override
89 assert(_object->layout() != nullptr);
90 auto offset = _object->layout()->at(nth, ch, row, col);
91 return *(_base + offset.value());
95 T &at(uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) override
97 assert(_object->layout() != nullptr);
98 auto offset = _object->layout()->at(nth, ch, row, col);
99 return *(_base + offset.value());
104 const coco::KernelObject *_object;
111 template <typename T> class PlainWeightContextImpl final : public coco::PlainWeightContext<T>
114 PlainWeightContextImpl(BlobContext *blob) : _blob{blob}
120 PlainWeightContextImpl(const PlainWeightContextImpl &) = delete;
121 PlainWeightContextImpl(PlainWeightContextImpl &&) = delete;
124 coco::Span<T> allocate(const coco::Bag *bag) override
126 assert(bag != nullptr);
127 _blob->allocate(bag, sizeof(T));
131 coco::Span<T> weight(const coco::Bag *b) override
133 // TODO Check type later
134 if (auto data = _blob->at(b))
136 uint32_t byte_size = _blob->size(b);
137 assert(byte_size % sizeof(T) == 0);
138 uint32_t elem_size = static_cast<uint32_t>(byte_size / sizeof(T));
140 return coco::Span<T>{reinterpret_cast<T *>(data), elem_size};
143 return coco::Span<T>{nullptr, 0};
147 std::unique_ptr<kernel::Accessor<T>> access(const coco::KernelObject *o) override
150 assert(b != nullptr);
152 if (auto base = reinterpret_cast<T *>(_blob->at(b)))
154 return make_unique<KernelOverlay<T>>(base, o);
161 std::unique_ptr<kernel::Reader<T>> read(const coco::KernelObject *o) const override
164 assert(b != nullptr);
166 if (auto base = reinterpret_cast<T *>(_blob->at(b)))
168 return make_unique<KernelOverlay<T>>(base, o);
175 BlobContext *const _blob;
181 struct DataImpl final : public coco::Data
183 std::unique_ptr<BlobContext> _blob;
184 std::unique_ptr<PlainWeightContextImpl<float>> _fp32;
186 bool allocated(const coco::Bag *b) const override { return _blob->at(b) != nullptr; }
188 void release(const coco::Bag *b) override
190 assert(allocated(b));
194 coco::PlainWeightContext<float> *f32(void) override { return _fp32.get(); }
195 const coco::PlainWeightContext<float> *f32(void) const override { return _fp32.get(); }
202 std::unique_ptr<Data> Data::create(void)
204 auto blob = make_unique<BlobContext>();
205 auto fp32 = make_unique<PlainWeightContextImpl<float>>(blob.get());
207 auto data = make_unique<DataImpl>();
209 data->_blob = std::move(blob);
210 data->_fp32 = std::move(fp32);