More accurate shape inference for TensorArrayGatherV3 and TensorArrayScatterV3
authorBenoit Steiner <bsteiner@google.com>
Wed, 21 Mar 2018 15:40:35 +0000 (08:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 15:43:32 +0000 (08:43 -0700)
PiperOrigin-RevId: 189912762

tensorflow/core/ops/data_flow_ops.cc

index 4f946fb..3112f35 100644 (file)
@@ -668,13 +668,31 @@ REGISTER_OP("TensorArrayGatherV3")
     .Attr("dtype: type")
     .Attr("element_shape: shape = { unknown_rank: true }")
     .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle indices;
       ShapeHandle unused;
       DimensionHandle unused_dim;
       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
-      return shape_inference::UnknownShape(c);
+      auto shapes = c->input_handle_shapes_and_types(0);
+      if (shapes != nullptr && !shapes->empty()) {
+        ShapeHandle tensor_shape = shapes->at(0).shape;
+        ShapeHandle output_shape;
+        TF_RETURN_IF_ERROR(
+            c->Concatenate(indices, tensor_shape, &output_shape));
+        c->set_output(0, output_shape);
+        return Status::OK();
+      } else {
+        PartialTensorShape p;
+        TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
+        ShapeHandle s;
+        TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
+        ShapeHandle output_shape;
+        TF_RETURN_IF_ERROR(c->Concatenate(indices, s, &output_shape));
+        c->set_output(0, output_shape);
+        return Status::OK();
+      }
     });
 
 REGISTER_OP("TensorArrayScatterV3")
@@ -685,12 +703,25 @@ REGISTER_OP("TensorArrayScatterV3")
     .Output("flow_out: float")
     .Attr("T: type")
     .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle indices;
       ShapeHandle unused;
       DimensionHandle unused_dim;
       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+      ShapeHandle value_shape;
+      // Assert that the length of the indices tensor is equal to the first
+      // dimension of the value tensor.
+      TF_RETURN_IF_ERROR(
+          c->MergePrefix(c->input(2), indices, &value_shape, &indices));
+      auto shapes = c->input_handle_shapes_and_types(0);
+      if (shapes != nullptr && !shapes->empty()) {
+        ShapeHandle tensor_shape = shapes->at(0).shape;
+        ShapeHandle fed_shape;
+        TF_RETURN_IF_ERROR(c->Subshape(value_shape, 1, &fed_shape));
+        TF_RETURN_IF_ERROR(c->Merge(tensor_shape, fed_shape, &fed_shape));
+      }
       return shape_inference::ScalarShape(c);
     });