Convert a local variable and mutex to a struct so GUARDED_BY annotation works correctly.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 19 Apr 2018 22:27:19 +0000 (15:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 19 Apr 2018 22:30:02 +0000 (15:30 -0700)
PiperOrigin-RevId: 193584438

tensorflow/core/kernels/sdca_ops.cc

index 55e68b3..05c835e 100644 (file)
@@ -156,8 +156,10 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
   } else {
     examples.RandomShuffle();
   }
-  mutex mu;
-  Status train_step_status GUARDED_BY(mu);
+  struct {
+    mutex mu;
+    Status value GUARDED_BY(mu);
+  } train_step_status;
   std::atomic<std::int64_t> atomic_index(-1);
   auto train_step = [&](const int64 begin, const int64 end) {
     // The static_cast here is safe since begin and end can be at most
@@ -171,8 +173,8 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
       const Status conversion_status =
           options.loss_updater->ConvertLabel(&example_label);
       if (!conversion_status.ok()) {
-        mutex_lock l(mu);
-        train_step_status = conversion_status;
+        mutex_lock l(train_step_status.mu);
+        train_step_status.value = conversion_status;
         // Return from this worker thread - the calling thread is
         // responsible for checking context status and returning on error.
         return;
@@ -217,7 +219,8 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
 
   Shard(worker_threads.num_threads, worker_threads.workers,
         examples.num_examples(), kCostPerUnit, train_step);
-  OP_REQUIRES_OK(context, train_step_status);
+  mutex_lock l(train_step_status.mu);
+  OP_REQUIRES_OK(context, train_step_status.value);
 }
 
 }  // namespace