typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-namespace {
-inline functor::DataFormat FormatNameToEnum(const string& name) {
- if (name == "NHWC") return functor::DataFormat::NHWC;
- if (name == "NCHW") return functor::DataFormat::NCHW;
- if (name == "HWNC") return functor::DataFormat::HWNC;
- return functor::DataFormat::UNKNOWN;
-}
-} // namespace
-
template <typename Device, typename T>
class DataFormatDimMapOp : public OpKernel {
public:
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
string dst_format;
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
- OP_REQUIRES(context,
- (src_format == "NHWC" && dst_format == "NCHW") ||
- (src_format == "NCHW" && dst_format == "NHWC") ||
- (src_format == "NHWC" && dst_format == "HWNC") ||
- (src_format == "HWNC" && dst_format == "NHWC"),
- errors::InvalidArgument(strings::StrCat(
- "Current implementation only supports NHWC<->NCHW and "
- "NHWC<->HWNC conversion; got source format ",
- src_format, " and destination format ", dst_format)));
- src_format_ = FormatNameToEnum(src_format);
- dst_format_ = FormatNameToEnum(dst_format);
+ src_format_ = src_format;
+ dst_format_ = dst_format;
}
void Compute(OpKernelContext* context) override {
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
- functor::DataFormatVecPermute<Device, T>()(
- context->eigen_device<Device>(), input.flat<T>(), output->flat<T>(),
- src_format_, dst_format_);
+ // Support 1D and 2D cases.
+ Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx;
+ ComputeDstIndex(input.dims(), &dst_idx);
+
+ functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
+ input.flat<T>(),
+ output->flat<T>(), dst_idx);
}
private:
- functor::DataFormat src_format_;
- functor::DataFormat dst_format_;
+ // Finds out the destination index. Support 1D and 2D cases.
+ // Example: HWNC --> NHWC
+ // 1D: dst = [1, 2, 0, 3],
+ // 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
+ void ComputeDstIndex(int num_dim, Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
+ for (int i = 0; i < src_format_.size(); ++i) {
+ for (int j = 0; j < dst_format_.size(); ++j) {
+ if (dst_format_[j] != src_format_[i]) continue;
+ // Found the dst index. Set output based on the number of dims.
+ for (int k = 0; k < num_dim; ++k) {
+ (*dst)[i * num_dim + k] = j * num_dim + k;
+ }
+ }
+ }
+ }
+
+ string src_format_;
+ string dst_format_;
};
#define REGISTER_KERNEL(T) \
TF_CALL_int64(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPEC
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void DataFormatVecPermute<GPUDevice, T>::operator()( \
- const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
- typename TTypes<T>::Vec y, const DataFormat src_format, \
- const DataFormat dst_format); \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void DataFormatVecPermute<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
+ typename TTypes<T>::Vec y, \
+ const Eigen::DSizes<Eigen::DenseIndex, 8>& dst_idx); \
extern template struct DataFormatVecPermute<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);
namespace tensorflow {
namespace functor {
-enum class DataFormat {
- UNKNOWN = 0,
- NHWC,
- NCHW,
- HWNC,
-};
-
// Functor used by DataFormatDimMapOP to do the computations.
template <typename Device, typename T>
struct DataFormatDimMap {
};
template <typename T>
-struct VecPermuteNHWCToNCHW {
- Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
- typename TTypes<T>::ConstFlat input) const {
- Eigen::DSizes<Eigen::DenseIndex, 1> result;
- result[0] = input.dimension(0);
- return result;
- }
- template <typename Output, typename Device>
- void eval(typename TTypes<T>::ConstFlat input, Output& output,
- const Device& d) const {
- if (input.size() == 8) {
- output.template chip<0>(0).device(d) = input.template chip<0>(0);
- output.template chip<0>(1).device(d) = input.template chip<0>(1);
- output.template chip<0>(2).device(d) = input.template chip<0>(6);
- output.template chip<0>(3).device(d) = input.template chip<0>(7);
- output.template chip<0>(4).device(d) = input.template chip<0>(2);
- output.template chip<0>(5).device(d) = input.template chip<0>(3);
- output.template chip<0>(6).device(d) = input.template chip<0>(4);
- output.template chip<0>(7).device(d) = input.template chip<0>(5);
- } else {
- output.template chip<0>(0).device(d) = input.template chip<0>(0);
- output.template chip<0>(1).device(d) = input.template chip<0>(3);
- output.template chip<0>(2).device(d) = input.template chip<0>(1);
- output.template chip<0>(3).device(d) = input.template chip<0>(2);
- }
- }
-};
-
-template <typename T>
-struct VecPermuteNCHWToNHWC {
- Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
- typename TTypes<T>::ConstFlat input) const {
- Eigen::DSizes<Eigen::DenseIndex, 1> result;
- result[0] = input.dimension(0);
- return result;
- }
- template <typename Output, typename Device>
- void eval(typename TTypes<T>::ConstFlat input, Output& output,
- const Device& d) const {
- if (input.size() == 8) {
- output.template chip<0>(0).device(d) = input.template chip<0>(0);
- output.template chip<0>(1).device(d) = input.template chip<0>(1);
- output.template chip<0>(2).device(d) = input.template chip<0>(4);
- output.template chip<0>(3).device(d) = input.template chip<0>(5);
- output.template chip<0>(4).device(d) = input.template chip<0>(6);
- output.template chip<0>(5).device(d) = input.template chip<0>(7);
- output.template chip<0>(6).device(d) = input.template chip<0>(2);
- output.template chip<0>(7).device(d) = input.template chip<0>(3);
- } else {
- output.template chip<0>(0).device(d) = input.template chip<0>(0);
- output.template chip<0>(1).device(d) = input.template chip<0>(2);
- output.template chip<0>(2).device(d) = input.template chip<0>(3);
- output.template chip<0>(3).device(d) = input.template chip<0>(1);
- }
- }
-};
-
-template <typename T>
-struct VecPermuteNHWCToHWNC {
+struct VecPermute {
+ VecPermute(const Eigen::DSizes<Eigen::DenseIndex, 8>& dst) : dst_(dst) {}
Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
typename TTypes<T>::ConstFlat input) const {
Eigen::DSizes<Eigen::DenseIndex, 1> result;
template <typename Output, typename Device>
void eval(typename TTypes<T>::ConstFlat input, Output& output,
const Device& d) const {
- if (input.size() == 8) {
- output.template chip<0>(0).device(d) = input.template chip<0>(2);
- output.template chip<0>(1).device(d) = input.template chip<0>(3);
- output.template chip<0>(2).device(d) = input.template chip<0>(4);
- output.template chip<0>(3).device(d) = input.template chip<0>(5);
- output.template chip<0>(4).device(d) = input.template chip<0>(0);
- output.template chip<0>(5).device(d) = input.template chip<0>(1);
- output.template chip<0>(6).device(d) = input.template chip<0>(6);
- output.template chip<0>(7).device(d) = input.template chip<0>(7);
- } else {
- output.template chip<0>(0).device(d) = input.template chip<0>(1);
- output.template chip<0>(1).device(d) = input.template chip<0>(2);
- output.template chip<0>(2).device(d) = input.template chip<0>(0);
- output.template chip<0>(3).device(d) = input.template chip<0>(3);
+ for (int i = 0; i < input.size(); ++i) {
+ output.template chip<0>(dst_[i]).device(d) = input.template chip<0>(i);
}
}
-};
-template <typename T>
-struct VecPermuteHWNCToNHWC {
- Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
- typename TTypes<T>::ConstFlat input) const {
- Eigen::DSizes<Eigen::DenseIndex, 1> result;
- result[0] = input.dimension(0);
- return result;
- }
- template <typename Output, typename Device>
- void eval(typename TTypes<T>::ConstFlat input, Output& output,
- const Device& d) const {
- if (input.size() == 8) {
- output.template chip<0>(0).device(d) = input.template chip<0>(4);
- output.template chip<0>(1).device(d) = input.template chip<0>(5);
- output.template chip<0>(2).device(d) = input.template chip<0>(0);
- output.template chip<0>(3).device(d) = input.template chip<0>(1);
- output.template chip<0>(4).device(d) = input.template chip<0>(2);
- output.template chip<0>(5).device(d) = input.template chip<0>(3);
- output.template chip<0>(6).device(d) = input.template chip<0>(6);
- output.template chip<0>(7).device(d) = input.template chip<0>(7);
- } else {
- output.template chip<0>(0).device(d) = input.template chip<0>(2);
- output.template chip<0>(1).device(d) = input.template chip<0>(0);
- output.template chip<0>(2).device(d) = input.template chip<0>(1);
- output.template chip<0>(3).device(d) = input.template chip<0>(3);
- }
- }
+ private:
+ Eigen::DSizes<Eigen::DenseIndex, 8> dst_;
};
// Functor used by DataFormatVecPermuteOp to do the computations.
template <typename Device, typename T>
struct DataFormatVecPermute {
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
- typename TTypes<T>::Flat y, const DataFormat src_format,
- const DataFormat dst_format) {
- if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) {
- y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>());
- } else if (src_format == DataFormat::NCHW &&
- dst_format == DataFormat::NHWC) {
- y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>());
- } else if (src_format == DataFormat::NHWC &&
- dst_format == DataFormat::HWNC) {
- y.device(d) = x.customOp(VecPermuteNHWCToHWNC<T>());
- } else if (src_format == DataFormat::HWNC &&
- dst_format == DataFormat::NHWC) {
- y.device(d) = x.customOp(VecPermuteHWNCToNHWC<T>());
- }
+ typename TTypes<T>::Flat y,
+ const Eigen::DSizes<Eigen::DenseIndex, 8>& dst) {
+ y.device(d) = x.customOp(VecPermute<T>(dst));
}
};