[enco] Decide pooling method from spec (#1339)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 5 Sep 2018 04:09:57 +0000 (13:09 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 5 Sep 2018 04:09:57 +0000 (13:09 +0900)
This commit revises PoolingSpec to correctly decide pooling method from
a model specification given as prototxt.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/enco/frontend/caffe/src/PoolingSpec.cpp
contrib/enco/frontend/caffe/src/PoolingSpec.h
contrib/enco/frontend/caffe/src/PoolingSpec.test.cpp

index 7c5915e..232ba4d 100644 (file)
@@ -1,5 +1,6 @@
 #include "PoolingSpec.h"
 
+#include <map>
 #include <cassert>
 
 PoolingSpec::PoolingSpec(const ::caffe::PoolingParameter &param) : _param(param)
@@ -7,6 +8,26 @@ PoolingSpec::PoolingSpec(const ::caffe::PoolingParameter &param) : _param(param)
   // DO NOTHING
 }
 
+PoolingMethod PoolingSpec::method(void) const
+{
+  if (!_param.has_pool())
+  {
+    // Default pooling method is MAX
+    // Reference: http://caffe.berkeleyvision.org/tutorial/layers/pooling.html
+    return PoolingMethod::Max;
+  }
+
+  std::map<::caffe::PoolingParameter_PoolMethod, PoolingMethod> methods;
+
+  // NOTE STOCHASTIC Pooling is not supported, yet
+  // TODO Support STOCHASTIC Pooling
+  methods[::caffe::PoolingParameter_PoolMethod_MAX] = PoolingMethod::Max;
+  methods[::caffe::PoolingParameter_PoolMethod_AVE] = PoolingMethod::Avg;
+
+  assert(_param.has_pool());
+  return methods.at(_param.pool());
+}
+
 uint32_t PoolingSpec::window_height(void) const
 {
   // NOTE Global pooling is not supported, yet
index d2ffa3a..13fd67e 100644 (file)
@@ -5,6 +5,12 @@
 
 #include <nncc/core/ADT/tensor/Shape.h>
 
+enum class PoolingMethod
+{
+  Max,
+  Avg
+};
+
 class PoolingSpec
 {
 public:
@@ -15,6 +21,9 @@ public:
   void ifm_shape(const nncc::core::ADT::tensor::Shape &shape) { _ifm_shape = shape; }
 
 public:
+  PoolingMethod method(void) const;
+
+public:
   uint32_t window_height(void) const;
   uint32_t window_width(void) const;
 
index 0372896..c1e3ad6 100644 (file)
@@ -23,6 +23,18 @@ using nncc::foundation::make_unique;
 namespace
 {
 
+bool from_txt(std::istream &is, ::caffe::PoolingParameter &pooling)
+{
+  ::google::protobuf::io::IstreamInputStream iis{&is};
+  return google::protobuf::TextFormat::Parse(&iis, &pooling);
+}
+
+template <typename T> bool from_txt(const std::string &txt, T &out)
+{
+  std::stringstream ss{txt};
+  return from_txt(ss, out);
+}
+
 class SequentialBuilder
 {
 public:
@@ -192,3 +204,39 @@ TEST_F(PoolingSpecTest, stride_for_all)
     ASSERT_EQ(expected, obtained);
   }
 }
+
+TEST_F(PoolingSpecTest, method_none)
+{
+  const char *prototxt = "";
+
+  ::caffe::PoolingParameter param;
+  from_txt(prototxt, param);
+
+  PoolingSpec spec{param};
+
+  ASSERT_EQ(spec.method(), PoolingMethod::Max);
+}
+
+TEST_F(PoolingSpecTest, method_max)
+{
+  const char *prototxt = "pool: MAX";
+
+  ::caffe::PoolingParameter param;
+  from_txt(prototxt, param);
+
+  PoolingSpec spec{param};
+
+  ASSERT_EQ(spec.method(), PoolingMethod::Max);
+}
+
+TEST_F(PoolingSpecTest, method_avg)
+{
+  const char *prototxt = "pool: AVE";
+
+  ::caffe::PoolingParameter param;
+  from_txt(prototxt, param);
+
+  PoolingSpec spec{param};
+
+  ASSERT_EQ(spec.method(), PoolingMethod::Avg);
+}