Add ExplicitShapes as a new shape inference function for Ops with
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 14 May 2018 22:45:33 +0000 (15:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 14 May 2018 22:48:31 +0000 (15:48 -0700)
multiple outputs, each of which is explicitly declared.

PiperOrigin-RevId: 196579920

tensorflow/core/framework/common_shape_fns.cc
tensorflow/core/framework/common_shape_fns.h

index 0916c9b..71a31b0 100644 (file)
@@ -1417,6 +1417,21 @@ Status ExplicitShape(InferenceContext* c) {
   return Status::OK();
 }
 
+Status ExplicitShapes(InferenceContext* c) {
+  std::vector<PartialTensorShape> shapes;
+  TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+  if (shapes.empty()) {
+    return errors::Internal("shapes attribute is empty");
+  }
+  for (int i = 0; i < shapes.size(); ++i) {
+    ShapeHandle output_shape;
+    TF_RETURN_IF_ERROR(
+        c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
+    c->set_output(i, output_shape);
+  }
+  return Status::OK();
+}
+
 }  // namespace shape_inference
 
 }  // namespace tensorflow
index 789746b..87bb133 100644 (file)
@@ -289,6 +289,9 @@ Status ScatterNdUpdateShape(InferenceContext* c);
 // Shape function for ops with an explicit "shape" attribute.
 Status ExplicitShape(InferenceContext* c);
 
+// Shape function for multiple-output ops with an explicit "shapes" attribute.
+Status ExplicitShapes(InferenceContext* c);
+
 }  // namespace shape_inference
 
 }  // namespace tensorflow