#include "PoolingSpec.h"
+#include <map>
#include <cassert>
PoolingSpec::PoolingSpec(const ::caffe::PoolingParameter ¶m) : _param(param)
// DO NOTHING
}
+PoolingMethod PoolingSpec::method(void) const
+{
+ if (!_param.has_pool())
+ {
+ // Default pooling method is MAX
+ // Reference: http://caffe.berkeleyvision.org/tutorial/layers/pooling.html
+ return PoolingMethod::Max;
+ }
+
+ std::map<::caffe::PoolingParameter_PoolMethod, PoolingMethod> methods;
+
+ // NOTE STOCHASTIC Pooling is not supported, yet
+ // TODO Support STOCHASTIC Pooling
+ methods[::caffe::PoolingParameter_PoolMethod_MAX] = PoolingMethod::Max;
+ methods[::caffe::PoolingParameter_PoolMethod_AVE] = PoolingMethod::Avg;
+
+ assert(_param.has_pool());
+ return methods.at(_param.pool());
+}
+
uint32_t PoolingSpec::window_height(void) const
{
// NOTE Global pooling is not supported, yet
#include <nncc/core/ADT/tensor/Shape.h>
+enum class PoolingMethod
+{
+ Max,
+ Avg
+};
+
class PoolingSpec
{
public:
void ifm_shape(const nncc::core::ADT::tensor::Shape &shape) { _ifm_shape = shape; }
public:
+ PoolingMethod method(void) const;
+
+public:
uint32_t window_height(void) const;
uint32_t window_width(void) const;
namespace
{
+bool from_txt(std::istream &is, ::caffe::PoolingParameter &pooling)
+{
+ ::google::protobuf::io::IstreamInputStream iis{&is};
+ return google::protobuf::TextFormat::Parse(&iis, &pooling);
+}
+
+template <typename T> bool from_txt(const std::string &txt, T &out)
+{
+ std::stringstream ss{txt};
+ return from_txt(ss, out);
+}
+
class SequentialBuilder
{
public:
ASSERT_EQ(expected, obtained);
}
}
+
+TEST_F(PoolingSpecTest, method_none)
+{
+ const char *prototxt = "";
+
+ ::caffe::PoolingParameter param;
+ from_txt(prototxt, param);
+
+ PoolingSpec spec{param};
+
+ ASSERT_EQ(spec.method(), PoolingMethod::Max);
+}
+
+TEST_F(PoolingSpecTest, method_max)
+{
+ const char *prototxt = "pool: MAX";
+
+ ::caffe::PoolingParameter param;
+ from_txt(prototxt, param);
+
+ PoolingSpec spec{param};
+
+ ASSERT_EQ(spec.method(), PoolingMethod::Max);
+}
+
+TEST_F(PoolingSpecTest, method_avg)
+{
+ const char *prototxt = "pool: AVE";
+
+ ::caffe::PoolingParameter param;
+ from_txt(prototxt, param);
+
+ PoolingSpec spec{param};
+
+ ASSERT_EQ(spec.method(), PoolingMethod::Avg);
+}