typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+namespace {
+inline functor::DataFormat FormatNameToEnum(const string& name) {
+ if (name == "NHWC") return functor::DataFormat::NHWC;
+ if (name == "NCHW") return functor::DataFormat::NCHW;
+ if (name == "HWNC") return functor::DataFormat::HWNC;
+ return functor::DataFormat::UNKNOWN;
+}
+} // namespace
+
template <typename Device, typename T>
class DataFormatDimMapOp : public OpKernel {
public:
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
OP_REQUIRES(context,
(src_format == "NHWC" && dst_format == "NCHW") ||
- (src_format == "NCHW" && dst_format == "NHWC"),
+ (src_format == "NCHW" && dst_format == "NHWC") ||
+ (src_format == "NHWC" && dst_format == "HWNC") ||
+ (src_format == "HWNC" && dst_format == "NHWC"),
errors::InvalidArgument(strings::StrCat(
- "Current implementation only supports NCHW-to-NHWC and "
- "NHWC-to-NCHW format conversion; got source format ",
+ "Current implementation only supports NHWC<->NCHW and "
+ "NHWC<->HWNC conversion; got source format ",
src_format, " and destination format ", dst_format)));
- nhwc_to_nchw_ = (src_format == "NHWC") ? true : false;
+ src_format_ = FormatNameToEnum(src_format);
+ dst_format_ = FormatNameToEnum(dst_format);
}
void Compute(OpKernelContext* context) override {
context->allocate_output(0, input.shape(), &output));
functor::DataFormatVecPermute<Device, T>()(
context->eigen_device<Device>(), input.flat<T>(), output->flat<T>(),
- nhwc_to_nchw_);
+ src_format_, dst_format_);
}
private:
- bool nhwc_to_nchw_;
+ functor::DataFormat src_format_;
+ functor::DataFormat dst_format_;
};
#define REGISTER_KERNEL(T) \
TF_CALL_int64(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPEC
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void DataFormatVecPermute<GPUDevice, T>::operator()( \
- const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
- typename TTypes<T>::Vec y, bool nhwc_to_nchw); \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void DataFormatVecPermute<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
+ typename TTypes<T>::Vec y, const DataFormat src_format, \
+ const DataFormat dst_format); \
extern template struct DataFormatVecPermute<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);
namespace tensorflow {
namespace functor {
+enum class DataFormat {
+ UNKNOWN = 0,
+ NHWC,
+ NCHW,
+ HWNC,
+};
+
// Functor used by DataFormatDimMapOP to do the computations.
template <typename Device, typename T>
struct DataFormatDimMap {
}
};
+template <typename T>
+struct VecPermuteNHWCToHWNC {
+ Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
+ typename TTypes<T>::ConstFlat input) const {
+ Eigen::DSizes<Eigen::DenseIndex, 1> result;
+ result[0] = input.dimension(0);
+ return result;
+ }
+ template <typename Output, typename Device>
+ void eval(typename TTypes<T>::ConstFlat input, Output& output,
+ const Device& d) const {
+ if (input.size() == 8) {
+ output.template chip<0>(0).device(d) = input.template chip<0>(2);
+ output.template chip<0>(1).device(d) = input.template chip<0>(3);
+ output.template chip<0>(2).device(d) = input.template chip<0>(4);
+ output.template chip<0>(3).device(d) = input.template chip<0>(5);
+ output.template chip<0>(4).device(d) = input.template chip<0>(0);
+ output.template chip<0>(5).device(d) = input.template chip<0>(1);
+ output.template chip<0>(6).device(d) = input.template chip<0>(6);
+ output.template chip<0>(7).device(d) = input.template chip<0>(7);
+ } else {
+ output.template chip<0>(0).device(d) = input.template chip<0>(1);
+ output.template chip<0>(1).device(d) = input.template chip<0>(2);
+ output.template chip<0>(2).device(d) = input.template chip<0>(0);
+ output.template chip<0>(3).device(d) = input.template chip<0>(3);
+ }
+ }
+};
+
+template <typename T>
+struct VecPermuteHWNCToNHWC {
+ Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
+ typename TTypes<T>::ConstFlat input) const {
+ Eigen::DSizes<Eigen::DenseIndex, 1> result;
+ result[0] = input.dimension(0);
+ return result;
+ }
+ template <typename Output, typename Device>
+ void eval(typename TTypes<T>::ConstFlat input, Output& output,
+ const Device& d) const {
+ if (input.size() == 8) {
+ output.template chip<0>(0).device(d) = input.template chip<0>(4);
+ output.template chip<0>(1).device(d) = input.template chip<0>(5);
+ output.template chip<0>(2).device(d) = input.template chip<0>(0);
+ output.template chip<0>(3).device(d) = input.template chip<0>(1);
+ output.template chip<0>(4).device(d) = input.template chip<0>(2);
+ output.template chip<0>(5).device(d) = input.template chip<0>(3);
+ output.template chip<0>(6).device(d) = input.template chip<0>(6);
+ output.template chip<0>(7).device(d) = input.template chip<0>(7);
+ } else {
+ output.template chip<0>(0).device(d) = input.template chip<0>(2);
+ output.template chip<0>(1).device(d) = input.template chip<0>(0);
+ output.template chip<0>(2).device(d) = input.template chip<0>(1);
+ output.template chip<0>(3).device(d) = input.template chip<0>(3);
+ }
+ }
+};
+
// Functor used by DataFormatVecPermuteOp to do the computations.
template <typename Device, typename T>
struct DataFormatVecPermute {
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
- typename TTypes<T>::Flat y, bool nhwc_to_nchw) {
- if (nhwc_to_nchw) {
+ typename TTypes<T>::Flat y, const DataFormat src_format,
+ const DataFormat dst_format) {
+ if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) {
y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>());
- } else {
+ } else if (src_format == DataFormat::NCHW &&
+ dst_format == DataFormat::NHWC) {
y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>());
+ } else if (src_format == DataFormat::NHWC &&
+ dst_format == DataFormat::HWNC) {
+ y.device(d) = x.customOp(VecPermuteNHWCToHWNC<T>());
+ } else if (src_format == DataFormat::HWNC &&
+ dst_format == DataFormat::NHWC) {
+ y.device(d) = x.customOp(VecPermuteHWNCToNHWC<T>());
}
}
};
y_val = sess.run(y)
self.assertAllEqual(y_val, [7, 9, 3, 4])
+ def testNHWCToHWNC(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [4, 9, 7, 3])
+
+ def testHWNCToNHWC(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [9, 7, 4, 3])
+
def testNHWCToNCHW2D(self):
x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
x = constant_op.constant(x_val)
y_val = sess.run(y)
self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]])
+ def testNHWCToHWNC2D(self):
+ x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]])
+
+ def testHWNCToNHWC2D(self):
+ x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]])
+
def testNCHWToNHWC2D(self):
x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
x = constant_op.constant(x_val)