From: 박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 Date: Tue, 14 Aug 2018 02:32:55 +0000 (+0900) Subject: [nnkit] Extract TensorContext (#992) X-Git-Tag: nncc_backup~2187 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4ac487ba592e0cb8c5ab27f2b82fcabed6129eeb;p=platform%2Fcore%2Fml%2Fnnfw.git [nnkit] Extract TensorContext (#992) This commit extracts CaffeTensorContext (in nnkit Caffe backend) as nnkit::support::caffe::TensorContext. Signed-off-by: Jonghyun Park --- diff --git a/contrib/nnkit/backends/caffe/Module.cpp b/contrib/nnkit/backends/caffe/Module.cpp index 7ce5002..adba778 100644 --- a/contrib/nnkit/backends/caffe/Module.cpp +++ b/contrib/nnkit/backends/caffe/Module.cpp @@ -1,106 +1,12 @@ #include "nnkit/support/caffe/BlobContext.h" #include "nnkit/support/caffe/InputBlobContext.h" #include "nnkit/support/caffe/OutputBlobContext.h" - -#include -#include +#include "nnkit/support/caffe/TensorContext.h" #include using namespace nnkit::support::caffe; -namespace -{ - -template nncc::core::ADT::tensor::Shape shape(const caffe::Blob &blob) -{ - nncc::core::ADT::tensor::Shape shape; - - const uint32_t rank = blob.shape().size(); - - shape.resize(rank); - for (uint32_t axis = 0; axis < rank; ++axis) - { - shape.dim(axis) = blob.shape(axis); - } - - return shape; -} - -} - -#include - -#include -#include - -namespace -{ - -template class CaffeTensorContext final : public nnkit::TensorContext -{ -public: - CaffeTensorContext(BlobContext &blobs) : _blobs(blobs) - { - } - -public: - uint32_t size(void) const override - { - return _blobs.size(); - } - - std::string name(uint32_t n) const override - { - return _blobs.name(n); - } - - nncc::core::ADT::tensor::Shape shape(uint32_t n) const override - { - return ::shape(*_blobs.blob(n)); - } - - // Float (fp32) tensor support - bool isFloatTensor(uint32_t n) const override { return std::is_same::value; } - - void getMutableFloatTensor(uint32_t n, const TensorContext::TypedAccessor &f) override - { - if (!std::is_same::value) - { - throw std::runtime_error{"type mismatch"}; - } - - using nncc::core::ADT::tensor::LexicalLayout; - using nncc::core::ADT::tensor::make_overlay; - - auto base = _blobs.region(n); - auto view = make_overlay(shape(n), base); - - f(*this, n, view); - } - - void getConstFloatTensor(uint32_t n, const TensorContext::TypedReader &f) const override - { - if (!std::is_same::value) - { - throw std::runtime_error{"type mismatch"}; - } - - using nncc::core::ADT::tensor::LexicalLayout; - using nncc::core::ADT::tensor::make_overlay; - - auto base = _blobs.region(n); - auto view = make_overlay(shape(n), base); - - f(*this, n, view); - } - -private: - BlobContext &_blobs; -}; - -} - #include #include @@ -128,7 +34,7 @@ private: void FloatCaffeBackend::prepare(const std::function &f) { InputBlobContext blobs(*_net); - CaffeTensorContext tensors(blobs); + TensorContext tensors(blobs); f(tensors); } @@ -140,7 +46,7 @@ void FloatCaffeBackend::run(void) void FloatCaffeBackend::teardown(const std::function &f) { OutputBlobContext blobs(*_net); - CaffeTensorContext tensors(blobs); + TensorContext tensors(blobs); f(tensors); } diff --git a/contrib/nnkit/libs/support/caffe/include/nnkit/support/caffe/TensorContext.h b/contrib/nnkit/libs/support/caffe/include/nnkit/support/caffe/TensorContext.h new file mode 100644 index 0000000..21236d9 --- /dev/null +++ b/contrib/nnkit/libs/support/caffe/include/nnkit/support/caffe/TensorContext.h @@ -0,0 +1,104 @@ +#ifndef __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__ +#define __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__ + +#include "nnkit/support/caffe/BlobContext.h" + +#include + +#include +#include + +#include +#include + +namespace nnkit +{ +namespace support +{ +namespace caffe +{ + +template class TensorContext final : public nnkit::TensorContext +{ +public: + TensorContext(BlobContext &blobs) : _blobs(blobs) + { + // DO NOTHING + } + +private: + static nncc::core::ADT::tensor::Shape shapeOf(const ::caffe::Blob &blob) + { + nncc::core::ADT::tensor::Shape shape; + + const uint32_t rank = blob.shape().size(); + + shape.resize(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + { + shape.dim(axis) = blob.shape(axis); + } + + return shape; + } + +public: + uint32_t size(void) const override + { + return _blobs.size(); + } + + std::string name(uint32_t n) const override + { + return _blobs.name(n); + } + + nncc::core::ADT::tensor::Shape shape(uint32_t n) const override + { + return shapeOf(*_blobs.blob(n)); + } + + // Float (fp32) tensor support + bool isFloatTensor(uint32_t n) const override { return std::is_same::value; } + + void getMutableFloatTensor(uint32_t n, const TensorContext::TypedAccessor &f) override + { + if (!std::is_same::value) + { + throw std::runtime_error{"type mismatch"}; + } + + using nncc::core::ADT::tensor::LexicalLayout; + using nncc::core::ADT::tensor::make_overlay; + + auto base = _blobs.region(n); + auto view = make_overlay(shape(n), base); + + f(*this, n, view); + } + + void getConstFloatTensor(uint32_t n, const TensorContext::TypedReader &f) const override + { + if (!std::is_same::value) + { + throw std::runtime_error{"type mismatch"}; + } + + using nncc::core::ADT::tensor::LexicalLayout; + using nncc::core::ADT::tensor::make_overlay; + + auto base = _blobs.region(n); + auto view = make_overlay(shape(n), base); + + f(*this, n, view); + } + +private: + BlobContext &_blobs; +}; + +} // namespace caffe +} // namespace support +} // namespace nnkit + +#endif // __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__