s[C].parallel(parallel_axis)
s[CC].compute_at(s[C], ow_chunk)
+ # the ow axis in the cached block CC is the ow_block in C
_, ic_chunk, oh, ow, ic_block = s[CC].op.axis
kh, kw = s[CC].op.reduce_axis
- ow_chunk, ow_block = s[CC].split(ow, factor=tile_ow)
- s[CC].reorder(ic_chunk, oh, kh, kw, ow_block, ic_block)
+ s[CC].reorder(ic_chunk, oh, kh, kw, ow, ic_block)
s[CC].vectorize(ic_block)
- s[CC].unroll(ow_block)
+ s[CC].unroll(ow)
if C != O:
out_ndim = len(s[O].op.axis)