row_weights=1,
col_weights=1,
use_factors_weights_cache=True,
- use_gramian_cache=True):
+ use_gramian_cache=True,
+ use_scoped_vars=False):
"""Creates model for WALS matrix factorization.
Args:
weights cache to take effect.
use_gramian_cache: When True, the Gramians will be cached on the workers
before the updates start. Defaults to True.
+ use_scoped_vars: When True, the factor and weight vars will also be nested
+ in a tf.name_scope.
"""
self._input_rows = input_rows
self._input_cols = input_cols
regularization * linalg_ops.eye(self._n_components)
if regularization is not None else None)
assert (row_weights is None) == (col_weights is None)
- self._row_weights = WALSModel._create_weights(
- row_weights, self._input_rows, self._num_row_shards, "row_weights")
- self._col_weights = WALSModel._create_weights(
- col_weights, self._input_cols, self._num_col_shards, "col_weights")
self._use_factors_weights_cache = use_factors_weights_cache
self._use_gramian_cache = use_gramian_cache
- self._row_factors = self._create_factors(
- self._input_rows, self._n_components, self._num_row_shards, row_init,
- "row_factors")
- self._col_factors = self._create_factors(
- self._input_cols, self._n_components, self._num_col_shards, col_init,
- "col_factors")
+
+ if use_scoped_vars:
+ with ops.name_scope("row_weights"):
+ self._row_weights = WALSModel._create_weights(
+ row_weights, self._input_rows, self._num_row_shards, "row_weights")
+ with ops.name_scope("col_weights"):
+ self._col_weights = WALSModel._create_weights(
+ col_weights, self._input_cols, self._num_col_shards, "col_weights")
+ with ops.name_scope("row_factors"):
+ self._row_factors = self._create_factors(
+ self._input_rows, self._n_components, self._num_row_shards,
+ row_init, "row_factors")
+ with ops.name_scope("col_factors"):
+ self._col_factors = self._create_factors(
+ self._input_cols, self._n_components, self._num_col_shards,
+ col_init, "col_factors")
+ else:
+ self._row_weights = WALSModel._create_weights(
+ row_weights, self._input_rows, self._num_row_shards, "row_weights")
+ self._col_weights = WALSModel._create_weights(
+ col_weights, self._input_cols, self._num_col_shards, "col_weights")
+ self._row_factors = self._create_factors(
+ self._input_rows, self._n_components, self._num_row_shards, row_init,
+ "row_factors")
+ self._col_factors = self._create_factors(
+ self._input_cols, self._n_components, self._num_col_shards, col_init,
+ "col_factors")
+
self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
with ops.name_scope("row_prepare_gramian"):
@classmethod
def _create_factors(cls, rows, cols, num_shards, init, name):
"""Helper function to create row and column factors."""
- with ops.name_scope(name):
- if callable(init):
- init = init()
- if isinstance(init, list):
- assert len(init) == num_shards
- elif isinstance(init, str) and init == "random":
- pass
- elif num_shards == 1:
- init = [init]
- sharded_matrix = []
- sizes = cls._shard_sizes(rows, num_shards)
- assert len(sizes) == num_shards
-
- def make_initializer(i, size):
-
- def initializer():
- if init == "random":
- return random_ops.random_normal([size, cols])
- else:
- return init[i]
+ if callable(init):
+ init = init()
+ if isinstance(init, list):
+ assert len(init) == num_shards
+ elif isinstance(init, str) and init == "random":
+ pass
+ elif num_shards == 1:
+ init = [init]
+ sharded_matrix = []
+ sizes = cls._shard_sizes(rows, num_shards)
+ assert len(sizes) == num_shards
+
+ def make_initializer(i, size):
- return initializer
+ def initializer():
+ if init == "random":
+ return random_ops.random_normal([size, cols])
+ else:
+ return init[i]
- for i, size in enumerate(sizes):
- var_name = "%s_shard_%d" % (name, i)
- var_init = make_initializer(i, size)
- sharded_matrix.append(
- variable_scope.variable(
- var_init, dtype=dtypes.float32, name=var_name))
+ return initializer
- return sharded_matrix
+ for i, size in enumerate(sizes):
+ var_name = "%s_shard_%d" % (name, i)
+ var_init = make_initializer(i, size)
+ sharded_matrix.append(
+ variable_scope.variable(
+ var_init, dtype=dtypes.float32, name=var_name))
+
+ return sharded_matrix
@classmethod
def _create_weights(cls, wt_init, num_wts, num_shards, name):
sizes = cls._shard_sizes(num_wts, num_shards)
assert len(sizes) == num_shards
- with ops.name_scope(name):
- def make_wt_initializer(i, size):
+ def make_wt_initializer(i, size):
- def initializer():
- if init_mode == "scalar":
- return wt_init * array_ops.ones([size])
- else:
- return wt_init[i]
+ def initializer():
+ if init_mode == "scalar":
+ return wt_init * array_ops.ones([size])
+ else:
+ return wt_init[i]
- return initializer
+ return initializer
- sharded_weight = []
- for i, size in enumerate(sizes):
- var_name = "%s_shard_%d" % (name, i)
- var_init = make_wt_initializer(i, size)
- sharded_weight.append(
- variable_scope.variable(
- var_init, dtype=dtypes.float32, name=var_name))
+ sharded_weight = []
+ for i, size in enumerate(sizes):
+ var_name = "%s_shard_%d" % (name, i)
+ var_init = make_wt_initializer(i, size)
+ sharded_weight.append(
+ variable_scope.variable(
+ var_init, dtype=dtypes.float32, name=var_name))
- return sharded_weight
+ return sharded_weight
@staticmethod
def _create_gramian(n_components, name):