Add NHWC order support in the cost inference function of 3d conv (#19170)
authorSummer Deng <summerdeng@fb.com>
Mon, 15 Apr 2019 23:43:58 +0000 (16:43 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 15 Apr 2019 23:47:22 +0000 (16:47 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19170

As title
The quantized resnext3d model in production got the following failures without the fix:

```
 Caffe2 operator Int8ConvRelu logging error: [enforce fail at conv_pool_op_base.h:463] order == StorageOrder::NCHW. 1 vs 2. Conv3D only supports NCHW on the production quantized model
```

Reviewed By: jspark1105

Differential Revision: D14894276

fbshipit-source-id: ef97772277f322ed45215e382c3b4a3702e47e59

caffe2/operators/conv_pool_op_base.h

index 5fb07d2..c1dd289 100644 (file)
@@ -460,15 +460,25 @@ class ConvPoolOpBase : public Operator<Context> {
     N = X.dims(0);
     if (X.dims_size() == 5) {
       // 3D convolution
-      CAFFE_ENFORCE_EQ(order, StorageOrder::NCHW, "Conv3D only supports NCHW");
-      Y_t = Y.dims(2);
-      Y_h = Y.dims(3);
-      Y_w = Y.dims(4);
-      kernel_t = W.dims(2);
-      kernel_h = W.dims(3);
-      kernel_w = W.dims(4);
-      in_channels = W.dims(1);
-      out_channels = W.dims(0);
+      if (order == StorageOrder::NHWC) {
+        Y_t = Y.dims(1);
+        Y_h = Y.dims(2);
+        Y_w = Y.dims(3);
+        kernel_t = W.dims(1);
+        kernel_h = W.dims(2);
+        kernel_w = W.dims(3);
+        in_channels = W.dims(4);
+        out_channels = W.dims(0);
+      } else {
+        Y_t = Y.dims(2);
+        Y_h = Y.dims(3);
+        Y_w = Y.dims(4);
+        kernel_t = W.dims(2);
+        kernel_h = W.dims(3);
+        kernel_w = W.dims(4);
+        in_channels = W.dims(1);
+        out_channels = W.dims(0);
+      }
     } else if (X.dims_size() == 4) {
       // 2D convolution
       CAFFE_ENFORCE_EQ(W.dims_size(), 4, "Conv2D should have 4D filter tensor");