}
}
- if (use_shape_info && IsSqueeze(*node) &&
- !properties->GetInputProperties(node->name()).empty()) {
- // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
- // error to squeeze a dimension that is not 1, so we only need to check
- // whether the input has > 1 size for each dimension.
- const auto& shape = properties->GetInputProperties(node->name())[0].shape();
- // The node is replaceable iff
- // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
- bool replaceable = !shape.unknown_rank();
- for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() > 1;
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- return Status::OK();
- }
+ if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
+ return Status::OK();
}
if (SimplifyPack(optimized_graph, node)) {
return Status::OK();
}
+bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph,
+ NodeDef* node) {
+ if (use_shape_info && IsSqueeze(*node) &&
+ !properties.GetInputProperties(node->name()).empty()) {
+ // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
+ // error to squeeze a dimension that is not 1, so we only need to check
+ // whether the input has > 1 size for each dimension.
+ const auto& shape = properties.GetInputProperties(node->name())[0].shape();
+ // The node is replaceable iff
+ // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
+ bool replaceable = !shape.unknown_rank();
+ for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() > 1;
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ return true;
+ }
+ }
+ return false;
+}
+
bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
!OptimizedNodeExists(*node, "_const_axis")) {
// Simplifies Pack operation if applicable.
bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
+ // Simplifies a Squeeze operation to an Identity operation if applicable.
+ bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node);
+
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;