From: A. Unique TensorFlower Date: Tue, 13 Mar 2018 04:12:52 +0000 (-0700) Subject: Support NHWC <--> HWNC permutation. X-Git-Tag: tflite-v0.1.7~201^2~1^2~15 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=cc6a2feea65fb499ec11264a7fe7e81100c810fa;p=platform%2Fupstream%2Ftensorflow.git Support NHWC <--> HWNC permutation. PiperOrigin-RevId: 188824058 --- diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc index fa67545..bea3af9 100644 --- a/tensorflow/core/kernels/data_format_ops.cc +++ b/tensorflow/core/kernels/data_format_ops.cc @@ -28,6 +28,15 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +namespace { +inline functor::DataFormat FormatNameToEnum(const string& name) { + if (name == "NHWC") return functor::DataFormat::NHWC; + if (name == "NCHW") return functor::DataFormat::NCHW; + if (name == "HWNC") return functor::DataFormat::HWNC; + return functor::DataFormat::UNKNOWN; +} +} // namespace + template class DataFormatDimMapOp : public OpKernel { public: @@ -69,12 +78,15 @@ class DataFormatVecPermuteOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); OP_REQUIRES(context, (src_format == "NHWC" && dst_format == "NCHW") || - (src_format == "NCHW" && dst_format == "NHWC"), + (src_format == "NCHW" && dst_format == "NHWC") || + (src_format == "NHWC" && dst_format == "HWNC") || + (src_format == "HWNC" && dst_format == "NHWC"), errors::InvalidArgument(strings::StrCat( - "Current implementation only supports NCHW-to-NHWC and " - "NHWC-to-NCHW format conversion; got source format ", + "Current implementation only supports NHWC<->NCHW and " + "NHWC<->HWNC conversion; got source format ", src_format, " and destination format ", dst_format))); - nhwc_to_nchw_ = (src_format == "NHWC") ? true : false; + src_format_ = FormatNameToEnum(src_format); + dst_format_ = FormatNameToEnum(dst_format); } void Compute(OpKernelContext* context) override { @@ -106,11 +118,12 @@ class DataFormatVecPermuteOp : public OpKernel { context->allocate_output(0, input.shape(), &output)); functor::DataFormatVecPermute()( context->eigen_device(), input.flat(), output->flat(), - nhwc_to_nchw_); + src_format_, dst_format_); } private: - bool nhwc_to_nchw_; + functor::DataFormat src_format_; + functor::DataFormat dst_format_; }; #define REGISTER_KERNEL(T) \ @@ -143,11 +156,12 @@ TF_CALL_int32(DECLARE_GPU_SPECS); TF_CALL_int64(DECLARE_GPU_SPECS); #undef DECLARE_GPU_SPEC -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void DataFormatVecPermute::operator()( \ - const GPUDevice& d, typename TTypes::ConstFlat x, \ - typename TTypes::Vec y, bool nhwc_to_nchw); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void DataFormatVecPermute::operator()( \ + const GPUDevice& d, typename TTypes::ConstFlat x, \ + typename TTypes::Vec y, const DataFormat src_format, \ + const DataFormat dst_format); \ extern template struct DataFormatVecPermute; #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); TF_CALL_int32(DECLARE_GPU_SPECS); diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h index bf704cc..d27415e 100644 --- a/tensorflow/core/kernels/data_format_ops.h +++ b/tensorflow/core/kernels/data_format_ops.h @@ -23,6 +23,13 @@ limitations under the License. namespace tensorflow { namespace functor { +enum class DataFormat { + UNKNOWN = 0, + NHWC, + NCHW, + HWNC, +}; + // Functor used by DataFormatDimMapOP to do the computations. template struct DataFormatDimMap { @@ -97,15 +104,81 @@ struct VecPermuteNCHWToNHWC { } }; +template +struct VecPermuteNHWCToHWNC { + Eigen::DSizes dimensions( + typename TTypes::ConstFlat input) const { + Eigen::DSizes result; + result[0] = input.dimension(0); + return result; + } + template + void eval(typename TTypes::ConstFlat input, Output& output, + const Device& d) const { + if (input.size() == 8) { + output.template chip<0>(0).device(d) = input.template chip<0>(2); + output.template chip<0>(1).device(d) = input.template chip<0>(3); + output.template chip<0>(2).device(d) = input.template chip<0>(4); + output.template chip<0>(3).device(d) = input.template chip<0>(5); + output.template chip<0>(4).device(d) = input.template chip<0>(0); + output.template chip<0>(5).device(d) = input.template chip<0>(1); + output.template chip<0>(6).device(d) = input.template chip<0>(6); + output.template chip<0>(7).device(d) = input.template chip<0>(7); + } else { + output.template chip<0>(0).device(d) = input.template chip<0>(1); + output.template chip<0>(1).device(d) = input.template chip<0>(2); + output.template chip<0>(2).device(d) = input.template chip<0>(0); + output.template chip<0>(3).device(d) = input.template chip<0>(3); + } + } +}; + +template +struct VecPermuteHWNCToNHWC { + Eigen::DSizes dimensions( + typename TTypes::ConstFlat input) const { + Eigen::DSizes result; + result[0] = input.dimension(0); + return result; + } + template + void eval(typename TTypes::ConstFlat input, Output& output, + const Device& d) const { + if (input.size() == 8) { + output.template chip<0>(0).device(d) = input.template chip<0>(4); + output.template chip<0>(1).device(d) = input.template chip<0>(5); + output.template chip<0>(2).device(d) = input.template chip<0>(0); + output.template chip<0>(3).device(d) = input.template chip<0>(1); + output.template chip<0>(4).device(d) = input.template chip<0>(2); + output.template chip<0>(5).device(d) = input.template chip<0>(3); + output.template chip<0>(6).device(d) = input.template chip<0>(6); + output.template chip<0>(7).device(d) = input.template chip<0>(7); + } else { + output.template chip<0>(0).device(d) = input.template chip<0>(2); + output.template chip<0>(1).device(d) = input.template chip<0>(0); + output.template chip<0>(2).device(d) = input.template chip<0>(1); + output.template chip<0>(3).device(d) = input.template chip<0>(3); + } + } +}; + // Functor used by DataFormatVecPermuteOp to do the computations. template struct DataFormatVecPermute { void operator()(const Device& d, typename TTypes::ConstFlat x, - typename TTypes::Flat y, bool nhwc_to_nchw) { - if (nhwc_to_nchw) { + typename TTypes::Flat y, const DataFormat src_format, + const DataFormat dst_format) { + if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) { y.device(d) = x.customOp(VecPermuteNHWCToNCHW()); - } else { + } else if (src_format == DataFormat::NCHW && + dst_format == DataFormat::NHWC) { y.device(d) = x.customOp(VecPermuteNCHWToNHWC()); + } else if (src_format == DataFormat::NHWC && + dst_format == DataFormat::HWNC) { + y.device(d) = x.customOp(VecPermuteNHWCToHWNC()); + } else if (src_format == DataFormat::HWNC && + dst_format == DataFormat::NHWC) { + y.device(d) = x.customOp(VecPermuteHWNCToNHWC()); } } }; diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 21eea3d..af9dae2 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1049,6 +1049,22 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = sess.run(y) self.assertAllEqual(y_val, [7, 9, 3, 4]) + def testNHWCToHWNC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, [9, 7, 4, 3]) + def testNHWCToNCHW2D(self): x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] x = constant_op.constant(x_val) @@ -1057,6 +1073,22 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = sess.run(y) self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]]) + def testNHWCToHWNC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]]) + def testNCHWToNHWC2D(self): x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] x = constant_op.constant(x_val)