Improved accuracy of op_level_cost_estimator (QuantizeV2, Dequantize, Gather).
authorMax Galkin <maxgalkin@google.com>
Tue, 20 Mar 2018 18:45:23 +0000 (11:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 18:48:34 +0000 (11:48 -0700)
PiperOrigin-RevId: 189779691

tensorflow/core/grappler/costs/op_level_cost_estimator.cc
tensorflow/core/grappler/costs/op_level_cost_estimator.h
tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc

index 29ef317..84ad8a3 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/grappler/clusters/utils.h"
 
 namespace tensorflow {
@@ -46,6 +47,7 @@ constexpr char kShape[] = "Shape";
 constexpr char kSize[] = "Size";
 constexpr char kStopGradient[] = "StopGradient";
 constexpr char kPreventGradient[] = "PreventGradient";
+constexpr char kGather[] = "Gather";
 
 static const Costs::Duration kMinComputeTime(1);
 
@@ -167,6 +169,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
 
       {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
 
+      {kGather, wrap(&OpLevelCostEstimator::PredictGather)},
+
       {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
       {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
       {kRefIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -184,6 +188,17 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
       {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
       {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)}};
 
+  // Quantize = apply min and max bounds, multiply by scale factor and round.
+  const int quantize_v2_cost =
+      Eigen::internal::functor_traits<
+          Eigen::internal::scalar_product_op<float>>::Cost +
+      Eigen::internal::functor_traits<
+          Eigen::internal::scalar_max_op<float>>::Cost +
+      Eigen::internal::functor_traits<
+          Eigen::internal::scalar_min_op<float>>::Cost +
+      Eigen::internal::functor_traits<
+          Eigen::internal::scalar_round_op<float>>::Cost;
+
   elementwise_ops_ = {
       // Unary ops alphabetically sorted
       {"Acos", Eigen::internal::functor_traits<
@@ -200,6 +215,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
                    Eigen::internal::scalar_ceil_op<float>>::Cost},
       {"Cos", Eigen::internal::functor_traits<
                   Eigen::internal::scalar_cos_op<float>>::Cost},
+      {"Dequantize", Eigen::internal::functor_traits<
+                         Eigen::internal::scalar_product_op<float>>::Cost},
       {"Erf", 1},
       {"Erfc", 1},
       {"Exp", Eigen::internal::functor_traits<
@@ -218,6 +235,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
                     Eigen::internal::scalar_log1p_op<float>>::Cost},
       {"Neg", Eigen::internal::functor_traits<
                   Eigen::internal::scalar_opposite_op<float>>::Cost},
+      {"QuantizeV2", quantize_v2_cost},
       {"Reciprocal", Eigen::internal::functor_traits<
                          Eigen::internal::scalar_inverse_op<float>>::Cost},
       {"Rint", 1},
@@ -411,28 +429,33 @@ Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp(
 }
 
 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
-    double operations, const OpInfo& op_features) const {
-  DeviceInfo device_perf = GetDeviceInfo(op_features.device());
-  if (device_perf.gigaops <= 0 || device_perf.gb_per_sec <= 0) {
-    VLOG(1) << "BAD DEVICE. Op:" << op_features.op()
-            << " device type:" << op_features.device().type()
-            << " device model:" << op_features.device().model();
-  }
+    double operations, const OpInfo& op_info) const {
+  bool unknown_shapes = false;
+  const double input_size = CalculateInputSize(op_info, &unknown_shapes);
+  const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
+  const double total_io_bytes = input_size + output_size;
+  Costs costs = PredictOpCountBasedCost(operations, total_io_bytes, op_info);
+  costs.inaccurate = unknown_shapes;
+  costs.max_memory = output_size;
+  return costs;
+}
 
-  Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.gigaops));
-  VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9
-          << " Execution Time (ns):" << compute_cost.count();
+Costs OpLevelCostEstimator::PredictOpCountBasedCost(
+    double operations, double total_io_bytes, const OpInfo& op_info) const {
+  const DeviceInfo device_info = GetDeviceInfo(op_info.device());
+  if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0) {
+    VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
+            << " device type:" << op_info.device().type()
+            << " device model:" << op_info.device().model();
+  }
 
-  bool found_unknown_shapes = false;
-  const double total_input_size =
-      CalculateInputSize(op_features, &found_unknown_shapes);
-  const double total_output_size =
-      CalculateOutputSize(op_features, &found_unknown_shapes);
-  const double total_io_size = total_input_size + total_output_size;
+  Costs::NanoSeconds compute_cost(std::ceil(operations / device_info.gigaops));
+  VLOG(1) << "Op:" << op_info.op() << " GOps:" << operations / 1e9
+          << " Compute Time (ns):" << compute_cost.count();
 
   Costs::NanoSeconds memory_cost(
-      std::ceil(total_io_size / device_perf.gb_per_sec));
-  VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3
+      std::ceil(total_io_bytes / device_info.gb_per_sec));
+  VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
           << " Memory Time (ns):" << memory_cost.count();
 
   Costs costs;
@@ -443,8 +466,6 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
   } else {
     costs.execution_time = compute_cost + memory_cost;
   }
-  costs.inaccurate = found_unknown_shapes;
-  costs.max_memory = total_output_size;
   return costs;
 }
 
