return Status::OK();
}
- // Remove Shuffle or Transpose op over dimensions of size 1.
- if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& shape = properties->GetInputProperties(node->name())[0].shape();
- if (shape.unknown_rank()) {
- // Not optimizable.
- return Status::OK();
- }
- const auto& p = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor perm(p.dtype(), p.shape());
- if (!perm.FromProto(p.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
- }
- std::vector<int> permutation;
- for (int j = 0; j < perm.NumElements(); ++j) {
- if (perm.dtype() == DT_INT64) {
- permutation.push_back(perm.vec<int64>()(j));
- } else {
- permutation.push_back(perm.vec<int>()(j));
- }
- }
- if (permutation.size() != shape.dim_size()) {
- // Number of elements in perm should be same as dim_size. Skip if not.
- return Status::OK();
- }
- // The node is replaceable iff
- // dim_size == 0 || all dims have size 1 ||
- // all dims with > 1 size are not permuted.
- bool replaceable = true;
- for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- return Status::OK();
- }
- }
+ bool remove_shuffle_transpose_successful = false;
+ Status remove_shuffle_transpose_status =
+ RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
+ node, &remove_shuffle_transpose_successful);
+ if (!remove_shuffle_transpose_status.ok()) {
+ return remove_shuffle_transpose_status;
+ } else if (remove_shuffle_transpose_successful) {
+ return Status::OK();
}
if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
return Status::OK();
}
+Status ConstantFolding::RemoveShuffleOrTranspose(
+ const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node, bool* success) {
+ if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
+ properties.GetInputProperties(node->name()).size() >= 2) {
+ const auto& shape = properties.GetInputProperties(node->name())[0].shape();
+ if (shape.unknown_rank()) {
+ // Not optimizable.
+ return Status::OK();
+ }
+ const auto& p = properties.GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor perm(p.dtype(), p.shape());
+ if (!perm.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
+ }
+ std::vector<int> permutation;
+ for (int j = 0; j < perm.NumElements(); ++j) {
+ if (perm.dtype() == DT_INT64) {
+ permutation.push_back(perm.vec<int64>()(j));
+ } else {
+ permutation.push_back(perm.vec<int>()(j));
+ }
+ }
+ if (permutation.size() != shape.dim_size()) {
+ // Number of elements in perm should be same as dim_size. Skip if not.
+ return Status::OK();
+ }
+ // The node is replaceable iff
+ // dim_size == 0 || all dims have size 1 ||
+ // all dims with > 1 size are not permuted.
+ bool replaceable = true;
+ for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
+ return Status::OK();
+ }
+ }
+ }
+ *success = false;
+ return Status::OK();
+}
bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,