Support general permutation.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Mar 2018 01:20:12 +0000 (18:20 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 01:24:30 +0000 (18:24 -0700)
PiperOrigin-RevId: 189675019

tensorflow/core/kernels/data_format_ops.cc
tensorflow/core/kernels/data_format_ops.h

index bea3af9..39ef8ee 100644 (file)
@@ -28,15 +28,6 @@ namespace tensorflow {
 typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 
-namespace {
-inline functor::DataFormat FormatNameToEnum(const string& name) {
-  if (name == "NHWC") return functor::DataFormat::NHWC;
-  if (name == "NCHW") return functor::DataFormat::NCHW;
-  if (name == "HWNC") return functor::DataFormat::HWNC;
-  return functor::DataFormat::UNKNOWN;
-}
-}  // namespace
-
 template <typename Device, typename T>
 class DataFormatDimMapOp : public OpKernel {
  public:
@@ -76,17 +67,8 @@ class DataFormatVecPermuteOp : public OpKernel {
     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
     string dst_format;
     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
-    OP_REQUIRES(context,
-                (src_format == "NHWC" && dst_format == "NCHW") ||
-                    (src_format == "NCHW" && dst_format == "NHWC") ||
-                    (src_format == "NHWC" && dst_format == "HWNC") ||
-                    (src_format == "HWNC" && dst_format == "NHWC"),
-                errors::InvalidArgument(strings::StrCat(
-                    "Current implementation only supports NHWC<->NCHW and "
-                    "NHWC<->HWNC conversion; got source format ",
-                    src_format, " and destination format ", dst_format)));
-    src_format_ = FormatNameToEnum(src_format);
-    dst_format_ = FormatNameToEnum(dst_format);
+    src_format_ = src_format;
+    dst_format_ = dst_format;
   }
 
   void Compute(OpKernelContext* context) override {
@@ -116,14 +98,34 @@ class DataFormatVecPermuteOp : public OpKernel {
     Tensor* output = nullptr;
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, input.shape(), &output));
-    functor::DataFormatVecPermute<Device, T>()(
-        context->eigen_device<Device>(), input.flat<T>(), output->flat<T>(),
-        src_format_, dst_format_);
+    // Support 1D and 2D cases.
+    Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx;
+    ComputeDstIndex(input.dims(), &dst_idx);
+
+    functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
+                                               input.flat<T>(),
+                                               output->flat<T>(), dst_idx);
   }
 
  private:
-  functor::DataFormat src_format_;
-  functor::DataFormat dst_format_;
+  // Finds out the destination index. Support 1D and 2D cases.
+  // Example: HWNC --> NHWC
+  // 1D: dst = [1, 2, 0, 3],
+  // 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
+  void ComputeDstIndex(int num_dim, Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
+    for (int i = 0; i < src_format_.size(); ++i) {
+      for (int j = 0; j < dst_format_.size(); ++j) {
+        if (dst_format_[j] != src_format_[i]) continue;
+        // Found the dst index. Set output based on the number of dims.
+        for (int k = 0; k < num_dim; ++k) {
+          (*dst)[i * num_dim + k] = j * num_dim + k;
+        }
+      }
+    }
+  }
+
+  string src_format_;
+  string dst_format_;
 };
 
 #define REGISTER_KERNEL(T)                                                \
@@ -156,12 +158,12 @@ TF_CALL_int32(DECLARE_GPU_SPECS);
 TF_CALL_int64(DECLARE_GPU_SPECS);
 #undef DECLARE_GPU_SPEC
 
-#define DECLARE_GPU_SPEC(T)                                   \
-  template <>                                                 \
-  void DataFormatVecPermute<GPUDevice, T>::operator()(        \
-      const GPUDevice& d, typename TTypes<T>::ConstFlat x,    \
-      typename TTypes<T>::Vec y, const DataFormat src_format, \
-      const DataFormat dst_format);                           \
+#define DECLARE_GPU_SPEC(T)                                \
+  template <>                                              \
+  void DataFormatVecPermute<GPUDevice, T>::operator()(     \
+      const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
+      typename TTypes<T>::Vec y,                           \
+      const Eigen::DSizes<Eigen::DenseIndex, 8>& dst_idx); \
   extern template struct DataFormatVecPermute<GPUDevice, T>;
 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
 TF_CALL_int32(DECLARE_GPU_SPECS);
