}
}
- if (use_shape_info && IsStridedSlice(*node) &&
- properties->GetInputProperties(node->name()).size() == 4) {
- if (node->attr().at("new_axis_mask").i() != 0 ||
- node->attr().at("shrink_axis_mask").i() != 0) {
- // Skip nodes with new/shrink axis mask, since they involve dimension
- // changes.
- return Status::OK();
- }
- const auto& input = properties->GetInputProperties(node->name())[0];
- for (int j = 0; j < input.shape().dim_size(); ++j) {
- // Skip if input shape is not fully determined.
- if (input.shape().dim(j).size() < 0) {
- return Status::OK();
- }
- }
- const auto& b = properties->GetInputProperties(node->name())[1];
- const auto& e = properties->GetInputProperties(node->name())[2];
- const auto& s = properties->GetInputProperties(node->name())[3];
- if (TensorShape::IsValid(b.shape()) && b.has_value() &&
- TensorShape::IsValid(e.shape()) && e.has_value() &&
- TensorShape::IsValid(s.shape()) && s.has_value()) {
- Tensor begin(b.dtype(), b.shape());
- if (!begin.FromProto(b.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- b.value().DebugString());
- }
- Tensor end(e.dtype(), e.shape());
- if (!end.FromProto(e.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- e.value().DebugString());
- }
- Tensor strides(s.dtype(), s.shape());
- if (!strides.FromProto(s.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- s.value().DebugString());
- }
- int begin_mask = node->attr().at("begin_mask").i();
- int end_mask = node->attr().at("end_mask").i();
- std::set<int> expanded_ellipsis_indices;
- int ellipsis_index = -1;
- for (int j = 0; j < input.shape().dim_size(); ++j) {
- // find the ellipsis_mask. If not found, insert one in the end if
- // necessary.
- if (node->attr().at("ellipsis_mask").i() & 1 << j ||
- (ellipsis_index == -1 && j >= strides.NumElements())) {
- ellipsis_index = j;
- }
- // insert the indices that are immediately after ellipsis_index if
- // necessary.
- if (ellipsis_index != -1 &&
- input.shape().dim_size() >
- strides.NumElements() + j - ellipsis_index) {
- expanded_ellipsis_indices.insert(j);
- }
- }
-
- // The node is replaceable iff unknown_rank == false &&
- // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
- // && strides == 1) for all dimensions.
- bool replaceable = !input.shape().unknown_rank();
- for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
- if (expanded_ellipsis_indices.find(j) !=
- expanded_ellipsis_indices.end()) {
- // ellipsis_mask is effective on current dimension.
- continue;
- }
- // when we have ellipsis_mask in between, input.shape().dim_size() will
- // be greater than strides.NumElements(), since we will insert
- // as many as expanded_ellipsis_indices.size() axes during computation.
- // We need to subtract this number from j.
- int i = j;
- if (ellipsis_index != -1 &&
- j >= ellipsis_index + expanded_ellipsis_indices.size()) {
- i = j - expanded_ellipsis_indices.size();
- }
- int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
- : begin.vec<int64>()(i);
- int e =
- end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
- int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
- : strides.vec<int64>()(i);
- replaceable &=
- (begin_mask & 1 << i || b == 0) &&
- (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- return Status::OK();
- }
- }
+ bool simplify_strided_slice_successful = false;
+ Status simplify_strided_slice_status =
+ SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
+ &simplify_strided_slice_successful);
+ if (!simplify_strided_slice_status.ok()) {
+ return simplify_strided_slice_status;
+ } else if (simplify_strided_slice_successful) {
+ return Status::OK();
}
bool simplify_tile_successful = false;
return Status::OK();
}
+Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph,
+ NodeDef* node, bool* success) {
+ if (use_shape_info && IsStridedSlice(*node) &&
+ properties.GetInputProperties(node->name()).size() == 4) {
+ if (node->attr().at("new_axis_mask").i() != 0 ||
+ node->attr().at("shrink_axis_mask").i() != 0) {
+ // Skip nodes with new/shrink axis mask, since they involve dimension
+ // changes.
+ return Status::OK();
+ }
+ const auto& input = properties.GetInputProperties(node->name())[0];
+ for (int j = 0; j < input.shape().dim_size(); ++j) {
+ // Skip if input shape is not fully determined.
+ if (input.shape().dim(j).size() < 0) {
+ return Status::OK();
+ }
+ }
+ const auto& b = properties.GetInputProperties(node->name())[1];
+ const auto& e = properties.GetInputProperties(node->name())[2];
+ const auto& s = properties.GetInputProperties(node->name())[3];
+ if (TensorShape::IsValid(b.shape()) && b.has_value() &&
+ TensorShape::IsValid(e.shape()) && e.has_value() &&
+ TensorShape::IsValid(s.shape()) && s.has_value()) {
+ Tensor begin(b.dtype(), b.shape());
+ if (!begin.FromProto(b.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ b.value().DebugString());
+ }
+ Tensor end(e.dtype(), e.shape());
+ if (!end.FromProto(e.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ e.value().DebugString());
+ }
+ Tensor strides(s.dtype(), s.shape());
+ if (!strides.FromProto(s.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ s.value().DebugString());
+ }
+ int begin_mask = node->attr().at("begin_mask").i();
+ int end_mask = node->attr().at("end_mask").i();
+ std::set<int> expanded_ellipsis_indices;
+ int ellipsis_index = -1;
+ for (int j = 0; j < input.shape().dim_size(); ++j) {
+ // find the ellipsis_mask. If not found, insert one in the end if
+ // necessary.
+ if (node->attr().at("ellipsis_mask").i() & 1 << j ||
+ (ellipsis_index == -1 && j >= strides.NumElements())) {
+ ellipsis_index = j;
+ }
+ // insert the indices that are immediately after ellipsis_index if
+ // necessary.
+ if (ellipsis_index != -1 &&
+ input.shape().dim_size() >
+ strides.NumElements() + j - ellipsis_index) {
+ expanded_ellipsis_indices.insert(j);
+ }
+ }
+
+ // The node is replaceable iff unknown_rank == false &&
+ // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
+ // && strides == 1) for all dimensions.
+ bool replaceable = !input.shape().unknown_rank();
+ for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
+ if (expanded_ellipsis_indices.find(j) !=
+ expanded_ellipsis_indices.end()) {
+ // ellipsis_mask is effective on current dimension.
+ continue;
+ }
+ // when we have ellipsis_mask in between, input.shape().dim_size() will
+ // be greater than strides.NumElements(), since we will insert
+ // as many as expanded_ellipsis_indices.size() axes during computation.
+ // We need to subtract this number from j.
+ int i = j;
+ if (ellipsis_index != -1 &&
+ j >= ellipsis_index + expanded_ellipsis_indices.size()) {
+ i = j - expanded_ellipsis_indices.size();
+ }
+ int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
+ : begin.vec<int64>()(i);
+ int e =
+ end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
+ int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
+ : strides.vec<int64>()(i);
+ replaceable &=
+ (begin_mask & 1 << i || b == 0) &&
+ (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
+ return Status::OK();
+ }
+ }
+ }
+ *success = false;
+ return Status::OK();
+}
+
Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,