Making the tf.name_scope blocks related to the factor and weight vars configurable...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 31 May 2018 19:16:54 +0000 (12:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 19:19:57 +0000 (12:19 -0700)
PiperOrigin-RevId: 198759754

tensorflow/contrib/factorization/python/ops/factorization_ops.py

index 09745e2..8f73274 100644 (file)
@@ -197,7 +197,8 @@ class WALSModel(object):
                row_weights=1,
                col_weights=1,
                use_factors_weights_cache=True,
-               use_gramian_cache=True):
+               use_gramian_cache=True,
+               use_scoped_vars=False):
     """Creates model for WALS matrix factorization.
 
     Args:
@@ -239,6 +240,8 @@ class WALSModel(object):
         weights cache to take effect.
       use_gramian_cache: When True, the Gramians will be cached on the workers
         before the updates start. Defaults to True.
+      use_scoped_vars: When True, the factor and weight vars will also be nested
+        in a tf.name_scope.
     """
     self._input_rows = input_rows
     self._input_cols = input_cols
@@ -251,18 +254,36 @@ class WALSModel(object):
         regularization * linalg_ops.eye(self._n_components)
         if regularization is not None else None)
     assert (row_weights is None) == (col_weights is None)
-    self._row_weights = WALSModel._create_weights(
-        row_weights, self._input_rows, self._num_row_shards, "row_weights")
-    self._col_weights = WALSModel._create_weights(
-        col_weights, self._input_cols, self._num_col_shards, "col_weights")
     self._use_factors_weights_cache = use_factors_weights_cache
     self._use_gramian_cache = use_gramian_cache
-    self._row_factors = self._create_factors(
-        self._input_rows, self._n_components, self._num_row_shards, row_init,
-        "row_factors")
-    self._col_factors = self._create_factors(
-        self._input_cols, self._n_components, self._num_col_shards, col_init,
-        "col_factors")
+
+    if use_scoped_vars:
+      with ops.name_scope("row_weights"):
+        self._row_weights = WALSModel._create_weights(
+            row_weights, self._input_rows, self._num_row_shards, "row_weights")
+      with ops.name_scope("col_weights"):
+        self._col_weights = WALSModel._create_weights(
+            col_weights, self._input_cols, self._num_col_shards, "col_weights")
+      with ops.name_scope("row_factors"):
+        self._row_factors = self._create_factors(
+            self._input_rows, self._n_components, self._num_row_shards,
+            row_init, "row_factors")
+      with ops.name_scope("col_factors"):
+        self._col_factors = self._create_factors(
+            self._input_cols, self._n_components, self._num_col_shards,
+            col_init, "col_factors")
+    else:
+      self._row_weights = WALSModel._create_weights(
+          row_weights, self._input_rows, self._num_row_shards, "row_weights")
+      self._col_weights = WALSModel._create_weights(
+          col_weights, self._input_cols, self._num_col_shards, "col_weights")
+      self._row_factors = self._create_factors(
+          self._input_rows, self._n_components, self._num_row_shards, row_init,
+          "row_factors")
+      self._col_factors = self._create_factors(
+          self._input_cols, self._n_components, self._num_col_shards, col_init,
+          "col_factors")
+
     self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
     self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
     with ops.name_scope("row_prepare_gramian"):
@@ -313,37 +334,36 @@ class WALSModel(object):
   @classmethod
   def _create_factors(cls, rows, cols, num_shards, init, name):
     """Helper function to create row and column factors."""
-    with ops.name_scope(name):
-      if callable(init):
-        init = init()
-      if isinstance(init, list):
-        assert len(init) == num_shards
-      elif isinstance(init, str) and init == "random":
-        pass
-      elif num_shards == 1:
-        init = [init]
-      sharded_matrix = []
-      sizes = cls._shard_sizes(rows, num_shards)
-      assert len(sizes) == num_shards
-
-      def make_initializer(i, size):
-
-        def initializer():
-          if init == "random":
-            return random_ops.random_normal([size, cols])
-          else:
-            return init[i]
+    if callable(init):
+      init = init()
+    if isinstance(init, list):
+      assert len(init) == num_shards
+    elif isinstance(init, str) and init == "random":
+      pass
+    elif num_shards == 1:
+      init = [init]
+    sharded_matrix = []
+    sizes = cls._shard_sizes(rows, num_shards)
+    assert len(sizes) == num_shards
+
+    def make_initializer(i, size):
 
-        return initializer
+      def initializer():
+        if init == "random":
+          return random_ops.random_normal([size, cols])
+        else:
+          return init[i]
 
-      for i, size in enumerate(sizes):
-        var_name = "%s_shard_%d" % (name, i)
-        var_init = make_initializer(i, size)
-        sharded_matrix.append(
-            variable_scope.variable(
-                var_init, dtype=dtypes.float32, name=var_name))
+      return initializer
 
-      return sharded_matrix
+    for i, size in enumerate(sizes):
+      var_name = "%s_shard_%d" % (name, i)
+      var_init = make_initializer(i, size)
+      sharded_matrix.append(
+          variable_scope.variable(
+              var_init, dtype=dtypes.float32, name=var_name))
+
+    return sharded_matrix
 
   @classmethod
   def _create_weights(cls, wt_init, num_wts, num_shards, name):
@@ -384,26 +404,25 @@ class WALSModel(object):
     sizes = cls._shard_sizes(num_wts, num_shards)
     assert len(sizes) == num_shards
 
-    with ops.name_scope(name):
-      def make_wt_initializer(i, size):
+    def make_wt_initializer(i, size):
 
-        def initializer():
-          if init_mode == "scalar":
-            return wt_init * array_ops.ones([size])
-          else:
-            return wt_init[i]
+      def initializer():
+        if init_mode == "scalar":
+          return wt_init * array_ops.ones([size])
+        else:
+          return wt_init[i]
 
-        return initializer
+      return initializer
 
-      sharded_weight = []
-      for i, size in enumerate(sizes):
-        var_name = "%s_shard_%d" % (name, i)
-        var_init = make_wt_initializer(i, size)
-        sharded_weight.append(
-            variable_scope.variable(
-                var_init, dtype=dtypes.float32, name=var_name))
+    sharded_weight = []
+    for i, size in enumerate(sizes):
+      var_name = "%s_shard_%d" % (name, i)
+      var_init = make_wt_initializer(i, size)
+      sharded_weight.append(
+          variable_scope.variable(
+              var_init, dtype=dtypes.float32, name=var_name))
 
-      return sharded_weight
+    return sharded_weight
 
   @staticmethod
   def _create_gramian(n_components, name):