Implement the logic to parse TensorProto (the tensor value for input or filter shape...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 22 Feb 2018 23:34:50 +0000 (15:34 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 23:38:52 +0000 (15:38 -0800)
PiperOrigin-RevId: 186685409

tensorflow/core/grappler/costs/op_level_cost_estimator.cc
tensorflow/core/grappler/costs/op_level_cost_estimator.h
tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc

index a57cfdd..983b689 100644 (file)
@@ -718,6 +718,56 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations(
   return ops;
 }
 
+bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
+                                        TensorShapeProto* tensor_shape_proto) {
+  tensor_shape_proto->Clear();
+  // First convert TensorProto into Tensor class so that it correctly parses
+  // data values within TensorProto (whether it's in int_val, int64_val,
+  // tensor_content, or anything.
+  Tensor tensor(tensor_proto.dtype());
+  if (!tensor.FromProto(tensor_proto)) {
+    LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
+                 << "failed to parse TensorProto: "
+                 << tensor_proto.DebugString();
+    return false;
+  }
+  if (tensor.dims() != 1) {
+    LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
+                 << "tensor is not 1D: " << tensor.dims();
+    return false;
+  }
+  // Then, convert it back to TensorProto using AsProtoField, which makes sure
+  // the data is in int_val, int64_val, or such repeated data fields, not in
+  // tensor_content.
+  TensorProto temp_tensor;
+  tensor.AsProtoField(&temp_tensor);
+
+#define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type)        \
+  do {                                                   \
+    for (const auto& value : temp_tensor.type##_val()) { \
+      tensor_shape_proto->add_dim()->set_size(value);    \
+    }                                                    \
+  } while (0)
+
+  if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
+      tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
+    TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
+  } else if (tensor.dtype() == DT_INT64) {
+    TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
+  } else if (tensor.dtype() == DT_UINT32) {
+    TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
+  } else if (tensor.dtype() == DT_UINT64) {
+    TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
+  } else {
+    LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
+                 << "Unsupported dtype: " << tensor.dtype();
+    return false;
+  }
+#undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
+
+  return true;
+}
+
 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
 int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
     const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
@@ -732,20 +782,16 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
   }
 
   TensorShapeProto input_shape;
