Support NHWC <--> HWNC permutation.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 13 Mar 2018 04:12:52 +0000 (21:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 04:20:35 +0000 (21:20 -0700)
PiperOrigin-RevId: 188824058

tensorflow/core/kernels/data_format_ops.cc
tensorflow/core/kernels/data_format_ops.h
tensorflow/python/ops/nn_test.py

index fa67545..bea3af9 100644 (file)
@@ -28,6 +28,15 @@ namespace tensorflow {
 typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 
+namespace {
+inline functor::DataFormat FormatNameToEnum(const string& name) {
+  if (name == "NHWC") return functor::DataFormat::NHWC;
+  if (name == "NCHW") return functor::DataFormat::NCHW;
+  if (name == "HWNC") return functor::DataFormat::HWNC;
+  return functor::DataFormat::UNKNOWN;
+}
+}  // namespace
+
 template <typename Device, typename T>
 class DataFormatDimMapOp : public OpKernel {
  public:
@@ -69,12 +78,15 @@ class DataFormatVecPermuteOp : public OpKernel {
     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
     OP_REQUIRES(context,
                 (src_format == "NHWC" && dst_format == "NCHW") ||
-                    (src_format == "NCHW" && dst_format == "NHWC"),
+                    (src_format == "NCHW" && dst_format == "NHWC") ||
+                    (src_format == "NHWC" && dst_format == "HWNC") ||
+                    (src_format == "HWNC" && dst_format == "NHWC"),
                 errors::InvalidArgument(strings::StrCat(
-                    "Current implementation only supports NCHW-to-NHWC and "
-                    "NHWC-to-NCHW format conversion; got source format ",
+                    "Current implementation only supports NHWC<->NCHW and "
+                    "NHWC<->HWNC conversion; got source format ",
                     src_format, " and destination format ", dst_format)));
-    nhwc_to_nchw_ = (src_format == "NHWC") ? true : false;
+    src_format_ = FormatNameToEnum(src_format);
+    dst_format_ = FormatNameToEnum(dst_format);
   }
 
   void Compute(OpKernelContext* context) override {
@@ -106,11 +118,12 @@ class DataFormatVecPermuteOp : public OpKernel {
                    context->allocate_output(0, input.shape(), &output));
     functor::DataFormatVecPermute<Device, T>()(
         context->eigen_device<Device>(), input.flat<T>(), output->flat<T>(),
-        nhwc_to_nchw_);
+        src_format_, dst_format_);
   }
 
  private:
-  bool nhwc_to_nchw_;
+  functor::DataFormat src_format_;
+  functor::DataFormat dst_format_;
 };
 
 #define REGISTER_KERNEL(T)                                                \
@@ -143,11 +156,12 @@ TF_CALL_int32(DECLARE_GPU_SPECS);
 TF_CALL_int64(DECLARE_GPU_SPECS);
 #undef DECLARE_GPU_SPEC
 
-#define DECLARE_GPU_SPEC(T)                                \
-  template <>                                              \
-  void DataFormatVecPermute<GPUDevice, T>::operator()(     \
-      const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
-      typename TTypes<T>::Vec y, bool nhwc_to_nchw);       \
+#define DECLARE_GPU_SPEC(T)                                   \
+  template <>                                                 \
+  void DataFormatVecPermute<GPUDevice, T>::operator()(        \
+      const GPUDevice& d, typename TTypes<T>::ConstFlat x,    \
+      typename TTypes<T>::Vec y, const DataFormat src_format, \
+      const DataFormat dst_format);                           \
   extern template struct DataFormatVecPermute<GPUDevice, T>;
 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
 TF_CALL_int32(DECLARE_GPU_SPECS);
index bf704cc..d27415e 100644 (file)
@@ -23,6 +23,13 @@ limitations under the License.
 namespace tensorflow {
 namespace functor {
 
+enum class DataFormat {
+  UNKNOWN = 0,
+  NHWC,
+  NCHW,
+  HWNC,
+};
+
 // Functor used by DataFormatDimMapOP to do the computations.
 template <typename Device, typename T>
 struct DataFormatDimMap {
@@ -97,15 +104,81 @@ struct VecPermuteNCHWToNHWC {
   }
 };
 
+template <typename T>
+struct VecPermuteNHWCToHWNC {
+  Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
+      typename TTypes<T>::ConstFlat input) const {
+    Eigen::DSizes<Eigen::DenseIndex, 1> result;
+    result[0] = input.dimension(0);
+    return result;
+  }
+  template <typename Output, typename Device>
+  void eval(typename TTypes<T>::ConstFlat input, Output& output,
+            const Device& d) const {
+    if (input.size() == 8) {
+      output.template chip<0>(0).device(d) = input.template chip<0>(2);
+      output.template chip<0>(1).device(d) = input.template chip<0>(3);
+      output.template chip<0>(2).device(d) = input.template chip<0>(4);
+      output.template chip<0>(3).device(d) = input.template chip<0>(5);
+      output.template chip<0>(4).device(d) = input.template chip<0>(0);
+      output.template chip<0>(5).device(d) = input.template chip<0>(1);
+      output.template chip<0>(6).device(d) = input.template chip<0>(6);
+      output.template chip<0>(7).device(d) = input.template chip<0>(7);
+    } else {
+      output.template chip<0>(0).device(d) = input.template chip<0>(1);
+      output.template chip<0>(1).device(d) = input.template chip<0>(2);
+      output.template chip<0>(2).device(d) = input.template chip<0>(0);
+      output.template chip<0>(3).device(d) = input.template chip<0>(3);
+    }
+  }
+};
+
+template <typename T>
+struct VecPermuteHWNCToNHWC {
+  Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
+      typename TTypes<T>::ConstFlat input) const {
+    Eigen::DSizes<Eigen::DenseIndex, 1> result;
+    result[0] = input.dimension(0);
+    return result;
+  }
+  template <typename Output, typename Device>
+  void eval(typename TTypes<T>::ConstFlat input, Output& output,
+            const Device& d) const {
+    if (input.size() == 8) {
+      output.template chip<0>(0).device(d) = input.template chip<0>(4);
+      output.template chip<0>(1).device(d) = input.template chip<0>(5);
+      output.template chip<0>(2).device(d) = input.template chip<0>(0);
+      output.template chip<0>(3).device(d) = input.template chip<0>(1);
+      output.template chip<0>(4).device(d) = input.template chip<0>(2);
+      output.template chip<0>(5).device(d) = input.template chip<0>(3);
+      output.template chip<0>(6).device(d) = input.template chip<0>(6);
+      output.template chip<0>(7).device(d) = input.template chip<0>(7);
+    } else {
+      output.template chip<0>(0).device(d) = input.template chip<0>(2);
+      output.template chip<0>(1).device(d) = input.template chip<0>(0);
+      output.template chip<0>(2).device(d) = input.template chip<0>(1);
+      output.template chip<0>(3).device(d) = input.template chip<0>(3);
+    }
+  }
+};
+
 // Functor used by DataFormatVecPermuteOp to do the computations.
 template <typename Device, typename T>
 struct DataFormatVecPermute {
   void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
-                  typename TTypes<T>::Flat y, bool nhwc_to_nchw) {
-    if (nhwc_to_nchw) {
+                  typename TTypes<T>::Flat y, const DataFormat src_format,
+                  const DataFormat dst_format) {
+    if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) {
       y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>());
-    } else {
+    } else if (src_format == DataFormat::NCHW &&
+               dst_format == DataFormat::NHWC) {
       y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>());
+    } else if (src_format == DataFormat::NHWC &&
+               dst_format == DataFormat::HWNC) {
+      y.device(d) = x.customOp(VecPermuteNHWCToHWNC<T>());
+    } else if (src_format == DataFormat::HWNC &&
+               dst_format == DataFormat::NHWC) {
+      y.device(d) = x.customOp(VecPermuteHWNCToNHWC<T>());
     }
   }
 };
