Cleaning up tracing code.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 30 Apr 2018 11:21:09 +0000 (04:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 30 Apr 2018 11:23:46 +0000 (04:23 -0700)
PiperOrigin-RevId: 194768567

21 files changed:
tensorflow/compiler/jit/xla_device.cc
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
tensorflow/core/common_runtime/copy_tensor.cc
tensorflow/core/common_runtime/gpu/gpu_device.cc
tensorflow/core/common_runtime/gpu/gpu_util.cc
tensorflow/core/common_runtime/process_util.cc
tensorflow/core/common_runtime/sycl/sycl_device.cc
tensorflow/core/common_runtime/threadpool_device.cc
tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
tensorflow/core/framework/dataset.h
tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
tensorflow/core/kernels/function_ops.cc
tensorflow/core/lib/core/threadpool.cc
tensorflow/core/platform/default/device_tracer.cc
tensorflow/core/platform/default/tracing.cc
tensorflow/core/platform/default/tracing_impl.h
tensorflow/core/platform/posix/tracing.cc [deleted file]
tensorflow/core/platform/tracing.cc
tensorflow/core/platform/tracing.h

index c814b7e..70263b1 100644 (file)
@@ -260,11 +260,10 @@ Status XlaDevice::FillContextMap(const Graph* graph,
 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
   VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":"
           << op_kernel->type_string();
-  // When TraceMe profiling is off (which is the default), the
-  // following TraceMe constructor is simply a conditional test of
-  // false value. Measurements show that its overhead is negligible.
-  port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(),
-                                  op_kernel->IsExpensive());
+  // When Xprof profiling is off (which is the default), constructing the
+  // activity is simple enough that its overhead is negligible.
+  tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
+                                   op_kernel->IsExpensive());
   op_kernel->Compute(context);
 }
 
@@ -272,8 +271,8 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
                              AsyncOpKernel::DoneCallback done) {
   VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
           << op_kernel->type_string();
-  port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(),
-                                  op_kernel->IsExpensive());
+  tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
+                                   op_kernel->IsExpensive());
   op_kernel->ComputeAsync(context, done);
 }
 
index 30bfc93..796c307 100644 (file)
@@ -100,7 +100,7 @@ namespace gpu {
 
 namespace {
 
-using tensorflow::port::Tracing;
+namespace tracing = tensorflow::tracing;
 
 // Returns the directory containing nvvm libdevice files.  config_cuda_data_dir
 // should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the
@@ -410,7 +410,7 @@ void WarnIfBadDriverJITVersion() {
 // code (i.e. a cubin) as a byte array.
 StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
                                         int cc_minor) {
-  Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true);
+  tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true);
   const string ptxas_path =
       tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas");
   VLOG(2) << "Using ptxas at " << ptxas_path;
@@ -481,8 +481,8 @@ StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
     DeviceMemoryAllocator* device_allocator) {
   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
-  Tracing::TraceMe annotation("HLO Transforms", module->name(),
-                              /*is_expensive=*/true);
+  tracing::ScopedActivity activity("HLO Transforms", module->name(),
+                                   /*is_expensive=*/true);
   TF_RETURN_IF_ERROR(
       OptimizeHloModule(module.get(), stream_exec, device_allocator));
   return std::move(module);
@@ -692,7 +692,7 @@ std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
                                                             int cc_major,
                                                             int cc_minor) {
   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult");
-  Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true);
+  tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true);
   bool inserted;
   decltype(compilation_cache_.begin()) iter;
   // Pointers into compilation_cache_ where the ptx and (optional) cubin are
index df9d9be..d70cb07 100644 (file)
@@ -491,7 +491,7 @@ StatusOr<string> CompileToPtx(llvm::Module* module,
 
   string ptx;
   {
-    tensorflow::port::Tracing::TraceMe annotation(
+    tensorflow::tracing::ScopedActivity activity(
         "Compiling IR", llvm_ir::AsString(module->getName()),
         /*is_expensive=*/true);
     XLA_SCOPED_LOGGING_TIMER("Compile module " +
index e355487..08d120c 100644 (file)
@@ -237,7 +237,7 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
                         const AllocatorAttributes dst_alloc_attr,
                         const Tensor* input, Tensor* output,
                         StatusCallback done) {
-  port::Tracing::ScopedAnnotation annotation(edge_name);
+  tracing::ScopedAnnotation annotation(edge_name);
   VLOG(1) << "Copy " << edge_name;
 
   const DeviceType src_device_type(
index 944f0c8..9b434e5 100644 (file)
@@ -406,12 +406,8 @@ Status BaseGPUDevice::FillContextMap(const Graph* graph,
 }
 
 void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
-  // ScopedActivity is cheap when tracing is not active, but we
-  // can avoid computing the Hash64.
-  // TODO(pbar) This would no longer be needed if Ops have a unique id.
-  const uint64 id = port::Tracing::IsActive() ? Hash64(op_kernel->name()) : 0;
-  port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
-                                       id);
+  tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+                               op_kernel->name());
 
   // NOTE(tucker): We need to discriminate between Eigen GPU
   // operations and all others.  If an operation is Eigen
@@ -425,11 +421,9 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
   if (op_kernel->is_internal() && op_kernel->type_string() == "_Recv") {
     context->SetStatus(errors::Internal(
         "Invalid synchronous 'Compute' on GPU for '_Recv' op"));
-  } else if (port::Tracing::ScopedAnnotation::Enabled()) {
-    port::Tracing::ScopedAnnotation annotation(op_kernel->name(),
-                                               op_kernel->type_string());
-    ComputeHelper(op_kernel, context);
   } else {
+    tracing::ScopedAnnotation annotation(op_kernel->name(),
+                                         op_kernel->type_string());
     ComputeHelper(op_kernel, context);
   }
 }
@@ -527,11 +521,10 @@ void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
           << op_kernel->type_string() << " on GPU" << tf_gpu_id_ << " stream["
           << stream_id << "]";
 
-  // When TraceMe profiling is off (which is the default), the
-  // following TraceMe constructor is simply a conditional test of
-  // false value. Measurements show that its overhead is negligible.
-  port::Tracing::TraceMe activity(op_kernel->name(), op_kernel->type_string(),
-                                  op_kernel->IsExpensive());
+  // When Xprof profiling is off (which is the default), constructing the
+  // activity is simple enough that its overhead is negligible.
+  tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
+                                   op_kernel->IsExpensive());
   se::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
   op_kernel->ComputeAsync(context, done);
 }
