[enco] Build AvgPool2D op from caffe model (#1341)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 5 Sep 2018 06:19:38 +0000 (15:19 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 5 Sep 2018 06:19:38 +0000 (15:19 +0900)
With this commit, enco caffe frontend is now able to build a simple
AvgPool2D op from caffe model.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/enco/frontend/caffe/src/Frontend.cpp

index b424e9d..ab9ab95 100644 (file)
@@ -295,17 +295,39 @@ enco::Bundle Frontend::load(void) const
       ofm_obj->bag(ofm_bag);
       ofm_obj->reorder<feature::CHWLayout>();
 
-      // TODO Support other pooling method
-      assert(!param.has_pool() || (param.pool() == ::caffe::PoolingParameter_PoolMethod_MAX));
+      using PoolingOpBuilder = std::function<coco::Op *(coco::Module * m, const PoolingSpec &spec)>;
 
-      // Create a MaxPool2D op
-      auto op = m->entity()->op()->create<coco::MaxPool2D>();
+      std::map<PoolingMethod, PoolingOpBuilder> builders;
 
-      op->window()->vertical(spec.window_height());
-      op->window()->horizontal(spec.window_width());
+      // MaxPool2D op builder
+      builders[PoolingMethod::Max] = [](coco::Module *m, const PoolingSpec &spec) {
+        auto op = m->entity()->op()->create<coco::MaxPool2D>();
 
-      op->stride()->vertical(spec.vertical_stride());
-      op->stride()->horizontal(spec.horizontal_stride());
+        op->window()->vertical(spec.window_height());
+        op->window()->horizontal(spec.window_width());
+
+        op->stride()->vertical(spec.vertical_stride());
+        op->stride()->horizontal(spec.horizontal_stride());
+
+        return op;
+      };
+
+      // AvgPool2D op builder
+      builders[PoolingMethod::Avg] = [](coco::Module *m, const PoolingSpec &spec) {
+        auto op = m->entity()->op()->create<coco::AvgPool2D>();
+
+        op->window()->vertical(spec.window_height());
+        op->window()->horizontal(spec.window_height());
+
+        assert(spec.vertical_stride() == 1);
+        assert(spec.horizontal_stride() == 1);
+
+        return op;
+      };
+
+      // Create a pooling op
+      auto builder = builders.at(spec.method());
+      auto op = builder(m.get(), spec);
 
       // Create a UnitF instruction
       auto ins = m->entity()->instr()->create<coco::UnitF>();