Further improve accuracy of op_level_cost_estimator (Gather, GatherV2, Slice).
authorMax Galkin <maxgalkin@google.com>
Wed, 21 Mar 2018 19:53:04 +0000 (12:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 19:55:38 +0000 (12:55 -0700)
PiperOrigin-RevId: 189952132

tensorflow/core/grappler/costs/op_level_cost_estimator.cc
tensorflow/core/grappler/costs/op_level_cost_estimator.h
tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc

index 84ad8a3..d3ffa03 100644 (file)
@@ -48,6 +48,8 @@ constexpr char kSize[] = "Size";
 constexpr char kStopGradient[] = "StopGradient";
 constexpr char kPreventGradient[] = "PreventGradient";
 constexpr char kGather[] = "Gather";
+constexpr char kGatherV2[] = "GatherV2";
+constexpr char kSlice[] = "Slice";
 
 static const Costs::Duration kMinComputeTime(1);
 
@@ -169,7 +171,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
 
       {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
 
-      {kGather, wrap(&OpLevelCostEstimator::PredictGather)},
+      {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
+      {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
+      {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
 
       {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
       {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -1049,17 +1053,33 @@ Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
   return costs;
 }
 
-Costs OpLevelCostEstimator::PredictGather(const OpContext& op_context) const {
-  // Gather op can have a very large input, but only the size of the output
-  // matters, because indices may select only a very small subset of input.
-
+Costs OpLevelCostEstimator::PredictGatherOrSlice(
+    const OpContext& op_context) const {
+  // Gather & Slice ops can have a very large input, but only access a small
+  // part of it. For these op the size of the output determines the memory cost.
   const auto& op_info = op_context.op_info;
 
   bool unknown_shapes = false;
+
+  // Each output element is a copy of some element from input.
+  // For roofline estimate we assume each copy has a unit cost.
   const int64 op_count =
       CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
+
   const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
-  const double total_io = 2 * output_size;
+  double input_size = output_size;
+  if (op_info.op() == "Slice") {
+    // Add 'begin' & 'size' tensors sizes.
+    input_size +=
+        CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) +
+        CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes);
+  } else {
+    // Assuming this is "Gather" or "GatherV2" op, add 'indices' size.
+    input_size +=
+        CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
+  }
+
+  const double total_io = input_size + output_size;
   Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info);
   costs.inaccurate = unknown_shapes;
   costs.max_memory = output_size;
index e5dd31a..1b3babb 100644 (file)
@@ -144,7 +144,7 @@ class OpLevelCostEstimator {
   Costs PredictVariable(const OpContext& op_context) const;
   Costs PredictBatchMatMul(const OpContext& op_context) const;
   Costs PredictMetadata(const OpContext& op_context) const;
-  Costs PredictGather(const OpContext& op_context) const;
+  Costs PredictGatherOrSlice(const OpContext& op_context) const;
 
   // Utility function for safe division. Returns 0
   // if rhs is 0 or negative.
index a92f230..f2a9615 100644 (file)
@@ -206,9 +206,27 @@ TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
   DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
 
   auto cost = estimator_.PredictCosts(op_context);
-  EXPECT_EQ(Costs::Duration(128), cost.memory_time);
+  EXPECT_EQ(Costs::Duration(130), cost.memory_time);
   EXPECT_EQ(Costs::Duration(16), cost.compute_time);
-  EXPECT_EQ(Costs::Duration(144), cost.execution_time);
+  EXPECT_EQ(Costs::Duration(146), cost.execution_time);
+  EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
+  OpContext op_context;
+  SetCpuDevice(&op_context.op_info);
+  op_context.op_info.set_op("Slice");
+
+  // Huge first input shouldn't affect Slice execution and memory costs.
+  DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
+  DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
+  DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
+  DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
+
+  auto cost = estimator_.PredictCosts(op_context);
+  EXPECT_EQ(Costs::Duration(81), cost.memory_time);
+  EXPECT_EQ(Costs::Duration(10), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(91), cost.execution_time);
   EXPECT_FALSE(cost.inaccurate);
 }