index 21eea3d..af9dae2 100644 (file)
@@ -1049,6 +1049,22 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
       y_val = sess.run(y)
       self.assertAllEqual(y_val, [7, 9, 3, 4])
 
+  def testNHWCToHWNC(self):
+    x_val = [7, 4, 9, 3]
+    x = constant_op.constant(x_val)
+    y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
+    with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+      y_val = sess.run(y)
+      self.assertAllEqual(y_val, [4, 9, 7, 3])
+
+  def testHWNCToNHWC(self):
+    x_val = [7, 4, 9, 3]
+    x = constant_op.constant(x_val)
+    y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
+    with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+      y_val = sess.run(y)
+      self.assertAllEqual(y_val, [9, 7, 4, 3])
+
   def testNHWCToNCHW2D(self):
     x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
     x = constant_op.constant(x_val)
@@ -1057,6 +1073,22 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
       y_val = sess.run(y)
       self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]])
 
+  def testNHWCToHWNC2D(self):
+    x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+    x = constant_op.constant(x_val)
+    y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
+    with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+      y_val = sess.run(y)
+      self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]])
+
+  def testHWNCToNHWC2D(self):
+    x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+    x = constant_op.constant(x_val)
+    y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
+    with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+      y_val = sess.run(y)
+      self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]])
+
   def testNCHWToNHWC2D(self):
     x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
     x = constant_op.constant(x_val)