From 1130e5be68e4cc2f83ba6f883e4e3b89d0551bf0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 May 2018 16:59:50 -0700 Subject: [PATCH] Add NCHW_VECT_W tensor format. PiperOrigin-RevId: 197074411 --- .../core/kernels/depthwise_conv_op_gpu.cu.cc | 20 ++-- tensorflow/core/util/tensor_format.cc | 6 + tensorflow/core/util/tensor_format.h | 123 ++++++++++++++------- tensorflow/core/util/tensor_format_test.cc | 4 +- 4 files changed, 103 insertions(+), 50 deletions(-) diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 0abd640..5390222 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -613,8 +613,8 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, kKnownFilterHeight, kBlockDepth, kKnownEvenHeight>; break; - case FORMAT_NCHW_VECT_C: - LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + default: + LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; return; } const int tile_width = args.in_cols + args.filter_cols - 1; @@ -690,8 +690,8 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& device, DepthwiseConv2dGPUKernelNCHW; break; - case FORMAT_NCHW_VECT_C: - LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + default: + LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; return; } const int num_outputs = @@ -919,8 +919,8 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device, kernel = DepthwiseConv2dBackpropInputGPUKernelNCHW< T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; break; - case FORMAT_NCHW_VECT_C: - LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + default: + LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; return; } const int num_in_backprop = @@ -1559,8 +1559,8 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; break; - case FORMAT_NCHW_VECT_C: - LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + default: + LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; return false; } const int num_out_backprop = args.out_rows * args.out_cols * block_count; @@ -1662,8 +1662,8 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device, kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHW< T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; break; - case FORMAT_NCHW_VECT_C: - LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + default: + LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; return; } const int num_out_backprop = diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc index 8c83365..d4311d1 100644 --- a/tensorflow/core/util/tensor_format.cc +++ b/tensorflow/core/util/tensor_format.cc @@ -41,6 +41,8 @@ string ToString(TensorFormat format) { return "NCHW"; case FORMAT_NCHW_VECT_C: return "NCHW_VECT_C"; + case FORMAT_NHWC_VECT_W: + return "NHWC_VECT_W"; default: LOG(FATAL) << "Invalid Format: " << static_cast(format); return "INVALID_FORMAT"; @@ -74,6 +76,10 @@ bool FormatFromString(const string& format_str, TensorFormat* format) { *format = FORMAT_NCHW_VECT_C; return true; } + if (format_str == "NHWC_VECT_W") { + *format = FORMAT_NHWC_VECT_W; + return true; + } return false; } diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index 6466735..58bc79a 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_UTIL_TENSOR_FORMAT_H_ -#define TENSORFLOW_UTIL_TENSOR_FORMAT_H_ +#ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_ +#define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_ #include #include @@ -29,6 +29,9 @@ namespace tensorflow { // The mnemonics specify the meaning of each tensor dimension sorted from // largest to smallest memory stride. // N = Batch, H = Image Height, W = Image Width, C = Number of Channels. +// TODO(pauldonnelly): It would probably be better to switch to a registration +// process for tensor formats, so specialized formats could be defined more +// locally to where they are used. enum TensorFormat { // FORMAT_NHWC is the default format in TensorFlow. FORMAT_NHWC = 0, @@ -45,6 +48,17 @@ enum TensorFormat { // NCHW_VECT_C format. // A pre-condition of this format is that C must be a multiple of 4. FORMAT_NCHW_VECT_C = 2, + + // Similar to NHWC, but the size of the W dimension is divided by 4, and a + // new dimension of size 4 is appended, which packs 4 adjacent activations + // in the width dimension. + FORMAT_NHWC_VECT_W = 3, + + // Note: although the current code in this file assumes VECT_C and VECT_W + // enums imply int8x4 vectors, this should not be relied upon. + // In the future we may change the meaning of these enums to include vectors + // of other types such as int16x2, with op implementations automatically + // determining which format is implied based on the datatype. }; // Tensor format for convolutional filters. @@ -89,10 +103,17 @@ string ToString(FilterTensorFormat format); // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor // format 'format'. inline int GetTensorSpatialDims(int num_dims, TensorFormat format) { - if (format == FORMAT_NCHW_VECT_C) { - return num_dims - 3; // Exclude N,C,InnerC. - } else { - return num_dims - 2; // Exclude N,C. + switch (format) { + case FORMAT_NHWC: + return num_dims - 2; // Exclude N,C. + case FORMAT_NCHW: + return num_dims - 2; // Exclude N,C. + case FORMAT_NCHW_VECT_C: + return num_dims - 3; // Exclude N,C,VectDim. + case FORMAT_NHWC_VECT_W: + // Note: the VECT_W is not counted as an independent spatial dim here, + // since it just a component of the width dimension. + return num_dims - 3; // Exclude N,C,VectDim. } } @@ -108,10 +129,13 @@ inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) { // tensor format 'format'. This is the inverse of GetTensorSpatialDims. inline int GetTensorDimsFromSpatialDims(int num_spatial_dims, TensorFormat format) { - if (format == FORMAT_NCHW_VECT_C) { - return num_spatial_dims + 3; // Include N,C,InnerC. - } else { - return num_spatial_dims + 2; // Include N,C. + switch (format) { + case FORMAT_NHWC: + case FORMAT_NCHW: + return num_spatial_dims + 2; // Include N,C. + case FORMAT_NCHW_VECT_C: + case FORMAT_NHWC_VECT_W: + return num_spatial_dims + 3; // Include N,C,VectDim. } } @@ -132,6 +156,7 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) { case FORMAT_NHWC: case FORMAT_NCHW: case FORMAT_NCHW_VECT_C: + case FORMAT_NHWC_VECT_W: return 0; default: LOG(FATAL) << "Unknown format " << format; @@ -146,6 +171,8 @@ inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) { switch (format) { case FORMAT_NHWC: return num_dims - 1; + case FORMAT_NHWC_VECT_W: + return num_dims - 2; case FORMAT_NCHW: case FORMAT_NCHW_VECT_C: return 1; @@ -161,24 +188,34 @@ inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) { return num_dims - 1; } -// Returns the index of the `dim`-th spatial dimension. +// Returns the index of the inner width dimension. +inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) { + DCHECK_EQ(format, FORMAT_NHWC_VECT_W); + return num_dims - 1; +} + +// Returns the dimension index of the specified 'spatial_dim' within an +// activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns +// the index of the outer width dimension (i.e. dimension 2, whose size would +// be width / 4 in this case). inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format, - int dim) { - CHECK(dim >= 0 && dim < GetTensorSpatialDims(num_dims, format)) - << dim << " " << num_dims << " " << ToString(format); + int spatial_dim) { + CHECK(spatial_dim >= 0 && + spatial_dim < GetTensorSpatialDims(num_dims, format)) + << spatial_dim << " " << num_dims << " " << ToString(format); switch (format) { case FORMAT_NHWC: - return dim + 1; + case FORMAT_NHWC_VECT_W: + return spatial_dim + 1; case FORMAT_NCHW: case FORMAT_NCHW_VECT_C: - return dim + 2; + return spatial_dim + 2; default: LOG(FATAL) << "Unknown format " << format; return -1; // Avoid compiler warning about missing return value } } -// Returns the index of the `dim`-th spatial dimension. inline int GetFilterTensorSpatialDimIndex(int num_dims, FilterTensorFormat format, int dim) { CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format)) @@ -246,7 +283,7 @@ inline int GetFilterTensorOutputChannelsDimIndex(int num_dims, // the outer channel dimension (i.e. 1). template inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { - if (format == FORMAT_NHWC) { + if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) { // clang-format off switch (dimension) { case 'N': return 0; @@ -404,28 +441,37 @@ string GetConvnet3dDataFormatAttrString(); string GetConvnetFilterFormatAttrString(); string GetConvnet3dFilterFormatAttrString(); -// Return a tensor shape for the given format. Works for both 2D and 3D -// operations. If format is FORMAT_NCHW_VECT_C, the output TensorShape has rank -// spatial.size()+3 (N,C,spatial,InnerC); otherwise, it has rank -// spatial.size()+2 (e.g. N,C,spatial or N,spatial,C). +// Returns a tensor shape for the specified format and dimension sizes. +// Works for both 2D and 3D operations. The output shapes are as follows: +// FORMAT_NHWC: (N, spatial, C); rank = spatial.size() + 2 +// FORMAT_NCHW: (N, C, spatial); rank = spatial.size() + 2 +// FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3 +// FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, gtl::ArraySlice spatial, int64 C) { const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format); gtl::InlinedVector dim_sizes(dims); dim_sizes[GetTensorBatchDimIndex(dims, format)] = N; for (int dim = 0; static_cast(dim) < spatial.size(); dim++) { - dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = spatial[dim]; + auto dim_size = spatial[dim]; + if (format == FORMAT_NHWC_VECT_W && dim == spatial.size() - 1) { + CHECK_EQ(0, dim_size % 4) + << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W=" + << dim_size; + dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4; + dim_size /= 4; + } + dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size; } int feature_index = GetTensorFeatureDimIndex(dims, format); if (format == FORMAT_NCHW_VECT_C) { CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C=" << C; - dim_sizes[feature_index] = C / 4; + C /= 4; dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4; - } else { - dim_sizes[feature_index] = C; } + dim_sizes[feature_index] = C; return TensorShape(dim_sizes); } @@ -478,19 +524,18 @@ inline TensorShape ShapeFromFormat(TensorFormat dst_format, const int64 batch = GetTensorDim(src_shape, src_format, 'N'); const int64 channels = GetTensorDim(src_shape, src_format, 'C') * (src_format == FORMAT_NCHW_VECT_C ? 4 : 1); - - if (GetTensorSpatialDims(src_shape.dims(), src_format) == 3) { - return ShapeFromFormat(dst_format, batch, - {{GetTensorDim(src_shape, src_format, '0'), - GetTensorDim(src_shape, src_format, '1'), - GetTensorDim(src_shape, src_format, '2')}}, - channels); + const int num_src_spatial_dims = + GetTensorSpatialDims(src_shape.dims(), src_format); + std::vector spatial_dims(num_src_spatial_dims); + for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) { + spatial_dims[spatial_dim] = + gtl::ArraySlice(src_shape.dim_sizes())[GetTensorSpatialDimIndex( + src_shape.dims(), src_format, spatial_dim)]; } - - return ShapeFromFormat(dst_format, batch, - {{GetTensorDim(src_shape, src_format, 'H'), - GetTensorDim(src_shape, src_format, 'W')}}, - channels); + if (src_format == FORMAT_NHWC_VECT_W) { + spatial_dims[num_src_spatial_dims - 1] *= 4; + } + return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels); } // Returns a copy of the specified filter tensor 'src_shape' converted from @@ -525,4 +570,4 @@ inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format, } // namespace tensorflow -#endif // TENSORFLOW_UTIL_TENSOR_FORMAT_H_ +#endif // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_ diff --git a/tensorflow/core/util/tensor_format_test.cc b/tensorflow/core/util/tensor_format_test.cc index 36698e0..9390229 100644 --- a/tensorflow/core/util/tensor_format_test.cc +++ b/tensorflow/core/util/tensor_format_test.cc @@ -29,6 +29,7 @@ std::pair test_data_formats[] = { EnumStringPair(FORMAT_NHWC), EnumStringPair(FORMAT_NCHW), EnumStringPair(FORMAT_NCHW_VECT_C), + EnumStringPair(FORMAT_NHWC_VECT_W), }; std::pair test_filter_formats[] = { @@ -104,7 +105,8 @@ struct DimMaps { inline constexpr const TensorDimMap& GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) { return - (format == FORMAT_NHWC) ? DimMaps::kTdmNHWC[num_spatial_dims] : + (format == FORMAT_NHWC || + format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] : (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] : DimMaps::kTdmInvalid; -- 2.7.4