[TF:XLA] Remove underscore prefix from XlaLaunch operator.
authorPeter Hawkins <phawkins@google.com>
Fri, 18 May 2018 18:43:25 +0000 (11:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 18:46:04 +0000 (11:46 -0700)
Minor fixes to comments.

PiperOrigin-RevId: 197177582

14 files changed:
tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/ops/xla_ops.cc
tensorflow/compiler/jit/xla_compilation_cache.cc
tensorflow/compiler/jit/xla_compile_on_demand_op.h
tensorflow/compiler/jit/xla_device_ops.h
tensorflow/compiler/tests/dense_layer_test.py
tensorflow/compiler/tests/jit_test.py
tensorflow/compiler/tf2xla/xla_compiler.h
tensorflow/compiler/tf2xla/xla_op_registry.cc
tensorflow/core/common_runtime/eager/execute.cc
tensorflow/core/grappler/optimizers/dependency_optimizer.cc
tensorflow/docs_src/performance/xla/jit.md

index 9a2bb00..b17ff58 100644 (file)
@@ -40,7 +40,7 @@ static Status BuildLaunchNode(
     Graph* graph, Node** node) {
   NodeDef def;
   def.set_name(graph->NewName(nodename));
-  def.set_op("_XlaLaunch");
+  def.set_op("XlaLaunch");
   def.set_device(device_name);
   AddNodeAttr("Tconstants", constant_dtypes, &def);
   AddNodeAttr("Targs", arg_dtypes, &def);
@@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
       node->input_types().begin() + num_constant_args,
       node->input_types().begin() + num_constant_args + num_nonconst_args);
 
-  // Build a _XlaLaunch operator to execute the function body.
+  // Build a XlaLaunch operator to execute the function body.
   Node* launch_node;
   TF_RETURN_IF_ERROR(BuildLaunchNode(
       graph->NewName(node->name()), node->type_string(), node->def().attr(),
index 34be440..5fee36f 100644 (file)
@@ -80,7 +80,7 @@ Status EncapsulateSubgraphsInFunctions(
     std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
 
 // The attribute that marks function calls produced by the encapsulate
-// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
+// subgraphs pass and that should in turn be compiled via XlaLaunch operators.
 extern const char* const kXlaCompiledKernelAttr;
 
 // Does `node` have the kXlaCompiledKernelAttr attribute?
index 9d85634..27287e0 100644 (file)
@@ -256,10 +256,9 @@ XlaLocalLaunchOp::~XlaLocalLaunchOp() {
   VLOG(1) << "XlaLocalLaunchOp destroyed";
 }
 
-REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU),
-                        XlaLocalLaunchOp);
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
 
-REGISTER_KERNEL_BUILDER(Name("_XlaLaunch")
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
                             .Device(DEVICE_GPU)
                             .HostMemory("constants")
                             .HostMemory("resources"),
index 07320b4..f2473d9 100644 (file)
@@ -17,7 +17,7 @@ limitations under the License.
 
 namespace tensorflow {
 
-REGISTER_OP("_XlaLaunch")
+REGISTER_OP("XlaLaunch")
     .Input("constants: Tconstants")
     .Attr("Tconstants: list(type) >= 0")
     .Input("args: Targs")
@@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch")
     .Attr("Tresults: list(type) >= 0")
     .Attr("function: func")
     // XLA random-number generation ops are stateful.
-    // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch.
+    // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch.
     .SetIsStateful()
     .Doc("XLA Launch Op. For use by the XLA JIT only.");
 
index 6430975..7ed609c 100644 (file)
@@ -122,8 +122,7 @@ Status XlaCompilationCache::BuildSignature(
 
 namespace {
 
-// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch
-// op.
+// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op.
 Status BuildArguments(const std::map<int, Tensor>& constant_args,
                       const std::map<int, OptionalTensor>& variable_args,
                       OpKernelContext* ctx,
index 23c6f39..7cc3d0e 100644 (file)
@@ -29,11 +29,8 @@ limitations under the License.
 namespace tensorflow {
 
 // An OpKernel that compiles an op to an XLA computation and runs it. Unlike
-// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a
+// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a
 // vanilla TensorFlow op as long as the bridge supports it.
-//
-// Importantly _XlaLaunch assumes all input and output tensors are on the host,
-// whereas XlacompileOnDemandOp works with tensors in device memory.
 class XlaCompileOnDemandOp : public OpKernel {
  public:
   explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
index 65c0e85..9c00a06 100644 (file)
@@ -33,7 +33,7 @@ namespace tensorflow {
 
 // Dummy OpKernel, used for kernels assigned to an XLA device that should be
 // compiled. Should never be called at runtime since such ops should be
-// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an
+// rewritten to a XlaLaunch op. If it is called, it means the placer placed an
 // operator on an XLA device but the compiler did not compile it.
 class XlaDeviceDummyOp : public OpKernel {
  public:
@@ -42,7 +42,7 @@ class XlaDeviceDummyOp : public OpKernel {
 };
 
 #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
-  REGISTER_KERNEL_BUILDER(Name("_XlaLaunch")              \
+  REGISTER_KERNEL_BUILDER(Name("XlaLaunch")               \
                               .Device(DEVICE)             \
                               .HostMemory("constants")    \
                               .HostMemory("resources"),   \
index b0bf1b7..865f60c 100644 (file)
@@ -46,8 +46,8 @@ def InLabels(labels, substr):
 
 
 def XlaLaunchOpCount(labels):
-  """Count how many _XlaLaunch labels are present."""
-  return sum("_XlaLaunch(" in x for x in labels)
+  """Count how many XlaLaunch labels are present."""
+  return sum("XlaLaunch(" in x for x in labels)
 
 
 class DenseLayerTest(test.TestCase):
@@ -55,7 +55,7 @@ class DenseLayerTest(test.TestCase):
   def testDenseLayerAutoJit(self):
     """Tests dense layer compilation in auto-jit mode.
 
-    Dense layer should be compiled into a single _XlaLaunch op in auto-jit mode.
+    Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
     """
 
     os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
@@ -83,7 +83,7 @@ class DenseLayerTest(test.TestCase):
     """Tests that the dense layer node is properly compiled in jit scope.
 
     Dense layer with static shape input tensor should be compiled into a single
-    _XlaLaunch op by XLA.
+    XlaLaunch op by XLA.
     """
 
     with self.test_session() as sess:
@@ -110,7 +110,7 @@ class DenseLayerTest(test.TestCase):
     Dense layer uses shape op to get shape of input tensor if its shape is not
     fully defined. XLA does not cluster shape op with other operators. But in
     experimental_jit_scope, XLA is forced to compile shape op into its own
-    cluster, causing dense layer to be split into TWO _XlaLaunch ops.
+    cluster, causing dense layer to be split into TWO XlaLaunch ops.
     """
 
     with self.test_session() as sess:
index 0310cdd..4b0043b 100644 (file)
@@ -78,10 +78,10 @@ def InLabels(labels, substr):
 
 
 def MetadataHasXlaLaunch(run_metadata):
-  """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline."""
+  """Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
 
   # TODO(phawkins): find a less hacky way to test whether a kernel ran.
-  return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch")
+  return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
 
 
 class JitLaunchTest(test.TestCase):
@@ -90,8 +90,8 @@ class JitLaunchTest(test.TestCase):
   # Verifies that the outputs match and that XLA was invoked. 'fn' must take
   # the same number of tensors as arguments that are in 'args', and must return
   # a tuple of output tensors.
-  # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node
-  # actually ran. However, it is sometimes possible for _XlaLaunch ops to be
+  # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
+  # actually ran. However, it is sometimes possible for XlaLaunch ops to be
   # constant-folded away, so the check is optional.
   def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
     with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
@@ -441,14 +441,14 @@ class XlaCompilationTest(test.TestCase):
     self.assertFalse(InLabels(labels, "Log"))
     self.assertTrue(InLabels(labels, "Reciprocal"))
     self.assertTrue(InLabels(labels, "Mul"))
-    self.assertFalse(InLabels(labels, "_XlaLaunch"))
+    self.assertFalse(InLabels(labels, "XlaLaunch"))
 
-    # Compile the backprop. One _XlaLaunch.
+    # Compile the backprop. One XlaLaunch.
     labels = _Run(compiled=True)
     self.assertFalse(InLabels(labels, "Log"))
     self.assertFalse(InLabels(labels, "Reciprocal"))
     self.assertFalse(InLabels(labels, "Mul"))
-    self.assertTrue(InLabels(labels, "_XlaLaunch"))
+    self.assertTrue(InLabels(labels, "XlaLaunch"))
 
 
 class ElementWiseFusionTest(test.TestCase):
@@ -482,7 +482,7 @@ class ElementWiseFusionTest(test.TestCase):
               trace_level=config_pb2.RunOptions.FULL_TRACE))
 
       labels = RunMetadataLabels(run_metadata)
-      count = sum("_XlaLaunch(" in x for x in labels)
+      count = sum("XlaLaunch(" in x for x in labels)
 
       return output, count
 
index 621fbc1..bf496bd 100644 (file)
@@ -38,7 +38,7 @@ class XlaContext;
 // It does a symbolic execution of the graph starting from specific input
 // shapes, using a JIT device to convert operators into XLA computations.
 //
-// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the
+// XlaCompiler is typically invoked from an `XlaLaunch` operator once the
 // shapes of all input parameters to the computation are known. This is
 // because the symbolic execution requires known shapes for all operations.
 //
index e309cb1..4692038 100644 (file)
@@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU";
 
 static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) {
   const OpDef* op_def;
-  TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def));
+  TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def));
   NodeDef node_def;
   node_def.set_name("_XlaLaunch-op");
-  node_def.set_op("_XlaLaunch");
+  node_def.set_op("XlaLaunch");
   string kernel_class_name;
   TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr,
                                    &kernel_class_name));
index 1df4996..ce989f4 100644 (file)
@@ -186,14 +186,14 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
 // primitive op (e.g. matmul).
 //
 // The wrapper function conforms to the function signature expected by
-// _XlaLaunchOp, with input params ordered by <constants, (variable) args and
+// XlaLaunch, with input params ordered by <constants, (variable) args and
 // resources>. For example, if the op has input params <Const1, Arg2, Const3,
 // Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
 // Resource4> as the input params to the synthesized function.
 //
 // It populates `const_input_types`, `arg_input_types` and
 // `op_input_to_func_input` based on the reordering results, that the caller can
-// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
+// use them to build an XlaLaunch. On error, it returns NULL, and sets
 // `status` accordingly.
 const FunctionDef* OpToFunction(TFE_Op* op,
                                 std::vector<TF_DataType>* const_input_types,
@@ -311,12 +311,12 @@ const FunctionDef* OpToFunction(TFE_Op* op,
   return ret;
 }
 
-// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
+// Builds an XlaLaunch as a wrapper over 'op', so that 'op' can be executed
 // via XLA.
 std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
-  VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->operation.Name();
+  VLOG(1) << "Creating XlaLaunch for TFE_Op " << op->operation.Name();
   auto launch_op = std::unique_ptr<TFE_Op>(
-      TFE_NewOp(op->operation.ctx, "_XlaLaunch", status));
+      TFE_NewOp(op->operation.ctx, "XlaLaunch", status));
   if (TF_GetCode(status) != TF_OK) return nullptr;
   if (op->operation.device) {
     TFE_OpSetDevice(launch_op.get(), op->operation.device->name().c_str(),
@@ -331,7 +331,7 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
   gtl::FlatMap<int, int> op_input_to_func_input;
   if (fdef == nullptr) {
     // See if this is a primitive op, and if so create a function for it, so
-    // that _XlaLaunchOp can access it.
+    // that XlaLaunch can access it.
     fdef = OpToFunction(op, &const_input_types, &arg_input_types,
                         &op_input_to_func_input, status);
     if (!status.ok()) return nullptr;
@@ -423,7 +423,7 @@ Status EagerLocalExecute(EagerOperation* op,
   if (!status.ok()) return status;
 #ifdef TENSORFLOW_EAGER_USE_XLA
   std::unique_ptr<TFE_Op> xla_launch_op;
-  if (op->UseXla() && op->Name() != "_XlaLaunch") {
+  if (op->UseXla() && op->Name() != "XlaLaunch") {
     xla_launch_op = BuildXlaLaunch(op, status);
     if (!status.ok()) return status;
     op = xla_launch_op.get();
index 7b7fd81..200454b 100644 (file)
@@ -126,9 +126,9 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
     return false;
   }
   const std::unordered_set<string> do_not_rewrite_ops{
-      "Assert",      "CheckNumerics",         "_Retval",
-      "_Arg",        "_ParallelConcatUpdate", "_TPUExecute",
-      "_TPUCompile", "ControlTrigger"};
+      "Assert",     "CheckNumerics",         "_Retval",
+      "_Arg",       "_ParallelConcatUpdate", "TPUExecute",
+      "TPUCompile", "ControlTrigger"};
   if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
     return false;
   }
index d9a979c..6724d1e 100644 (file)
@@ -137,12 +137,12 @@ TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python mnist_softmax_xla.py
 ```
 
 Open the timeline file created (`timeline.ctf.json`).  The rendered timeline
-should look similar to the picture below with one long bar labeled `_XlaLaunch`.
+should look similar to the picture below with one long bar labeled `XlaLaunch`.
 <div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
   <img style="width:100%" src="https://www.tensorflow.org/images/jit_timeline_gpu_xla.png">
 </div>
 
-To understand what is happening in `_XlaLaunch`, look at the console output for
+To understand what is happening in `XlaLaunch`, look at the console output for
 statements similar to the following:
 
 ```shell