ShapeRefiner fix: some variant-type tensors have handle data.
authorSkye Wanderman-Milne <skyewm@google.com>
Mon, 7 May 2018 23:16:32 +0000 (16:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 00:51:42 +0000 (17:51 -0700)
ShapeRefiner::AddNode() would only propagate handle data for
DT_RESOURCE tensors, but not DT_VARIANT. The Python shape inference
logic in common_shapes.py handled this correct, which is why we didn't
notice this earlier. In particular, list ops use DT_VARIANT with
handle data.
PiperOrigin-RevId: 195739586

tensorflow/core/common_runtime/shape_refiner.cc
tensorflow/python/kernel_tests/list_ops_test.py

index 06dbe04..a077271 100644 (file)
@@ -232,13 +232,12 @@ Status ShapeRefiner::AddNode(const Node* node) {
     input_nodes[e->dst_input()] = input;
     input_shapes[e->dst_input()] = c->output(e->src_output());
 
-    // Only propagate handle data of edges which are carrying resource handles.
-    if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
-      const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
-      if (in_v != nullptr) {
-        input_handle_shapes_and_types[e->dst_input()].reset(
-            new std::vector<ShapeAndType>(*in_v));
-      }
+    const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
+    if (in_v != nullptr) {
+      DataType input_type = e->src()->output_type(e->src_output());
+      DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
+      input_handle_shapes_and_types[e->dst_input()].reset(
+          new std::vector<ShapeAndType>(*in_v));
     }
   }
 
index 098f972..4985520 100644 (file)
@@ -43,6 +43,7 @@ def scalar_shape():
   return ops.convert_to_tensor([], dtype=dtypes.int32)
 
 
+@test_util.with_c_shapes
 class ListOpsTest(test_util.TensorFlowTestCase):
 
   @test_util.run_in_graph_and_eager_modes()