Add analytical cost model for FusedConv2DBiasActivation.
authorRob Sloan <varomodt@google.com>
Sat, 7 Apr 2018 04:55:10 +0000 (21:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 7 Apr 2018 04:57:46 +0000 (21:57 -0700)
PiperOrigin-RevId: 191978272

tensorflow/core/grappler/costs/op_level_cost_estimator.cc
tensorflow/core/grappler/costs/op_level_cost_estimator.h
tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc

index 79735e6..087190a 100644 (file)
@@ -30,6 +30,7 @@ constexpr char kConst[] = "Const";
 constexpr char kConv2d[] = "Conv2D";
 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
+constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
 constexpr char kMatMul[] = "MatMul";
 constexpr char kSparseMatMul[] = "SparseMatMul";
 constexpr char kPlaceholder[] = "Placeholder";
@@ -196,6 +197,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
        wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
       {kConv2dBackpropInput,
        wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
+      {kFusedConv2dBiasActivation,
+       wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)},
       {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
       {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
       {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
@@ -545,7 +548,6 @@ int64 OpLevelCostEstimator::CountConv2DOperations(
   ops *= conv_dims.kx * conv_dims.ky;
   ops *= conv_dims.iz * conv_dims.oz;
   ops *= kOpsPerMac;
-  VLOG(1) << "Operations for Conv2D " << ops;
 
   if (conv_info != nullptr) {
     *conv_info = conv_dims;
@@ -983,6 +985,91 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
   return costs;
 }
 
+Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
+    const OpContext& op_context) const {
+  // FusedConv2DBiasActivation computes a fused kernel which implements:
+  // 2D convolution, adds side input with separate scaling on convolution and
+  // side inputs, then adds bias, and finally applies the ReLU activation
+  // function to the result:
+  //
+  // Input -> Conv2D  ->  Add  -> BiasAdd  -> ReLU
+  //            ^          ^         ^
+  //          Filter   Side Input   Bias
+  //
+  // Note that when adding the side input, the operation multiplies the output
+  // of Conv2D by conv_input_scale, confusingly, and the side_input by
+  // side_input_scale.
+  //
+  // Note that in the special case that side_input_scale is 0, which we infer
+  // from side_input having dimensions [], we skip that addition operation.
+  //
+  // For more information, see
+  // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+  auto& conv_input = op_context.op_info.inputs(0);
+  auto& filter = op_context.op_info.inputs(1);
+  auto& bias = op_context.op_info.inputs(2);
+  auto& side_input = op_context.op_info.inputs(3);
+  auto& conv_input_scale = op_context.op_info.inputs(4);
+  auto& side_input_scale = op_context.op_info.inputs(5);
+
+  // Manually compute our convolution dimensions.
+  bool found_unknown_shapes = false;
+  auto dims = ConvolutionDimensionsFromInputs(
+      conv_input.shape(), filter.shape(), op_context.op_info,
+      &found_unknown_shapes);
+
+  // Construct the shape of our output tensor from our convolution dimensions
+  // and format, as it may not be available yet.
+  //
+  // TODO(varomodt): should we centralize the Conv2D input/output shapes?
+  bool unknown_conv_format = false;
+  OpInfo::TensorProperties output;
+  switch (GetConvolutionFormat(op_context)) {
+    case NCHW:
+      output =
+          DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
+      break;
+    case NHWC:
+      output =
+          DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
+      break;
+    default:
+      // TODO(b/77722245): support cost estimation for NCHW_VECT_C.
+      LOG(WARNING) << "unsupported data format: "
+                   << GetDataFormat(op_context.op_info)
+                   << " Defaulting to NHWC.";
+      output =
+          DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
+      unknown_conv_format = true;
+      break;
+  }
+
+  // Add the operations the fused op always computes.
+  std::vector<OpContext> component_ops = {
+      FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
+      FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
+      FusedChildContext(op_context, "BiasAdd", output, {output, bias}),
+      FusedChildContext(op_context, "Relu", output, {output})};
+
+  // Add our side_input iff it's non-empty.
+  if (side_input.shape().dim_size() > 0) {
+    component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
+                                              {side_input, side_input_scale}));
+    component_ops.push_back(
+        FusedChildContext(op_context, "Add", output, {side_input, output}));
+  }
+
+  // Construct an op_context which definitely has our output shape.
+  auto op_context_with_output = op_context;
+  op_context_with_output.op_info.mutable_outputs()->Clear();
+  *op_context_with_output.op_info.mutable_outputs()->Add() = output;
+
+  // Construct component operations and run the cost computation.
+  auto costs = PredictFusedOp(op_context_with_output, component_ops);
+  costs.inaccurate |= found_unknown_shapes || unknown_conv_format;
+  return costs;
+}
+
 Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
   const auto& op_features = op_context.op_info;
   bool found_unknown_shapes = false;
@@ -1086,6 +1173,66 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
   return costs;
 }
 
+Costs OpLevelCostEstimator::PredictFusedOp(
+    const OpContext& op_context,
+    const std::vector<OpContext>& fused_op_contexts) const {
+  // Note that PredictOpCountBasedCost will get the correct memory_time from
+  // the node's inputs and outputs; but we don't want to have to re-implement
+  // the logic for computing the operation count of each of our component
+  // operations here; so we simply add the compute times of each component
+  // operation, then update the execution time.
+  Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
+  fused_cost.compute_time = 0;
+  fused_cost.inaccurate = false;
+  for (auto& fused_op : fused_op_contexts) {
+    auto op_cost = PredictCosts(fused_op);
+    fused_cost.compute_time += op_cost.compute_time;
+    fused_cost.inaccurate |= op_cost.inaccurate;
+  }
+
+  CombineCostsAndUpdateExecutionTime(&fused_cost);
+  return fused_cost;
+}
+
+/* static */
+OpContext OpLevelCostEstimator::FusedChildContext(
+    const OpContext& parent, const string& op_name,
+    const OpInfo::TensorProperties& output,
+    const std::vector<OpInfo::TensorProperties>& inputs) {
+  // Setup the base parameters of our new context.
+  OpContext new_context;
+  new_context.name = op_name;
+  new_context.device_name = parent.device_name;
+  new_context.op_info = parent.op_info;
+  new_context.op_info.set_op(op_name);
+
+  // Setup the inputs of our new context.
+  new_context.op_info.mutable_inputs()->Clear();
+  for (const auto& input : inputs) {
+    *new_context.op_info.mutable_inputs()->Add() = input;
+  }
+
+  // Setup the output of our new context.
+  new_context.op_info.mutable_outputs()->Clear();
+  *new_context.op_info.mutable_outputs()->Add() = output;
+
+  return new_context;
+}
+
+/* static */
+OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
+    DataType type, const std::vector<int64>& dims) {
+  OpInfo::TensorProperties ret;
+  ret.set_dtype(type);
+
+  auto shape = ret.mutable_shape();
+  for (const int dim : dims) {
+    shape->add_dim()->set_size(dim);
+  }
+
+  return ret;
+}
+
 /* static */
 OpLevelCostEstimator::ConvolutionDimensions
 OpLevelCostEstimator::OpDimensionsFromInputs(
@@ -1371,6 +1518,21 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
   return costs;
 }
 
