Allow other types of variables to act as a resource variable.
authorIgor Saprykin <isaprykin@google.com>
Tue, 13 Feb 2018 19:18:15 +0000 (11:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 19:21:46 +0000 (11:21 -0800)
Introduce resource_variable_ops.is_resource_variable() function that returns true
if an _should_act_as_resource_variable attribute is set.

PiperOrigin-RevId: 185559202

tensorflow/python/ops/gradients_impl.py
tensorflow/python/ops/resource_variable_ops.py
tensorflow/python/training/slot_creator.py

index 9f06c0e..1418c0b 100644 (file)
@@ -494,7 +494,7 @@ def gradients(ys,
       list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
     ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
     xs = [
-        x.handle if isinstance(x, resource_variable_ops.ResourceVariable) else x
+        x.handle if resource_variable_ops.is_resource_variable(x) else x
         for x in xs
     ]
     xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
index 75cb57f..11f452f 100644 (file)
@@ -957,3 +957,9 @@ ops.register_proto_function(
     proto_type=variable_pb2.VariableDef,
     to_proto=_to_proto_fn,
     from_proto=_from_proto_fn)
+
+
+def is_resource_variable(var):
+  """"Returns True if `var` is to be considered a ResourceVariable."""
+  return isinstance(var, ResourceVariable) or hasattr(
+      var, "_should_act_as_resource_variable")
index 18a5b89..75ef3d5 100644 (file)
@@ -48,11 +48,6 @@ from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 
 
-def _is_resource(v):
-  """Returns true if v is something you get from a resource variable."""
-  return isinstance(v, resource_variable_ops.ResourceVariable)
-
-
 def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
   """Helper function for creating a slot variable."""
 
@@ -65,7 +60,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
   shape = shape if callable(val) else None
   slot = variable_scope.get_variable(
       scope, initializer=val, trainable=False,
-      use_resource=_is_resource(primary),
+      use_resource=resource_variable_ops.is_resource_variable(primary),
       shape=shape, dtype=dtype,
       validate_shape=validate_shape)
   variable_scope.get_variable_scope().set_partitioner(current_partitioner)