From: A. Unique TensorFlower Date: Mon, 14 May 2018 22:45:33 +0000 (-0700) Subject: Add ExplicitShapes as a new shape inference function for Ops with X-Git-Tag: upstream/v1.9.0_rc1~116^2^2~28 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=16dcaf3e09c72d5702ecef20d6a24957981b34a4;p=platform%2Fupstream%2Ftensorflow.git Add ExplicitShapes as a new shape inference function for Ops with multiple outputs, each of which is explicitly declared. PiperOrigin-RevId: 196579920 --- diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 0916c9b..71a31b0 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1417,6 +1417,21 @@ Status ExplicitShape(InferenceContext* c) { return Status::OK(); } +Status ExplicitShapes(InferenceContext* c) { + std::vector shapes; + TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); + if (shapes.empty()) { + return errors::Internal("shapes attribute is empty"); + } + for (int i = 0; i < shapes.size(); ++i) { + ShapeHandle output_shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape)); + c->set_output(i, output_shape); + } + return Status::OK(); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 789746b..87bb133 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -289,6 +289,9 @@ Status ScatterNdUpdateShape(InferenceContext* c); // Shape function for ops with an explicit "shape" attribute. Status ExplicitShape(InferenceContext* c); +// Shape function for multiple-output ops with an explicit "shapes" attribute. +Status ExplicitShapes(InferenceContext* c); + } // namespace shape_inference } // namespace tensorflow