@@ -573,7 +566,7 @@ Status BaseGPUDevice::MaybeCopyTensorToGPU(
         },
         std::move(done), std::placeholders::_1);
 
-    port::Tracing::ScopedAnnotation annotation("MakeTensorFromProto");
+    tracing::ScopedAnnotation annotation("MakeTensorFromProto");
     device_contexts_[0]->CopyCPUTensorToDevice(&from, this, copy,
                                                std::move(wrapped_done));
     return Status::OK();
index 7ba853f..d38413d 100644 (file)
@@ -149,7 +149,7 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
   char* buf = nullptr;
   const int64 total_bytes = is_dead ? 0 : tensor.TotalBytes();
   if (total_bytes > 0) {
-    port::Tracing::ScopedAnnotation annotation("SetProtoFromGPU");
+    tracing::ScopedAnnotation annotation("SetProtoFromGPU");
     alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
     buf = alloc->Allocate<char>(total_bytes);
     if (LogMemory::IsEnabled()) {
index f8f3a1e..2191223 100644 (file)
@@ -79,21 +79,18 @@ thread::ThreadPool* NewThreadPoolFromSessionOptions(
 }
 
 void SchedClosure(std::function<void()> closure) {
-  if (port::Tracing::IsActive()) {
-    const uint64 id = port::Tracing::UniqueId();
-    port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
-                               id);
-    std::function<void()> wrapper = std::bind(
-        [id](std::function<void()> closure) {
-          port::Tracing::ScopedActivity region(
-              port::Tracing::EventCategory::kRunClosure, id);
-          closure();
-        },
-        std::move(closure));
-    Env::Default()->SchedClosure(std::move(wrapper));
-  } else {
-    Env::Default()->SchedClosure(std::move(closure));
+  if (!tracing::EventCollector::IsEnabled()) {
+    return Env::Default()->SchedClosure(std::move(closure));
   }
+  uint64 id = tracing::GetUniqueArg();
+  tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
+
+  Env::Default()->SchedClosure(std::bind(
+      [id](std::function<void()> closure) {
+        tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, id);
+        closure();
+      },
+      std::move(closure)));
 }
 
 void SchedNonBlockingClosureAfter(int64 micros, std::function<void()> closure) {
index 6e1a45b..f3bd72f 100644 (file)
@@ -27,12 +27,11 @@ SYCLDevice::~SYCLDevice() {}
 
 void SYCLDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
   assert(context);
-  if (port::Tracing::IsActive()) {
-    // TODO(pbar) We really need a useful identifier of the graph node.
-    const uint64 id = Hash64(op_kernel->name());
-    port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
-                                         id);
-  }
+  // When ThreadScape profiling is off (which is the default), constructing the
+  // following code is simple enough that its overhead is negligible.
+  tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+                               op_kernel->name());
+
   op_kernel->Compute(context);
 }
 
index 6d8de6a..f7a07fe 100644 (file)
@@ -48,20 +48,14 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
 ThreadPoolDevice::~ThreadPoolDevice() {}
 
 void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
