From 4207d13553ceeb9066fbc4f4ffc82b9c85431a30 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Wed, 4 Jul 2018 16:06:48 +0900 Subject: [PATCH] [nnkit] Introduce HDF5 import action (#380) This commit introduces HDF5 import action, which fills tensor with the values loaded from HDF5 file. Signed-off-by: Jonghyun Park --- contrib/nnkit/actions/HDF5/CMakeLists.txt | 5 ++ contrib/nnkit/actions/HDF5/Import.cpp | 79 +++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 contrib/nnkit/actions/HDF5/Import.cpp diff --git a/contrib/nnkit/actions/HDF5/CMakeLists.txt b/contrib/nnkit/actions/HDF5/CMakeLists.txt index ef9f408..1899972 100644 --- a/contrib/nnkit/actions/HDF5/CMakeLists.txt +++ b/contrib/nnkit/actions/HDF5/CMakeLists.txt @@ -8,3 +8,8 @@ add_library(nnkit_HDF5_export_action SHARED Export.cpp) target_include_directories(nnkit_HDF5_export_action PRIVATE ${HDF5_INCLUDE_DIRS}) target_link_libraries(nnkit_HDF5_export_action nnkit_intf_action) target_link_libraries(nnkit_HDF5_export_action ${HDF5_CXX_LIBRARIES}) + +add_library(nnkit_HDF5_import_action SHARED Import.cpp) +target_include_directories(nnkit_HDF5_import_action PRIVATE ${HDF5_INCLUDE_DIRS}) +target_link_libraries(nnkit_HDF5_import_action nnkit_intf_action) +target_link_libraries(nnkit_HDF5_import_action ${HDF5_CXX_LIBRARIES}) diff --git a/contrib/nnkit/actions/HDF5/Import.cpp b/contrib/nnkit/actions/HDF5/Import.cpp new file mode 100644 index 0000000..7bc76a1 --- /dev/null +++ b/contrib/nnkit/actions/HDF5/Import.cpp @@ -0,0 +1,79 @@ +#include + +#include +#include + +#include + +#include + +using nnkit::TensorContext; + +class HD5ImportAction final : public nnkit::Action +{ +public: + HD5ImportAction(const std::string &path) : _file{path, H5F_ACC_RDONLY} + { + // DO NOTHING + } + +public: + void run(TensorContext &ctx) override + { + for (uint32_t n = 0; n < ctx.size(); ++n) + { + using nncc::core::ADT::tensor::Accessor; + + auto fn = [this] (const TensorContext &ctx, uint32_t n, Accessor &t) + { + const auto name = ctx.name(n); + + auto dataset = _file.openDataSet(name); + + // TODO Support non-float tensors + assert(dataset.getDataType() == H5::PredType::IEEE_F32BE); + + // TODO Check whether shape is consistent + const auto shape = ctx.shape(n); + + std::vector buffer; + + using nncc::core::ADT::tensor::num_elements; + buffer.resize(num_elements(shape)); + + dataset.read(buffer.data(), H5::PredType::NATIVE_FLOAT); + + using nncc::core::ADT::tensor::range; + using nncc::core::ADT::tensor::Index; + using nncc::core::ADT::tensor::LexicalLayout; + + LexicalLayout layout{}; + + range(shape).iterate() << [&buffer, &t, &shape, &layout] (const Index &i) + { + t.at(i) = buffer[layout.offset(shape, i)]; + }; + }; + + try + { + ctx.getMutableFloatTensor(n, fn); + } + catch (const H5::FileIException &) + { + // Skip if data is not present in HDF5 file + } + } + } + +private: + H5::H5File _file; +}; + +#include +#include + +extern "C" std::unique_ptr make_action(const nnkit::CmdlineArguments &args) +{ + return nncc::foundation::make_unique(args.at(0)); +} -- 2.7.4