From 161ba8ecc433c4ddbdbf88eb3a7a0d38bb253d0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 18 May 2018 11:43:25 -0700 Subject: [PATCH] [TF:XLA] Remove underscore prefix from XlaLaunch operator. Minor fixes to comments. PiperOrigin-RevId: 197177582 --- tensorflow/compiler/jit/build_xla_launch_ops_pass.cc | 4 ++-- tensorflow/compiler/jit/encapsulate_subgraphs_pass.h | 2 +- tensorflow/compiler/jit/kernels/xla_launch_op.cc | 5 ++--- tensorflow/compiler/jit/ops/xla_ops.cc | 4 ++-- tensorflow/compiler/jit/xla_compilation_cache.cc | 3 +-- tensorflow/compiler/jit/xla_compile_on_demand_op.h | 5 +---- tensorflow/compiler/jit/xla_device_ops.h | 4 ++-- tensorflow/compiler/tests/dense_layer_test.py | 10 +++++----- tensorflow/compiler/tests/jit_test.py | 16 ++++++++-------- tensorflow/compiler/tf2xla/xla_compiler.h | 2 +- tensorflow/compiler/tf2xla/xla_op_registry.cc | 4 ++-- tensorflow/core/common_runtime/eager/execute.cc | 14 +++++++------- .../core/grappler/optimizers/dependency_optimizer.cc | 6 +++--- tensorflow/docs_src/performance/xla/jit.md | 4 ++-- 14 files changed, 39 insertions(+), 44 deletions(-) diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index 9a2bb00..b17ff58 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -40,7 +40,7 @@ static Status BuildLaunchNode( Graph* graph, Node** node) { NodeDef def; def.set_name(graph->NewName(nodename)); - def.set_op("_XlaLaunch"); + def.set_op("XlaLaunch"); def.set_device(device_name); AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); @@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { node->input_types().begin() + num_constant_args, node->input_types().begin() + num_constant_args + num_nonconst_args); - // Build a _XlaLaunch operator to execute the function body. + // Build a XlaLaunch operator to execute the function body. Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 34be440..5fee36f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -80,7 +80,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate -// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. +// subgraphs pass and that should in turn be compiled via XlaLaunch operators. extern const char* const kXlaCompiledKernelAttr; // Does `node` have the kXlaCompiledKernelAttr attribute? diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 9d85634..27287e0 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -256,10 +256,9 @@ XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), - XlaLocalLaunchOp); +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") .Device(DEVICE_GPU) .HostMemory("constants") .HostMemory("resources"), diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 07320b4..f2473d9 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -17,7 +17,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XlaLaunch") +REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") .Input("args: Targs") @@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch") .Attr("Tresults: list(type) >= 0") .Attr("function: func") // XLA random-number generation ops are stateful. - // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. + // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch. .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6430975..7ed609c 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -122,8 +122,7 @@ Status XlaCompilationCache::BuildSignature( namespace { -// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. +// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. Status BuildArguments(const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 23c6f39..7cc3d0e 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -29,11 +29,8 @@ limitations under the License. namespace tensorflow { // An OpKernel that compiles an op to an XLA computation and runs it. Unlike -// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a // vanilla TensorFlow op as long as the bridge supports it. -// -// Importantly _XlaLaunch assumes all input and output tensors are on the host, -// whereas XlacompileOnDemandOp works with tensors in device memory. class XlaCompileOnDemandOp : public OpKernel { public: explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 65c0e85..9c00a06 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -33,7 +33,7 @@ namespace tensorflow { // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be -// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an +// rewritten to a XlaLaunch op. If it is called, it means the placer placed an // operator on an XLA device but the compiler did not compile it. class XlaDeviceDummyOp : public OpKernel { public: @@ -42,7 +42,7 @@ class XlaDeviceDummyOp : public OpKernel { }; #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \ + REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ .Device(DEVICE) \ .HostMemory("constants") \ .HostMemory("resources"), \ diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index b0bf1b7..865f60c 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -46,8 +46,8 @@ def InLabels(labels, substr): def XlaLaunchOpCount(labels): - """Count how many _XlaLaunch labels are present.""" - return sum("_XlaLaunch(" in x for x in labels) + """Count how many XlaLaunch labels are present.""" + return sum("XlaLaunch(" in x for x in labels) class DenseLayerTest(test.TestCase): @@ -55,7 +55,7 @@ class DenseLayerTest(test.TestCase): def testDenseLayerAutoJit(self): """Tests dense layer compilation in auto-jit mode. - Dense layer should be compiled into a single _XlaLaunch op in auto-jit mode. + Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. """ os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") @@ -83,7 +83,7 @@ class DenseLayerTest(test.TestCase): """Tests that the dense layer node is properly compiled in jit scope. Dense layer with static shape input tensor should be compiled into a single - _XlaLaunch op by XLA. + XlaLaunch op by XLA. """ with self.test_session() as sess: @@ -110,7 +110,7 @@ class DenseLayerTest(test.TestCase): Dense layer uses shape op to get shape of input tensor if its shape is not fully defined. XLA does not cluster shape op with other operators. But in experimental_jit_scope, XLA is forced to compile shape op into its own - cluster, causing dense layer to be split into TWO _XlaLaunch ops. + cluster, causing dense layer to be split into TWO XlaLaunch ops. """ with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 0310cdd..4b0043b 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -78,10 +78,10 @@ def InLabels(labels, substr): def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" + """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") class JitLaunchTest(test.TestCase): @@ -90,8 +90,8 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node - # actually ran. However, it is sometimes possible for _XlaLaunch ops to be + # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node + # actually ran. However, it is sometimes possible for XlaLaunch ops to be # constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: @@ -441,14 +441,14 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "_XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaLaunch")) - # Compile the backprop. One _XlaLaunch. + # Compile the backprop. One XlaLaunch. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "_XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaLaunch")) class ElementWiseFusionTest(test.TestCase): @@ -482,7 +482,7 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("_XlaLaunch(" in x for x in labels) + count = sum("XlaLaunch(" in x for x in labels) return output, count diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 621fbc1..bf496bd 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -38,7 +38,7 @@ class XlaContext; // It does a symbolic execution of the graph starting from specific input // shapes, using a JIT device to convert operators into XLA computations. // -// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the +// XlaCompiler is typically invoked from an `XlaLaunch` operator once the // shapes of all input parameters to the computation are known. This is // because the symbolic execution requires known shapes for all operations. // diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e309cb1..4692038 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU"; static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { const OpDef* op_def; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def)); NodeDef node_def; node_def.set_name("_XlaLaunch-op"); - node_def.set_op("_XlaLaunch"); + node_def.set_op("XlaLaunch"); string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 1df4996..ce989f4 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -186,14 +186,14 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { // primitive op (e.g. matmul). // // The wrapper function conforms to the function signature expected by -// _XlaLaunchOp, with input params ordered by . For example, if the op has input params , they will be reordered to as the input params to the synthesized function. // // It populates `const_input_types`, `arg_input_types` and // `op_input_to_func_input` based on the reordering results, that the caller can -// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets +// use them to build an XlaLaunch. On error, it returns NULL, and sets // `status` accordingly. const FunctionDef* OpToFunction(TFE_Op* op, std::vector* const_input_types, @@ -311,12 +311,12 @@ const FunctionDef* OpToFunction(TFE_Op* op, return ret; } -// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed +// Builds an XlaLaunch as a wrapper over 'op', so that 'op' can be executed // via XLA. std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { - VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->operation.Name(); + VLOG(1) << "Creating XlaLaunch for TFE_Op " << op->operation.Name(); auto launch_op = std::unique_ptr( - TFE_NewOp(op->operation.ctx, "_XlaLaunch", status)); + TFE_NewOp(op->operation.ctx, "XlaLaunch", status)); if (TF_GetCode(status) != TF_OK) return nullptr; if (op->operation.device) { TFE_OpSetDevice(launch_op.get(), op->operation.device->name().c_str(), @@ -331,7 +331,7 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { gtl::FlatMap op_input_to_func_input; if (fdef == nullptr) { // See if this is a primitive op, and if so create a function for it, so - // that _XlaLaunchOp can access it. + // that XlaLaunch can access it. fdef = OpToFunction(op, &const_input_types, &arg_input_types, &op_input_to_func_input, status); if (!status.ok()) return nullptr; @@ -423,7 +423,7 @@ Status EagerLocalExecute(EagerOperation* op, if (!status.ok()) return status; #ifdef TENSORFLOW_EAGER_USE_XLA std::unique_ptr xla_launch_op; - if (op->UseXla() && op->Name() != "_XlaLaunch") { + if (op->UseXla() && op->Name() != "XlaLaunch") { xla_launch_op = BuildXlaLaunch(op, status); if (!status.ok()) return status; op = xla_launch_op.get(); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 7b7fd81..200454b 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -126,9 +126,9 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { return false; } const std::unordered_set do_not_rewrite_ops{ - "Assert", "CheckNumerics", "_Retval", - "_Arg", "_ParallelConcatUpdate", "_TPUExecute", - "_TPUCompile", "ControlTrigger"}; + "Assert", "CheckNumerics", "_Retval", + "_Arg", "_ParallelConcatUpdate", "TPUExecute", + "TPUCompile", "ControlTrigger"}; if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) { return false; } diff --git a/tensorflow/docs_src/performance/xla/jit.md b/tensorflow/docs_src/performance/xla/jit.md index d9a979c..6724d1e 100644 --- a/tensorflow/docs_src/performance/xla/jit.md +++ b/tensorflow/docs_src/performance/xla/jit.md @@ -137,12 +137,12 @@ TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python mnist_softmax_xla.py ``` Open the timeline file created (`timeline.ctf.json`). The rendered timeline -should look similar to the picture below with one long bar labeled `_XlaLaunch`. +should look similar to the picture below with one long bar labeled `XlaLaunch`.
-To understand what is happening in `_XlaLaunch`, look at the console output for +To understand what is happening in `XlaLaunch`, look at the console output for statements similar to the following: ```shell -- 2.7.4