OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
string dst_format;
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
- OP_REQUIRES(
- context, src_format == "NHWC",
- errors::InvalidArgument(strings::StrCat(
- "Current implementation doesn't support source data format ",
- src_format)));
- OP_REQUIRES(context, dst_format == "NCHW",
+ OP_REQUIRES(context,
+ (src_format == "NHWC" && dst_format == "NCHW") ||
+ (src_format == "NCHW" && dst_format == "NHWC"),
errors::InvalidArgument(strings::StrCat(
- "Current implementation doesn't support dst data format ",
- dst_format)));
+ "Current implementation only supports NCHW-to-NHWC and "
+ "NHWC-to-NCHW format conversion; got source format ",
+ src_format, " and destination format ", dst_format)));
+ nhwc_to_nchw_ = (src_format == "NHWC") ? true : false;
}
void Compute(OpKernelContext* context) override {
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
- functor::DataFormatVecPermute<Device, T>()(
- context->eigen_device<Device>(), input.vec<T>(), output->vec<T>());
+ functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
+ input.vec<T>(), output->vec<T>(),
+ nhwc_to_nchw_);
}
+
+ private:
+ bool nhwc_to_nchw_;
};
#define REGISTER_KERNEL(T) \
template <> \
void DataFormatVecPermute<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstVec x, \
- typename TTypes<T>::Vec y); \
+ typename TTypes<T>::Vec y, bool nhwc_to_nchw); \
extern template struct DataFormatVecPermute<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);
};
template <typename T>
-struct VecPermute {
+struct VecPermuteNHWCToNCHW {
Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
typename TTypes<T>::ConstVec input) const {
Eigen::DSizes<Eigen::DenseIndex, 1> result;
}
};
+template <typename T>
+struct VecPermuteNCHWToNHWC {
+ Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
+ typename TTypes<T>::ConstVec input) const {
+ Eigen::DSizes<Eigen::DenseIndex, 1> result;
+ result[0] = input.dimension(0);
+ return result;
+ }
+ template <typename Output, typename Device>
+ void eval(typename TTypes<T>::ConstVec input, Output& output,
+ const Device& d) const {
+ output.template chip<0>(0).device(d) = input.template chip<0>(0);
+ output.template chip<0>(1).device(d) = input.template chip<0>(2);
+ output.template chip<0>(2).device(d) = input.template chip<0>(3);
+ output.template chip<0>(3).device(d) = input.template chip<0>(1);
+ }
+};
+
// Functor used by DataFormatVecPermuteOp to do the computations.
template <typename Device, typename T>
struct DataFormatVecPermute {
void operator()(const Device& d, typename TTypes<T>::ConstVec x,
- typename TTypes<T>::Vec y) {
- y.device(d) = x.customOp(VecPermute<T>());
+ typename TTypes<T>::Vec y, bool nhwc_to_nchw) {
+ if (nhwc_to_nchw) {
+ y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>());
+ } else {
+ y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>());
+ }
}
};
class DataFormatVectorPermuteTest(test_lib.TestCase):
- def test(self):
+ def testNHWCToNCHW(self):
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x)
y_val = sess.run(y)
self.assertAllEqual(y_val, [7, 3, 4, 9])
+ def testNCHWToNHWC(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [7, 9, 3, 4])
+
if __name__ == "__main__":
test_lib.main()