bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
+const char kOutputShapesAttr[] = "_output_shapes";
+
// Shape is symbolically defined if it has a known rank, and each dimension is
// defined, or is an unknown symbol (dim.size <= -2).
bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) {
// Returns whether `reshape` is an identity op. The tensor that `reshape`
// reshapes is the `output_pos`-th output of node `input`.
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
- const int output_pos,
- const GraphProperties& graph_properties) {
- const std::vector<OpInfo::TensorProperties>& reshape_props =
- graph_properties.GetOutputProperties(reshape.name());
- const std::vector<OpInfo::TensorProperties>& input_props =
- graph_properties.GetOutputProperties(input.name());
- if (reshape_props.empty() || input_props.empty() ||
- input_props.size() <= output_pos) {
+ const int output_pos) {
+ if (!reshape.attr().count(kOutputShapesAttr) ||
+ !input.attr().count(kOutputShapesAttr)) {
return false;
}
- const PartialTensorShape& src_shape = input_props[output_pos].shape();
- const PartialTensorShape& dst_shape = reshape_props[0].shape();
+ PartialTensorShape src_shape(
+ input.attr().at(kOutputShapesAttr).list().shape(output_pos));
+ PartialTensorShape dst_shape(
+ reshape.attr().at(kOutputShapesAttr).list().shape(0));
if (src_shape.unknown_rank() || dst_shape.unknown_rank()) {
return false;
}
// outputs tensors of shape [M, N] while feeding it with tensors of shape
// [M*N] (or worse). The reshape nodes are then necessary to update the
// tensor metadata to the required shape.
- if (can_use_shapes_ &&
- ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_)) {
+ if (ReshapeIsIdentity(*reshape, *input, output_pos)) {
return reshape->input(0);
}
}
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
- if (options_.combine_add_to_addn && can_use_shapes_) {
+ if (options_.combine_add_to_addn) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new AddOpsRewriteStage(ctx, ctx_ext)));
}
- if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes_) {
+ if (options_.hoist_common_factor_out_of_aggregation) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new HoistCommonFactorOutOfAggregation(ctx, ctx_ext)));
}
if (simplified_tensor.empty()) {
for (auto& stage : stages) {
if (stage->IsSupported(node)) {
- const Status stage_status =
- stage->TrySimplify(node, &simplified_tensor);
- // Each stage must be "error safe" (just like exception safe). In
- // case of any error it must leave optimized graph unmodified.
- if (!stage_status.ok()) {
- LOG(WARNING) << "Failed to run arithmetic optimizer stage "
- << stage->stage_name()
- << ". Error: " << stage_status.error_message();
- }
+ TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor));
if (!simplified_tensor.empty()) {
break;
}
&frame_map_, &num_frames));
// Shapes are only needed in aggressive mode.
graph_properties_.reset(new GraphProperties(item));
- const Status status = graph_properties_->InferStatically(false);
- can_use_shapes_ = status.ok();
- if (!can_use_shapes_) {
- LOG(WARNING) << "Shape inference failed.";
- }
+ TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+ // TODO(ezhulenev): Use GraphProperties to lookup tensor shapes directly
+ TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_));
// Perform the optimizations.
DedupComputations();
TF_RETURN_IF_ERROR(SimplifyArithmeticOps());
+ // Clear output shapes.
+ for (int i = 0; i < optimized_graph->node_size(); ++i) {
+ optimized_graph_->mutable_node(i)->mutable_attr()->erase(kOutputShapesAttr);
+ }
+
return Status::OK();
}