[nnkit] Extract TensorContext (#992)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 14 Aug 2018 02:32:55 +0000 (11:32 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 14 Aug 2018 02:32:55 +0000 (11:32 +0900)
This commit extracts CaffeTensorContext (in nnkit Caffe backend) as
nnkit::support::caffe::TensorContext.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/nnkit/backends/caffe/Module.cpp
contrib/nnkit/libs/support/caffe/include/nnkit/support/caffe/TensorContext.h [new file with mode: 0644]

index 7ce5002..adba778 100644 (file)
 #include "nnkit/support/caffe/BlobContext.h"
 #include "nnkit/support/caffe/InputBlobContext.h"
 #include "nnkit/support/caffe/OutputBlobContext.h"
-
-#include <nncc/core/ADT/tensor/LexicalLayout.h>
-#include <nncc/core/ADT/tensor/Overlay.h>
+#include "nnkit/support/caffe/TensorContext.h"
 
 #include <caffe/caffe.hpp>
 
 using namespace nnkit::support::caffe;
 
-namespace
-{
-
-template <typename DType> nncc::core::ADT::tensor::Shape shape(const caffe::Blob<DType> &blob)
-{
-  nncc::core::ADT::tensor::Shape shape;
-
-  const uint32_t rank = blob.shape().size();
-
-  shape.resize(rank);
-  for (uint32_t axis = 0; axis < rank; ++axis)
-  {
-    shape.dim(axis) = blob.shape(axis);
-  }
-
-  return shape;
-}
-
-}
-
-#include <nnkit/TensorContext.h>
-
-#include <type_traits>
-#include <stdexcept>
-
-namespace
-{
-
-template <typename DType> class CaffeTensorContext final : public nnkit::TensorContext
-{
-public:
-  CaffeTensorContext(BlobContext<DType> &blobs) : _blobs(blobs)
-  {
-  }
-
-public:
-  uint32_t size(void) const override
-  {
-    return _blobs.size();
-  }
-
-  std::string name(uint32_t n) const override
-  {
-    return _blobs.name(n);
-  }
-
-  nncc::core::ADT::tensor::Shape shape(uint32_t n) const override
-  {
-    return ::shape(*_blobs.blob(n));
-  }
-
-  // Float (fp32) tensor support
-  bool isFloatTensor(uint32_t n) const override { return std::is_same<DType, float>::value; }
-
-  void getMutableFloatTensor(uint32_t n, const TensorContext::TypedAccessor<float> &f) override
-  {
-    if (!std::is_same<DType, float>::value)
-    {
-      throw std::runtime_error{"type mismatch"};
-    }
-
-    using nncc::core::ADT::tensor::LexicalLayout;
-    using nncc::core::ADT::tensor::make_overlay;
-
-    auto base = _blobs.region(n);
-    auto view = make_overlay<float, LexicalLayout>(shape(n), base);
-
-    f(*this, n, view);
-  }
-
-  void getConstFloatTensor(uint32_t n, const TensorContext::TypedReader<float> &f) const override
-  {
-    if (!std::is_same<DType, float>::value)
-    {
-      throw std::runtime_error{"type mismatch"};
-    }
-
-    using nncc::core::ADT::tensor::LexicalLayout;
-    using nncc::core::ADT::tensor::make_overlay;
-
-    auto base = _blobs.region(n);
-    auto view = make_overlay<float, LexicalLayout>(shape(n), base);
-
-    f(*this, n, view);
-  }
-
-private:
-  BlobContext<DType> &_blobs;
-};
-
-}
-
 #include <nnkit/Backend.h>
 
 #include <memory>
@@ -128,7 +34,7 @@ private:
 void FloatCaffeBackend::prepare(const std::function<void (nnkit::TensorContext &)> &f)
 {
   InputBlobContext<float> blobs(*_net);
-  CaffeTensorContext<float> tensors(blobs);
+  TensorContext<float> tensors(blobs);
   f(tensors);
 }
 
@@ -140,7 +46,7 @@ void FloatCaffeBackend::run(void)
 void FloatCaffeBackend::teardown(const std::function<void (nnkit::TensorContext &)> &f)
 {
   OutputBlobContext<float> blobs(*_net);
-  CaffeTensorContext<float> tensors(blobs);
+  TensorContext<float> tensors(blobs);
   f(tensors);
 }
 
diff --git a/contrib/nnkit/libs/support/caffe/include/nnkit/support/caffe/TensorContext.h b/contrib/nnkit/libs/support/caffe/include/nnkit/support/caffe/TensorContext.h
new file mode 100644 (file)
index 0000000..21236d9
--- /dev/null
@@ -0,0 +1,104 @@
+#ifndef __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__
+#define __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__
+
+#include "nnkit/support/caffe/BlobContext.h"
+
+#include <nnkit/TensorContext.h>
+
+#include <nncc/core/ADT/tensor/LexicalLayout.h>
+#include <nncc/core/ADT/tensor/Overlay.h>
+
+#include <type_traits>
+#include <stdexcept>
+
+namespace nnkit
+{
+namespace support
+{
+namespace caffe
+{
+
+template <typename DType> class TensorContext final : public nnkit::TensorContext
+{
+public:
+  TensorContext(BlobContext<DType> &blobs) : _blobs(blobs)
+  {
+    // DO NOTHING
+  }
+
+private:
+  static nncc::core::ADT::tensor::Shape shapeOf(const ::caffe::Blob<DType> &blob)
+  {
+    nncc::core::ADT::tensor::Shape shape;
+
+    const uint32_t rank = blob.shape().size();
+
+    shape.resize(rank);
+    for (uint32_t axis = 0; axis < rank; ++axis)
+    {
+      shape.dim(axis) = blob.shape(axis);
+    }
+
+    return shape;
+  }
+
+public:
+  uint32_t size(void) const override
+  {
+    return _blobs.size();
+  }
+
+  std::string name(uint32_t n) const override
+  {
+    return _blobs.name(n);
+  }
+
+  nncc::core::ADT::tensor::Shape shape(uint32_t n) const override
+  {
+    return shapeOf(*_blobs.blob(n));
+  }
+
+  // Float (fp32) tensor support
+  bool isFloatTensor(uint32_t n) const override { return std::is_same<DType, float>::value; }
+
+  void getMutableFloatTensor(uint32_t n, const TensorContext::TypedAccessor<float> &f) override
+  {
+    if (!std::is_same<DType, float>::value)
+    {
+      throw std::runtime_error{"type mismatch"};
+    }
+
+    using nncc::core::ADT::tensor::LexicalLayout;
+    using nncc::core::ADT::tensor::make_overlay;
+
+    auto base = _blobs.region(n);
+    auto view = make_overlay<float, LexicalLayout>(shape(n), base);
+
+    f(*this, n, view);
+  }
+
+  void getConstFloatTensor(uint32_t n, const TensorContext::TypedReader<float> &f) const override
+  {
+    if (!std::is_same<DType, float>::value)
+    {
+      throw std::runtime_error{"type mismatch"};
+    }
+
+    using nncc::core::ADT::tensor::LexicalLayout;
+    using nncc::core::ADT::tensor::make_overlay;
+
+    auto base = _blobs.region(n);
+    auto view = make_overlay<float, LexicalLayout>(shape(n), base);
+
+    f(*this, n, view);
+  }
+
+private:
+  BlobContext<DType> &_blobs;
+};
+
+} // namespace caffe
+} // namespace support
+} // namespace nnkit
+
+#endif // __NNKIT_SUPPORT_CAFFE_TENSOR_CONTEXT_H__