from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import momentum
+from tensorflow.python.training import saver
+from tensorflow.python.training import training_util
from tensorflow.python.util import compat
@test_util.run_in_graph_and_eager_modes()
def testScatterMin(self):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(
- handle, constant_op.constant([[6]], dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_min(
- handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[3]])
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_min(handle, [0],
+ constant_op.constant(
+ [[3]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ def testMetagraph(self):
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope("foo", use_resource=True):
+ a = variable_scope.get_variable("a", initializer=10.0)
+
+ momentum.MomentumOptimizer(
+ learning_rate=0.001, momentum=0.1).minimize(
+ a,
+ colocate_gradients_with_ops=True,
+ global_step=training_util.get_or_create_global_step())
+
+ graph = ops.get_default_graph()
+ meta_graph_def = saver.export_meta_graph(graph=graph)
+
+ with ops.Graph().as_default():
+ saver.import_meta_graph(meta_graph_def, import_scope="")
+ meta_graph_two = saver.export_meta_graph(graph=graph)
+ self.assertEqual(meta_graph_def, meta_graph_two)
@test_util.run_in_graph_and_eager_modes()
def testScatterMax(self):
self._cached_value = g.as_graph_element(
ops.prepend_name_scope(
variable_def.snapshot_name, import_scope=import_scope))
+ self._graph_element = g.as_graph_element(
+ ops.prepend_name_scope(variable_def.snapshot_name,
+ import_scope=import_scope))
else:
self._cached_value = None
+ # Legacy case for protos without the snapshot name; assume it's the
+ # following.
+ self._graph_element = g.get_tensor_by_name(
+ self._handle.op.name + "/Read/ReadVariableOp:0")
if variable_def.HasField("save_slice_info_def"):
self._save_slice_info = variables.Variable.SaveSliceInfo(
save_slice_info_def=variable_def.save_slice_info_def,
self._save_slice_info = None
self._caching_device = None
self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
- self._graph_element = g.get_tensor_by_name(
- self._handle.op.name + "/Read/ReadVariableOp:0")
self._constraint = None
self._cached_shape_as_list = None
if self._cached_value is not None:
var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
export_scope)
+ else:
+ # Store the graph_element here
+ var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
+ export_scope)
var_def.is_resource = True
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
del name
if dtype is not None and dtype != self.dtype:
- print("trying to switch the dtype to ", dtype, " from ", self.dtype)
return NotImplemented
if as_ref:
return self.read_value().op.inputs[0]