# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name, unused-argument, too-many-arguments
+# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments
"""Backend compiler related feature registration"""
from __future__ import absolute_import
def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
+ # NHWC layout
if kernel_layout.startswith("HW"):
return weight_shape[2] * weight_shape[3]
- return weight_shape[0] * weight_shape[1]
-
+ # NCHW layout.
+ # in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout
+ if len(weight_shape) == 4:
+ return weight_shape[0] * weight_shape[1]
+ else:
+ assert len(weight_shape) == 5
+ C, M, _, _, VC = weight_shape
+ return C * VC * M
if groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding,
["depthwise_conv2d_nchw", [1, 512, 32, 32, "float32"], \
[512, 1, 3, 3, "float32"], [1, 1], [1, 1], [1, 1], "float32"], \
{"i": 743640, "t": "contrib_spatial_pack", "c": null, \
- "e": [["tile_co", "sp", [512, 1]], ["tile_oh", "sp", [8, 1]], \
+ "e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [8, 1]], \
["tile_ow", "sp", [1, 8]], \
["reorder_0", "re", [0, 1, 2, 3, 4, 5, 8, 6, 7]], \
["reorder_1", "re", [0, 1, 2, 3, 6, 4, 5]], \