From 30e2b97897d05e47b457ab1d5d0d9c4227b87845 Mon Sep 17 00:00:00 2001 From: Rob Sloan Date: Fri, 6 Apr 2018 21:55:10 -0700 Subject: [PATCH] Add analytical cost model for FusedConv2DBiasActivation. PiperOrigin-RevId: 191978272 --- .../core/grappler/costs/op_level_cost_estimator.cc | 165 ++++++++++++++++++++- .../core/grappler/costs/op_level_cost_estimator.h | 26 ++++ .../grappler/costs/op_level_cost_estimator_test.cc | 64 +++++++- 3 files changed, 249 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 79735e6..087190a 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -30,6 +30,7 @@ constexpr char kConst[] = "Const"; constexpr char kConv2d[] = "Conv2D"; constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter"; constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput"; +constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation"; constexpr char kMatMul[] = "MatMul"; constexpr char kSparseMatMul[] = "SparseMatMul"; constexpr char kPlaceholder[] = "Placeholder"; @@ -196,6 +197,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() { wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)}, {kConv2dBackpropInput, wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)}, + {kFusedConv2dBiasActivation, + wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)}, {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}, @@ -545,7 +548,6 @@ int64 OpLevelCostEstimator::CountConv2DOperations( ops *= conv_dims.kx * conv_dims.ky; ops *= conv_dims.iz * conv_dims.oz; ops *= kOpsPerMac; - VLOG(1) << "Operations for Conv2D " << ops; if (conv_info != nullptr) { *conv_info = conv_dims; @@ -983,6 +985,91 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter( return costs; } +Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( + const OpContext& op_context) const { + // FusedConv2DBiasActivation computes a fused kernel which implements: + // 2D convolution, adds side input with separate scaling on convolution and + // side inputs, then adds bias, and finally applies the ReLU activation + // function to the result: + // + // Input -> Conv2D -> Add -> BiasAdd -> ReLU + // ^ ^ ^ + // Filter Side Input Bias + // + // Note that when adding the side input, the operation multiplies the output + // of Conv2D by conv_input_scale, confusingly, and the side_input by + // side_input_scale. + // + // Note that in the special case that side_input_scale is 0, which we infer + // from side_input having dimensions [], we skip that addition operation. + // + // For more information, see + // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc + auto& conv_input = op_context.op_info.inputs(0); + auto& filter = op_context.op_info.inputs(1); + auto& bias = op_context.op_info.inputs(2); + auto& side_input = op_context.op_info.inputs(3); + auto& conv_input_scale = op_context.op_info.inputs(4); + auto& side_input_scale = op_context.op_info.inputs(5); + + // Manually compute our convolution dimensions. + bool found_unknown_shapes = false; + auto dims = ConvolutionDimensionsFromInputs( + conv_input.shape(), filter.shape(), op_context.op_info, + &found_unknown_shapes); + + // Construct the shape of our output tensor from our convolution dimensions + // and format, as it may not be available yet. + // + // TODO(varomodt): should we centralize the Conv2D input/output shapes? + bool unknown_conv_format = false; + OpInfo::TensorProperties output; + switch (GetConvolutionFormat(op_context)) { + case NCHW: + output = + DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy}); + break; + case NHWC: + output = + DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); + break; + default: + // TODO(b/77722245): support cost estimation for NCHW_VECT_C. + LOG(WARNING) << "unsupported data format: " + << GetDataFormat(op_context.op_info) + << " Defaulting to NHWC."; + output = + DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz}); + unknown_conv_format = true; + break; + } + + // Add the operations the fused op always computes. + std::vector component_ops = { + FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}), + FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}), + FusedChildContext(op_context, "BiasAdd", output, {output, bias}), + FusedChildContext(op_context, "Relu", output, {output})}; + + // Add our side_input iff it's non-empty. + if (side_input.shape().dim_size() > 0) { + component_ops.push_back(FusedChildContext(op_context, "Mul", side_input, + {side_input, side_input_scale})); + component_ops.push_back( + FusedChildContext(op_context, "Add", output, {side_input, output})); + } + + // Construct an op_context which definitely has our output shape. + auto op_context_with_output = op_context; + op_context_with_output.op_info.mutable_outputs()->Clear(); + *op_context_with_output.op_info.mutable_outputs()->Add() = output; + + // Construct component operations and run the cost computation. + auto costs = PredictFusedOp(op_context_with_output, component_ops); + costs.inaccurate |= found_unknown_shapes || unknown_conv_format; + return costs; +} + Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const { const auto& op_features = op_context.op_info; bool found_unknown_shapes = false; @@ -1086,6 +1173,66 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( return costs; } +Costs OpLevelCostEstimator::PredictFusedOp( + const OpContext& op_context, + const std::vector& fused_op_contexts) const { + // Note that PredictOpCountBasedCost will get the correct memory_time from + // the node's inputs and outputs; but we don't want to have to re-implement + // the logic for computing the operation count of each of our component + // operations here; so we simply add the compute times of each component + // operation, then update the execution time. + Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info); + fused_cost.compute_time = 0; + fused_cost.inaccurate = false; + for (auto& fused_op : fused_op_contexts) { + auto op_cost = PredictCosts(fused_op); + fused_cost.compute_time += op_cost.compute_time; + fused_cost.inaccurate |= op_cost.inaccurate; + } + + CombineCostsAndUpdateExecutionTime(&fused_cost); + return fused_cost; +} + +/* static */ +OpContext OpLevelCostEstimator::FusedChildContext( + const OpContext& parent, const string& op_name, + const OpInfo::TensorProperties& output, + const std::vector& inputs) { + // Setup the base parameters of our new context. + OpContext new_context; + new_context.name = op_name; + new_context.device_name = parent.device_name; + new_context.op_info = parent.op_info; + new_context.op_info.set_op(op_name); + + // Setup the inputs of our new context. + new_context.op_info.mutable_inputs()->Clear(); + for (const auto& input : inputs) { + *new_context.op_info.mutable_inputs()->Add() = input; + } + + // Setup the output of our new context. + new_context.op_info.mutable_outputs()->Clear(); + *new_context.op_info.mutable_outputs()->Add() = output; + + return new_context; +} + +/* static */ +OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor( + DataType type, const std::vector& dims) { + OpInfo::TensorProperties ret; + ret.set_dtype(type); + + auto shape = ret.mutable_shape(); + for (const int dim : dims) { + shape->add_dim()->set_size(dim); + } + + return ret; +} + /* static */ OpLevelCostEstimator::ConvolutionDimensions OpLevelCostEstimator::OpDimensionsFromInputs( @@ -1371,6 +1518,21 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( return costs; } +/* static */ +OpLevelCostEstimator::ConvolutionFormat +OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) { + auto data_format = GetDataFormat(op_context.op_info); + if (data_format == "NCHW") { + return NCHW; + } else if (data_format == "NHWC") { + return NHWC; + } else if (data_format == "NCHW_VECT_C") { + return NCHW_VECT_C; + } + + return UNKNOWN_CONVOLUTION_FORMAT; +} + void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime( Costs* costs) const { if (compute_memory_overlap_) { @@ -1379,6 +1541,5 @@ void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime( costs->execution_time = costs->compute_time + costs->memory_time; } } - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 7080264..35649f7 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -82,6 +82,13 @@ class OpLevelCostEstimator { int64 sy; // Stride y. Padding padding; // SAME or VALID. }; + enum ConvolutionFormat { + UNKNOWN_CONVOLUTION_FORMAT, + NHWC, + NCHW, + NCHW_VECT_C, + NCHW_VECT_W, + }; int64 CountConv2DOperations(const OpInfo& op_features, bool* found_unknown_shapes) const; int64 CountConv2DOperations(const OpInfo& op_features, @@ -138,6 +145,7 @@ class OpLevelCostEstimator { Costs PredictCwiseOp(const OpContext& op_context) const; Costs PredictConv2DBackpropInput(const OpContext& op_context) const; Costs PredictConv2DBackpropFilter(const OpContext& op_context) const; + Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const; Costs PredictMatMul(const OpContext& op_context) const; Costs PredictNoOp(const OpContext& op_context) const; Costs PredictIdentity(const OpContext& op_context) const; @@ -152,6 +160,10 @@ class OpLevelCostEstimator { Costs PredictFusedBatchNorm(const OpContext& op_context) const; Costs PredictFusedBatchNormGrad(const OpContext& op_context) const; + // Generic cost prediction method for fused operations. + Costs PredictFusedOp(const OpContext& op_context, + const std::vector& fused_op_contexts) const; + // Utility function for safe division. Returns 0 // if rhs is 0 or negative. static double SafeDiv(const double lhs, const double rhs) { @@ -173,6 +185,20 @@ class OpLevelCostEstimator { const TensorShapeProto& original_image_shape, const OpInfo& op_info, bool* found_unknown_shapes); + // Helper to construct child operation contexts for the component operations + // of fused ops. + static OpContext FusedChildContext( + const OpContext& parent, const string& op_name, + const OpInfo::TensorProperties& output, + const std::vector& inputs); + + // Helper to construct tensor shapes. + static OpInfo::TensorProperties DescribeTensor( + DataType type, const std::vector& dims); + + // Returns the Conv2D format for this operation. + static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context); + // This method calculates the execution time depending on whether IO can // overlap with computation. It assumes the memory and the compute times have // already been calculated. diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index d797a8a..13ea43b 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -93,6 +93,14 @@ OpContext DescribeBatchMatMul(const std::vector& dims_a, return op_context; } +// Wrangles the minimum number of proto fields to set up a 1D Tensor for cost +// estimation purposes. +void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) { + auto shape = tensor->mutable_shape(); + shape->add_dim()->set_size(dim0); + tensor->set_dtype(DT_FLOAT); +} + // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost // estimation purposes. void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3, @@ -120,6 +128,38 @@ OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, return op_context; } +// DescribeFusedConv2DBiasActivation constructs an OpContext for a +// FusedConv2DBiasActivation applied to a convolution input tensor with shape +// (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a +// bias tensor with shape (oz), a side input tensor with shape +// (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with +// shape (1). +// +// Note that this assumes the NHWC data format. +OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1, + int iz2, int kx, int ky, int ox, + int oy, int oz, + bool has_side_input) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("FusedConv2DBiasActivation"); + DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs()); + DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs()); + DescribeTensor1D(oz, op_context.op_info.add_inputs()); + + // Add the side_input, if any. + auto side_input = op_context.op_info.add_inputs(); + if (has_side_input) { + DescribeTensor4D(batch, ox, oy, oz, side_input); + } + + // Add the scaling tensors. + DescribeTensor1D(1, op_context.op_info.add_inputs()); + DescribeTensor1D(1, op_context.op_info.add_inputs()); + + return op_context; +} + // DescribeUnaryOp constructs an OpContext for the given operation applied to // a 4-tensor with shape (size1, 1, 1, 1). OpContext DescribeUnaryOp(const string& op, int size1) { @@ -162,12 +202,9 @@ OpContext DescribeBiasAdd(int size1, int size2) { op_context.op_info.set_op("BiasAdd"); DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs()); + DescribeTensor1D(size1, op_context.op_info.add_inputs()); DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs()); - auto bias = op_context.op_info.add_inputs(); - bias->mutable_shape()->add_dim()->set_size(size1); - bias->set_dtype(DT_FLOAT); - return op_context; } @@ -486,6 +523,25 @@ TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) { SetComputeMemoryOverlap(false); // Set it back to default. } +TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationExecutionTime) { + auto cost = PredictCosts(DescribeFusedConv2DBiasActivation( + 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true)); + EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); + EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + +TEST_F(OpLevelCostEstimatorTest, + FusedConv2DBiasActivationNoSideInputExecutionTime) { + auto cost = PredictCosts(DescribeFusedConv2DBiasActivation( + 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false)); + EXPECT_EQ(Costs::Duration(825345), cost.memory_time); + EXPECT_EQ(Costs::Duration(355321038), cost.compute_time); + EXPECT_EQ(Costs::Duration(356146383), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) { auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1)); EXPECT_EQ(Costs::Duration(2000), cost.memory_time); -- 2.7.4