Support permutation from NCHW to NHWC.
authorYao Zhang <yaozhang@google.com>
Tue, 12 Dec 2017 23:29:16 +0000 (15:29 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 12 Dec 2017 23:32:56 +0000 (15:32 -0800)
PiperOrigin-RevId: 178824999

tensorflow/core/kernels/data_format_ops.cc
tensorflow/core/kernels/data_format_ops.h
tensorflow/python/ops/nn_test.py

index 0d427eddf3f7cca928bbb04427f6c53765eaa70f..a6ac1190027783a7f091ad11ee98d7612c318d9e 100644 (file)
@@ -72,15 +72,14 @@ class DataFormatVecPermuteOp : public OpKernel {
     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
     string dst_format;
     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
-    OP_REQUIRES(
-        context, src_format == "NHWC",
-        errors::InvalidArgument(strings::StrCat(
-            "Current implementation doesn't support source data format ",
-            src_format)));
-    OP_REQUIRES(context, dst_format == "NCHW",
+    OP_REQUIRES(context,
+                (src_format == "NHWC" && dst_format == "NCHW") ||
+                    (src_format == "NCHW" && dst_format == "NHWC"),
                 errors::InvalidArgument(strings::StrCat(
-                    "Current implementation doesn't support dst data format ",
-                    dst_format)));
+                    "Current implementation only supports NCHW-to-NHWC and "
+                    "NHWC-to-NCHW format conversion; got source format ",
+                    src_format, " and destination format ", dst_format)));
+    nhwc_to_nchw_ = (src_format == "NHWC") ? true : false;
   }
 
   void Compute(OpKernelContext* context) override {
@@ -96,9 +95,13 @@ class DataFormatVecPermuteOp : public OpKernel {
     Tensor* output = nullptr;
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, input.shape(), &output));
-    functor::DataFormatVecPermute<Device, T>()(
-        context->eigen_device<Device>(), input.vec<T>(), output->vec<T>());
+    functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
+                                               input.vec<T>(), output->vec<T>(),
+                                               nhwc_to_nchw_);
   }
+
+ private:
+  bool nhwc_to_nchw_;
 };
 
 #define REGISTER_KERNEL(T)                                                \
@@ -135,7 +138,7 @@ TF_CALL_int64(DECLARE_GPU_SPECS);
   template <>                                             \
   void DataFormatVecPermute<GPUDevice, T>::operator()(    \
       const GPUDevice& d, typename TTypes<T>::ConstVec x, \
-      typename TTypes<T>::Vec y);                         \
+      typename TTypes<T>::Vec y, bool nhwc_to_nchw);      \
   extern template struct DataFormatVecPermute<GPUDevice, T>;
 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
 TF_CALL_int32(DECLARE_GPU_SPECS);
index 54798cc6abf345f192735beb0fe9d484130ed3f6..6a25823c73d78cd799db48e59d6342d0c3328f5c 100644 (file)
@@ -40,7 +40,7 @@ struct DataFormatDimMap {
 };
 
 template <typename T>
-struct VecPermute {
+struct VecPermuteNHWCToNCHW {
   Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
       typename TTypes<T>::ConstVec input) const {
     Eigen::DSizes<Eigen::DenseIndex, 1> result;
@@ -57,12 +57,34 @@ struct VecPermute {
   }
 };
 
+template <typename T>
+struct VecPermuteNCHWToNHWC {
+  Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
+      typename TTypes<T>::ConstVec input) const {
+    Eigen::DSizes<Eigen::DenseIndex, 1> result;
+    result[0] = input.dimension(0);
+    return result;
+  }
+  template <typename Output, typename Device>
+  void eval(typename TTypes<T>::ConstVec input, Output& output,
+            const Device& d) const {
+    output.template chip<0>(0).device(d) = input.template chip<0>(0);
+    output.template chip<0>(1).device(d) = input.template chip<0>(2);
+    output.template chip<0>(2).device(d) = input.template chip<0>(3);
+    output.template chip<0>(3).device(d) = input.template chip<0>(1);
+  }
+};
+
 // Functor used by DataFormatVecPermuteOp to do the computations.
 template <typename Device, typename T>
 struct DataFormatVecPermute {
   void operator()(const Device& d, typename TTypes<T>::ConstVec x,
-                  typename TTypes<T>::Vec y) {
-    y.device(d) = x.customOp(VecPermute<T>());
+                  typename TTypes<T>::Vec y, bool nhwc_to_nchw) {
+    if (nhwc_to_nchw) {
+      y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>());
+    } else {
+      y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>());
+    }
   }
 };
 
index 8dfd0740bb180d1ac676a7ca353a27fd63b84846..38b8430996bb54783f2c34d0f76d149673f9fb48 100644 (file)
@@ -975,7 +975,7 @@ class DataFormatDimMapTest(test_lib.TestCase):
 
 class DataFormatVectorPermuteTest(test_lib.TestCase):
 
-  def test(self):
+  def testNHWCToNCHW(self):
     x_val = [7, 4, 9, 3]
     x = constant_op.constant(x_val)
     y = nn_ops.data_format_vec_permute(x)
@@ -983,6 +983,14 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
       y_val = sess.run(y)
       self.assertAllEqual(y_val, [7, 3, 4, 9])
 
+  def testNCHWToNHWC(self):
+    x_val = [7, 4, 9, 3]
+    x = constant_op.constant(x_val)
+    y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC")
+    with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+      y_val = sess.run(y)
+      self.assertAllEqual(y_val, [7, 9, 3, 4])
+
 
 if __name__ == "__main__":
   test_lib.main()