return Status::OK();
}
+ if (ConstantPushDown(node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialConstPropThroughIdentityN(node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialConcatConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+bool ConstantFolding::ConstantPushDown(NodeDef* node) {
// Consider the transformation
//
// + + = parent
// division/multiplication.
// Don't touch BiasAdd since they can't handle vectors as their first
// inputs.
- if (has_fetch_ && (IsAdd(*node) || is_mul) &&
+ if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
NumNonControlInputs(*node) == 2) {
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
// One child must be constant, and the other the same op as the parent.
if (node->op() != left_child->op() && node->op() != right_child->op()) {
- return Status::OK();
+ return false;
}
const bool left_child_is_constant = IsReallyConstant(*left_child);
const bool right_child_is_constant = IsReallyConstant(*right_child);
if (!left_child_is_constant && !right_child_is_constant) {
- return Status::OK();
+ return false;
}
if (node->device() != left_child->device() ||
node->device() != right_child->device()) {
- return Status::OK();
+ return false;
}
NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
NodeDef* const_child_node =
nodes_to_preserve_.find(op_child_node->name()) !=
nodes_to_preserve_.end() ||
NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
- return Status::OK();
+ return false;
}
// Identify the nodes to swap.
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
if (left_leaf_is_constant && right_leaf_is_constant) {
// Child is already foldable, leave it alone.
- return Status::OK();
+ return false;
}
const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
const int parent_const_input = left_child_is_constant ? 0 : 1;
node->input(parent_const_input));
std::swap(*node->mutable_input(parent_const_input),
*op_child_node->mutable_input(non_const_leaf_input));
- graph_modified_ = true;
- return Status::OK();
+ return true;
}
+ return false;
+}
+bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
// Partial constant propagation through IdentityN.
if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) {
const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
for (NodeDef* consumer : consumers) {
DedupControlInputs(consumer);
}
- graph_modified_ = true;
- return Status::OK();
+ return true;
}
}
-
- if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (PartialConcatConstFolding(optimized_graph, properties, node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- return Status::OK();
+ return false;
}
bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,