Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
GraphDef* optimized_graph,
GraphProperties* properties) {
- if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
- ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
- return Status::OK();
- }
-
- if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
return Status::OK();
}
return Status::OK();
}
+bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
+ GraphDef* optimized_graph,
+ NodeDef* node) {
+ if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
+ return true;
+ }
+
+ if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ return true;
+ }
+ return false;
+}
+
Status ConstantFolding::RemoveShuffleOrTranspose(
const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success) {
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success);
+
+ // Removes Split or SplitV node if possible.
+ bool RemoveSplitOrSplitV(const GraphProperties& properties,
+ GraphDef* optimized_graph, NodeDef* node);
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;