@@ -867,7 +888,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
 
 int64 OpLevelCostEstimator::CalculateTensorElementCount(
     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const {
-  VLOG(2) << "   with " << tensor.dtype() << " tensor of shape "
+  VLOG(2) << "   with " << DataTypeString(tensor.dtype()) << " tensor of shape "
           << tensor.shape().DebugString();
   int64 tensor_size = 1;
   int num_dims = std::max(1, tensor.shape().dim_size());
@@ -1028,5 +1049,23 @@ Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
   return costs;
 }
 
+Costs OpLevelCostEstimator::PredictGather(const OpContext& op_context) const {
+  // Gather op can have a very large input, but only the size of the output
+  // matters, because indices may select only a very small subset of input.
+
+  const auto& op_info = op_context.op_info;
+
+  bool unknown_shapes = false;
+  const int64 op_count =
+      CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
+  const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
+  const double total_io = 2 * output_size;
+  Costs costs = PredictOpCountBasedCost(op_count, total_io, op_info);
+  costs.inaccurate = unknown_shapes;
+  costs.max_memory = output_size;
+
+  return costs;
+}
+
 }  // end namespace grappler
 }  // end namespace tensorflow
index 7bb530f..e5dd31a 100644 (file)
@@ -51,10 +51,15 @@ class OpLevelCostEstimator {
   // Predict cost of an op for which no accurate estimator is defined.
   Costs PredictCostOfAnUnknownOp(const OpContext& op_context) const;
 
-  // Naive cost estimate based on operations divided by device ops/sec,
-  // and input/output tensor sizes.
-  Costs PredictOpCountBasedCost(double operations,
-                                const OpInfo& op_features) const;
+  // Naive cost estimate based on the given operations count and total
+  // input/output tensor sizes of the given op_info combined.
+  Costs PredictOpCountBasedCost(double operations, const OpInfo& op_info) const;
+
+  // Naive cost estimate based on the given operations count and the given total
+  // io size in bytes. Sizes of op_info inputs and outputs are not taken into
+  // consideration.
+  Costs PredictOpCountBasedCost(double operations, double total_io_bytes,
+                                const OpInfo& op_info) const;
 
   // This family of routines counts the number of operations to perform the
   // specified TensorFlow Op.
@@ -125,7 +130,7 @@ class OpLevelCostEstimator {
   // implementation just divides the operations to
   // perform the op (from the "Count" routines,
   // above) by the device peak operations per
-  // second. Override to supply a better estimate.
+  // second.
   // Implementation of costs other than
   // execution_time is optional, depending on the
   // device.
@@ -139,6 +144,7 @@ class OpLevelCostEstimator {
   Costs PredictVariable(const OpContext& op_context) const;
   Costs PredictBatchMatMul(const OpContext& op_context) const;
   Costs PredictMetadata(const OpContext& op_context) const;
+  Costs PredictGather(const OpContext& op_context) const;
 
   // Utility function for safe division. Returns 0
   // if rhs is 0 or negative.
index 4790b9b..d5360cb 100644 (file)
@@ -75,8 +75,8 @@ OpContext DescribeMatMulUnknownShape() {
 // Wrangles the minimum number of proto fields to set up an input of
 // arbitrary rank and type.
 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
-                                OpInfo* op_features) {
-  auto input = op_features->add_inputs();
+                                OpInfo* op_info) {
+  auto input = op_info->add_inputs();
   input->set_dtype(dtype);
   auto shape = input->mutable_shape();
   for (auto d : dims) {
@@ -84,6 +84,18 @@ void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
   }
 }
 
+// Wrangles the minimum number of proto fields to set up an output of
+// arbitrary rank and type.
+void DescribeArbitraryRankOutput(const std::vector<int>& dims, DataType dtype,
+                                 OpInfo* op_info) {
+  auto output = op_info->add_outputs();
+  output->set_dtype(dtype);
+  auto shape = output->mutable_shape();
+  for (auto d : dims) {
+    shape->add_dim()->set_size(d);
+  }
+}
+
 // Returns an OpInfo for a BatchMatMul
 OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
                               const std::vector<int>& dims_b) {
@@ -200,6 +212,23 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
   OpLevelCostEstimator estimator_;
 };
 
+TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
+  OpContext op_context;
+  SetCpuDevice(&op_context.op_info);
+  op_context.op_info.set_op("Gather");
+
+  // Huge first input shouldn't affect Gather execution and memory costs.
+  DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
+  DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
+  DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
+
+  auto cost = estimator_.PredictCosts(op_context);
+  EXPECT_EQ(Costs::Duration(128), cost.memory_time);
+  EXPECT_EQ(Costs::Duration(16), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(144), cost.execution_time);
+  EXPECT_FALSE(cost.inaccurate);
+}
+
 TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
   auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
   EXPECT_EQ(Costs::Duration(8400), cost.memory_time);
@@ -354,7 +383,7 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
   TensorProto tensor_proto;
   TensorShapeProto tensor_shape_proto;
 
-  // Dimention larger than max value; should fail while converting to Tensor
+  // Dimension larger than max value; should fail while converting to Tensor
   // class.
   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
   EXPECT_FALSE(