Expose the adaptive sampling option for SDCA and shuffle the data when adaptive sampl...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 02:33:58 +0000 (19:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 02:36:19 +0000 (19:36 -0700)
PiperOrigin-RevId: 191836004

tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
tensorflow/core/kernels/sdca_internal.cc
tensorflow/core/kernels/sdca_internal.h
tensorflow/core/kernels/sdca_ops.cc

index cfe62fa..ac50699 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import random
 import threading
 
 from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
@@ -102,6 +103,33 @@ def make_example_dict(example_protos, example_weights):
       example_ids=['%d' % i for i in range(0, len(example_protos))])
 
 
+def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
+  random.seed(1)
+  sparse_features = [
+      SparseFeatureColumn(
+          [int(i / num_non_zero) for i in range(num_examples * num_non_zero)],
+          [int(random.random() * dim) for _ in range(
+              num_examples * num_non_zero)],
+          [num_non_zero**(-0.5) for _ in range(num_examples * num_non_zero)])
+  ]
+  examples_dict = dict(
+      sparse_features=sparse_features,
+      dense_features=[],
+      example_weights=[random.random() for _ in range(num_examples)],
+      example_labels=[
+          1. if random.random() > 0.5 else 0. for _ in range(num_examples)
+      ],
+      example_ids=[str(i) for i in range(num_examples)])
+
+  weights = variables_lib.Variable(
+      array_ops.zeros([dim], dtype=dtypes.float32))
+  variables_dict = dict(
+      sparse_features_weights=[weights],
+      dense_features_weights=[])
+
+  return examples_dict, variables_dict
+
+
 def make_variable_dict(max_age, max_gender):
   # TODO(sibyl-toe9oF2e):  Figure out how to derive max_age & max_gender from
   # examples_dict.
@@ -235,6 +263,32 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
         self.assertAllClose(
             0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
 
+  def testSparseRandom(self):
+    dim = 20
+    num_examples = 1000
+    # Number of non-zero features per example.
+    non_zeros = 10
+    # Setup test data.
+    with self._single_threaded_test_session():
+      examples, variables = make_random_examples_and_variables_dicts(
+          num_examples, dim, non_zeros)
+      options = dict(
+          symmetric_l2_regularization=.1,
+          symmetric_l1_regularization=0,
+          num_table_shards=1,
+          adaptive=False,
+          loss_type='logistic_loss')
+
+      lr = SdcaModel(examples, variables, options)
+      variables_lib.global_variables_initializer().run()
+      train_op = lr.minimize()
+      for _ in range(4):
+        train_op.run()
+      lr.update_weights(train_op).run()
+      # Duality gap is 1.4e-5.
+      # It would be 0.01 without shuffling and 0.02 with adaptive sampling.
+      self.assertNear(0.0, lr.approximate_duality_gap().eval(), err=1e-3)
+
   def testDistributedSimple(self):
     # Setup test data
     example_protos = [
index 3f5fdc1..f980746 100644 (file)
@@ -168,6 +168,10 @@ class SdcaModel(object):
     # of workers
     return self._options.get('num_loss_partitions', 1)
 
+  def _adaptive(self):
+    # Perform adaptive sampling.
+    return self._options.get('adaptive', True)
+
   def _num_table_shards(self):
     # Number of hash table shards.
     # Return 1 if not specified or if the value is 'None'
@@ -344,7 +348,8 @@ class SdcaModel(object):
           l1=self._options['symmetric_l1_regularization'],
           l2=self._symmetric_l2_regularization(),
           num_loss_partitions=self._num_loss_partitions(),
-          num_inner_iterations=1)
+          num_inner_iterations=1,
+          adaptative=self._adaptive())
       # pylint: enable=protected-access
 
       with ops.control_dependencies([esu]):
index 92d022f..dffddda 100644 (file)
@@ -71,12 +71,14 @@ class SDCAOptimizer(object):
                num_loss_partitions=1,
                num_table_shards=None,
                symmetric_l1_regularization=0.0,
-               symmetric_l2_regularization=1.0):
+               symmetric_l2_regularization=1.0,
+               adaptive=True):
     self._example_id_column = example_id_column
     self._num_loss_partitions = num_loss_partitions
     self._num_table_shards = num_table_shards
     self._symmetric_l1_regularization = symmetric_l1_regularization
     self._symmetric_l2_regularization = symmetric_l2_regularization
+    self._adaptive = adaptive
 
   def get_name(self):
     return 'SDCAOptimizer'
@@ -101,6 +103,10 @@ class SDCAOptimizer(object):
   def symmetric_l2_regularization(self):
     return self._symmetric_l2_regularization
 
+  @property
+  def adaptive(self):
+    return self._adaptive
+
   def get_train_step(self, columns_to_variables, weight_column_name, loss_type,
                      features, targets, global_step):
     """Returns the training operation of an SdcaModel optimizer."""
@@ -228,6 +234,7 @@ class SDCAOptimizer(object):
         options=dict(
             symmetric_l1_regularization=self._symmetric_l1_regularization,
             symmetric_l2_regularization=self._symmetric_l2_regularization,
+            adaptive=self._adaptive,
             num_loss_partitions=self._num_loss_partitions,
             num_table_shards=self._num_table_shards,
             loss_type=loss_type))
index 5a389a6..623de2a 100644 (file)
@@ -302,6 +302,11 @@ Status Examples::SampleAdaptiveProbabilities(
   return Status::OK();
 }
 
+void Examples::RandomShuffle() {
+  std::iota(sampled_index_.begin(), sampled_index_.end(), 0);
+  std::random_shuffle(sampled_index_.begin(), sampled_index_.end());
+}
+
 // TODO(sibyl-Aix6ihai): Refactor/shorten this function.
 Status Examples::Initialize(OpKernelContext* const context,
                             const ModelWeights& weights,
index 1665b12..bfdb3fe 100644 (file)
@@ -322,10 +322,7 @@ class Examples {
     return examples_.at(example_index);
   }
 
-  int sampled_index(const int id, const bool adaptive) const {
-    if (adaptive) return sampled_index_[id];
-    return id;
-  }
+  int sampled_index(const int id) const { return sampled_index_[id]; }
 
   // Adaptive SDCA in the current implementation only works for
   // binary classification, where the input argument for num_weight_vectors
@@ -337,6 +334,8 @@ class Examples {
       const std::unique_ptr<DualLossUpdater>& loss_updater,
       const int num_weight_vectors);
 
+  void RandomShuffle();
+
   int num_examples() const { return examples_.size(); }
 
   int num_features() const { return num_features_; }
index 5b63057..55e68b3 100644 (file)
@@ -153,8 +153,9 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
                        options.num_loss_partitions, options.regularizations,
                        model_weights, example_state_data, options.loss_updater,
                        /*num_weight_vectors =*/1));
+  } else {
+    examples.RandomShuffle();
   }
-
   mutex mu;
   Status train_step_status GUARDED_BY(mu);
   std::atomic<std::int64_t> atomic_index(-1);
@@ -162,8 +163,7 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
     // The static_cast here is safe since begin and end can be at most
     // num_examples which is an int.
     for (int id = static_cast<int>(begin); id < end; ++id) {
-      const int64 example_index =
-          examples.sampled_index(++atomic_index, options.adaptive);
+      const int64 example_index = examples.sampled_index(++atomic_index);
       const Example& example = examples.example(example_index);
       const float dual = example_state_data(example_index, 0);
       const float example_weight = example.example_weight();