Add NHWC support to Resize Operator (#15553)
authorDavid Carrillo Cisneros <davidca@fb.com>
Wed, 9 Jan 2019 00:04:41 +0000 (16:04 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 9 Jan 2019 00:44:17 +0000 (16:44 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15553

Add unit test and implementation of NHWC layout for Resize operator.

Also, add pragma parallel loop to old NCHWC layout.

Reviewed By: jspark1105

Differential Revision: D13540762

fbshipit-source-id: eebf252bf0d1efdff180a171d804181045f100a5

caffe2/operators/resize_op.cc
caffe2/operators/resize_op.h
caffe2/python/operator_test/resize_op_test.py

index 2982c76..99c1f93 100644 (file)
@@ -10,7 +10,7 @@
 
 namespace caffe2 {
 
-void resizeNearest2x(
+void resizeNearestNCHW2x(
     int batch_size,
     int num_channels,
     int input_height,
@@ -59,7 +59,7 @@ void resizeNearest2x(
 }
 
 template <>
-bool ResizeNearestOp<float, CPUContext>::RunOnDevice() {
+bool ResizeNearestOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
   const auto& X = Input(0);
 
   const int batch_size = X.dim32(0),
@@ -87,7 +87,7 @@ bool ResizeNearestOp<float, CPUContext>::RunOnDevice() {
 
   // Specialized implementation for fast 2x upsampling
   if (width_scale_ == 2.0 && height_scale_ == 2.0) {
-    resizeNearest2x(
+    resizeNearestNCHW2x(
         batch_size, num_channels, input_height, input_width, Xdata, Ydata);
     return true;
   }
@@ -110,7 +110,66 @@ bool ResizeNearestOp<float, CPUContext>::RunOnDevice() {
 }
 
 template <>
-bool ResizeNearestGradientOp<float, CPUContext>::RunOnDevice() {
+bool ResizeNearestOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
+  const auto& X = Input(0);
+
+  const int batch_size = X.dim32(0), input_height = X.dim32(1),
+            input_width = X.dim32(2), num_channels = X.dim32(3);
+  if (InputSize() == 2) {
+    const auto& scales = Input(1);
+    CAFFE_ENFORCE_EQ(scales.dim(), 1);
+    CAFFE_ENFORCE_EQ(scales.numel(), 2);
+    const float* scales_data = scales.data<float>();
+    height_scale_ = scales_data[0];
+    width_scale_ = scales_data[1];
+  }
+
+  int output_width = input_width * width_scale_;
+  int output_height = input_height * height_scale_;
+
+  const int output_width_stride = output_width * num_channels;
+  const int input_width_stride = input_width * num_channels;
+
+  auto* Y = Output(
+      0,
+      {batch_size, output_height, output_width, num_channels},
+      at::dtype<float>());
+
+  const float* Xdata = X.data<float>();
+  float* Ydata = Y->template mutable_data<float>();
+
+  for (int n = 0; n < batch_size; ++n) {
+    for (int y = 0; y < output_height; ++y) {
+      const int in_y = std::min((int)(y / height_scale_), (input_height - 1));
+      for (int x = 0; x < output_width; ++x) {
+        const int in_x = std::min((int)(x / width_scale_), (input_width - 1));
+        std::memcpy(
+            &Ydata[output_width_stride * y + num_channels * x],
+            &Xdata[input_width_stride * in_y + num_channels * in_x],
+            num_channels * sizeof(float));
+      }
+    }
+    Xdata += input_height * input_width_stride;
+    Ydata += output_height * output_width_stride;
+  }
+
+  return true;
+}
+
+template <>
+bool ResizeNearestOp<float, CPUContext>::RunOnDevice() {
+  switch (order_) {
+    case StorageOrder::NHWC:
+      return RunOnDeviceWithOrderNHWC();
+    case StorageOrder::NCHW:
+      return RunOnDeviceWithOrderNCHW();
+    default:
+      CAFFE_THROW("Unknown Storage order: ", order_);
+  }
+}
+
+template <>
+bool ResizeNearestGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
   const auto& dY = Input(0);
   const auto& X = Input(1);
 
@@ -159,6 +218,72 @@ bool ResizeNearestGradientOp<float, CPUContext>::RunOnDevice() {
   return true;
 }
 
+template <>
+bool ResizeNearestGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
+  const auto& dY = Input(0);
+  const auto& X = Input(1);
+
+  const auto inputDims = dY.sizes();
+  CAFFE_ENFORCE_EQ(4, inputDims.size());
+  const int batch_size = dY.dim32(0), input_height = dY.dim32(1),
+            input_width = dY.dim32(2), num_channels = dY.dim32(3);
+  const int output_height = X.dim32(1);
+  const int output_width = X.dim32(2);
+  if (InputSize() == 3) {
+    const auto& scales = Input(2);
+    CAFFE_ENFORCE_EQ(scales.dim(), 1);
+    CAFFE_ENFORCE_EQ(scales.numel(), 2);
+    const float* scales_data = scales.data<float>();
+    height_scale_ = scales_data[0];
+    width_scale_ = scales_data[1];
+  }
+  auto* dX = Output(
+      0,
+      {batch_size, output_height, output_width, num_channels},
+      at::dtype<float>());
+  math::Set<float, CPUContext>(
+      dX->numel(), 0.0f, dX->template mutable_data<float>(), &context_);
+
+  const int output_width_stride = output_width * num_channels;
+  const int input_width_stride = input_width * num_channels;
+
+  const float* dYdata = dY.data<float>();
+  float* dXdata = dX->template mutable_data<float>();
+
+  for (int n = 0; n < batch_size; ++n) {
+    for (int y = 0; y < input_height; ++y) {
+      const int out_y = std::min((int)(y / height_scale_), (output_height - 1));
+      for (int x = 0; x < input_width; ++x) {
+        const int out_x = std::min((int)(x / width_scale_), (output_width - 1));
+
+        float* dXdata_c0 =
+            dXdata + output_width_stride * out_y + num_channels * out_x;
+        const float* dYdata_c0 =
+            dYdata + input_width_stride * y + num_channels * x;
+
+        for (int c = 0; c < num_channels; ++c) {
+          dXdata_c0[c] += dYdata_c0[c];
+        }
+      }
+    }
+    dYdata += input_height * input_width_stride;
+    dXdata += output_height * output_width_stride;
+  }
+
+  return true;
+}
+
+template <>
+bool ResizeNearestGradientOp<float, CPUContext>::RunOnDevice() {
+  switch (order_) {
+    case StorageOrder::NHWC:
+      return RunOnDeviceWithOrderNHWC();
+    case StorageOrder::NCHW:
+      return RunOnDeviceWithOrderNCHW();
+    default:
+      CAFFE_THROW("Unknown Storage order: ", order_);
+  }
+}
 REGISTER_CPU_OPERATOR(ResizeNearest, ResizeNearestOp<float, CPUContext>);
 REGISTER_CPU_GRADIENT_OPERATOR(
     ResizeNearestGradient,
index 5e1e8c6..32fad09 100644 (file)
@@ -10,7 +10,11 @@ template <typename T, class Context>
 class ResizeNearestOp final : public Operator<Context> {
  public:
   ResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
-      : Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
+      : Operator<Context>(operator_def, ws),
+        width_scale_(1),
+        height_scale_(1),
+        order_(StringToStorageOrder(
+            this->template GetSingleArgument<std::string>("order", "NCHW"))) {
     if (HasArgument("width_scale")) {
       width_scale_ = static_cast<T>(
           this->template GetSingleArgument<float>("width_scale", 1));
@@ -22,21 +26,31 @@ class ResizeNearestOp final : public Operator<Context> {
 
     CAFFE_ENFORCE_GT(width_scale_, 0);
     CAFFE_ENFORCE_GT(height_scale_, 0);
+
+    CAFFE_ENFORCE(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
   }
+
   USE_OPERATOR_CONTEXT_FUNCTIONS;
 
   bool RunOnDevice() override;
+  bool RunOnDeviceWithOrderNCHW();
+  bool RunOnDeviceWithOrderNHWC();
 
  protected:
   T width_scale_;
   T height_scale_;
+  StorageOrder order_;
 };
 
 template <typename T, class Context>
 class ResizeNearestGradientOp final : public Operator<Context> {
  public:
   ResizeNearestGradientOp(const OperatorDef& operator_def, Workspace* ws)
-      : Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
+      : Operator<Context>(operator_def, ws),
+        width_scale_(1),
+        height_scale_(1),
+        order_(StringToStorageOrder(
+            this->template GetSingleArgument<std::string>("order", "NCHW"))) {
     width_scale_ = static_cast<T>(
         this->template GetSingleArgument<float>("width_scale", 1));
     height_scale_ = static_cast<T>(
@@ -44,14 +58,20 @@ class ResizeNearestGradientOp final : public Operator<Context> {
 
     CAFFE_ENFORCE_GT(width_scale_, 0);
     CAFFE_ENFORCE_GT(height_scale_, 0);
+
+    CAFFE_ENFORCE(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
   }
+
   USE_OPERATOR_CONTEXT_FUNCTIONS;
 
   bool RunOnDevice() override;
+  bool RunOnDeviceWithOrderNCHW();
+  bool RunOnDeviceWithOrderNHWC();
 
  protected:
   T width_scale_;
   T height_scale_;
+  StorageOrder order_;
 };
 
 } // namespace caffe2
index c349ac1..acbc3b2 100644 (file)
@@ -7,7 +7,9 @@ import hypothesis.strategies as st
 import unittest
 import caffe2.python.hypothesis_test_util as hu
 from caffe2.python import core
+from caffe2.proto import caffe2_pb2
 from hypothesis import given
+from hypothesis import assume
 
 
 class TestResize(hu.HypothesisTestCase):
@@ -18,11 +20,17 @@ class TestResize(hu.HypothesisTestCase):
            num_channels=st.integers(1, 4),
            batch_size=st.integers(1, 4),
            seed=st.integers(0, 65535),
+           order=st.sampled_from(["NCHW", "NHWC"]),
            **hu.gcs)
     def test_nearest(self, height_scale, width_scale, height, width,
-                     num_channels, batch_size, seed,
+                     num_channels, batch_size, seed, order,
                      gc, dc):
 
+        assume(order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
+        # NHWC currently only supported for CPU. Ignore other devices.
+        if order == "NHWC":
+            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
+
         np.random.seed(seed)
         op = core.CreateOperator(
             "ResizeNearest",
@@ -30,10 +38,13 @@ class TestResize(hu.HypothesisTestCase):
             ["Y"],
             width_scale=width_scale,
             height_scale=height_scale,
+            order=order,
         )
 
         X = np.random.rand(
             batch_size, num_channels, height, width).astype(np.float32)
+        if order == "NHWC":
+            X = X.transpose([0, 2, 3, 1])
 
         def ref(X):
             output_height = np.int32(height * height_scale)
@@ -48,7 +59,10 @@ class TestResize(hu.HypothesisTestCase):
             input_w_idxs = np.minimum(
                 output_w_idxs / width_scale, width - 1).astype(np.int32)
 
-            Y = X[:, :, input_h_idxs, input_w_idxs]
+            if order == "NCHW":
+                Y = X[:, :, input_h_idxs, input_w_idxs]
+            else:
+                Y = X[:, input_h_idxs, input_w_idxs, :]
 
             return Y,
 
@@ -63,9 +77,15 @@ class TestResize(hu.HypothesisTestCase):
            num_channels=st.integers(1, 4),
            batch_size=st.integers(1, 4),
            seed=st.integers(0, 65535),
+           order=st.sampled_from(["NCHW", "NHWC"]),
            **hu.gcs)
     def test_nearest_grad(self, height_scale, width_scale, height, width,
-                          num_channels, batch_size, seed, gc, dc):
+                          num_channels, batch_size, seed, order, gc, dc):
+
+        assume(order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
+        # NHWC currently only supported for CPU. Ignore other devices.
+        if order == "NHWC":
+            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
 
         np.random.seed(seed)
 
@@ -75,10 +95,14 @@ class TestResize(hu.HypothesisTestCase):
                            num_channels,
                            height,
                            width).astype(np.float32)
+
         dY = np.random.rand(batch_size,
                             num_channels,
                             output_height,
                             output_width).astype(np.float32)
+        if order == "NHWC":
+            X = X.transpose([0, 2, 3, 1])
+            dY = dY.transpose([0, 2, 3, 1])
 
         op = core.CreateOperator(
             "ResizeNearestGradient",
@@ -86,6 +110,7 @@ class TestResize(hu.HypothesisTestCase):
             ["dX"],
             width_scale=width_scale,
             height_scale=height_scale,
+            order=order,
         )
 
         def ref(dY, X):
@@ -95,8 +120,10 @@ class TestResize(hu.HypothesisTestCase):
                 for j in range(output_width):
                     input_i = np.minimum(i / height_scale, height - 1).astype(np.int32)
                     input_j = np.minimum(j / width_scale, width - 1).astype(np.int32)
-                    dX[:, :, input_i, input_j] += dY[:, :, i, j]
-
+                    if order == "NCHW":
+                        dX[:, :, input_i, input_j] += dY[:, :, i, j]
+                    else:
+                        dX[:, input_i, input_j, :] += dY[:, i, j, :]
             return dX,
 
         self.assertDeviceChecks(dc, op, [dY, X], [0])
@@ -109,20 +136,30 @@ class TestResize(hu.HypothesisTestCase):
            num_channels=st.integers(1, 4),
            batch_size=st.integers(1, 4),
            seed=st.integers(0, 65535),
+           order=st.sampled_from(["NCHW", "NHWC"]),
            **hu.gcs)
     def test_nearest_onnx(self, height_scale, width_scale, height, width,
-                     num_channels, batch_size, seed,
-                     gc, dc):
+                          num_channels, batch_size, seed, order,
+                          gc, dc):
+
+        assume(order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
+        # NHWC currently only supported for CPU. Ignore other devices.
+        if order == "NHWC":
+            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
 
         np.random.seed(seed)
         op = core.CreateOperator(
             "ResizeNearest",
             ["X", "scales"],
             ["Y"],
+            order=order,
         )
 
         X = np.random.rand(
             batch_size, num_channels, height, width).astype(np.float32)
+        if order == "NHWC":
+            X = X.transpose([0, 2, 3, 1])
+
         scales = np.array([height_scale, width_scale]).astype(np.float32)
 
         def ref(X, scales):
@@ -138,7 +175,10 @@ class TestResize(hu.HypothesisTestCase):
             input_w_idxs = np.minimum(
                 output_w_idxs / scales[1], width - 1).astype(np.int32)
 
-            Y = X[:, :, input_h_idxs, input_w_idxs]
+            if order == "NCHW":
+                Y = X[:, :, input_h_idxs, input_w_idxs]
+            else:
+                Y = X[:, input_h_idxs, input_w_idxs, :]
 
             return Y,
 
@@ -154,9 +194,15 @@ class TestResize(hu.HypothesisTestCase):
            num_channels=st.integers(1, 4),
            batch_size=st.integers(1, 4),
            seed=st.integers(0, 65535),
+           order=st.sampled_from(["NCHW", "NHWC"]),
            **hu.gcs)
     def test_nearest_onnx_grad(self, height_scale, width_scale, height, width,
-                          num_channels, batch_size, seed, gc, dc):
+                               num_channels, batch_size, seed, order, gc, dc):
+
+        assume(order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
+        # NHWC currently only supported for CPU. Ignore other devices.
+        if order == "NHWC":
+            dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
 
         np.random.seed(seed)
 
@@ -170,12 +216,17 @@ class TestResize(hu.HypothesisTestCase):
                             num_channels,
                             output_height,
                             output_width).astype(np.float32)
+        if order == "NHWC":
+            X = X.transpose([0, 2, 3, 1])
+            dY = dY.transpose([0, 2, 3, 1])
+
         scales = np.array([height_scale, width_scale]).astype(np.float32)
 
         op = core.CreateOperator(
             "ResizeNearestGradient",
             ["dY", "X", "scales"],
             ["dX"],
+            order=order,
         )
 
         def ref(dY, X, scales):
@@ -185,7 +236,11 @@ class TestResize(hu.HypothesisTestCase):
                 for j in range(output_width):
                     input_i = np.minimum(i / scales[0], height - 1).astype(np.int32)
                     input_j = np.minimum(j / scales[1], width - 1).astype(np.int32)
-                    dX[:, :, input_i, input_j] += dY[:, :, i, j]
+
+                    if order == "NCHW":
+                        dX[:, :, input_i, input_j] += dY[:, :, i, j]
+                    else:
+                        dX[:, input_i, input_j, :] += dY[:, i, j, :]
 
             return dX,