}
}
- if (use_shape_info && IsPad(*node) &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& p = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor paddings(p.dtype(), p.shape());
- if (!paddings.FromProto(p.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
- }
- // The node is replaceable iff all values in paddings are 0.
- bool replaceable = true;
- // The operation requires it to be int32 value so we don't check for
- // 1nt64.
- const auto flatten = paddings.flat<int32>();
- for (int j = 0; replaceable && j < flatten.size(); ++j) {
- replaceable &= flatten(j) == 0;
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- return Status::OK();
- }
- }
+ bool simplify_pad_successful = false;
+ Status simplify_pad_status =
+ SimplifyPad(*properties, use_shape_info, optimized_graph, node,
+ &simplify_pad_successful);
+ if (!simplify_pad_status.ok()) {
+ return simplify_pad_status;
+ } else if (simplify_pad_successful) {
+ return Status::OK();
}
if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
return Status::OK();
}
+Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
+ bool* success) {
+ if (use_shape_info && IsPad(*node) &&
+ properties.GetInputProperties(node->name()).size() >= 2) {
+ const auto& p = properties.GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor paddings(p.dtype(), p.shape());
+ if (!paddings.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
+ }
+ // The node is replaceable iff all values in paddings are 0.
+ bool replaceable = true;
+ // The operation requires it to be int32 value so we don't check for
+ // 1nt64.
+ const auto flatten = paddings.flat<int32>();
+ for (int j = 0; replaceable && j < flatten.size(); ++j) {
+ replaceable &= flatten(j) == 0;
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
+ return Status::OK();
+ }
+ }
+ }
+ *success = false;
+ return Status::OK();
+}
+
bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,