Support None trainable variables that don't produce a gradient in replicate_model_fn.
authorIgor Saprykin <isaprykin@google.com>
Tue, 13 Feb 2018 00:24:45 +0000 (16:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 00:28:14 +0000 (16:28 -0800)
This fixes #16829.

PiperOrigin-RevId: 185453911

tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py

index dfae034..7134cd3 100644 (file)
@@ -790,7 +790,7 @@ def _extract_tensors(tensors_and_vars):
     tensor, _ = tensor_and_var
     if isinstance(tensor, ops_lib.IndexedSlices):
       tensors.append(tensor.values)
-    else:
+    elif tensor is not None:
       tensors.append(tensor)
   return tensors
 
index ab117e6..d46a18a 100644 (file)
@@ -240,6 +240,13 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
     labels = np.array([[1.0], [2.0]])
 
     with self.test_session() as session:
+      # Add another trainable variable that doesn't produce a gradient to
+      # verify that None gradients are supported.
+      _ = variable_scope.get_variable(
+          'another_variable',
+          initializer=constant_op.constant(1, dtype=dtypes.float64),
+          dtype=dtypes.float64)
+
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
           self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
@@ -1119,8 +1126,6 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
       feature_shards, label_shards = replicate_model_fn._split_batch(
           features, labels, 2, device='/gpu:0')
 
-      print(feature_shards[0]['x'].eval())
-      print(feature_shards[1]['x'].eval())
       self.assertSparseValuesEqual(
           sparse_tensor.SparseTensorValue(
               indices=[[0, 0], [1, 0], [1, 1]],