return Status::OK();
}
-Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
- GraphProperties* properties,
- bool use_shape_info) {
+Status ConstantFolding::SimplifyGraph(bool use_shape_info,
+ GraphDef* optimized_graph,
+ GraphProperties* properties) {
for (int i = 0; i < optimized_graph->node_size(); ++i) {
- TF_RETURN_IF_ERROR(SimplifyNode(optimized_graph->mutable_node(i),
- optimized_graph, properties,
- use_shape_info));
+ TF_RETURN_IF_ERROR(SimplifyNode(use_shape_info,
+ optimized_graph->mutable_node(i),
+ optimized_graph, properties));
}
return Status::OK();
}
-Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
- GraphProperties* properties,
- bool use_shape_info) {
+Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
+ GraphDef* optimized_graph,
+ GraphProperties* properties) {
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
return Status::OK();
graph_modified_ = true;
return Status::OK();
}
- if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
- DataType output_type = node->attr().at("T").type();
- node->set_op("Identity");
- node->clear_attr();
- (*node->mutable_attr())["T"].set_type(output_type);
- *node->mutable_input(1) = AsControlDependency(node->input(1));
+
+ if (SimplifyReshape(*properties, use_shape_info, node)) {
graph_modified_ = true;
return Status::OK();
}
bool arithmetic_simplification_succeed = false;
- Status simplify_arithmetic_status = SimplifyArithmeticOperations(
- optimized_graph, properties, node, use_shape_info,
- &arithmetic_simplification_succeed);
+ Status simplify_arithmetic_status =
+ SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
+ node, &arithmetic_simplification_succeed);
if (!simplify_arithmetic_status.ok()) {
return simplify_arithmetic_status;
} else if (arithmetic_simplification_succeed) {
return Status::OK();
}
+bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
+ bool use_shape_info, NodeDef* node) {
+ if (!use_shape_info) return false;
+ if (!IsSimplifiableReshape(*node, properties)) return false;
+ DataType output_type = node->attr().at("T").type();
+ node->set_op("Identity");
+ node->clear_attr();
+ (*node->mutable_attr())["T"].set_type(output_type);
+ *node->mutable_input(1) = AsControlDependency(node->input(1));
+ return true;
+}
+
Status ConstantFolding::SimplifyArithmeticOperations(
- GraphDef* optimized_graph, GraphProperties* properties, NodeDef* node,
- bool use_shape_info, bool* success) {
+ const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node, bool* success) {
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
// Simplify arithmetic operations with ones or zeros.
if (use_shape_info &&
(is_mul || is_matmul || is_add || is_sub || is_any_div) &&
- properties->HasInputProperties(node->name()) &&
- properties->HasOutputProperties(node->name())) {
+ properties.HasInputProperties(node->name()) &&
+ properties.HasOutputProperties(node->name())) {
const NodeDef* x = node_map_->GetNode(node->input(0));
const NodeDef* y = node_map_->GetNode(node->input(1));
if (x == nullptr || y == nullptr) {
node->DebugString());
}
const TensorShapeProto& output_shape =
- properties->GetOutputProperties(node->name())[0].shape();
+ properties.GetOutputProperties(node->name())[0].shape();
// Simplify element-wise multiplication by ones or addition/subtraction
// of zeros.
const TensorShapeProto& y_shape =
- properties->GetInputProperties(node->name())[1].shape();
+ properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
- ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
+ ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
}
const TensorShapeProto& x_shape =
- properties->GetInputProperties(node->name())[0].shape();
+ properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
- ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
+ ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
bool replace_succeed = false;
- Status replace_op_status =
- ReplaceOperationWithConstant(1, *properties, output_shape, node,
- optimized_graph, &replace_succeed);
+ Status replace_op_status = ReplaceOperationWithConstant(
+ 1, properties, output_shape, node, optimized_graph, &replace_succeed);
if (!replace_op_status.ok()) {
return replace_op_status;
} else if (replace_succeed) {
if (shp.IsFullyDefined()) {
bool replace_succeed = false;
Status replace_op_status =
- ReplaceOperationWithConstant(0, *properties, output_shape, node,
+ ReplaceOperationWithConstant(0, properties, output_shape, node,
optimized_graph, &replace_succeed);
if (!replace_op_status.ok()) {
return replace_op_status;
// matches the output shape and thus forward the corresponding zero
// input.
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
} else if (is_mul && y_is_zero && y_matches_output_shape) {
- ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(FoldGraph(optimized_graph));
node_map_.reset(new NodeMap(optimized_graph));
TF_RETURN_IF_ERROR(
- SimplifyGraph(optimized_graph, &properties, can_use_shape_info));
+ SimplifyGraph(can_use_shape_info, optimized_graph, &properties));
return Status::OK();
}
const GraphProperties& properties) const;
bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
- Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
- bool use_shape_info);
- Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
- GraphProperties* properties, bool use_shape_info);
+ Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph,
+ GraphProperties* properties);
+ Status SimplifyNode(bool use_shape_info, NodeDef* node,
+ GraphDef* optimized_graph, GraphProperties* properties);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);
// Simplifies arithmetic operations with ones or zeros. Returns the status,
// and updates the success input argument that denotes if any simplification
// was applied.
- Status SimplifyArithmeticOperations(GraphDef* optimized_graph,
- GraphProperties* properties,
- NodeDef* node, bool use_shape_info,
+ Status SimplifyArithmeticOperations(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
bool* success);
+ // Simplifies a Reshape operation to an Identity operation if the input node
+ // to the operation is a constant.
+ bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
+ NodeDef* node);
+
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;