Make adaptive SDCA the default.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Mar 2018 21:36:46 +0000 (13:36 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Mar 2018 21:41:15 +0000 (13:41 -0800)
PiperOrigin-RevId: 188380039

tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
tensorflow/core/api_def/base_api/api_def_SdcaOptimizer.pbtxt
tensorflow/core/kernels/sdca_internal.cc
tensorflow/core/kernels/sdca_internal.h
tensorflow/core/kernels/sdca_ops.cc
tensorflow/tools/api/golden/tensorflow.train.pbtxt

index 70f777f..cfe62fa 100644 (file)
@@ -270,14 +270,14 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
 
           train_op = lr.minimize()
 
-          def Minimize():
+          def minimize():
             with self._single_threaded_test_session():
               for _ in range(_MAX_ITERATIONS):
-                train_op.run()
+                train_op.run()  # pylint: disable=cell-var-from-loop
 
           threads = []
           for _ in range(num_loss_partitions):
-            threads.append(threading.Thread(target=Minimize))
+            threads.append(threading.Thread(target=minimize))
             threads[-1].start()
 
           for t in threads:
@@ -395,7 +395,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
         predicted_labels = get_binary_predictions_for_logistic(predictions)
         self.assertAllClose([0, 1, 1, 1], predicted_labels.eval())
         self.assertAllClose(
-            0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+            0.0, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
 
   def testFractionalExampleLabel(self):
     # Setup test data with 1 positive, and 1 mostly-negative example.
@@ -407,7 +407,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
         make_example_proto({
             'age': [1],
             'gender': [1]
-        }, 1),
+        }, 0.9),
     ]
     example_weights = [1.0, 1.0]
     for num_shards in _SHARD_NUMBERS:
index b0b58ac..9da0e12 100644 (file)
@@ -97,8 +97,11 @@ END
   }
   attr {
     name: "adaptative"
+    default_value {
+      b: True
+    }
     description: <<END
-Whether to use Adapative SDCA for the inner loop.
+Whether to use Adaptive SDCA for the inner loop.
 END
   }
   attr {
index 066a4b8..5a389a6 100644 (file)
@@ -226,7 +226,7 @@ const ExampleStatistics Example::ComputeWxAndWeightedExampleNorm(
 }
 
 // Examples contains all the training examples that SDCA uses for a mini-batch.
-Status Examples::SampleAdaptativeProbabilities(
+Status Examples::SampleAdaptiveProbabilities(
     const int num_loss_partitions, const Regularizations& regularization,
     const ModelWeights& model_weights,
     const TTypes<float>::Matrix example_state_data,
index 4591569..1665b12 100644 (file)
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_KERNELS_SDCA_INTERNAL_H_
-#define TENSORFLOW_KERNELS_SDCA_INTERNAL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_
+#define TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_
 
 #define EIGEN_USE_THREADS
 
@@ -75,7 +75,7 @@ struct ExampleStatistics {
 
 class Regularizations {
  public:
-  Regularizations(){};
+  Regularizations() {}
 
   // Initialize() must be called immediately after construction.
   Status Initialize(OpKernelConstruction* const context) {
@@ -199,7 +199,7 @@ class FeatureWeightsDenseStorage {
   FeatureWeightsDenseStorage(const TTypes<const float>::Matrix nominals,
                              TTypes<float>::Matrix deltas)
       : nominals_(nominals), deltas_(deltas) {
-    CHECK(deltas.rank() > 1);
+    CHECK_GT(deltas.rank(), 1);
   }
 
   // Check if a feature index is with-in the bounds.
@@ -322,15 +322,15 @@ class Examples {
     return examples_.at(example_index);
   }
 
-  int sampled_index(const int id, const bool adaptative) const {
-    if (adaptative) return sampled_index_[id];
+  int sampled_index(const int id, const bool adaptive) const {
+    if (adaptive) return sampled_index_[id];
     return id;
   }
 
   // Adaptive SDCA in the current implementation only works for
   // binary classification, where the input argument for num_weight_vectors
   // is 1.
-  Status SampleAdaptativeProbabilities(
+  Status SampleAdaptiveProbabilities(
       const int num_loss_partitions, const Regularizations& regularization,
       const ModelWeights& model_weights,
       const TTypes<float>::Matrix example_state_data,
@@ -378,7 +378,7 @@ class Examples {
   // All examples in the batch.
   std::vector<Example> examples_;
 
-  // Adaptative sampling variables
+  // Adaptive sampling variables.
   std::vector<float> probabilities_;
   std::vector<int> sampled_index_;
   std::vector<int> sampled_count_;
@@ -391,4 +391,4 @@ class Examples {
 }  // namespace sdca
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_KERNELS_SDCA_INTERNAL_H_
+#endif  // TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_
index dbe0177..5b63057 100644 (file)
@@ -80,7 +80,7 @@ struct ComputeOptions {
           context, false,
           errors::InvalidArgument("Unsupported loss type: ", loss_type));
     }
-    OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptative));
+    OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptive));
     OP_REQUIRES_OK(
         context, context->GetAttr("num_sparse_features", &num_sparse_features));
     OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
@@ -113,7 +113,7 @@ struct ComputeOptions {
   int num_dense_features = 0;
   int num_inner_iterations = 0;
   int num_loss_partitions = 0;
-  bool adaptative = false;
+  bool adaptive = true;
   Regularizations regularizations;
 };
 
@@ -147,9 +147,9 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
   OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
                                               mutable_example_state_data_t));
 
-  if (options.adaptative) {
+  if (options.adaptive) {
     OP_REQUIRES_OK(context,
-                   examples.SampleAdaptativeProbabilities(
+                   examples.SampleAdaptiveProbabilities(
                        options.num_loss_partitions, options.regularizations,
                        model_weights, example_state_data, options.loss_updater,
                        /*num_weight_vectors =*/1));
@@ -163,7 +163,7 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
     // num_examples which is an int.
     for (int id = static_cast<int>(begin); id < end; ++id) {
       const int64 example_index =
-          examples.sampled_index(++atomic_index, options.adaptative);
+          examples.sampled_index(++atomic_index, options.adaptive);
       const Example& example = examples.example(example_index);
       const float dual = example_state_data(example_index, 0);
       const float example_weight = example.example_weight();
index e49c719..3b06aaf 100644 (file)
@@ -402,7 +402,7 @@ tf_module {
   }
   member_method {
     name: "sdca_optimizer"
-    argspec: "args=[\'sparse_example_indices\', \'sparse_feature_indices\', \'sparse_feature_values\', \'dense_features\', \'example_weights\', \'example_labels\', \'sparse_indices\', \'sparse_weights\', \'dense_weights\', \'example_state_data\', \'loss_type\', \'l1\', \'l2\', \'num_loss_partitions\', \'num_inner_iterations\', \'adaptative\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+    argspec: "args=[\'sparse_example_indices\', \'sparse_feature_indices\', \'sparse_feature_values\', \'dense_features\', \'example_weights\', \'example_labels\', \'sparse_indices\', \'sparse_weights\', \'dense_weights\', \'example_state_data\', \'loss_type\', \'l1\', \'l2\', \'num_loss_partitions\', \'num_inner_iterations\', \'adaptative\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
   member_method {
     name: "sdca_shrink_l1"