ofm_obj->bag(ofm_bag);
ofm_obj->reorder<feature::CHWLayout>();
- // TODO Support other pooling method
- assert(!param.has_pool() || (param.pool() == ::caffe::PoolingParameter_PoolMethod_MAX));
+ using PoolingOpBuilder = std::function<coco::Op *(coco::Module * m, const PoolingSpec &spec)>;
- // Create a MaxPool2D op
- auto op = m->entity()->op()->create<coco::MaxPool2D>();
+ std::map<PoolingMethod, PoolingOpBuilder> builders;
- op->window()->vertical(spec.window_height());
- op->window()->horizontal(spec.window_width());
+ // MaxPool2D op builder
+ builders[PoolingMethod::Max] = [](coco::Module *m, const PoolingSpec &spec) {
+ auto op = m->entity()->op()->create<coco::MaxPool2D>();
- op->stride()->vertical(spec.vertical_stride());
- op->stride()->horizontal(spec.horizontal_stride());
+ op->window()->vertical(spec.window_height());
+ op->window()->horizontal(spec.window_width());
+
+ op->stride()->vertical(spec.vertical_stride());
+ op->stride()->horizontal(spec.horizontal_stride());
+
+ return op;
+ };
+
+ // AvgPool2D op builder
+ builders[PoolingMethod::Avg] = [](coco::Module *m, const PoolingSpec &spec) {
+ auto op = m->entity()->op()->create<coco::AvgPool2D>();
+
+ op->window()->vertical(spec.window_height());
+ op->window()->horizontal(spec.window_height());
+
+ assert(spec.vertical_stride() == 1);
+ assert(spec.horizontal_stride() == 1);
+
+ return op;
+ };
+
+ // Create a pooling op
+ auto builder = builders.at(spec.method());
+ auto op = builder(m.get(), spec);
// Create a UnitF instruction
auto ins = m->entity()->instr()->create<coco::UnitF>();