Add usage example to KMeans Estimator documentation.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Feb 2018 20:24:53 +0000 (12:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 20:32:50 +0000 (12:32 -0800)
PiperOrigin-RevId: 186805772

tensorflow/contrib/factorization/python/ops/kmeans.py

index c861cff..7319eaa 100644 (file)
@@ -61,8 +61,8 @@ class _LossRelativeChangeHook(session_run_hook.SessionRunHook):
     loss = run_values.results
     assert loss is not None
     if self._prev_loss:
-      relative_change = (abs(loss - self._prev_loss) /
-                         (1 + abs(self._prev_loss)))
+      relative_change = (
+          abs(loss - self._prev_loss) / (1 + abs(self._prev_loss)))
       if relative_change < self._tolerance:
         run_context.request_stop()
     self._prev_loss = loss
@@ -233,7 +233,57 @@ class _ModelFn(object):
 
 # TODO(agarwal,ands): support sharded input.
 class KMeansClustering(estimator.Estimator):
-  """An Estimator for K-Means clustering."""
+  """An Estimator for K-Means clustering.
+
+  Example:
+  ```
+  import numpy as np
+  import tensorflow as tf
+
+  num_points = 100
+  dimensions = 2
+  points = np.random.uniform(0, 1000, [num_points, dimensions])
+
+  def input_fn():
+    return tf.train.limit_epochs(
+        tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1)
+
+  num_clusters = 5
+  kmeans = tf.contrib.factorization.KMeansClustering(
+      num_clusters=num_clusters, use_mini_batch=False)
+
+  # train
+  num_iterations = 10
+  previous_centers = None
+  for _ in xrange(num_iterations):
+    kmeans.train(input_fn)
+    cluster_centers = kmeans.cluster_centers()
+    if previous_centers is not None:
+      print 'delta:', cluster_centers - previous_centers
+    previous_centers = cluster_centers
+    print 'score:', kmeans.score(input_fn)
+  print 'cluster centers:', cluster_centers
+
+  # map the input points to their clusters
+  cluster_indices = list(kmeans.predict_cluster_index(input_fn))
+  for i, point in enumerate(points):
+    cluster_index = cluster_indices[i]
+    center = cluster_centers[cluster_index]
+    print 'point:', point, 'is in cluster', cluster_index, 'centered at', center
+  ```
+
+  The `SavedModel` saved by the `export_savedmodel` method does not include the
+  cluster centers. However, the cluster centers may be retrieved by the
+  latest checkpoint saved during training. Specifically,
+  ```
+  kmeans.cluster_centers()
+  ```
+  is equivalent to
+  ```
+  tf.train.load_variable(
+      kmeans.model_dir, KMeansClustering.CLUSTER_CENTERS_VAR_NAME)
+  ```
+  """
 
   # Valid values for the distance_metric constructor argument.
   SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE
@@ -253,6 +303,9 @@ class KMeansClustering(estimator.Estimator):
   CLUSTER_INDEX = 'cluster_index'
   ALL_DISTANCES = 'all_distances'
 
+  # Variable name used by cluster_centers().
+  CLUSTER_CENTERS_VAR_NAME = clustering_ops.CLUSTERS_VAR_NAME
+
   def __init__(self,
                num_clusters,
                model_dir=None,
@@ -406,4 +459,4 @@ class KMeansClustering(estimator.Estimator):
 
   def cluster_centers(self):
     """Returns the cluster centers."""
-    return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME)
+    return self.get_variable_value(KMeansClustering.CLUSTER_CENTERS_VAR_NAME)