+/* static */
+OpLevelCostEstimator::ConvolutionFormat
+OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) {
+  auto data_format = GetDataFormat(op_context.op_info);
+  if (data_format == "NCHW") {
+    return NCHW;
+  } else if (data_format == "NHWC") {
+    return NHWC;
+  } else if (data_format == "NCHW_VECT_C") {
+    return NCHW_VECT_C;
+  }
+
+  return UNKNOWN_CONVOLUTION_FORMAT;
+}
+
 void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
     Costs* costs) const {
   if (compute_memory_overlap_) {
@@ -1379,6 +1541,5 @@ void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
     costs->execution_time = costs->compute_time + costs->memory_time;
   }
 }
-
 }  // end namespace grappler
 }  // end namespace tensorflow
index 7080264..35649f7 100644 (file)
@@ -82,6 +82,13 @@ class OpLevelCostEstimator {
     int64 sy;         // Stride y.
     Padding padding;  // SAME or VALID.
   };
+  enum ConvolutionFormat {
+    UNKNOWN_CONVOLUTION_FORMAT,
+    NHWC,
+    NCHW,
+    NCHW_VECT_C,
+    NCHW_VECT_W,
+  };
   int64 CountConv2DOperations(const OpInfo& op_features,
                               bool* found_unknown_shapes) const;
   int64 CountConv2DOperations(const OpInfo& op_features,
@@ -138,6 +145,7 @@ class OpLevelCostEstimator {
   Costs PredictCwiseOp(const OpContext& op_context) const;
   Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
   Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
+  Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const;
   Costs PredictMatMul(const OpContext& op_context) const;
   Costs PredictNoOp(const OpContext& op_context) const;
   Costs PredictIdentity(const OpContext& op_context) const;
@@ -152,6 +160,10 @@ class OpLevelCostEstimator {
   Costs PredictFusedBatchNorm(const OpContext& op_context) const;
   Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
 
+  // Generic cost prediction method for fused operations.
+  Costs PredictFusedOp(const OpContext& op_context,
+                       const std::vector<OpContext>& fused_op_contexts) const;
+
   // Utility function for safe division. Returns 0
   // if rhs is 0 or negative.
   static double SafeDiv(const double lhs, const double rhs) {
@@ -173,6 +185,20 @@ class OpLevelCostEstimator {
       const TensorShapeProto& original_image_shape, const OpInfo& op_info,
       bool* found_unknown_shapes);
 
+  // Helper to construct child operation contexts for the component operations
+  // of fused ops.
+  static OpContext FusedChildContext(
+      const OpContext& parent, const string& op_name,
+      const OpInfo::TensorProperties& output,
+      const std::vector<OpInfo::TensorProperties>& inputs);
+
+  // Helper to construct tensor shapes.
+  static OpInfo::TensorProperties DescribeTensor(
+      DataType type, const std::vector<int64>& dims);
+
+  // Returns the Conv2D format for this operation.
+  static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context);
+
   // This method calculates the execution time depending on whether IO can
   // overlap with computation. It assumes the memory and the compute times have
   // already been calculated.
index d797a8a..13ea43b 100644 (file)
@@ -93,6 +93,14 @@ OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
   return op_context;
 }
 
+// Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
+// estimation purposes.
+void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
+  auto shape = tensor->mutable_shape();
+  shape->add_dim()->set_size(dim0);
+  tensor->set_dtype(DT_FLOAT);
+}
+
 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
 // estimation purposes.
 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
@@ -120,6 +128,38 @@ OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
   return op_context;
 }
 
