}
}
}
-
// Remove RandomShuffle op if it is scalar or first dimension is of size 1.
if (use_shape_info && IsRandomShuffle(*node) &&
!properties->GetInputProperties(node->name()).empty()) {
}
}
- // Remove Reverse op over dimensions with size 1.
- if (use_shape_info && node->op() == "ReverseV2" &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& shape = properties->GetInputProperties(node->name())[0].shape();
- if (shape.unknown_rank()) {
- // Not optimizable.
- return Status::OK();
- }
- const auto& a = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(a.shape()) && a.has_value()) {
- Tensor axis(a.dtype(), a.shape());
- if (!axis.FromProto(a.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- a.value().DebugString());
- }
- std::set<int> target_axes;
- for (int j = 0; j < axis.NumElements(); ++j) {
- // value of axis can be negative.
- if (axis.dtype() == DT_INT64) {
- target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
- shape.dim_size());
- } else {
- target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
- shape.dim_size());
- }
- }
-
- // The node is replaceable iff
- // unknown_rank == false &&
- // (dim_size == 0 || all dims have size 1 ||
- // all dims with > 1 size are not in target_axes)
- bool replaceable = !shape.unknown_rank();
- for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() == 1 ||
- target_axes.find(j) == target_axes.end();
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- return Status::OK();
- }
- }
+ bool remove_reverse_successful = false;
+ Status remove_reverse_status =
+ RemoveReverse(*properties, use_shape_info, optimized_graph, node,
+ &remove_reverse_successful);
+ if (!remove_reverse_status.ok()) {
+ return remove_reverse_status;
+ } else if (remove_reverse_successful) {
+ return Status::OK();
}
bool simplify_slice_successful = false;
return Status::OK();
}
+Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
+ bool* success) {
+ if (use_shape_info && node->op() == "ReverseV2" &&
+ properties.GetInputProperties(node->name()).size() >= 2) {
+ const auto& shape = properties.GetInputProperties(node->name())[0].shape();
+ if (shape.unknown_rank()) {
+ // Not optimizable.
+ return Status::OK();
+ }
+ const auto& a = properties.GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(a.shape()) && a.has_value()) {
+ Tensor axis(a.dtype(), a.shape());
+ if (!axis.FromProto(a.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ a.value().DebugString());
+ }
+ std::set<int> target_axes;
+ for (int j = 0; j < axis.NumElements(); ++j) {
+ // value of axis can be negative.
+ if (axis.dtype() == DT_INT64) {
+ target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
+ shape.dim_size());
+ } else {
+ target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
+ shape.dim_size());
+ }
+ }
+
+ // The node is replaceable iff
+ // unknown_rank == false &&
+ // (dim_size == 0 || all dims have size 1 ||
+ // all dims with > 1 size are not in target_axes)
+ bool replaceable = !shape.unknown_rank();
+ for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() == 1 ||
+ target_axes.find(j) == target_axes.end();
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
+ return Status::OK();
+ }
+ }
+ }
+ *success = false;
+ return Status::OK();
+}
+
Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,