Support NHWC_VECT_W in MakeShapeFromFormat.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 29 May 2018 17:53:40 +0000 (10:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 17:55:57 +0000 (10:55 -0700)
PiperOrigin-RevId: 198421617

tensorflow/core/framework/common_shape_fns.cc

index 71a31b0..d1b495d 100644 (file)
@@ -303,6 +303,9 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
   if (format == FORMAT_NCHW_VECT_C) {
     dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
         context->MakeDim(4);
+  } else if (format == FORMAT_NHWC_VECT_W) {
+    dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
+        context->MakeDim(4);
   }
   for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) {
     dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =