N = X.dims(0);
if (X.dims_size() == 5) {
// 3D convolution
- CAFFE_ENFORCE_EQ(order, StorageOrder::NCHW, "Conv3D only supports NCHW");
- Y_t = Y.dims(2);
- Y_h = Y.dims(3);
- Y_w = Y.dims(4);
- kernel_t = W.dims(2);
- kernel_h = W.dims(3);
- kernel_w = W.dims(4);
- in_channels = W.dims(1);
- out_channels = W.dims(0);
+ if (order == StorageOrder::NHWC) {
+ Y_t = Y.dims(1);
+ Y_h = Y.dims(2);
+ Y_w = Y.dims(3);
+ kernel_t = W.dims(1);
+ kernel_h = W.dims(2);
+ kernel_w = W.dims(3);
+ in_channels = W.dims(4);
+ out_channels = W.dims(0);
+ } else {
+ Y_t = Y.dims(2);
+ Y_h = Y.dims(3);
+ Y_w = Y.dims(4);
+ kernel_t = W.dims(2);
+ kernel_h = W.dims(3);
+ kernel_w = W.dims(4);
+ in_channels = W.dims(1);
+ out_channels = W.dims(0);
+ }
} else if (X.dims_size() == 4) {
// 2D convolution
CAFFE_ENFORCE_EQ(W.dims_size(), 4, "Conv2D should have 4D filter tensor");