+// DescribeFusedConv2DBiasActivation constructs an OpContext for a
+// FusedConv2DBiasActivation applied to a convolution input tensor with shape
+// (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
+// bias tensor with shape (oz), a side input tensor with shape
+// (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
+// shape (1).
+//
+// Note that this assumes the NHWC data format.
+OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
+                                            int iz2, int kx, int ky, int ox,
+                                            int oy, int oz,
+                                            bool has_side_input) {
+  OpContext op_context;
+  SetCpuDevice(&op_context.op_info);
+  op_context.op_info.set_op("FusedConv2DBiasActivation");
+  DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
+  DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+  DescribeTensor1D(oz, op_context.op_info.add_inputs());
+
+  // Add the side_input, if any.
+  auto side_input = op_context.op_info.add_inputs();
+  if (has_side_input) {
+    DescribeTensor4D(batch, ox, oy, oz, side_input);
+  }
+
+  // Add the scaling tensors.
+  DescribeTensor1D(1, op_context.op_info.add_inputs());
+  DescribeTensor1D(1, op_context.op_info.add_inputs());
+
+  return op_context;
+}
+
 // DescribeUnaryOp constructs an OpContext for the given operation applied to
 // a 4-tensor with shape (size1, 1, 1, 1).
 OpContext DescribeUnaryOp(const string& op, int size1) {
@@ -162,12 +202,9 @@ OpContext DescribeBiasAdd(int size1, int size2) {
   op_context.op_info.set_op("BiasAdd");
 
   DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
+  DescribeTensor1D(size1, op_context.op_info.add_inputs());
   DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
 
-  auto bias = op_context.op_info.add_inputs();
-  bias->mutable_shape()->add_dim()->set_size(size1);
-  bias->set_dtype(DT_FLOAT);
-
   return op_context;
 }
 
@@ -486,6 +523,25 @@ TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
   SetComputeMemoryOverlap(false);  // Set it back to default.
 }
 
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationExecutionTime) {
+  auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+      16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true));
+  EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+  EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+  EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest,
+       FusedConv2DBiasActivationNoSideInputExecutionTime) {
+  auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+      16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false));
+  EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
+  EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
+  EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+  EXPECT_FALSE(cost.inaccurate);
+}
+
 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
   auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);