return tvm::max(x(indices), { dheight, dwidth }); // NOLINT(*)
}, "tensor", "adaptive_pool_max");
} else if (pool_type == kAvgPool) {
- return tvm::compute(out_shape, [&](const Array<Var>& output) {
+ auto pool_sum = tvm::compute(out_shape, [&](const Array<Var>& output) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
indices.Set(height_axis, i_start_h + dheight);
indices.Set(width_axis, i_start_w + dwidth);
- return tvm::sum(div(x(indices), divide_factor), { dheight, dwidth });
- }, "tensor", "adaptive_pool_avg");
+ return tvm::sum(x(indices), { dheight, dwidth });
+ }, "tensor", "adaptive_pool_sum");
+
+ return tvm::compute(out_shape, [&](const Array<Var>& output) {
+ Array<Expr> indices;
+ for (const Var& var : output) indices.push_back(var);
+ auto i_start_h = start_index(output[height_axis], out_height, height);
+ auto i_end_h = end_index(output[height_axis], out_height, height);
+ auto i_start_w = start_index(output[width_axis], out_width, width);
+ auto i_end_w = end_index(output[width_axis], out_width, width);
+ auto divide_factor = tvm::cast(x->dtype, (i_end_h - i_start_h)
+ * (i_end_w - i_start_w));
+ return div(pool_sum(indices), divide_factor);
+ }, "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;