[XLA:python] Plumb hlo_profile flag.
authorChris Leary <leary@google.com>
Fri, 16 Mar 2018 19:34:34 +0000 (12:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Mar 2018 19:38:29 +0000 (12:38 -0700)
PiperOrigin-RevId: 189377860

tensorflow/compiler/xla/client/executable_build_options.cc
tensorflow/compiler/xla/client/executable_build_options.h
tensorflow/compiler/xla/python/local_computation_builder.i
tensorflow/compiler/xla/python/xla_client.py
tensorflow/compiler/xla/service/local_service.cc

index 804e34f..d84f201 100644 (file)
@@ -76,4 +76,11 @@ ExecutableBuildOptions::generate_hlo_graph() const {
   return generate_hlo_graph_;
 }
 
+ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) {
+  hlo_profile_ = enabled;
+  return *this;
+}
+
+bool ExecutableBuildOptions::hlo_profile() const { return hlo_profile_; }
+
 }  // namespace xla
index 3a52dba..3e18e5d 100644 (file)
@@ -57,11 +57,17 @@ class ExecutableBuildOptions {
   ExecutableBuildOptions& set_generate_hlo_graph(string regex);
   const tensorflow::gtl::optional<string>& generate_hlo_graph() const;
 
+  // If set, specifies that we should record an HLO profile during execution and
+  // log it after execution (as in DebugOptions).
+  ExecutableBuildOptions& set_hlo_profile(bool enabled);
+  bool hlo_profile() const;
+
   // Returns a string representation of the build options, suitable for
   // debugging.
   string ToString() const;
 
  private:
+  bool hlo_profile_ = false;
   int device_ordinal_ = -1;
   Shape result_layout_;
   bool result_layout_set_ = false;
index b2681d5..ca91cf0 100644 (file)
@@ -833,6 +833,19 @@ tensorflow::ImportNumpy();
     }
     Py_DECREF(o);
 
+    o = PyObject_GetAttrString($input, "hlo_profile");
+    if (o == NULL) {
+      return NULL;
+    }
+    if (o != Py_None) {
+      if (!PyBool_Check(o)) {
+        PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None.");
+        return NULL;
+      }
+      build_options.set_hlo_profile(o == Py_True);
+    }
+    Py_DECREF(o);
+
     o = PyObject_GetAttrString($input, "result_shape");
     if (o == nullptr) {
       return nullptr;
index 90cda42..d747a0b 100644 (file)
@@ -320,6 +320,7 @@ class CompileOptions(object):
 
   def __init__(self):
     self.generate_hlo_graph = None
+    self.hlo_profile = False
 
 
 def transfer_to_infeed(value, replica_number=None):
index 07f989d..74aa6ea 100644 (file)
@@ -119,6 +119,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
   }
 
   ExecutionOptions execution_options = CreateDefaultExecutionOptions();
+  execution_options.mutable_debug_options()->set_xla_hlo_profile(
+      build_options.hlo_profile());
   if (build_options.generate_hlo_graph().has_value()) {
     execution_options.mutable_debug_options()->set_xla_generate_hlo_graph(
         build_options.generate_hlo_graph().value());