#include "nnkit/support/caffe/BlobContext.h"
#include "nnkit/support/caffe/InputBlobContext.h"
#include "nnkit/support/caffe/OutputBlobContext.h"
-
-#include <nncc/core/ADT/tensor/LexicalLayout.h>
-#include <nncc/core/ADT/tensor/Overlay.h>
+#include "nnkit/support/caffe/TensorContext.h"
#include <caffe/caffe.hpp>
using namespace nnkit::support::caffe;
-namespace
-{
-
-template <typename DType> nncc::core::ADT::tensor::Shape shape(const caffe::Blob<DType> &blob)
-{
- nncc::core::ADT::tensor::Shape shape;
-
- const uint32_t rank = blob.shape().size();
-
- shape.resize(rank);
- for (uint32_t axis = 0; axis < rank; ++axis)
- {
- shape.dim(axis) = blob.shape(axis);
- }
-
- return shape;
-}
-
-}
-
-#include <nnkit/TensorContext.h>
-
-#include <type_traits>
-#include <stdexcept>
-
-namespace
-{
-
-template <typename DType> class CaffeTensorContext final : public nnkit::TensorContext
-{
-public:
- CaffeTensorContext(BlobContext<DType> &blobs) : _blobs(blobs)
- {
- }
-
-public:
- uint32_t size(void) const override
- {
- return _blobs.size();
- }
-
- std::string name(uint32_t n) const override
- {
- return _blobs.name(n);
- }
-
- nncc::core::ADT::tensor::Shape shape(uint32_t n) const override
- {
- return ::shape(*_blobs.blob(n));
- }
-
- // Float (fp32) tensor support
- bool isFloatTensor(uint32_t n) const override { return std::is_same<DType, float>::value; }
-
- void getMutableFloatTensor(uint32_t n, const TensorContext::TypedAccessor<float> &f) override
- {
- if (!std::is_same<DType, float>::value)
- {
- throw std::runtime_error{"type mismatch"};
- }
-
- using nncc::core::ADT::tensor::LexicalLayout;
- using nncc::core::ADT::tensor::make_overlay;
-
- auto base = _blobs.region(n);
- auto view = make_overlay<float, LexicalLayout>(shape(n), base);
-
- f(*this, n, view);
- }
-
- void getConstFloatTensor(uint32_t n, const TensorContext::TypedReader<float> &f) const override
- {
- if (!std::is_same<DType, float>::value)
- {
- throw std::runtime_error{"type mismatch"};
- }
-
- using nncc::core::ADT::tensor::LexicalLayout;
- using nncc::core::ADT::tensor::make_overlay;
-
- auto base = _blobs.region(n);
- auto view = make_overlay<float, LexicalLayout>(shape(n), base);
-
- f(*this, n, view);
- }
-
-private:
- BlobContext<DType> &_blobs;
-};
-
-}
-
#include <nnkit/Backend.h>
#include <memory>
void FloatCaffeBackend::prepare(const std::function<void (nnkit::TensorContext &)> &f)
{
InputBlobContext<float> blobs(*_net);
- CaffeTensorContext<float> tensors(blobs);
+ TensorContext<float> tensors(blobs);
f(tensors);
}
void FloatCaffeBackend::teardown(const std::function<void (nnkit::TensorContext &)> &f)
{
OutputBlobContext<float> blobs(*_net);
- CaffeTensorContext<float> tensors(blobs);
+ TensorContext<float> tensors(blobs);
f(tensors);
}
--- /dev/null
+#ifndef __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__
+#define __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__
+
+#include "nnkit/support/caffe/BlobContext.h"
+
+#include <nnkit/TensorContext.h>
+
+#include <nncc/core/ADT/tensor/LexicalLayout.h>
+#include <nncc/core/ADT/tensor/Overlay.h>
+
+#include <type_traits>
+#include <stdexcept>
+
+namespace nnkit
+{
+namespace support
+{
+namespace caffe
+{
+
+template <typename DType> class TensorContext final : public nnkit::TensorContext
+{
+public:
+ TensorContext(BlobContext<DType> &blobs) : _blobs(blobs)
+ {
+ // DO NOTHING
+ }
+
+private:
+ static nncc::core::ADT::tensor::Shape shapeOf(const ::caffe::Blob<DType> &blob)
+ {
+ nncc::core::ADT::tensor::Shape shape;
+
+ const uint32_t rank = blob.shape().size();
+
+ shape.resize(rank);
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ shape.dim(axis) = blob.shape(axis);
+ }
+
+ return shape;
+ }
+
+public:
+ uint32_t size(void) const override
+ {
+ return _blobs.size();
+ }
+
+ std::string name(uint32_t n) const override
+ {
+ return _blobs.name(n);
+ }
+
+ nncc::core::ADT::tensor::Shape shape(uint32_t n) const override
+ {
+ return shapeOf(*_blobs.blob(n));
+ }
+
+ // Float (fp32) tensor support
+ bool isFloatTensor(uint32_t n) const override { return std::is_same<DType, float>::value; }
+
+ void getMutableFloatTensor(uint32_t n, const TensorContext::TypedAccessor<float> &f) override
+ {
+ if (!std::is_same<DType, float>::value)
+ {
+ throw std::runtime_error{"type mismatch"};
+ }
+
+ using nncc::core::ADT::tensor::LexicalLayout;
+ using nncc::core::ADT::tensor::make_overlay;
+
+ auto base = _blobs.region(n);
+ auto view = make_overlay<float, LexicalLayout>(shape(n), base);
+
+ f(*this, n, view);
+ }
+
+ void getConstFloatTensor(uint32_t n, const TensorContext::TypedReader<float> &f) const override
+ {
+ if (!std::is_same<DType, float>::value)
+ {
+ throw std::runtime_error{"type mismatch"};
+ }
+
+ using nncc::core::ADT::tensor::LexicalLayout;
+ using nncc::core::ADT::tensor::make_overlay;
+
+ auto base = _blobs.region(n);
+ auto view = make_overlay<float, LexicalLayout>(shape(n), base);
+
+ f(*this, n, view);
+ }
+
+private:
+ BlobContext<DType> &_blobs;
+};
+
+} // namespace caffe
+} // namespace support
+} // namespace nnkit
+
+#endif // __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__