[ARM CPU] Fix infer shape error of depthwise (#4384)
authorZhao Wu <wuzhaozju@gmail.com>
Wed, 27 Nov 2019 06:42:20 +0000 (14:42 +0800)
committerThierry Moreau <moreau@uw.edu>
Wed, 27 Nov 2019 06:42:20 +0000 (22:42 -0800)
* [ARM CPU] Fix contrib_spatial_pack error

* PyLint error fix

* diable no-else-return as other files

* Change the test case split OC not be 1 to cover 5D weight layout

python/tvm/relay/op/nn/_nn.py
tests/python/relay/test_op_level2.py

index cb2ecc9..08ef1bf 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, unused-argument, too-many-arguments
+# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
@@ -163,10 +163,17 @@ def compute_conv2d(attrs, inputs, out_type, target):
 
     def _get_out_depth():
         weight_shape = get_const_tuple(inputs[1].shape)
+        # NHWC layout
         if kernel_layout.startswith("HW"):
             return weight_shape[2] * weight_shape[3]
-        return weight_shape[0] * weight_shape[1]
-
+        # NCHW layout.
+        # in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout
+        if len(weight_shape) == 4:
+            return weight_shape[0] * weight_shape[1]
+        else:
+            assert len(weight_shape) == 5
+            C, M, _, _, VC = weight_shape
+            return C * VC * M
     if groups == 1:
         out = topi.nn.conv2d(
             inputs[0], inputs[1], strides, padding,
index b54efaa..4099d19 100644 (file)
@@ -158,7 +158,7 @@ def test_conv2d_run():
                         ["depthwise_conv2d_nchw", [1, 512, 32, 32, "float32"], \
                         [512, 1, 3, 3, "float32"], [1, 1], [1, 1], [1, 1], "float32"], \
                         {"i": 743640, "t": "contrib_spatial_pack", "c": null, \
-                        "e": [["tile_co", "sp", [512, 1]], ["tile_oh", "sp", [8, 1]], \
+                        "e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [8, 1]], \
                         ["tile_ow", "sp", [1, 8]], \
                         ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 8, 6, 7]], \
                         ["reorder_1", "re", [0, 1, 2, 3, 6, 4, 5]], \