-  // When TraceMe profiling is off (which is the default), the
-  // following TraceMe constructor is simply a conditional test of
-  // false value. Measurements show that its overhead is negligible.
-  port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(),
-                                  op_kernel->IsExpensive());
-  if (port::Tracing::IsActive()) {
-    // TODO(pbar) We really need a useful identifier of the graph node.
-    const uint64 id = Hash64(op_kernel->name());
-    port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
-                                         id);
-    op_kernel->Compute(context);
-  } else {
-    op_kernel->Compute(context);
-  }
+  // When Xprof/ThreadScape profiling is off (which is the default), the
+  // following code is simple enough that its overhead is negligible.
+  tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
+                                   op_kernel->IsExpensive());
+  tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+                               op_kernel->name());
+
+  op_kernel->Compute(context);
 }
 
 Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) {
index 23968e2..e025e55 100644 (file)
@@ -285,7 +285,7 @@ class GrpcMasterService : public AsyncServiceInterface {
 #undef ENQUEUE_REQUEST
 
   // Start tracing, including the ID attached to the RPC.
-  port::Tracing::TraceMe* TraceRpc(
+  tracing::ScopedActivity* TraceRpc(
       StringPiece name,
       const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
     StringPiece id;
@@ -293,7 +293,7 @@ class GrpcMasterService : public AsyncServiceInterface {
     if (it != metadata.end()) {
       id = StringPiece(it->second.data(), it->second.size());
     }
-    return new port::Tracing::TraceMe(name, id);
+    return new tracing::ScopedActivity(name, id);
   }
 
   TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
index 1b92a79..b832a21 100644 (file)
@@ -119,11 +119,11 @@ class GrpcRemoteMaster : public MasterInterface {
 
  private:
   // Start tracing, attaching a unique ID to both the trace and the RPC.
-  port::Tracing::TraceMe TraceRpc(StringPiece name,
-                                  ::grpc::ClientContext* ctx) {
-    string trace_id = strings::StrCat(port::Tracing::UniqueId());
+  tracing::ScopedActivity TraceRpc(StringPiece name,
+                                   ::grpc::ClientContext* ctx) {
+    string trace_id = strings::StrCat(tracing::GetUniqueArg());
     ctx->AddMetadata(GrpcIdKey(), trace_id);
-    return port::Tracing::TraceMe(name, trace_id);
+    return tracing::ScopedActivity(name, trace_id);
   }
 
   void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms) {
index 8d127ba..775d9f6 100644 (file)
@@ -521,7 +521,7 @@ class DatasetIterator : public IteratorBase {
 
   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                  bool* end_of_sequence) final {
-    port::Tracing::TraceMe activity(params_.prefix);
+    tracing::ScopedActivity activity(params_.prefix);
     Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
     if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
       s = errors::Internal(
index 605ef3c..7bc43e2 100644 (file)
@@ -468,7 +468,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
 
       void StartInvocationBatch(IteratorContext* ctx, int64 batch_index)
           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-        port::Tracing::TraceMe activity(strings::StrCat(prefix(), "::Start"));
+        tracing::ScopedActivity activity(strings::StrCat(prefix(), "::Start"));
         // Initialize batch result.
         {
           mutex_lock l(batch_results_[batch_index].mu);
@@ -493,7 +493,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
 
       Status WaitForBatch(int64 batch_index, int64* num_elements)
           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-        port::Tracing::TraceMe activity(strings::StrCat(prefix(), "::Wait"));
+        tracing::ScopedActivity activity(strings::StrCat(prefix(), "::Wait"));
         batch_results_[batch_index].counter->Wait();
         Status status = Status::OK();
         for (size_t i = 0; i < dataset()->batch_size_; ++i, ++*num_elements) {
index f8e0267..8f66f0a 100644 (file)
@@ -324,7 +324,7 @@ class RemoteCallOp : public AsyncOpKernel {
         handle = cached_entry->second;
       } else {
         VLOG(1) << "Instantiating " << func_.name() << " on " << target_device;
-        port::Tracing::TraceMe activity(strings::StrCat(
+        tracing::ScopedActivity activity(strings::StrCat(
             "RemoteCall: Instantiate: ", func_.name(), " on ", target_device));
         OP_REQUIRES_OK_ASYNC(
             ctx,
@@ -355,12 +355,12 @@ class RemoteCallOp : public AsyncOpKernel {
       args.push_back(argument);
     }
     auto* rets = new std::vector<Tensor>;
-    auto* trace = new port::Tracing::TraceMe(strings::StrCat(
+    auto* activity = new tracing::ScopedActivity(strings::StrCat(
         "RemoteCall: Run: ", func_.name(), " on ", target_device));
     VLOG(1) << "Running " << func_.name() << " on " << target_device
             << " with handle: " << handle;
     lib->Run(opts, handle, args, rets,
-             [rets, trace, done, ctx](const Status& status) {
+             [rets, activity, done, ctx](const Status& status) {
                if (!status.ok()) {
                  ctx->SetStatus(status);
                } else {
@@ -369,7 +369,7 @@ class RemoteCallOp : public AsyncOpKernel {
                  }
                }
                delete rets;
-               delete trace;
+               delete activity;
                done();
              });
   }
index e55ed79..99684ae 100644 (file)
@@ -59,10 +59,9 @@ struct EigenEnvironment {
 
   Task CreateTask(std::function<void()> f) {
     uint64 id = 0;
-    if (port::Tracing::IsActive()) {
-      id = port::Tracing::UniqueId();
-      port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
-                                 id);
+    if (tracing::EventCollector::IsEnabled()) {
+      id = tracing::GetUniqueArg();
+      tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
     }
     return Task{
         std::unique_ptr<TaskImpl>(new TaskImpl{
@@ -75,13 +74,9 @@ struct EigenEnvironment {
 
   void ExecuteTask(const Task& t) {
     WithContext wc(t.f->context);
-    if (t.f->trace_id != 0) {
-      port::Tracing::ScopedActivity region(
-          port::Tracing::EventCategory::kRunClosure, t.f->trace_id);
-      t.f->f();
-    } else {
-      t.f->f();
-    }
+    tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
+                                 t.f->trace_id);
+    t.f->f();
   }
 };
 
index 8e60a7f..ccddf1e 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/step_stats_collector.h"
 #include "tensorflow/core/framework/step_stats.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/cupti_wrapper.h"
 #include "tensorflow/core/platform/env.h"
@@ -288,7 +289,7 @@ TF_STATIC_THREAD_LOCAL_POD(const char *, tls_current_annotation);
 
 class DeviceTracerImpl : public DeviceTracer,
                          public CUPTIClient,
-                         public port::Tracing::Engine {
+                         public tracing::TraceCollector {
  public:
   DeviceTracerImpl();
   ~DeviceTracerImpl() override;
@@ -298,25 +299,25 @@ class DeviceTracerImpl : public DeviceTracer,
   Status Stop() override;
   Status Collect(StepStatsCollector *collector) override;
 
-  // port::Tracing::Engine interface:
-  bool IsEnabled() const override {
-    // We only register the Engine while tracing is enabled.
-    return true;
-  }
-  Annotation *PushAnnotation(StringPiece name) override {
-    VLOG(2) << "PushAnnotation " << name;
-    struct Impl : public port::Tracing::Engine::Annotation {
+  // tracing::TraceCollector interface:
+  virtual std::unique_ptr<Handle> CreateAnnotationHandle(
+      StringPiece name_part1, StringPiece name_part2) const {
+    struct Impl : public tracing::TraceCollector::Handle {
       string annotation;
-      explicit Impl(StringPiece n) : annotation(n.ToString()) {
+      explicit Impl(string &&name_scope) : annotation(name_scope) {
+        VLOG(2) << "CreateAnnotationHandle " << annotation;
         // Remember the most recent ScopedAnnotation for each thread.
         tls_current_annotation.get() = annotation.c_str();
       }
       ~Impl() override { tls_current_annotation.get() = nullptr; }
     };
-    return new Impl(name);
+    return std::unique_ptr<Handle>(
+        new Impl{ConcatenateNames(name_part1, name_part2)});
   }
-  Tracer *StartTracing(StringPiece label, bool is_expensive) override {
-    // We don't do anything with 'TraceMe' regions yet.
+
+  virtual std::unique_ptr<Handle> CreateActivityHandle(StringPiece, StringPiece,
+                                                       bool) const {
+    // We don't do anything with 'Activities' yet.
     return nullptr;
   }
 
@@ -410,7 +411,7 @@ Status DeviceTracerImpl::Start() {
   }
 
   // Register as a TraceEngine to receive ScopedAnnotations.
-  port::Tracing::RegisterEngine(this);
+  tracing::SetTraceCollector(this);
 
   // Intercept launch and memcpy calls to capture the Op name annotation.
   // TODO(pbar) Add callbacks for memcpy variants.
@@ -458,7 +459,7 @@ Status DeviceTracerImpl::Stop() {
     return Status::OK();
   }
   CUPTI_CALL(Unsubscribe(subscriber_));
-  port::Tracing::RegisterEngine(nullptr);
+  tracing::SetTraceCollector(nullptr);
   TF_RETURN_IF_ERROR(cupti_manager_->DisableTrace());
   end_walltime_us_ = NowInUsec();
   CUPTI_CALL(GetTimestamp(&end_timestamp_));
index 422564f..3efcef0 100644 (file)
@@ -15,21 +15,33 @@ limitations under the License.
 
 #include "tensorflow/core/platform/tracing.h"
 
-namespace tensorflow {
-namespace port {
-
-void Tracing::RegisterEvent(EventCategory id, const char* name) {
-  // TODO(opensource): implement
-}
+#include <cstdlib>
 
-void Tracing::Initialize() {}
+#ifndef PLATFORM_WINDOWS
+#include <unistd.h>
+#endif
 
-static bool DoInit() {
-  Tracing::Initialize();
-  return true;
+namespace tensorflow {
+namespace tracing {
+namespace {
+bool TryGetEnv(const char* name, const char** value) {
+  *value = getenv(name);
+  return *value != nullptr && (*value)[0] != '\0';
 }
-
-static const bool dummy = DoInit();
-
-}  // namespace port
+}  // namespace
+
+void EventCollector::SetCurrentThreadName(const char*) {}
+
+const char* GetLogDir() {
+  const char* dir;
+  if (TryGetEnv("TEST_TMPDIR", &dir)) return dir;
+  if (TryGetEnv("TMP", &dir)) return dir;
+  if (TryGetEnv("TMPDIR", &dir)) return dir;
+#ifndef PLATFORM_WINDOWS
+  dir = "/tmp";
+  if (access(dir, R_OK | W_OK | X_OK) == 0) return dir;
+#endif
+  return ".";  // Default to current directory.
+}
+}  // namespace tracing
 }  // namespace tensorflow
index 7834548..b161378 100644 (file)
@@ -21,13 +21,8 @@ limitations under the License.
 // IWYU pragma: private, include "third_party/tensorflow/core/platform/tracing.h"
 // IWYU pragma: friend third_party/tensorflow/core/platform/tracing.h
 
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/tracing.h"
 
-namespace tensorflow {
-namespace port {
-
 // Definitions that do nothing for platforms that don't have underlying thread
 // tracing support.
 #define TRACELITERAL(a) \
@@ -40,21 +35,12 @@ namespace port {
   do {                           \
   } while (0)
 
-inline uint64 Tracing::UniqueId() { return random::New64(); }
-inline bool Tracing::IsActive() { return false; }
-inline void Tracing::RegisterCurrentThread(const char* name) {}
-
-// Posts an atomic threadscape event with the supplied category and arg.
-inline void Tracing::RecordEvent(EventCategory category, uint64 arg) {
-  // TODO(opensource): Implement
-}
-
-inline Tracing::ScopedActivity::ScopedActivity(EventCategory category,
-                                               uint64 arg) {}
+namespace tensorflow {
+namespace tracing {
 
-inline Tracing::ScopedActivity::~ScopedActivity() {}
+inline bool EventCollector::IsEnabled() { return false; }
 
-}  // namespace port
+}  // namespace tracing
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
diff --git a/tensorflow/core/platform/posix/tracing.cc b/tensorflow/core/platform/posix/tracing.cc
deleted file mode 100644 (file)
index 1d1aa53..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/platform/tracing.h"
-
-#include <stdlib.h>
-#include <unistd.h>
-
-namespace tensorflow {
-namespace port {
-
-static bool TryGetEnv(const char* name, const char** value) {
-  *value = getenv(name);
-  return *value != nullptr && (*value)[0] != '\0';
-}
-
-const char* Tracing::LogDir() {
-  const char* dir;
-  if (TryGetEnv("TEST_TMPDIR", &dir)) return dir;
-  if (TryGetEnv("TMP", &dir)) return dir;
-  if (TryGetEnv("TMPDIR", &dir)) return dir;
-  dir = "/tmp";
-  if (access(dir, R_OK | W_OK | X_OK) == 0) return dir;
-  return ".";  // Default to current directory.
-}
-
-}  // namespace port
-}  // namespace tensorflow
index f7d2a8e..c0386c0 100644 (file)
@@ -15,24 +15,24 @@ limitations under the License.
 
 #include "tensorflow/core/platform/tracing.h"
 
+#include <array>
 #include <atomic>
 #include <map>
 #include <string>
 #include <vector>
+#include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace tensorflow {
+namespace tracing {
+namespace {
+std::atomic<uint64> unique_arg{1};
+std::atomic<const TraceCollector*> trace_collector;
+}  // namespace
 
-namespace port {
-
-int32 Tracing::category_id_[kEventCategoryMax];
-uint64 Tracing::event_mask_ = 0;
-std::map<string, int32>* Tracing::name_map_ = new std::map<string, int32>;
-
-// This needs to be kept in sync with the EventCategory enumeration.
-const char* Tracing::EventCategoryString(EventCategory category) {
+const char* GetEventCategoryName(EventCategory category) {
   switch (category) {
     case EventCategory::kScheduleClosure:
       return "ScheduleClosure";
@@ -40,63 +40,45 @@ const char* Tracing::EventCategoryString(EventCategory category) {
       return "RunClosure";
     case EventCategory::kCompute:
       return "Compute";
-    case EventCategory::kEventCategoryMax:
-      return "EventCategoryMax";
+    default:
+      return "Unknown";
   }
-  return "Unknown";
 }
 
-// This function allows the user to specify arbitrary subsets of the
-// supported Threadscape events and activities.
-bool Tracing::ParseEventMask(const char* flagname, const string& value) {
-  VLOG(1) << flagname << " set to " << value;
-  int64 new_mask = 0;
-  std::vector<string> events =
-      str_util::Split(value, ',', str_util::SkipEmpty());
-  for (string name : events) {
-    bool clear = false;
-    int64 mask = 0;
-    if (name[0] == '!') {
-      // invert the sense of the flag
-      clear = true;
-      name = name.substr(1);
-    }
-    if (name == "ALL") {
-      mask = ~0;
-    } else {
-      auto it = name_map_->find(name);
-      int32 id;
-      if (it == name_map_->end()) {
-        id = -1;
-      } else {
-        id = it->second;
-      }
-      if (id < 0) {
-        LOG(ERROR) << "Can't parse event mask name " << name;
-        return false;
-      }
-      mask = 1 << id;
-    }
-    if (clear) {
-      new_mask &= ~mask;
-    } else {
-      new_mask |= mask;
-    }
-  }
-  // parsing was successful; set the permanent event mask
-  event_mask_ = new_mask;
-  return true;
+std::array<const EventCollector*, GetNumEventCategories()>
+    EventCollector::instances_;
+
+void SetEventCollector(EventCategory category,
+                       const EventCollector* collector) {
+  EventCollector::instances_[static_cast<unsigned>(category)] = collector;
+}
+
+uint64 GetUniqueArg() {
+  return unique_arg.fetch_add(1, std::memory_order_relaxed);
 }
 
-/*static*/ std::atomic<Tracing::Engine*> Tracing::tracing_engine_;
+uint64 GetArgForName(StringPiece name) {
+  return Hash64(name.data(), name.size());
+}
 
-void Tracing::RegisterEngine(Engine* e) {
-  tracing_engine_.store(e, std::memory_order_release);
+string TraceCollector::ConcatenateNames(StringPiece first, StringPiece second) {
+  std::string result;
+  bool has_two_parts = !first.empty() && !second.empty();
+  result.reserve(first.size() + second.size() +
+                 static_cast<int>(has_two_parts));
+  result.append(first.data(), first.size());
+  if (has_two_parts) result.append({':'});
+  result.append(second.data(), second.size());
+  return result;
 }
 
-Tracing::Engine::~Engine() {}
-Tracing::Engine::Annotation::~Annotation() {}
-Tracing::Engine::Tracer::~Tracer() {}
+void SetTraceCollector(const TraceCollector* collector) {
+  return trace_collector.store(collector, std::memory_order_release);
+}
+
+const TraceCollector* GetTraceCollector() {
+  return trace_collector.load(std::memory_order_acquire);
+}
 
-}  // namespace port
+}  // namespace tracing
 }  // namespace tensorflow
index 3c6e7b0..c322777 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 
 // Tracing interface
 
+#include <array>
 #include <atomic>
 #include <map>
 #include <memory>
@@ -30,255 +31,205 @@ limitations under the License.
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
+namespace tracing {
+
+// This enumeration contains the identifiers of all TensorFlow CPU profiler
+// events. It must be kept in sync with the code in GetEventCategoryName().
+enum struct EventCategory : unsigned {
+  kScheduleClosure = 0,
+  kRunClosure = 1,
+  kCompute = 2,
+  kNumCategories = 3  // sentinel - keep last
+};
+constexpr unsigned GetNumEventCategories() {
+  return static_cast<unsigned>(EventCategory::kNumCategories);
+}
+const char* GetEventCategoryName(EventCategory);
 
-namespace port {
-
-class Tracing {
+// Interface for CPU profiler events.
+class EventCollector {
  public:
-  // This enumeration contains the identifiers of all TensorFlow
-  // threadscape events and code regions.  Threadscape assigns its
-  // own identifiers at runtime when we register our events and we
-  // cannot know in advance what IDs it will choose.  The "RecordEvent"
-  // method and "ScopedActivity" use these event IDs for consistency
-  // and remap them to threadscape IDs at runtime.  This enum is limited
-  // to 64 values since we use a bitmask to configure which events are
-  // enabled.  It must also be kept in step with the code in
-  // "Tracing::EventCategoryString".
-  enum EventCategory {
-    kScheduleClosure = 0,
-    kRunClosure = 1,
-    kCompute = 2,
-    kEventCategoryMax = 3  // sentinel - keep last
-  };
-  // Note: We currently only support up to 64 categories.
-  static_assert(kEventCategoryMax <= 64, "only support up to 64 events");
+  virtual ~EventCollector() {}
+  virtual void RecordEvent(uint64 arg) const = 0;
+  virtual void StartRegion(uint64 arg) const = 0;
+  virtual void StopRegion() const = 0;
 
-  // Called by main programs to initialize tracing facilities
-  static void Initialize();
+  // Annotates the current thread with a name.
+  static void SetCurrentThreadName(const char* name);
+  // Returns whether event collection is enabled.
+  static bool IsEnabled();
 
-  // Return the pathname of the directory where we are writing log files.
-  static const char* LogDir();
+ private:
+  friend void SetEventCollector(EventCategory, const EventCollector*);
+  friend const EventCollector* GetEventCollector(EventCategory);
 
-  // Returns a non-zero identifier which can be used to correlate
-  // related events.
-  static inline uint64 UniqueId();
+  static std::array<const EventCollector*, GetNumEventCategories()> instances_;
+};
+// Set the callback for RecordEvent and ScopedRegion of category.
+// Not thread safe. Only call while EventCollector::IsEnabled returns false.
+void SetEventCollector(EventCategory category, const EventCollector* collector);
+
+// Returns the callback for RecordEvent and ScopedRegion of category if
+// EventCollector::IsEnabled(), otherwise returns null.
+inline const EventCollector* GetEventCollector(EventCategory category) {
+  if (EventCollector::IsEnabled()) {
+    return EventCollector::instances_[static_cast<unsigned>(category)];
+  }
+  return nullptr;
+}
 
-  // Returns true if a trace is in progress.  Can be used to reduce tracing
-  // overheads in fast-path code.
-  static inline bool IsActive();
+// Returns a unique id to pass to RecordEvent/ScopedRegion. Never returns zero.
+uint64 GetUniqueArg();
 
-  // Associate name with the current thread.
-  static void RegisterCurrentThread(const char* name);
+// Returns an id for name to pass to RecordEvent/ScopedRegion.
+uint64 GetArgForName(StringPiece name);
 
-  // Posts an event with the supplied category and arg.
-  static void RecordEvent(EventCategory category, uint64 arg);
+// Records an atomic event through the currently registered EventCollector.
+inline void RecordEvent(EventCategory category, uint64 arg) {
+  if (auto collector = GetEventCollector(category)) {
+    collector->RecordEvent(arg);
+  }
+}
 
-  // Traces a region of code.  Posts a tracing "EnterCodeRegion" event
-  // when created and an "ExitCodeRegion" event when destroyed.
-  class ScopedActivity {
-   public:
-    explicit ScopedActivity(EventCategory category, uint64 arg);
-    ~ScopedActivity();
+// Records an event for the duration of the instance lifetime through the
+// currently registered EventCollector.
+class ScopedRegion {
+  ScopedRegion(ScopedRegion&) = delete;             // Not copy-constructible.
+  ScopedRegion& operator=(ScopedRegion&) = delete;  // Not assignable.
 
  private:
-#if defined(PLATFORM_GOOGLE)
-    const bool enabled_;
-    const int32 region_id_;
-#endif
public:
+  ScopedRegion(ScopedRegion&& other) noexcept  // Move-constructible.
+      : collector_(other.collector_) {
+    other.collector_ = nullptr;
+  }
 
-    TF_DISALLOW_COPY_AND_ASSIGN(ScopedActivity);
-  };
+  ScopedRegion(EventCategory category, uint64 arg)
+      : collector_(GetEventCollector(category)) {
+    if (collector_) {
+      collector_->StartRegion(arg);
+    }
+  }
 
-  // Trace collection engine can be registered with this module.
-  // If no engine is registered, ScopedAnnotation and TraceMe are no-ops.
-  class Engine;
-  static void RegisterEngine(Engine*);
+  // Same as ScopedRegion(category, GetUniqueArg()), but faster if
+  // EventCollector::IsEnaled() returns false.
+  ScopedRegion(EventCategory category)
+      : collector_(GetEventCollector(category)) {
+    if (collector_) {
+      collector_->StartRegion(GetUniqueArg());
+    }
+  }
 
-  // Forward declaration of the GPU utility classes.
-  class ScopedAnnotation;
-  class TraceMe;
+  // Same as ScopedRegion(category, GetArgForName(name)), but faster if
+  // EventCollector::IsEnaled() returns false.
+  ScopedRegion(EventCategory category, StringPiece name)
+      : collector_(GetEventCollector(category)) {
+    if (collector_) {
+      collector_->StartRegion(GetArgForName(name));
+    }
+  }
 
- private:
-  friend class TracingTest;
-  friend class ScopedAnnotation;
-  friend class TraceMe;
-
-  // TODO: TF_EXPORT is for building //tensorflow/contrib/data:_dataset_ops.so
-  //       on Windows. Figure out a way to remove TF_EXPORT here.
-  TF_EXPORT static std::atomic<Tracing::Engine*> tracing_engine_;
-  static Tracing::Engine* engine() {
-    return tracing_engine_.load(std::memory_order_acquire);
+  ~ScopedRegion() {
+    if (collector_) {
+      collector_->StopRegion();
+    }
   }
 
-  static void RegisterEvent(EventCategory id, const char* name);
-  static const char* EventCategoryString(EventCategory category);
-
-  //
-  // Parses event mask expressions in 'value' of the form:
-  //   expr ::= <term> (,<term>)*
-  //   term ::= <event> | "!" <event>
-  //   event ::= "ALL" | <wait_event> | <other_event>
-  //   wait_event ::= "ENewSession" | "ECloseSession" | ...
-  //   other_event ::= "Send" | "Wait" | ...
-  // ALL denotes all events, <event> turns on tracing for this event, and
-  // !<event> turns off tracing for this event.
-  // If the expression can be parsed correctly it returns true and sets
-  // the event_mask_. Otherwise it returns false and the event_mask_ is left
-  // unchanged.
-  static bool ParseEventMask(const char* flagname, const string& value);
-
-  // Bit mask of enabled trace categories.
-  static uint64 event_mask_;
-
-  // Records the mappings between Threadscape IDs and the "EventCategory" enum.
-  static int32 category_id_[kEventCategoryMax];
-  static std::map<string, int32>* name_map_;
+  bool IsEnabled() const { return collector_ != nullptr; }
+
+ private:
+  const EventCollector* collector_;
 };
 
-// Trace collection engine that actually implements collection.
-class Tracing::Engine {
+// Interface for accelerator profiler annotations.
+class TraceCollector {
  public:
-  Engine() {}
-  virtual ~Engine();
-
-  // Returns true if Tracing is currently enabled.
-  virtual bool IsEnabled() const = 0;
-
-  // Represents an active annotation.
-  class Annotation {
+  class Handle {
    public:
-    Annotation() {}
-    virtual ~Annotation();
+    virtual ~Handle() {}
   };
 
-  // Represents an active trace.
-  class Tracer {
-   public:
-    Tracer() {}
-    virtual ~Tracer();
-  };
+  virtual ~TraceCollector() {}
+  virtual std::unique_ptr<Handle> CreateAnnotationHandle(
+      StringPiece name_part1, StringPiece name_part2) const = 0;
+  virtual std::unique_ptr<Handle> CreateActivityHandle(
+      StringPiece name_part1, StringPiece name_part2,
+      bool is_expensive) const = 0;
 
- private:
-  friend class ScopedAnnotation;
-  friend class TraceMe;
-
-  // Register the specified name as an annotation on the current thread.
-  // Caller should delete the result to remove the annotation.
-  // Annotations from the same thread are destroyed in a LIFO manner.
-  // May return nullptr if annotations are not supported.
-  virtual Annotation* PushAnnotation(StringPiece name) = 0;
-
-  // Start tracing under the specified label. Caller should delete the result
-  // to stop tracing.
-  // May return nullptr if tracing is not supported.
-  virtual Tracer* StartTracing(StringPiece label, bool is_expensive) = 0;
-  // Same as above, but implementations can avoid copying the string.
-  virtual Tracer* StartTracing(string&& label, bool is_expensive) {
-    return StartTracing(StringPiece(label), is_expensive);
-  }
+ protected:
+  static string ConcatenateNames(StringPiece first, StringPiece second);
 
-  // Backwards compatibility one arg variants (assume is_expensive=true).
-  Tracer* StartTracing(StringPiece label) {
-    return StartTracing(label, /*is_expensive=*/true);
-  }
-  Tracer* StartTracing(string&& label) {
-    return StartTracing(StringPiece(label), /*is_expensive=*/true);
-  }
+ private:
+  friend void SetTraceCollector(const TraceCollector*);
+  friend const TraceCollector* GetTraceCollector();
 };
+// Set the callback for ScopedAnnotation and ScopedActivity.
+void SetTraceCollector(const TraceCollector* collector);
+// Returns the callback for ScopedAnnotation and ScopedActivity.
+const TraceCollector* GetTraceCollector();
 
-// This class permits a user to apply annotation on kernels and memcpys
-// when launching them. While an annotation is in scope, all activities
-// within that scope get their names replaced by the annotation. The kernel
-// name replacement is done when constructing the protobuf for sending out to
-// a client (e.g., the stubby requestor) for both API and Activity records.
-//
-// Ownership: The creator of ScopedAnnotation assumes ownership of the object.
+// Adds an annotation to all activities for the duration of the instance
+// lifetime through the currently registered TraceCollector.
 //
 // Usage: {
-//          ScopedAnnotation annotation("first set of kernels");
+//          ScopedAnnotation annotation("my kernels");
 //          Kernel1<<<x,y>>>;
-//          LaunchKernel2(); // Which eventually launches a cuda kernel.
+//          LaunchKernel2(); // Launches a CUDA kernel.
 //        }
-// In the above scenario, the GPUProf UI would show 2 kernels with the name
-// "first set of kernels" executing -- they will appear as the same kernel.
-class Tracing::ScopedAnnotation {
+// This will add 'my kernels' to both kernels in the profiler UI
+class ScopedAnnotation {
  public:
-  explicit ScopedAnnotation(StringPiece name);
+  explicit ScopedAnnotation(StringPiece name)
+      : ScopedAnnotation(name, StringPiece()) {}
 
-  // If tracing is enabled, set up an annotation with a label of
-  // "<name_part1>:<name_part2>".  Can be cheaper than the
+  // If tracing is enabled, add a name scope of
+  // "<name_part1>:<name_part2>".  This can be cheaper than the
   // single-argument constructor because the concatenation of the
   // label string is only done if tracing is enabled.
-  ScopedAnnotation(StringPiece name_part1, StringPiece name_part2);
+  ScopedAnnotation(StringPiece name_part1, StringPiece name_part2)
+      : handle_([&] {
+          auto trace_collector = GetTraceCollector();
+          return trace_collector ? trace_collector->CreateAnnotationHandle(
+                                       name_part1, name_part2)
+                                 : nullptr;
+        }()) {}
 
-  // Returns true iff scoped annotations are active.
-  static bool Enabled() {
-    auto e = Tracing::engine();
-    return e && e->IsEnabled();
-  }
+  bool IsEnabled() const { return static_cast<bool>(handle_); }
 
  private:
-  std::unique_ptr<Engine::Annotation> annotation_;
+  std::unique_ptr<TraceCollector::Handle> handle_;
 };
 
-// TODO(opensource): clean up the scoped classes for GPU tracing.
-// This class permits user-specified (CPU) tracing activities. A trace
-// activity is started when an object of this class is created and stopped
-// when the object is destroyed.
-class Tracing::TraceMe {
+// Adds an activity through the currently registered TraceCollector.
+// The activity starts when an object of this class is created and stops when
+// the object is destroyed.
+class ScopedActivity {
  public:
-  explicit TraceMe(StringPiece name);
-  TraceMe(StringPiece name, bool is_expensive);
+  explicit ScopedActivity(StringPiece name, bool is_expensive = true)
+      : ScopedActivity(name, StringPiece(), is_expensive) {}
 
-  // If tracing is enabled, set up a traceMe with a label of
+  // If tracing is enabled, set up an activity with a label of
   // "<name_part1>:<name_part2>".  This can be cheaper than the
   // single-argument constructor because the concatenation of the
   // label string is only done if tracing is enabled.
-  TraceMe(StringPiece name_part1, StringPiece name_part2);
-  TraceMe(StringPiece name_part1, StringPiece name_part2, bool is_expensive);
+  ScopedActivity(StringPiece name_part1, StringPiece name_part2,
+                 bool is_expensive = true)
+      : handle_([&] {
+          auto trace_collector = GetTraceCollector();
+          return trace_collector ? trace_collector->CreateActivityHandle(
+                                       name_part1, name_part2, is_expensive)
+                                 : nullptr;
+        }()) {}
+
+  bool IsEnabled() const { return static_cast<bool>(handle_); }
 
  private:
-  std::unique_ptr<Engine::Tracer> tracer_;
+  std::unique_ptr<TraceCollector::Handle> handle_;
 };
 
-inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name) {
-  auto e = Tracing::engine();
-  if (e && e->IsEnabled()) {
-    annotation_.reset(e->PushAnnotation(name));
-  }
-}
-
-inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name_part1,
-                                                   StringPiece name_part2) {
-  auto e = Tracing::engine();
-  if (e && e->IsEnabled()) {
-    annotation_.reset(
-        e->PushAnnotation(strings::StrCat(name_part1, ":", name_part2)));
-  }
-}
-
-inline Tracing::TraceMe::TraceMe(StringPiece name) : TraceMe(name, true) {}
-
-inline Tracing::TraceMe::TraceMe(StringPiece name, bool is_expensive) {
-  auto e = Tracing::engine();
-  if (e && e->IsEnabled()) {
-    tracer_.reset(e->StartTracing(name, is_expensive));
-  }
-}
-
-inline Tracing::TraceMe::TraceMe(StringPiece name_part1, StringPiece name_part2)
-    : TraceMe(name_part1, name_part2, true) {}
-
-inline Tracing::TraceMe::TraceMe(StringPiece name_part1, StringPiece name_part2,
-                                 bool is_expensive) {
-  auto e = Tracing::engine();
-  if (e && e->IsEnabled()) {
-    tracer_.reset(e->StartTracing(strings::StrCat(name_part1, ":", name_part2),
-                                  is_expensive));
-  }
-}
+// Return the pathname of the directory where we are writing log files.
+const char* GetLogDir();
 
-}  // namespace port
+}  // namespace tracing
 }  // namespace tensorflow
 
 #if defined(PLATFORM_GOOGLE)