From aa890776f062f3429bcedb0a080b712ebb97793b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 4 Apr 2018 22:37:35 -0700 Subject: [PATCH] Support arbitrary permutations for DataFormatDimMap. PiperOrigin-RevId: 191696203 --- tensorflow/core/kernels/data_format_ops.cc | 40 +++++++++++++++++++----------- tensorflow/core/kernels/data_format_ops.h | 18 +++++++++++--- tensorflow/python/ops/nn_test.py | 36 +++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc index 39ef8ee..4485152 100644 --- a/tensorflow/core/kernels/data_format_ops.cc +++ b/tensorflow/core/kernels/data_format_ops.cc @@ -37,25 +37,37 @@ class DataFormatDimMapOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); string dst_format; OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); + OP_REQUIRES(context, src_format.size() == 4, + errors::InvalidArgument(strings::StrCat( + "Source format must of length 4, received src_format = ", + src_format))); OP_REQUIRES( - context, src_format == "NHWC", + context, dst_format.size() == 4, errors::InvalidArgument(strings::StrCat( - "Current implementation doesn't support source data format ", - src_format))); - OP_REQUIRES(context, dst_format == "NCHW", - errors::InvalidArgument(strings::StrCat( - "Current implementation doesn't support dst data format ", - dst_format))); + "Destination format must of length 4, received dst_format = ", + dst_format))); + dst_idx_ = Tensor(DT_INT32, {static_cast(src_format.size())}); + for (int i = 0; i < src_format.size(); ++i) { + for (int j = 0; j < dst_format.size(); ++j) { + if (dst_format[j] == src_format[i]) { + dst_idx_.vec()(i) = j; + break; + } + } + } } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - Tensor* output = nullptr; + Tensor* output; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); functor::DataFormatDimMap()(context->eigen_device(), - input.flat(), output->flat()); + input.flat(), output->flat(), + dst_idx_.vec()); } + + Tensor dst_idx_; }; template @@ -147,11 +159,11 @@ TF_CALL_int64(REGISTER_KERNEL); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void DataFormatDimMap::operator()( \ - const GPUDevice& d, typename TTypes::ConstFlat x, \ - typename TTypes::Flat y); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void DataFormatDimMap::operator()( \ + const GPUDevice& d, typename TTypes::ConstFlat x, \ + typename TTypes::Flat y, const TTypes::Vec dst); \ extern template struct DataFormatDimMap; #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); TF_CALL_int32(DECLARE_GPU_SPECS); diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h index 2ccc919..1ca144c 100644 --- a/tensorflow/core/kernels/data_format_ops.h +++ b/tensorflow/core/kernels/data_format_ops.h @@ -27,15 +27,25 @@ namespace functor { template struct DataFormatDimMap { void operator()(const Device& d, typename TTypes::ConstFlat x, - typename TTypes::Flat y) { + typename TTypes::Flat y, const TTypes::Vec dst) { auto zero = x.constant(0); auto one = x.constant(1); - auto three = x.constant(3); + auto two = x.constant(2); + + auto f_zero = x.constant(dst(0)); + auto f_one = x.constant(dst(1)); + auto f_two = x.constant(dst(2)); + auto f_three = x.constant(dst(3)); + auto four = x.constant(4); auto x_mod = (x + four) % 4; + auto is_zero = (x_mod == zero); - auto is_three = (x_mod == three); - y.device(d) = is_zero.select(zero, is_three.select(one, x_mod + one)); + auto is_one = (x_mod == one); + auto is_two = (x_mod == two); + + y.device(d) = is_zero.select( + f_zero, is_one.select(f_one, is_two.select(f_two, f_three))); } }; diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index da86d5f..46a5f4f 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1081,6 +1081,42 @@ class DataFormatDimMapTest(test_lib.TestCase): self._test([1, -3, -2], [2, 2, 3]) self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) + def testNHWCtoNCHW(self): + x_val = [1, -3, -2] + y_val_expected = [2, 2, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoHWNC(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoWHCN(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, y_val_expected) + + def testArbitraryASCII(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, y_val_expected) + class DataFormatVectorPermuteTest(test_lib.TestCase): -- 2.7.4