index d27415e..2ccc919 100644 (file)
@@ -23,13 +23,6 @@ limitations under the License.
 namespace tensorflow {
 namespace functor {
 
-enum class DataFormat {
-  UNKNOWN = 0,
-  NHWC,
-  NCHW,
-  HWNC,
-};
-
 // Functor used by DataFormatDimMapOP to do the computations.
 template <typename Device, typename T>
 struct DataFormatDimMap {
@@ -47,65 +40,8 @@ struct DataFormatDimMap {
 };
 
 template <typename T>
-struct VecPermuteNHWCToNCHW {
-  Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
-      typename TTypes<T>::ConstFlat input) const {
-    Eigen::DSizes<Eigen::DenseIndex, 1> result;
-    result[0] = input.dimension(0);
-    return result;
-  }
-  template <typename Output, typename Device>
-  void eval(typename TTypes<T>::ConstFlat input, Output& output,
-            const Device& d) const {
-    if (input.size() == 8) {
-      output.template chip<0>(0).device(d) = input.template chip<0>(0);
-      output.template chip<0>(1).device(d) = input.template chip<0>(1);
-      output.template chip<0>(2).device(d) = input.template chip<0>(6);
-      output.template chip<0>(3).device(d) = input.template chip<0>(7);
-      output.template chip<0>(4).device(d) = input.template chip<0>(2);
-      output.template chip<0>(5).device(d) = input.template chip<0>(3);
-      output.template chip<0>(6).device(d) = input.template chip<0>(4);
-      output.template chip<0>(7).device(d) = input.template chip<0>(5);
-    } else {
-      output.template chip<0>(0).device(d) = input.template chip<0>(0);
-      output.template chip<0>(1).device(d) = input.template chip<0>(3);
-      output.template chip<0>(2).device(d) = input.template chip<0>(1);
-      output.template chip<0>(3).device(d) = input.template chip<0>(2);
-    }
-  }
-};
-
-template <typename T>
-struct VecPermuteNCHWToNHWC {
-  Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
-      typename TTypes<T>::ConstFlat input) const {
-    Eigen::DSizes<Eigen::DenseIndex, 1> result;
-    result[0] = input.dimension(0);
-    return result;
-  }
-  template <typename Output, typename Device>
-  void eval(typename TTypes<T>::ConstFlat input, Output& output,
-            const Device& d) const {
-    if (input.size() == 8) {
-      output.template chip<0>(0).device(d) = input.template chip<0>(0);
-      output.template chip<0>(1).device(d) = input.template chip<0>(1);
-      output.template chip<0>(2).device(d) = input.template chip<0>(4);
-      output.template chip<0>(3).device(d) = input.template chip<0>(5);
-      output.template chip<0>(4).device(d) = input.template chip<0>(6);
-      output.template chip<0>(5).device(d) = input.template chip<0>(7);
-      output.template chip<0>(6).device(d) = input.template chip<0>(2);
-      output.template chip<0>(7).device(d) = input.template chip<0>(3);
-    } else {
-      output.template chip<0>(0).device(d) = input.template chip<0>(0);
-      output.template chip<0>(1).device(d) = input.template chip<0>(2);
-      output.template chip<0>(2).device(d) = input.template chip<0>(3);
-      output.template chip<0>(3).device(d) = input.template chip<0>(1);
-    }
-  }
-};
-
-template <typename T>
-struct VecPermuteNHWCToHWNC {
+struct VecPermute {
+  VecPermute(const Eigen::DSizes<Eigen::DenseIndex, 8>& dst) : dst_(dst) {}
   Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
       typename TTypes<T>::ConstFlat input) const {
     Eigen::DSizes<Eigen::DenseIndex, 1> result;
@@ -115,71 +51,22 @@ struct VecPermuteNHWCToHWNC {
   template <typename Output, typename Device>
   void eval(typename TTypes<T>::ConstFlat input, Output& output,
             const Device& d) const {
-    if (input.size() == 8) {
-      output.template chip<0>(0).device(d) = input.template chip<0>(2);
-      output.template chip<0>(1).device(d) = input.template chip<0>(3);
-      output.template chip<0>(2).device(d) = input.template chip<0>(4);
-      output.template chip<0>(3).device(d) = input.template chip<0>(5);
-      output.template chip<0>(4).device(d) = input.template chip<0>(0);
-      output.template chip<0>(5).device(d) = input.template chip<0>(1);
-      output.template chip<0>(6).device(d) = input.template chip<0>(6);
-      output.template chip<0>(7).device(d) = input.template chip<0>(7);
-    } else {
-      output.template chip<0>(0).device(d) = input.template chip<0>(1);
-      output.template chip<0>(1).device(d) = input.template chip<0>(2);
-      output.template chip<0>(2).device(d) = input.template chip<0>(0);
-      output.template chip<0>(3).device(d) = input.template chip<0>(3);
+    for (int i = 0; i < input.size(); ++i) {
+      output.template chip<0>(dst_[i]).device(d) = input.template chip<0>(i);
     }
   }
-};
 
-template <typename T>
-struct VecPermuteHWNCToNHWC {
-  Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
-      typename TTypes<T>::ConstFlat input) const {
-    Eigen::DSizes<Eigen::DenseIndex, 1> result;
-    result[0] = input.dimension(0);
-    return result;
-  }
-  template <typename Output, typename Device>
-  void eval(typename TTypes<T>::ConstFlat input, Output& output,
-            const Device& d) const {
-    if (input.size() == 8) {
-      output.template chip<0>(0).device(d) = input.template chip<0>(4);
-      output.template chip<0>(1).device(d) = input.template chip<0>(5);
-      output.template chip<0>(2).device(d) = input.template chip<0>(0);
-      output.template chip<0>(3).device(d) = input.template chip<0>(1);
-      output.template chip<0>(4).device(d) = input.template chip<0>(2);
-      output.template chip<0>(5).device(d) = input.template chip<0>(3);
-      output.template chip<0>(6).device(d) = input.template chip<0>(6);
-      output.template chip<0>(7).device(d) = input.template chip<0>(7);
-    } else {
-      output.template chip<0>(0).device(d) = input.template chip<0>(2);
-      output.template chip<0>(1).device(d) = input.template chip<0>(0);
-      output.template chip<0>(2).device(d) = input.template chip<0>(1);
-      output.template chip<0>(3).device(d) = input.template chip<0>(3);
-    }
-  }
+ private:
+  Eigen::DSizes<Eigen::DenseIndex, 8> dst_;
 };
 
 // Functor used by DataFormatVecPermuteOp to do the computations.
 template <typename Device, typename T>
 struct DataFormatVecPermute {
   void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
-                  typename TTypes<T>::Flat y, const DataFormat src_format,
-                  const DataFormat dst_format) {
-    if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) {
-      y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>());
-    } else if (src_format == DataFormat::NCHW &&
-               dst_format == DataFormat::NHWC) {
-      y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>());
-    } else if (src_format == DataFormat::NHWC &&
-               dst_format == DataFormat::HWNC) {
-      y.device(d) = x.customOp(VecPermuteNHWCToHWNC<T>());
-    } else if (src_format == DataFormat::HWNC &&
-               dst_format == DataFormat::NHWC) {
-      y.device(d) = x.customOp(VecPermuteHWNCToNHWC<T>());
-    }
+                  typename TTypes<T>::Flat y,
+                  const Eigen::DSizes<Eigen::DenseIndex, 8>& dst) {
+    y.device(d) = x.customOp(VecPermute<T>(dst));
   }
 };