From 8e90f0a2221093a373af4f8a65884dce3d9be6bb Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Tue, 26 Jun 2018 16:23:48 +0900 Subject: [PATCH] [nnkit] Add Caffe backend (#367) This commit adds 'caffe' backend implementation which can be loaded by nni toolchain. Signed-off-by: Jonghyun Park --- contrib/nnkit/CMakeLists.txt | 1 + contrib/nnkit/backends/CMakeLists.txt | 1 + contrib/nnkit/backends/caffe/CMakeLists.txt | 9 ++ contrib/nnkit/backends/caffe/Module.cpp | 223 ++++++++++++++++++++++++++++ 4 files changed, 234 insertions(+) create mode 100644 contrib/nnkit/backends/CMakeLists.txt create mode 100644 contrib/nnkit/backends/caffe/CMakeLists.txt create mode 100644 contrib/nnkit/backends/caffe/Module.cpp diff --git a/contrib/nnkit/CMakeLists.txt b/contrib/nnkit/CMakeLists.txt index d4281ab..b3ac99f 100644 --- a/contrib/nnkit/CMakeLists.txt +++ b/contrib/nnkit/CMakeLists.txt @@ -10,4 +10,5 @@ macro(nnkit_add_backend PREFIX) endmacro(nnkit_add_backend) add_subdirectory(libs) +add_subdirectory(backends) add_subdirectory(tools) diff --git a/contrib/nnkit/backends/CMakeLists.txt b/contrib/nnkit/backends/CMakeLists.txt new file mode 100644 index 0000000..5ea6cda --- /dev/null +++ b/contrib/nnkit/backends/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectories() diff --git a/contrib/nnkit/backends/caffe/CMakeLists.txt b/contrib/nnkit/backends/caffe/CMakeLists.txt new file mode 100644 index 0000000..3481bfb --- /dev/null +++ b/contrib/nnkit/backends/caffe/CMakeLists.txt @@ -0,0 +1,9 @@ +nncc_find_package(Caffe QUIET) + +if(NOT Caffe_FOUND) + return() +endif(NOT Caffe_FOUND) + +add_library(nnkit_caffe_backend SHARED Module.cpp) +target_link_libraries(nnkit_caffe_backend nnkit_intf_backend) +target_link_libraries(nnkit_caffe_backend caffe) diff --git a/contrib/nnkit/backends/caffe/Module.cpp b/contrib/nnkit/backends/caffe/Module.cpp new file mode 100644 index 0000000..f664fa3 --- /dev/null +++ b/contrib/nnkit/backends/caffe/Module.cpp @@ -0,0 +1,223 @@ +#include +#include + +#include +#include + +#include + +namespace +{ + +template nncc::core::ADT::tensor::Shape shape(const caffe::Blob &blob) +{ + nncc::core::ADT::tensor::Shape shape; + + const uint32_t rank = blob.shape().size(); + + shape.resize(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + { + shape.dim(axis) = blob.shape(axis); + } + + return shape; +} + +template struct BlobContext +{ + virtual ~BlobContext() = default; + + virtual uint32_t size(void) const = 0; + + virtual std::string name(uint32_t n) const = 0; + virtual caffe::Blob *blob(uint32_t n) = 0; + + std::unique_ptr> region(uint32_t n) + { + auto b = blob(n); + auto count = b->count(); + auto data = b->mutable_cpu_data(); + + return nncc::foundation::make_unique>(data, count); + } +}; + +template class InputBlobContext final : public BlobContext +{ +public: + InputBlobContext(caffe::Net &net) : _net(net) + { + // DO NOTHING + } + +public: + uint32_t size(void) const override + { + return _net.num_inputs(); + } + + std::string name(uint32_t n) const override + { + return _net.blob_names().at(_net.input_blob_indices().at(n)); + } + + caffe::Blob *blob(uint32_t n) + { + return _net.input_blobs().at(n); + } + +private: + caffe::Net &_net; +}; + +template class OutputBlobContext final : public BlobContext +{ +public: + OutputBlobContext(caffe::Net &net) : _net(net) + { + // DO NOTHING + } + +public: + uint32_t size(void) const override + { + return _net.num_outputs(); + } + + std::string name(uint32_t n) const override + { + return _net.blob_names().at(_net.output_blob_indices().at(n)); + } + + caffe::Blob *blob(uint32_t n) + { + return _net.output_blobs().at(n); + } + +private: + caffe::Net &_net; +}; + +} + +#include + +namespace +{ + +class FloatCaffeTensorContext final : public nnkit::TensorContext +{ +public: + FloatCaffeTensorContext(BlobContext &blobs) : _blobs(blobs) + { + } + +public: + uint32_t size(void) const override + { + return _blobs.size(); + } + + std::string name(uint32_t n) const override + { + return _blobs.name(n); + } + + nncc::core::ADT::tensor::Shape shape(uint32_t n) const override + { + return ::shape(*_blobs.blob(n)); + } + + // Float (fp32) tensor support + bool isFloatTensor(uint32_t n) const override { return true; } + void getMutableFloatTensor(uint32_t n, const TensorContext::TypedAccessor &f) override + { + using nncc::core::ADT::tensor::LexicalLayout; + using nncc::core::ADT::tensor::make_view; + + auto span = _blobs.region(n); + auto view = make_view(shape(n), std::move(span)); + + f(*this, n, view); + } + + void getConstFloatTensor(uint32_t n, const TensorContext::TypedReader &f) const override + { + using nncc::core::ADT::tensor::LexicalLayout; + using nncc::core::ADT::tensor::make_view; + + auto span = _blobs.region(n); + auto view = make_view(shape(n), std::move(span)); + + f(*this, n, view); + } + +private: + BlobContext &_blobs; +}; + +} + +#include + +namespace +{ + +class FloatCaffeBackend final : public nnkit::Backend +{ +public: + FloatCaffeBackend(const std::string &prototxt) : _net{prototxt, caffe::TEST} + { + // DO NOTHING + } + +public: + FloatCaffeBackend(const std::string &prototxt, const std::string &caffemodel) : _net{prototxt, caffe::TEST} + { + _net.CopyTrainedLayersFrom(caffemodel); + } + +public: + void prepare(const std::function &f) override; + void run(void) override; + void teardown(const std::function &f) override; + +private: + caffe::Net _net; +}; + +void FloatCaffeBackend::prepare(const std::function &f) +{ + InputBlobContext blobs(_net); + FloatCaffeTensorContext tensors(blobs); + f(tensors); +} + +void FloatCaffeBackend::run(void) +{ + _net.Forward(); +} + +void FloatCaffeBackend::teardown(const std::function &f) +{ + OutputBlobContext blobs(_net); + FloatCaffeTensorContext tensors(blobs); + f(tensors); +} + +} // namespace + +#include +#include + +extern "C" std::unique_ptr make_backend(const nnkit::CmdlineArguments &args) +{ + if (args.size() == 1) + { + // TODO Select Float/Double based on command-line arguments + return nncc::foundation::make_unique(args.at(0)); + } + + return nncc::foundation::make_unique(args.at(0), args.at(1)); +} -- 2.7.4