+  bool shape_found = false;
   if (op_features.inputs(0).has_value()) {
     const TensorProto& value = op_features.inputs(0).value();
-    if (value.int64_val_size() > 0) {
-      for (int i = 0; i < value.int64_val_size(); ++i) {
-        input_shape.add_dim()->set_size(value.int64_val(i));
-      }
-    } else {
-      for (int i = 0; i < value.int_val_size(); ++i) {
-        input_shape.add_dim()->set_size(value.int_val(i));
-      }
-    }
-  } else if (op_features.outputs_size() == 1) {
+    shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
+  }
+  if (!shape_found && op_features.outputs_size() == 1) {
     input_shape = op_features.outputs(0).shape();
-  } else {
+    shape_found = true;
+  }
+  if (!shape_found) {
     // Set the minimum filter size that's feasible.
     for (int i = 0; i < 4; ++i) {
       input_shape.add_dim()->set_size(1);
@@ -778,20 +824,16 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
   DCHECK_EQ(kConv2dBackpropFilter, op_features.op());
 
   TensorShapeProto filter_shape;
+  bool shape_found = false;
   if (op_features.inputs_size() >= 2 && op_features.inputs(1).has_value()) {
     const TensorProto& value = op_features.inputs(1).value();
-    if (value.int64_val_size() > 0) {
-      for (int i = 0; i < value.int64_val_size(); ++i) {
-        filter_shape.add_dim()->set_size(value.int64_val(i));
-      }
-    } else {
-      for (int i = 0; i < value.int_val_size(); ++i) {
-        filter_shape.add_dim()->set_size(value.int_val(i));
-      }
-    }
-  } else if (op_features.outputs_size() == 1) {
+    shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
+  }
+  if (!shape_found && op_features.outputs_size() == 1) {
     filter_shape = op_features.outputs(0).shape();
-  } else {
+    shape_found = true;
+  }
+  if (!shape_found) {
     // Set the minimum filter size that's feasible.
     for (int i = 0; i < 4; ++i) {
       filter_shape.add_dim()->set_size(1);
index a292e5e..7bb530f 100644 (file)
@@ -28,6 +28,9 @@ limitations under the License.
 namespace tensorflow {
 namespace grappler {
 
+bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
+                                        TensorShapeProto* tensor_shape_proto);
+
 class OpLevelCostEstimator {
  public:
   OpLevelCostEstimator();
index 60fc783..583d261 100644 (file)
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/platform/test.h"
@@ -247,5 +249,108 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
 }
 
+// Helper functions for testing GetTensorShapeProtoFromTensorProto().
+void GetTensorProto(const DataType dtype, const std::vector<int64>& shape,
+                    const std::vector<int64> values, const bool tensor_content,
+                    TensorProto* tensor_proto) {
+  tensor_proto->Clear();
+  TensorProto temp_tensor_proto;
+  temp_tensor_proto.set_dtype(dtype);
+  for (const auto& x : shape) {
+    temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
+  }
+  for (const auto& x : values) {
+    if (dtype == DT_INT64) {
+      temp_tensor_proto.add_int64_val(x);
+    } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
+               dtype == DT_UINT8) {
+      temp_tensor_proto.add_int_val(x);
+    } else if (dtype == DT_UINT32) {
+      temp_tensor_proto.add_uint32_val(x);
+    } else if (dtype == DT_UINT64) {
+      temp_tensor_proto.add_uint64_val(x);
+    } else {
+      CHECK(false) << "Unsupported dtype: " << dtype;
+    }
+  }
+  Tensor tensor(dtype);
+  CHECK(tensor.FromProto(temp_tensor_proto));
+  if (tensor_content) {
+    tensor.AsProtoTensorContent(tensor_proto);
+  } else {
+    tensor.AsProtoField(tensor_proto);
+  }
+}
+
+void ExpectTensorShape(const std::vector<int64>& expected,
+                       const TensorShapeProto& tensor_shape_proto) {
+  TensorShape tensor_shape_expected(expected);
+  TensorShape tensor_shape(tensor_shape_proto);
+
+  LOG(INFO) << "Expected: " << tensor_shape_expected.DebugString();
+  LOG(INFO) << "TensorShape: " << tensor_shape.DebugString();
+  EXPECT_TRUE(tensor_shape_expected == tensor_shape);
+}
+
+TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
+  TensorProto tensor_proto;
+  TensorShapeProto tensor_shape_proto;
+
+  // Dimention larger than max value; should fail while converting to Tensor
+  // class.
+  tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
+  EXPECT_FALSE(
+      GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+
+  tensor_proto.Clear();
+  // Expect only 1D shape.
+  tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
+  tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
+  EXPECT_FALSE(
+      GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+
+  // Expect only handle integer data types.
+  GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto);
+  EXPECT_FALSE(
+      GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+
+  // Check GetTensorShapeProtoFromTensorProto() resturns correct values.
+  {
+    std::vector<int64> shape_expected = {10, 20, 30, 40};
+    GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/false,
+                   &tensor_proto);
+    EXPECT_TRUE(
+        GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+    ExpectTensorShape(shape_expected, tensor_shape_proto);
+  }
+
+  {
+    std::vector<int64> shape_expected = {40, 20, 90, 40};
+    GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/false,
+                   &tensor_proto);
+    EXPECT_TRUE(
+        GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+    ExpectTensorShape(shape_expected, tensor_shape_proto);
+  }
+
+  {
+    std::vector<int64> shape_expected = {10, 20, 30, 40};
+    GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/true,
+                   &tensor_proto);
+    EXPECT_TRUE(
+        GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+    ExpectTensorShape(shape_expected, tensor_shape_proto);
+  }
+
+  {
+    std::vector<int64> shape_expected = {40, 20, 90, 40};
+    GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/true,
+                   &tensor_proto);
+    EXPECT_TRUE(
+        GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
+    ExpectTensorShape(shape_expected, tensor_shape_proto);
+  }
+}
+
 }  // end namespace grappler
 }  // end namespace tensorflow