}
}
}
- // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
- if (use_shape_info && IsRandomShuffle(*node) &&
- !properties->GetInputProperties(node->name()).empty()) {
- const auto& shape = properties->GetInputProperties(node->name())[0].shape();
- // The node is replaceable iff
- // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
- if (!shape.unknown_rank() &&
- (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- return Status::OK();
- }
+
+ if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
+ return Status::OK();
}
bool remove_reverse_successful = false;
return Status::OK();
}
+bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph,
+ NodeDef* node) {
+ if (use_shape_info && IsRandomShuffle(*node) &&
+ !properties.GetInputProperties(node->name()).empty()) {
+ const auto& shape = properties.GetInputProperties(node->name())[0].shape();
+ // The node is replaceable iff
+ // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
+ if (!shape.unknown_rank() &&
+ (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ return true;
+ }
+ }
+ return false;
+}
+
Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
// Removes Reverse op over dimensions with size 1.
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success);
+
+ // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
+ bool RemoveRandomShuffle(const GraphProperties& properties,
+ bool use_shape_info, GraphDef* optimized_graph,
+ NodeDef* node);
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;