Split adaptive_pool2d_avg into sum and div (#4186)
authorYao Wang <kevinthesunwy@gmail.com>
Thu, 24 Oct 2019 15:37:56 +0000 (08:37 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 24 Oct 2019 15:37:56 +0000 (08:37 -0700)
topi/include/topi/nn/pooling.h
topi/python/topi/x86/pooling.py

index 289452e..ca35e6e 100644 (file)
@@ -492,7 +492,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
       return tvm::max(x(indices), { dheight, dwidth });  // NOLINT(*)
     }, "tensor", "adaptive_pool_max");
   } else if (pool_type == kAvgPool) {
-    return tvm::compute(out_shape, [&](const Array<Var>& output) {
+    auto pool_sum = tvm::compute(out_shape, [&](const Array<Var>& output) {
       Array<Expr> indices;
       for (const Var& var : output) indices.push_back(var);
       auto i_start_h = start_index(output[height_axis], out_height, height);
@@ -505,8 +505,20 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
       auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
       indices.Set(height_axis, i_start_h + dheight);
       indices.Set(width_axis, i_start_w + dwidth);
-      return tvm::sum(div(x(indices), divide_factor), { dheight, dwidth });
-    }, "tensor", "adaptive_pool_avg");
+      return tvm::sum(x(indices), { dheight, dwidth });
+    }, "tensor", "adaptive_pool_sum");
+
+    return tvm::compute(out_shape, [&](const Array<Var>& output) {
+      Array<Expr> indices;
+      for (const Var& var : output) indices.push_back(var);
+      auto i_start_h = start_index(output[height_axis], out_height, height);
+      auto i_end_h = end_index(output[height_axis], out_height, height);
+      auto i_start_w = start_index(output[width_axis], out_width, width);
+      auto i_end_w = end_index(output[width_axis], out_width, width);
+      auto divide_factor = tvm::cast(x->dtype, (i_end_h - i_start_h)
+                                               * (i_end_w - i_start_w));
+      return div(pool_sum(indices), divide_factor);
+    }, "tensor", kElementWise);
   } else {
     LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
     return x;
index ac19b19..e9f832d 100644 (file)
@@ -147,6 +147,11 @@ def schedule_adaptive_pool(outs):
                     traverse(tensor.op)
         # schedule pool
         elif OP.tag.startswith('adaptive_pool'):
+            if OP != outs[0].op:
+                output = outs[0]
+                output_fused = s[output].fuse(output.op.axis[0], output.op.axis[1])
+                s[output].parallel(output_fused)
+
             Pool = OP.output(0)
             _parallel_sch(s[Pool], outs[0].shape)
         else: