Refactor score definition in GMM operations. This is simplified to be the per-sample...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 7 Feb 2018 16:48:05 +0000 (08:48 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Feb 2018 16:52:21 +0000 (08:52 -0800)
PiperOrigin-RevId: 184843634

tensorflow/contrib/factorization/python/ops/gmm.py
tensorflow/contrib/factorization/python/ops/gmm_ops.py
tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
tensorflow/contrib/factorization/python/ops/gmm_test.py

index f72280c..b2dfe48 100644 (file)
@@ -24,17 +24,16 @@ import numpy as np
 from tensorflow.contrib import framework
 from tensorflow.contrib.factorization.python.ops import gmm_ops
 from tensorflow.contrib.framework.python.framework import checkpoint_utils
-from tensorflow.python.training import training_util
 from tensorflow.contrib.learn.python.learn.estimators import estimator
 from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import logging_ops as logging
-from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops.control_flow_ops import with_dependencies
 from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
 
 
 def _streaming_sum(scalar_tensor):
@@ -70,8 +69,8 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook):
 class GMM(estimator.Estimator):
   """An estimator for GMM clustering."""
   SCORES = 'scores'
+  LOG_LIKELIHOOD = 'loss'
   ASSIGNMENTS = 'assignments'
-  ALL_SCORES = 'all_scores'
 
   def __init__(self,
                num_clusters,
@@ -113,10 +112,7 @@ class GMM(estimator.Estimator):
       yield result[GMM.ASSIGNMENTS]
 
   def score(self, input_fn=None, batch_size=None, steps=None):
-    """Predict total sum of distances to nearest clusters.
-
-    Note that this function is different from the corresponding one in sklearn
-    which returns the negative of the sum of distances.
+    """Predict total log-likelihood.
 
     Args:
       input_fn: see predict.
@@ -124,11 +120,11 @@ class GMM(estimator.Estimator):
       steps: see predict.
 
     Returns:
-      Total sum of distances to nearest clusters.
+      Total log-likelihood.
     """
     results = self.evaluate(input_fn=input_fn, batch_size=batch_size,
                             steps=steps)
-    return np.sum(results[GMM.SCORES])
+    return np.log(np.sum(np.exp(results[GMM.SCORES])))
 
   def weights(self):
     """Returns the cluster weights."""
@@ -158,9 +154,10 @@ class GMM(estimator.Estimator):
     def _model_fn(features, labels, mode, config):
       """Model function."""
       assert labels is None, labels
-      (all_scores,
+      (loss,
+       scores,
        model_predictions,
-       losses, training_op,
+       training_op,
        init_op,
        is_initialized) = gmm_ops.gmm(self._parse_tensor_or_dict(features),
                                      self._training_initial_clusters,
@@ -168,16 +165,15 @@ class GMM(estimator.Estimator):
                                      self._covariance_type,
                                      self._params)
       incr_step = state_ops.assign_add(training_util.get_global_step(), 1)
-      loss = math_ops.reduce_sum(losses)
       training_op = with_dependencies([training_op, incr_step], loss)
       training_hooks = [_InitializeClustersHook(
           init_op, is_initialized, config.is_chief)]
       predictions = {
-          GMM.ALL_SCORES: all_scores[0],
           GMM.ASSIGNMENTS: model_predictions[0][0],
       }
       eval_metric_ops = {
-          GMM.SCORES: _streaming_sum(loss),
+          GMM.SCORES: scores,
+          GMM.LOG_LIKELIHOOD: _streaming_sum(loss),
       }
       return model_fn_lib.ModelFnOps(mode=mode, predictions=predictions,
                                      eval_metric_ops=eval_metric_ops,
index a61681c..98d6434 100644 (file)
@@ -21,7 +21,6 @@ from __future__ import division
 from __future__ import print_function
 
 import numpy as np
-from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -36,7 +35,6 @@ from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.ops.embedding_ops import embedding_lookup
-from tensorflow.python.summary import summary
 
 # Machine epsilon.
 MEPS = np.finfo(float).eps
@@ -253,14 +251,16 @@ class GmmAlgorithm(object):
     return ret
 
   def scores(self):
-    """Returns the distances to each class.
+    """Returns the per-sample likelihood fo the data.
 
     Returns:
-      A tuple with two Tensors. The first contains the distance to
-    each class. The second contains the distance to the assigned
-    class.
+      Log probabilities of each data point.
     """
-    return (self._all_scores, self._scores)
+    return self._scores
+
+  def log_likelihood_op(self):
+    """Returns the log-likelihood operation."""
+    return self._log_likelihood_op
 
   def _define_graph(self, data):
     """Define graph for a single iteration.
@@ -276,7 +276,8 @@ class GmmAlgorithm(object):
       self._define_expectation_operation(shard_id)
       self._define_partial_maximization_operation(shard_id, shard)
     self._define_maximization_operation(len(data))
-    self._define_distance_to_clusters(data)
+    self._define_loglikelihood_operation()
+    self._define_score_samples()
 
   def _define_full_covariance_probs(self, shard_id, shard):
     """Defines the full covariance probabilties per example in a class.
@@ -440,50 +441,20 @@ class GmmAlgorithm(object):
                 state_ops.assign(
                     self._covs, new_covs, validate_shape=False))
 
-  def _define_distance_to_clusters(self, data):
-    """Defines the Mahalanobis distance to the assigned Gaussian."""
-    # TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input -
-    # mean) from log probability function.
-    self._all_scores = []
-    for shard in data:
-      all_scores = []
-      shard = array_ops.expand_dims(shard, 0)
-      for c in xrange(self._num_classes):
-        if self._covariance_type == FULL_COVARIANCE:
-          cov = self._covs[c, :, :]
-        elif self._covariance_type == DIAG_COVARIANCE:
-          cov = array_ops.diag(self._covs[c, :])
-        inverse = linalg_ops.matrix_inverse(cov + self._min_var)
-        inv_cov = array_ops.tile(
-            array_ops.expand_dims(inverse, 0),
-            array_ops.stack([self._num_examples, 1, 1]))
-        diff = array_ops.transpose(shard - self._means[c, :, :], perm=[1, 0, 2])
-        m_left = math_ops.matmul(diff, inv_cov)
-        all_scores.append(
-            math_ops.sqrt(
-                math_ops.matmul(
-                    m_left, array_ops.transpose(
-                        diff, perm=[0, 2, 1]))))
-      self._all_scores.append(
-          array_ops.reshape(
-              array_ops.concat(all_scores, 1),
-              array_ops.stack([self._num_examples, self._num_classes])))
-
-    # Distance to the associated class.
-    self._all_scores = array_ops.concat(self._all_scores, 0)
-    assignments = array_ops.concat(self.assignments(), 0)
-    rows = math_ops.to_int64(math_ops.range(0, self._num_examples))
-    indices = array_ops.concat(
-        [array_ops.expand_dims(rows, 1), array_ops.expand_dims(assignments, 1)],
-        1)
-    self._scores = array_ops.gather_nd(self._all_scores, indices)
-
   def _define_loglikelihood_operation(self):
     """Defines the total log-likelihood of current iteration."""
-    self._ll_op = []
+    op = []
     for prior_probs in self._prior_probs:
-      self._ll_op.append(math_ops.reduce_sum(math_ops.log(prior_probs)))
-    summary.scalar('ll', math_ops.reduce_sum(self._ll_op))
+      op.append(math_ops.reduce_logsumexp(prior_probs))
+    self._log_likelihood_op = math_ops.reduce_logsumexp(op)
+
+  def _define_score_samples(self):
+    """Defines the likelihood of each data sample."""
+    op = []
+    for shard_id, prior_probs in enumerate(self._prior_probs):
+      op.append(prior_probs + math_ops.log(self._w[shard_id]))
+    self._scores = array_ops.squeeze(
+        math_ops.reduce_logsumexp(op, axis=2, keep_dims=True), axis=0)
 
 
 def gmm(inp,
@@ -511,14 +482,9 @@ def gmm(inp,
   Returns:
     Note: tuple of lists returned to be consistent with skflow
     A tuple consisting of:
-    all_scores: A matrix (or list of matrices) of dimensions (num_input,
-      num_clusters) where the value is the distance of an input vector and a
-      cluster center.
     assignments: A vector (or list of vectors). Each element in the vector
       corresponds to an input row in 'inp' and specifies the cluster id
       corresponding to the input.
-    scores: Similar to assignments but specifies the distance to the
-      assigned cluster instead.
     training_op: an op that runs an iteration of training.
     init_op: an op that runs the initialization.
   """
@@ -532,6 +498,7 @@ def gmm(inp,
   gmm_tool = GmmAlgorithm(inp, num_clusters, initial_means, params,
                           covariance_type, random_seed)
   assignments = gmm_tool.assignments()
-  all_scores, scores = gmm_tool.scores()
-  return ([all_scores], [assignments], [scores], gmm_tool.training_ops(),
+  scores = gmm_tool.scores()
+  loss = gmm_tool.log_likelihood_op()
+  return (loss, scores, [assignments], gmm_tool.training_ops(),
           gmm_tool.init_ops(), gmm_tool.is_initialized())
index c50e82d..888c3c2 100644 (file)
@@ -122,17 +122,23 @@ class GmmOpsTest(test.TestCase):
       g.seed = 5
       with self.test_session() as sess:
         data = constant_op.constant(self.data, dtype=dtypes.float32)
-        _, assignments, _, training_op, init_op, _ = gmm_ops.gmm(
+        loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm(
             data, 'random', num_classes, random_seed=self.seed)
 
         variables.global_variables_initializer().run()
         sess.run(init_op)
+        first_loss = sess.run(loss_op)
         for _ in xrange(self.iterations):
           sess.run(training_op)
         assignments = sess.run(assignments)
+        end_loss = sess.run(loss_op)
+        scores = sess.run(scores)
+        self.assertEqual((self.num_examples, 1), scores.shape)
         accuracy = np.mean(
             np.asarray(self.true_assignments) == np.squeeze(assignments))
         logging.info('Accuracy: %f', accuracy)
+        logging.info('First loss: %f, end loss: %f', first_loss, end_loss)
+        self.assertGreater(end_loss, first_loss)
         self.assertGreater(accuracy, 0.98)
 
   def testParams(self):
index 7717b47..00a4734 100644 (file)
@@ -19,7 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 import numpy as np
-from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.contrib.factorization.python.ops import gmm as gmm_lib
 from tensorflow.contrib.learn.python.learn.estimators import kmeans
@@ -30,12 +29,9 @@ from tensorflow.python.framework import random_seed as random_seed_lib
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import flags
 from tensorflow.python.platform import test
 from tensorflow.python.training import queue_runner
 
-FLAGS = flags.FLAGS
-
 
 class GMMTest(test.TestCase):
 
@@ -64,9 +60,8 @@ class GMMTest(test.TestCase):
     self.batch_size = self.num_points
     self.true_centers = self.make_random_centers(self.num_centers,
                                                  self.num_dims)
-    self.points, self.assignments, self.scores = self.make_random_points(
+    self.points, self.assignments = self.make_random_points(
         self.true_centers, self.num_points)
-    self.true_score = np.add.reduce(self.scores)
 
     # Use initial means from kmeans (just like scikit-learn does).
     clusterer = kmeans.KMeansClustering(num_clusters=self.num_centers)
@@ -86,24 +81,7 @@ class GMMTest(test.TestCase):
     offsets = np.round(
         np.random.randn(num_points, num_dims).astype(np.float32) * 20)
     points = centers[assignments] + offsets
-    means = [
-        np.mean(
-            points[assignments == center], axis=0)
-        for center in xrange(num_centers)
-    ]
-    covs = [
-        np.cov(points[assignments == center].T)
-        for center in xrange(num_centers)
-    ]
-    scores = []
-    for r in xrange(num_points):
-      scores.append(
-          np.sqrt(
-              np.dot(
-                  np.dot(points[r, :] - means[assignments[r]],
-                         np.linalg.inv(covs[assignments[r]])), points[r, :] -
-                  means[assignments[r]])))
-    return (points, assignments, scores)
+    return (points, assignments)
 
   def test_weights(self):
     """Tests the shape of the weights."""
@@ -136,8 +114,7 @@ class GMMTest(test.TestCase):
     gmm.fit(input_fn=self.input_fn(), steps=10)
     score2 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points),
                        steps=1)
-    self.assertGreater(score1, score2)
-    self.assertNear(self.true_score, score2, self.true_score * 0.15)
+    self.assertLess(score1, score2)
 
   def test_infer(self):
     gmm = gmm_lib.GMM(self.num_centers,
@@ -149,8 +126,7 @@ class GMMTest(test.TestCase):
 
     # Make a small test set
     num_points = 40
-    points, true_assignments, true_offsets = (
-        self.make_random_points(clusters, num_points))
+    points, true_assignments = self.make_random_points(clusters, num_points)
 
     assignments = []
     for item in gmm.predict_assignments(
@@ -159,11 +135,6 @@ class GMMTest(test.TestCase):
     assignments = np.ravel(assignments)
     self.assertAllEqual(true_assignments, assignments)
 
-    # Test score
-    score = gmm.score(input_fn=self.input_fn(points=points,
-                                             batch_size=num_points), steps=1)
-    self.assertNear(score, np.sum(true_offsets), 4.05)
-
   def _compare_with_sklearn(self, cov_type):
     # sklearn version.
     iterations = 40