Normally tf2xla (autoclustering, jit_scope and rewrite) rely on graph optimization
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 17 Mar 2018 18:21:02 +0000 (11:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 17 Mar 2018 18:25:08 +0000 (11:25 -0700)
passes to outline subgraphs. The XLA device itself only sees Compute() calls for
_XlaLaunch ops. All other ops are registered with a dummy op factory that just
prints an error.

This patch adds an alternative, selected at registration time, that disables
default graph optimization and instead registers a non-dummy op implementation.

This op implementation compiles the op "on demand"; it generates a fake graph containing
_Arg and _Retval nodes and calls into the XlaCompiler code as usual.

This allows the device to be used as a "normal" TensorFlow device, as well as from
Eager mode, at the expense of performance.

Later additions will add the ability to create traces to amortize kernel launch overhead,
and the ability to combine op-by-op/tracing and autoclustering with jit_scope annotations.

PiperOrigin-RevId: 189463593

23 files changed:
tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/legacy_flags/BUILD
tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc [new file with mode: 0644]
tensorflow/compiler/jit/legacy_flags/xla_device_flags.h [new file with mode: 0644]
tensorflow/compiler/jit/xla_compilation_cache.cc
tensorflow/compiler/jit/xla_compilation_cache.h
tensorflow/compiler/jit/xla_compile_on_demand_op.cc [new file with mode: 0644]
tensorflow/compiler/jit/xla_compile_on_demand_op.h [new file with mode: 0644]
tensorflow/compiler/jit/xla_cpu_device.cc
tensorflow/compiler/jit/xla_device.cc
tensorflow/compiler/jit/xla_device.h
tensorflow/compiler/jit/xla_device_context.cc
tensorflow/compiler/jit/xla_gpu_device.cc
tensorflow/compiler/jit/xla_launch_util.cc
tensorflow/compiler/jit/xla_tensor_info.h
tensorflow/compiler/tests/BUILD
tensorflow/compiler/tests/xla_test.py
tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
tensorflow/compiler/tf2xla/xla_compiler.cc
tensorflow/compiler/tf2xla/xla_compiler.h

index 39eb390..0475cd9 100644 (file)
@@ -76,6 +76,7 @@ cc_library(
         ":jit_compilation_passes",
         ":xla_device",
         "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/service:cpu_plugin",  # buildcleaner: keep
@@ -136,11 +137,13 @@ cc_library(
 cc_library(
     name = "xla_device",
     srcs = [
+        "xla_compile_on_demand_op.cc",
         "xla_device.cc",
         "xla_device_context.cc",
         "xla_device_ops.cc",
     ],
     hdrs = [
+        "xla_compile_on_demand_op.h",
         "xla_device.h",
         "xla_device_context.h",
         "xla_device_ops.h",
index e24a9a0..8a8e8bb 100644 (file)
@@ -148,7 +148,11 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   const XlaCompiler::CompilationResult* kernel;
   xla::LocalExecutable* executable;
 
-  OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_,
+  std::map<int, Tensor> constant_args;
+  for (int i = 0; i < num_constant_args_; ++i) {
+    constant_args.insert({i, ctx->input(i)});
+  }
+  OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
                                      variables, ctx, &kernel, &executable,
                                      /*compile_options=*/nullptr));
 
index 4491dd6..9cd66fc 100644 (file)
@@ -52,6 +52,18 @@ cc_library(
         ],
 )
 
+cc_library(
+    name = "xla_device_flags",
+    srcs = ["xla_device_flags.cc"],
+    hdrs = ["xla_device_flags.h"],
+    deps =
+        [
+            "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
+            "//tensorflow/core:framework_internal",
+            "//tensorflow/core:lib",
+        ],
+)
+
 # -----------------------------------------------------------------------------
 
 filegroup(
diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc
new file mode 100644 (file)
index 0000000..1bb2fce
--- /dev/null
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for the XLA bridge's xla_device module.
+
+#include <mutex>
+#include <vector>
+
+#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static XlaDeviceFlags* flags;
+static std::vector<Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags.  Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+  flags = new XlaDeviceFlags;
+  flags->tf_xla_compile_on_demand = false;
+  flag_list = new std::vector<Flag>({
+      Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand,
+           "Switch a device into 'on-demand' mode, where instead of "
+           "autoclustering ops are compiled one by one just-in-time."),
+  });
+  xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
+}
+
+// Return a pointer to the XlaDeviceFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+XlaDeviceFlags* GetXlaDeviceFlags() {
+  std::call_once(flags_init, &AllocateFlags);
+  return flags;
+}
+
+}  // namespace legacy_flags
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h
new file mode 100644 (file)
index 0000000..27b2212
--- /dev/null
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
+#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
+
+// Legacy flags for the XLA bridge's xla_device module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// The values of flags associated with the XLA bridge's
+// xla_device module.
+typedef struct {
+  // Switch the CPU device into "on-demand" mode, where instead of
+  // autoclustering ops are compiled one by one just-in-time.
+  // Enabling this mode by a legacy flag is a temporary mechanism. When this
+  // feature is battle-tested, we will switch this to be a session option.
+  bool tf_xla_compile_on_demand;
+} XlaDeviceFlags;
+
+// Return a pointer to the XlaDeviceFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+XlaDeviceFlags* GetXlaDeviceFlags();
+
+}  // namespace legacy_flags
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
index 8cc79a9..6430975 100644 (file)
@@ -92,39 +92,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()(
 }
 
 Status XlaCompilationCache::BuildSignature(
-    const NameAttrList& function, int num_constant_args,
+    const NameAttrList& function, const std::map<int, Tensor>& constant_args,
     const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
     Signature* signature) {
   signature->name = Canonicalize(function.name(), AttrSlice(&function.attr()));
-  signature->arg_values.resize(num_constant_args);
-
-  signature->arg_types.reserve(ctx->num_inputs() - num_constant_args);
-
-  // Inputs are in the order: constants, non-constants, resource variables.
-  int input_num = 0;
-  // Use the values of compile time constants in the signature->
-  while (input_num < num_constant_args) {
-    signature->arg_values[input_num] = ctx->input(input_num);
-    ++input_num;
-  }
-  // Add the types and shapes of the remaining arguments.
-  while (input_num < ctx->num_inputs() - variable_args.size()) {
-    signature->arg_types.emplace_back(ctx->input_dtype(input_num),
-                                      ctx->input(input_num).shape());
-    ++input_num;
-  }
-  // For variable signatures, use the type and shape of the variable's
-  // current value.
-  for (auto& iterator : variable_args) {
-    const OptionalTensor& variable = iterator.second;
-    TF_RET_CHECK(input_num < ctx->num_inputs());
-    if (variable.present) {
-      signature->arg_types.emplace_back(variable.value.dtype(),
-                                        variable.value.shape());
+  signature->arg_values.reserve(constant_args.size());
+
+  signature->arg_types.reserve(ctx->num_inputs() - constant_args.size());
+
+  for (int i = 0; i < ctx->num_inputs(); ++i) {
+    if (constant_args.count(i) > 0) {
+      // Use the values of compile time constants in the signature.
+      signature->arg_values.push_back(constant_args.at(i));
+    } else if (variable_args.count(i) > 0) {
+      const OptionalTensor& variable = variable_args.at(i);
+      if (variable.present) {
+        signature->arg_types.emplace_back(variable.value.dtype(),
+                                          variable.value.shape());
+      } else {
+        signature->arg_types.emplace_back(DT_INVALID, TensorShape());
+      }
     } else {
-      signature->arg_types.emplace_back(DT_INVALID, TensorShape());
+      signature->arg_types.emplace_back(ctx->input_dtype(i),
+                                        ctx->input(i).shape());
     }
-    ++input_num;
   }
   return Status::OK();
 }
@@ -132,74 +123,58 @@ Status XlaCompilationCache::BuildSignature(
 namespace {
 
 // Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch
-// op. The first `num_constant_args` arguments must be host-memory Tensors.
-Status BuildArguments(int num_constant_args,
+// op.
+Status BuildArguments(const std::map<int, Tensor>& constant_args,
                       const std::map<int, OptionalTensor>& variable_args,
                       OpKernelContext* ctx,
                       std::vector<XlaCompiler::Argument>* args) {
   args->resize(ctx->num_inputs());
 
-  int input_num = 0;
-
-  // Handles compile-time constants.
-  TF_RET_CHECK(num_constant_args <= ctx->num_inputs());
-  while (input_num < num_constant_args) {
-    const Tensor& input = ctx->input(input_num);
-    TF_RET_CHECK(input.dtype() != DT_RESOURCE);
-    XlaCompiler::Argument& arg = (*args)[input_num];
-    arg.kind = XlaCompiler::Argument::kConstant;
-    arg.type = input.dtype();
-    arg.shape = input.shape();
-    arg.constant_value = input;
-    ++input_num;
-  }
-
-  // Handles the non-constant arguments.
-  int num_variable_args = variable_args.size();
-  int num_nonconst_args =
-      ctx->num_inputs() - num_variable_args - num_constant_args;
-  TF_RET_CHECK(num_nonconst_args >= 0);
-  while (input_num < num_constant_args + num_nonconst_args) {
-    const Tensor& input = ctx->input(input_num);
-    TF_RET_CHECK(input.dtype() != DT_RESOURCE);
+  for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
     XlaCompiler::Argument& arg = (*args)[input_num];
-    if (input.NumElements() > 0) {
-      arg.kind = XlaCompiler::Argument::kParameter;
-    } else {
+    if (constant_args.count(input_num) > 0) {
+      // Handles compile-time constants.
+      const Tensor& input = constant_args.at(input_num);
+      TF_RET_CHECK(input.dtype() != DT_RESOURCE);
       arg.kind = XlaCompiler::Argument::kConstant;
+      arg.type = input.dtype();
+      arg.shape = input.shape();
       arg.constant_value = input;
-    }
-    arg.type = input.dtype();
-    arg.shape = input.shape();
-    ++input_num;
-  }
-
-  // Handles resource variables.
-  TF_RET_CHECK(input_num + num_variable_args == ctx->num_inputs());
-  for (auto& iterator : variable_args) {
-    const Tensor& input = ctx->input(input_num);
-    TF_RET_CHECK(input.dtype() == DT_RESOURCE);
-
-    XlaCompiler::Argument& arg = (*args)[input_num];
-
-    arg.name = iterator.second.name;
-    arg.kind = XlaCompiler::Argument::kResource;
-    arg.resource_kind = XlaResource::kVariable;
-    if (iterator.second.present) {
-      const Tensor& value = iterator.second.value;
-      arg.type = value.dtype();
-      arg.shape = value.shape();
-      arg.initialized = true;
+    } else if (variable_args.count(input_num) == 0) {
+      // Handles the non-constant arguments.
+      const Tensor& input = ctx->input(input_num);
+      TF_RET_CHECK(input.dtype() != DT_RESOURCE);
+      if (input.NumElements() > 0) {
+        arg.kind = XlaCompiler::Argument::kParameter;
+      } else {
+        arg.kind = XlaCompiler::Argument::kConstant;
+        arg.constant_value = input;
+      }
+      arg.type = input.dtype();
+      arg.shape = input.shape();
     } else {
-      // The values of uninitialized variables are not passed as inputs, since
-      // they are meaningless. However, it is legal to assign to a resource
-      // variable for the first time inside the XLA computation, so we do permit
-      // uninitialized variables.
-      arg.initialized = false;
-      arg.type = DT_INVALID;
-      arg.shape = TensorShape();
+      // Handles resource variables.
+      const Tensor& input = ctx->input(input_num);
+      TF_RET_CHECK(input.dtype() == DT_RESOURCE);
+      const OptionalTensor& variable = variable_args.at(input_num);
+      arg.name = variable.name;
+      arg.kind = XlaCompiler::Argument::kResource;
+      arg.resource_kind = XlaResource::kVariable;
+      if (variable.present) {
+        const Tensor& value = variable.value;
+        arg.type = value.dtype();
+        arg.shape = value.shape();
+        arg.initialized = true;
+      } else {
+        // The values of uninitialized variables are not passed as inputs, since
+        // they are meaningless. However, it is legal to assign to a resource
+        // variable for the first time inside the XLA computation, so we do
+        // permit uninitialized variables.
+        arg.initialized = false;
+        arg.type = DT_INVALID;
+        arg.shape = TensorShape();
+      }
     }
-    ++input_num;
   }
 
   return Status::OK();
@@ -234,16 +209,43 @@ Status XlaCompilationCache::BuildExecutable(
 
 Status XlaCompilationCache::Compile(
     const XlaCompiler::Options& options, const NameAttrList& function,
-    int num_constant_args, const std::map<int, OptionalTensor>& variable_args,
-    OpKernelContext* ctx,
+    const std::map<int, Tensor>& constant_args,
+    const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
     const XlaCompiler::CompilationResult** compilation_result,
     xla::LocalExecutable** executable,
     const XlaCompiler::CompileOptions* compile_options) {
+  return CompileImpl(options, function, constant_args, variable_args, ctx,
+                     compilation_result, executable, compile_options, false);
+}
+
+Status XlaCompilationCache::CompileSingleOp(
+    const XlaCompiler::Options& options,
+    const std::map<int, Tensor>& constant_args,
+    const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
+    const XlaCompiler::CompilationResult** compilation_result,
+    xla::LocalExecutable** executable,
+    const XlaCompiler::CompileOptions* compile_options) {
+  const NodeDef& def = ctx->op_kernel().def();
+  NameAttrList name;
+  name.set_name(def.op());
+  *name.mutable_attr() = def.attr();
+  return CompileImpl(options, name, constant_args, variable_args, ctx,
+                     compilation_result, executable, compile_options, true);
+}
+
+Status XlaCompilationCache::CompileImpl(
+    const XlaCompiler::Options& options, const NameAttrList& function,
+    const std::map<int, Tensor>& constant_args,
+    const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
+    const XlaCompiler::CompilationResult** compilation_result,
+    xla::LocalExecutable** executable,
+    const XlaCompiler::CompileOptions* compile_options,
+    bool compile_single_op) {
   VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
 
   if (VLOG_IS_ON(2)) {
     VLOG(2) << "num_inputs=" << ctx->num_inputs()
-            << " num_constant_args=" << num_constant_args
+            << " num_constant_args=" << constant_args.size()
             << " num_variable_args=" << variable_args.size();
     for (int i = 0; i < ctx->num_inputs(); i++) {
       TensorShape shape = ctx->input(i).shape();
@@ -264,11 +266,12 @@ Status XlaCompilationCache::Compile(
     }
   }
 
-  TF_RET_CHECK(num_constant_args + variable_args.size() <= ctx->num_inputs());
+  TF_RET_CHECK(constant_args.size() + variable_args.size() <=
+               ctx->num_inputs());
 
   Signature signature;
-  TF_RETURN_IF_ERROR(BuildSignature(function, num_constant_args, variable_args,
-                                    ctx, &signature));
+  TF_RETURN_IF_ERROR(
+      BuildSignature(function, constant_args, variable_args, ctx, &signature));
 
   VLOG(2) << "Signature: " << SignatureDebugString(signature);
   // The outer lock protects the existence of the cache entry. It does not
@@ -295,13 +298,20 @@ Status XlaCompilationCache::Compile(
     // a long time.)
     std::vector<XlaCompiler::Argument> args;
     TF_RETURN_IF_ERROR(
-        BuildArguments(num_constant_args, variable_args, ctx, &args));
+        BuildArguments(constant_args, variable_args, ctx, &args));
 
     XlaCompiler compiler(options);
     entry->compiled = true;
-    entry->compilation_status = compiler.CompileFunction(
-        compile_options ? *compile_options : XlaCompiler::CompileOptions(),
-        function, args, &entry->compilation_result);
+
+    if (compile_single_op) {
+      entry->compilation_status = compiler.CompileSingleOp(
+          compile_options ? *compile_options : XlaCompiler::CompileOptions(),
+          signature.name, ctx, args, &entry->compilation_result);
+    } else {
+      entry->compilation_status = compiler.CompileFunction(
+          compile_options ? *compile_options : XlaCompiler::CompileOptions(),
+          function, args, &entry->compilation_result);
+    }
   }
   *compilation_result = &entry->compilation_result;
   if (entry->compilation_status.ok() && executable) {
index d506378..5c0c79b 100644 (file)
@@ -52,8 +52,8 @@ class XlaCompilationCache : public ResourceBase {
   // Compiles a function into a XlaCompiler::CompilationResult that can be used
   // to execute an XLA Computation. Compilation results are cached.
   // `function` is the name of a Tensorflow function to compile.
-  // `num_constant_args` is the number of compile-time constant arguments to
-  // `function`. `variable_args` is a snapshot of the current values of the
+  // `constant_args` is a maps of tensorflow argument number to constant value.
+  // `variable_args` is a snapshot of the current values of the
   // resource variable arguments to `function`; uninitialized variables are
   // represented by an absent OptionalTensor.
   // The result of compilation is written to `*compilation_result`, which must
@@ -62,19 +62,40 @@ class XlaCompilationCache : public ResourceBase {
   // executable pointer may be null if the computation has no non-constant
   // outputs.
   Status Compile(const XlaCompiler::Options& options,
-                 const NameAttrList& function, int num_constant_args,
+                 const NameAttrList& function,
+                 const std::map<int, Tensor>& constant_args,
                  const std::map<int, OptionalTensor>& variable_args,
                  OpKernelContext* ctx,
                  const XlaCompiler::CompilationResult** compilation_result,
                  xla::LocalExecutable** executable,
                  const XlaCompiler::CompileOptions* compile_options);
 
+  // As above, but calls XlaCompiler::CompileSingleOp instead of
+  // XlaCompiler::CompileFunction.
+  Status CompileSingleOp(
+      const XlaCompiler::Options& options,
+      const std::map<int, Tensor>& constant_args,
+      const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
+      const XlaCompiler::CompilationResult** compilation_result,
+      xla::LocalExecutable** executable,
+      const XlaCompiler::CompileOptions* compile_options);
+
   xla::LocalClient* client() const { return client_; }
   const DeviceType& device_type() const { return device_type_; }
 
   string DebugString() override;
 
  private:
+  // Common implementation of Compile and CompileSingleOp.
+  Status CompileImpl(const XlaCompiler::Options& options,
+                     const NameAttrList& function,
+                     const std::map<int, Tensor>& constant_args,
+                     const std::map<int, OptionalTensor>& variable_args,
+                     OpKernelContext* ctx,
+                     const XlaCompiler::CompilationResult** compilation_result,
+                     xla::LocalExecutable** executable,
+                     const XlaCompiler::CompileOptions* compile_options,
+                     bool compile_single_op);
   // Takes `result` which has been compiled from a Tensorflow subgraph to a
   // XLA computation already, and generates an XLA LocalExecutable `executable`.
   Status BuildExecutable(const XlaCompiler::Options& options,
@@ -104,7 +125,8 @@ class XlaCompilationCache : public ResourceBase {
   static string SignatureDebugString(const Signature& sig);
 
   // Builds the signature for a compilation.
-  Status BuildSignature(const NameAttrList& function, int num_constant_args,
+  Status BuildSignature(const NameAttrList& function,
+                        const std::map<int, Tensor>& constant_args,
                         const std::map<int, OptionalTensor>& variable_args,
                         OpKernelContext* ctx, Signature* signature);
 
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
new file mode 100644 (file)
index 0000000..915b9ce
--- /dev/null
@@ -0,0 +1,178 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Defines the XlaCompileOnDemandOp.
+
+#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+
+namespace {
+std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
+  std::map<int, OptionalTensor> variables;
+  for (int64 i = 0; i < ctx->num_inputs(); ++i) {
+    if (ctx->input(i).dtype() == DT_RESOURCE) {
+      Var* variable = nullptr;
+      ResourceHandle handle = HandleFromInput(ctx, i);
+      OptionalTensor& optional = variables[i];
+      optional.name = handle.name();
+      if (LookupResource(ctx, handle, &variable).ok()) {
+        tf_shared_lock lock(*variable->mu());
+        optional.present = true;
+        optional.value = *variable->tensor();
+      }
+    }
+  }
+  return variables;
+}
+}  // namespace
+
+Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
+                                 const XlaDevice::Metadata& metadata,
+                                 const XlaCompiler::CompilationResult* result,
+                                 xla::LocalExecutable* executable) {
+  std::map<int, OptionalTensor> variables = GetVariables(ctx);
+  int64 num_resource_args = variables.size();
+
+  xla::LocalClient* client = metadata.client();
+  XlaTensorInfoManager* tensor_info_manager = &metadata.tensor_info_manager();
+
+  // Builds an XLA allocator for the device.
+  XlaAllocator xla_allocator(client->platform(), ctx);
+  XlaComputationLaunchContext launch_context(
+      num_resource_args, client, &xla_allocator, tensor_info_manager);
+
+  launch_context.PopulateInputs(ctx, result, variables);
+
+  perftools::gputools::Stream* stream =
+      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+  TF_RET_CHECK(stream);
+
+  VLOG(2) << "Executing computation.";
+  xla::ExecutableRunOptions run_options;
+  run_options.set_stream(stream);
+  run_options.set_allocator(&xla_allocator);
+  run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+
+  auto run_result = executable->Run(launch_context.arguments(), run_options);
+  TF_RETURN_IF_ERROR(run_result.status());
+
+  launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie());
+  return Status::OK();
+}
+
+bool XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel,
+                                                  int64 argument_idx) {
+  // TODO(jmolloy): This could be expensive, so memoize.
+  auto* constant_inputs = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
+      op_kernel->def().op());
+  CHECK(constant_inputs);
+  std::set<int64> constant_input_indices;
+  for (const auto& name : *constant_inputs) {
+    int start, stop;
+    TF_CHECK_OK(op_kernel->InputRange(name, &start, &stop));
+    for (int i = start; i < stop; ++i) {
+      constant_input_indices.insert(i);
+    }
+  }
+  return constant_input_indices.count(argument_idx) > 0;
+}
+
+bool XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel,
+                                                    int64 argument_idx) {
+  // Right now we only create kConstant arguments when absolutely required, but
+  // there may be benefit in eagerly constant-folding a larger subset of
+  // arguments in the future.
+  return MustArgumentBeConstant(op_kernel, argument_idx);
+}
+
+Status XlaCompileOnDemandOp::Compile(
+    OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
+    const XlaCompiler::CompilationResult** result,
+    xla::LocalExecutable** executable) {
+  XlaTensorInfoManager* tensor_info_manager = &metadata.tensor_info_manager();
+
+  std::map<int, Tensor> constant_arguments;
+  for (int64 i = 0; i < ctx->num_inputs(); ++i) {
+    const Tensor& device_tensor = ctx->input(i);
+    if (const XlaTensorInfo* tensor_info =
+            tensor_info_manager->GetTensorInfo(device_tensor)) {
+      if (tensor_info->has_host_tensor() &&
+          ShouldArgumentBeConstant(&ctx->op_kernel(), i)) {
+        constant_arguments[i] = tensor_info->host_tensor();
+      }
+    }
+    if (constant_arguments.count(i) == 0 &&
+        MustArgumentBeConstant(&ctx->op_kernel(), i)) {
+      // Slow path; the argument is not available as a host constant so we must
+      // fetch it synchronously.
+      Tensor host_tensor;
+      TF_RETURN_IF_ERROR(ctx->allocate_temp(
+          device_tensor.dtype(), device_tensor.shape(), &host_tensor));
+      Notification n;
+      ctx->op_device_context()->CopyDeviceTensorToCPU(
+          &device_tensor, "ConstantArgument",
+          reinterpret_cast<Device*>(ctx->device()), &host_tensor,
+          [&](Status status) { n.Notify(); });
+      n.WaitForNotification();
+      constant_arguments[i] = host_tensor;
+    }
+  }
+
+  // We store information about the JIT-compiled XLA computation
+  // in the ResourceMgr.
+  ResourceMgr* rm = ctx->resource_manager();
+  CHECK(rm);
+
+  XlaCompilationCache* cache;
+  TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
+      rm->default_container(), "xla_cache", &cache,
+      [&](XlaCompilationCache** cache) {
+        *cache = new XlaCompilationCache(metadata.client(),
+                                         metadata.jit_device_type());
+        return Status::OK();
+      }));
+  // Hold the reference to the JIT during evaluation. (We could probably
+  // free it sooner because the ResourceMgr will retain a reference, but
+  // this is more obviously correct.)
+  core::ScopedUnref cache_ref(cache);
+
+  XlaCompiler::Options options;
+  DeviceType device_type = metadata.jit_device_type();
+  options.device_type = &device_type;
+  options.client = metadata.client();
+  options.flib_def =
+      new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
+
+  std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
+  return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
+                                result, executable,
+                                /*compile_options=*/nullptr);
+}
+
+void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
+  const XlaCompiler::CompilationResult* result;
+  xla::LocalExecutable* executable;
+  const XlaDevice::Metadata* metadata;
+  OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
+  OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
+  OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h
new file mode 100644 (file)
index 0000000..23c6f39
--- /dev/null
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The XlaCompileOnDemandOp is an OpKernel that, when its Compute method is
+// called, will generate an xla::Computation and run it asynchronously.
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// An OpKernel that compiles an op to an XLA computation and runs it. Unlike
+// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a
+// vanilla TensorFlow op as long as the bridge supports it.
+//
+// Importantly _XlaLaunch assumes all input and output tensors are on the host,
+// whereas XlacompileOnDemandOp works with tensors in device memory.
+class XlaCompileOnDemandOp : public OpKernel {
+ public:
+  explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+  void Compute(OpKernelContext* ctx) override;
+
+ private:
+  XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i);
+  bool ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx);
+  bool MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx);
+  Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
+                 const XlaCompiler::CompilationResult** result,
+                 xla::LocalExecutable** executable);
+  Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
+             const XlaCompiler::CompilationResult* result,
+             xla::LocalExecutable* executable);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
index db3bf3e..d2dfdee 100644 (file)
@@ -17,6 +17,8 @@ limitations under the License.
 // operators using XLA via the XLA "Host" (CPU) backend.
 
 #include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
+#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
 #include "tensorflow/compiler/jit/xla_device.h"
 #include "tensorflow/compiler/jit/xla_device_ops.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -34,6 +36,15 @@ class XlaCpuDeviceFactory : public DeviceFactory {
 Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
                                           const string& name_prefix,
                                           std::vector<Device*>* devices) {
+  legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags();
+  bool compile_on_demand = flags->tf_xla_compile_on_demand;
+
+  XlaOpRegistry::DeviceRegistration registration;
+  registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
+  registration.requires_compilation = !compile_on_demand;
+  registration.enable_jit_by_default = false;
+  registration.compile_resource_ops = true;
+
   static XlaDeviceOpRegistrations* registrations =
       RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT);
   (void)registrations;
@@ -41,7 +52,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
   std::unique_ptr<XlaDevice> device;
   TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
                                        DEVICE_CPU_XLA_JIT, options, name_prefix,
-                                       /*register_device_for_compilation=*/true,
+                                       registration,
                                        /*transfer_as_literal=*/false, &device));
   devices->push_back(device.release());
   return Status::OK();
index e4e11d4..82048f5 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include <unordered_set>
 
 #include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
 #include "tensorflow/compiler/jit/xla_device_context.h"
 #include "tensorflow/compiler/jit/xla_device_ops.h"
 #include "tensorflow/compiler/tf2xla/dump_graph.h"
@@ -108,21 +109,15 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
 /* static */ Status XlaDevice::Create(
     const string& platform_name, const string& device_name, int device_ordinal,
     const string& jit_device_name, const SessionOptions& options,
-    const string& name_prefix, bool register_device_for_compilation,
+    const string& name_prefix,
+    const XlaOpRegistry::DeviceRegistration& registration,
     bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
   VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
           << device_ordinal;
 
-  if (register_device_for_compilation) {
-    // These are no-ops if they have already been done previously for
-    // this device_name/compilation_device_name pair.
-    XlaOpRegistry::DeviceRegistration registration;
-    registration.compilation_device_name = jit_device_name;
-    registration.requires_compilation = true;
-    registration.enable_jit_by_default = false;
-    registration.compile_resource_ops = true;
-    XlaOpRegistry::RegisterCompilationDevice(device_name, registration);
-  }
+  // These are no-ops if they have already been done previously for
+  // this device_name/compilation_device_name pair.
+  XlaOpRegistry::RegisterCompilationDevice(device_name, registration);
 
   auto platform = se::MultiPlatformManager::PlatformWithName(platform_name);
   if (!platform.ok()) {
@@ -306,19 +301,23 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
 
 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
                                                    const char* jit_device) {
+  // Any op assigned to the device that isn't rewritten by the graph rewriter
+  // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
+  // it just-in-time.
+  kernel_factory::OpKernelRegistrar::Factory factory =
+      [](OpKernelConstruction* context) -> OpKernel* {
+    return new XlaCompileOnDemandOp(context);
+  };
   XlaOpRegistry::RegisterCompilationKernels();
   XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
-  auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* {
-    return new XlaDeviceDummyOp(context);
-  };
   for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
            jit_device,
            /*include_compilation_only_kernels=*/false)) {
     KernelDef* def = new KernelDef(*jit_def);
     def->set_device_type(device);
     registrations->op_kernel_registrars.emplace_back(
-        new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp",
-                                              dummy_factory));
+        new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
+                                              factory));
   }
   return registrations;
 }
index 0f44762..9cd9167 100644 (file)
@@ -27,6 +27,7 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
 
 #include "tensorflow/compiler/jit/xla_tensor_info.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/local_device.h"
@@ -81,7 +82,7 @@ class XlaDevice : public LocalDevice {
   static Status Create(const string& platform_name, const string& device_name,
                        int device_ordinal, const string& jit_device_name,
                        const SessionOptions& options, const string& name_prefix,
-                       bool register_device_for_compilation,
+                       const XlaOpRegistry::DeviceRegistration& registration,
                        bool transfer_as_literal,
                        std::unique_ptr<XlaDevice>* device);
 
@@ -113,7 +114,7 @@ class XlaDevice : public LocalDevice {
   // Which hardware device in the client's platform this XlaDevice controls.
   const int device_ordinal_;
   // The name of the device that is used to compile Ops for this XlaDevice.
-  const DeviceType& jit_device_name_;
+  DeviceType jit_device_name_;
   // Memory allocator associated with this device.
   Allocator* xla_allocator_;                   // Not owned.
   ::perftools::gputools::Platform* platform_;  // Not owned.
@@ -134,7 +135,7 @@ class XlaDevice : public LocalDevice {
   bool transfer_as_literal_;
 };
 
-// Builds dummy OpKernel registrations on 'device' for the JIT operators
+// Builds OpKernel registrations on 'device' for the JIT operators
 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
 // object that encapsulates the kernel registrations.
 struct XlaDeviceOpRegistrations {
index b57f82f..88f7c15 100644 (file)
@@ -93,6 +93,10 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
       }
     }
 
+    XlaTensorInfo* tensor_info =
+        tensor_info_manager_->GetOrCreateTensorInfo(*device_tensor);
+    tensor_info->set_host_tensor(*cpu_tensor);
+
     done(status);
     return;
   }
index 383ed87..5a1db81 100644 (file)
@@ -34,15 +34,21 @@ class XlaGpuDeviceFactory : public DeviceFactory {
 Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
                                           const string& name_prefix,
                                           std::vector<Device*>* devices) {
+  XlaOpRegistry::DeviceRegistration registration;
+  registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
+  registration.requires_compilation = true;
+  registration.enable_jit_by_default = false;
+  registration.compile_resource_ops = true;
+
   static XlaDeviceOpRegistrations* registrations =
       RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
   (void)registrations;
 
   std::unique_ptr<XlaDevice> device;
-  Status status = XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0,
-                                    DEVICE_GPU_XLA_JIT, options, name_prefix,
-                                    /*register_device_for_compilation=*/true,
-                                    /*transfer_as_literal=*/false, &device);
+  Status status =
+      XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
+                        name_prefix, registration,
+                        /*transfer_as_literal=*/false, &device);
   if (!status.ok()) {
     // Treat failures as non-fatal; there might not be a GPU in the machine.
     VLOG(1) << "Failed to create XLA_GPU device: " << status;
index 689fa32..076cbd2 100644 (file)
@@ -176,21 +176,33 @@ void XlaComputationLaunchContext::PopulateOutputs(
     if (kernel->outputs[i].is_constant) {
       // Output is a constant.
       const Tensor& const_tensor = kernel->outputs[i].constant_value;
+      Tensor* output_tensor;
       const size_t total_bytes = const_tensor.TotalBytes();
       if (stream && total_bytes > 0) {
         // Copy host -> device. (Empty tensors don't have backing buffers.)
         VLOG(1) << "Constant output tensor on device";
-        Tensor* output_tensor;
+
         TF_CHECK_OK(
             ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
 
         const void* src_ptr = DMAHelper::base(&const_tensor);
         void* dst_ptr = DMAHelper::base(output_tensor);
         gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
+        // Memcpying asynchronously is safe for the GPU, but the CPU uses a
+        // shared allocator so hold a reference to the copied-to buffer until
+        // complete.
+        TensorReference ref(*output_tensor);
         stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
+        stream->ThenDoHostCallback([ref] { ref.Unref(); });
       } else {
         // No copy required.
         ctx->set_output(i, const_tensor);
+        output_tensor = ctx->mutable_output(i);
+      }
+      if (tensor_info_manager_) {
+        XlaTensorInfo* tensor_info =
+            tensor_info_manager_->GetOrCreateTensorInfo(*output_tensor);
+        tensor_info->set_host_tensor(const_tensor);
       }
     } else {
       const TensorShape& shape = kernel->outputs[i].shape;
index 0b0736b..fbd6ad7 100644 (file)
@@ -43,9 +43,25 @@ class XlaTensorInfo {
     shaped_buffer_.reset(new xla::ShapedBuffer(std::move(shaped_buffer)));
   }
 
+  // Some tensors on the device may have known values on the host. We use these
+  // in on-demand mode to avoid re-copying values from the device if we know the
+  // host value already.
+
+  // Return true if this TensorInfo contains a host tensor.
+  bool has_host_tensor() const { return host_tensor_ != nullptr; }
+  // Return the contained host tensor.
+  // REQUIRES: has_host_tensor()
+  const Tensor& host_tensor() const { return *host_tensor_; }
+  // Sets the contained host tensor.
+  void set_host_tensor(const Tensor& tensor) {
+    host_tensor_.reset(new Tensor(tensor));
+  }
+
  private:
   // The optional contained ShapedBuffer.
   std::unique_ptr<xla::ShapedBuffer> shaped_buffer_;
+  // An optional host tensor value.
+  std::unique_ptr<Tensor> host_tensor_;
 };
 
 // Manages XlaTensorInfo objects. This class is also an Allocator, so that
index 85a2ada..bbb6089 100644 (file)
@@ -86,7 +86,10 @@ tf_xla_py_test(
     # ArgMax needs CustomCall on CPU, which is not available in normal
     # (not precompiled) TensorFlow. The flag below excludes the CPU
     # backend.
-    disabled_backends = "cpu",
+    disabled_backends = [
+        "cpu",
+        "cpu_ondemand",
+    ],
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
@@ -315,6 +318,8 @@ tf_xla_py_test(
     name = "function_test",
     size = "small",
     srcs = ["function_test.py"],
+    # Functions are not implemented in the on-demand compilation model yet.
+    disabled_backends = "cpu_ondemand",
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
@@ -551,6 +556,8 @@ tf_xla_py_test(
     name = "stack_ops_test",
     size = "small",
     srcs = ["stack_ops_test.py"],
+    # Stack ops are not implemented in the on-demand compilation model yet.
+    disabled_backends = "cpu_ondemand",
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
@@ -577,6 +584,8 @@ tf_xla_py_test(
     name = "tensor_array_ops_test",
     size = "small",
     srcs = ["tensor_array_ops_test.py"],
+    # TensorArray ops are not implemented in the on-demand compilation model yet.
+    disabled_backends = "cpu_ondemand",
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
index cc778f1..e924fe1 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import contextlib
+import os
 import random
 import re
 
@@ -44,6 +45,8 @@ flags.DEFINE_string('test_device', None,
 flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.')
 flags.DEFINE_string('disabled_manifest', None,
                     'Path to a file with a list of tests that should not run.')
+flags.DEFINE_string('tf_xla_flags', None,
+                    'Value to set the TF_XLA_FLAGS environment variable to')
 
 
 class XLATestCase(test.TestCase):
@@ -97,6 +100,8 @@ class XLATestCase(test.TestCase):
       disabled_tests = []
       disabled_method_types = []
       for l in manifest_file.read().splitlines():
+        if not l:
+          continue
         entry = comments_re.sub('', l).strip().split(' ')
         if len(entry) == 1:
           disabled_tests.append(entry[0])
@@ -113,6 +118,9 @@ class XLATestCase(test.TestCase):
             for name in types])
       manifest_file.close()
 
+    if FLAGS.tf_xla_flags is not None:
+      os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
+
   @property
   def all_tf_types(self):
     name = '{}.{}'.format(type(self).__name__, self._testMethodName)
index cbade79..569950c 100644 (file)
@@ -184,9 +184,7 @@ class BatchToSpaceOp : public XlaOpKernel {
  private:
   int block_size_;
 };
-REGISTER_XLA_OP(Name("BatchToSpace")
-                    .CompileTimeConstInput("crops")
-                    .CompileTimeConstInput("block_shape"),
+REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstInput("crops"),
                 BatchToSpaceOp);
 
 }  // namespace
index 80d6df6..498342a 100644 (file)
@@ -83,7 +83,9 @@ class UnsortedSegmentSum : public XlaOpKernel {
   DataType dtype_;
 };
 
-REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum);
+REGISTER_XLA_OP(
+    Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"),
+    UnsortedSegmentSum);
 
 }  // namespace
 }  // namespace tensorflow
index b10880d..5bb773d 100644 (file)
@@ -239,6 +239,7 @@ class StatelessRandomUniformOp : public XlaOpKernel {
 
 // TODO(phawkins): generalize to non-float, non-int32 seed types.
 REGISTER_XLA_OP(Name("StatelessRandomUniform")
+                    .CompileTimeConstInput("shape")
                     .TypeConstraint("dtype", DT_FLOAT)
                     .TypeConstraint("Tseed", DT_INT32),
                 StatelessRandomUniformOp);
@@ -272,6 +273,7 @@ class StatelessRandomNormalOp : public XlaOpKernel {
 
 // TODO(phawkins): generalize to non-float, non-int32 seed types.
 REGISTER_XLA_OP(Name("StatelessRandomNormal")
+                    .CompileTimeConstInput("shape")
                     .TypeConstraint("dtype", DT_FLOAT)
                     .TypeConstraint("Tseed", DT_INT32),
                 StatelessRandomNormalOp);
index 7cdf4d1..86263d8 100644 (file)
@@ -600,6 +600,48 @@ Status XlaCompiler::BuildArguments(
   return Status::OK();
 }
 
+Status XlaCompiler::CompileSingleOp(
+    const XlaCompiler::CompileOptions& options, string const& name,
+    OpKernelContext* ctx, const std::vector<XlaCompiler::Argument>& args,
+    CompilationResult* result) {
+  // TODO(b/74182462): We implement this by creating a new dummy Graph including
+  // _Arg nodes, and let CompileGraph walk it. This could be optimized.
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+  Status status;
+  // First create the actual node we care about computing.
+  Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status);
+  TF_RETURN_IF_ERROR(status);
+
+  // Create dummy _Arg nodes. Link these to `node` and also via a control
+  // dependency edge to the _SOURCE node.
+  for (int64 i = 0; i < ctx->num_inputs(); ++i) {
+    Node* node;
+    string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg");
+    Status status = NodeBuilder(name, "_Arg")
+                        .ControlInput(graph->source_node())
+                        .Attr("T", ctx->input_dtype(i))
+                        .Attr("index", i)
+                        .Finalize(graph.get(), &node);
+    TF_RETURN_IF_ERROR(status);
+    graph->AddEdge(node, 0, main_node, i);
+  }
+
+  // Similarly with return values, create dummy _Retval nodes fed by `node`.
+  for (int64 i = 0; i < ctx->num_outputs(); ++i) {
+    Node* node;
+    string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval");
+    Status status = NodeBuilder(name, "_Retval")
+                        .Input(main_node, i)
+                        .Attr("T", ctx->expected_output_dtype(i))
+                        .Attr("index", i)
+                        .Finalize(graph.get(), &node);
+    TF_RETURN_IF_ERROR(status);
+  }
+
+  return CompileGraph(options, name, std::move(graph), args, result);
+}
+
 Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
                                  string const& name,
                                  std::unique_ptr<Graph> graph,
index 5f1c631..a6747bb 100644 (file)
@@ -289,6 +289,14 @@ class XlaCompiler {
                       const std::vector<Argument>& args,
                       CompilationResult* result);
 
+  // Compiles a single Op, given by an OpKernelContext, into an
+  // xla::Computation. Similar to CompileFunction but takes a single Op as
+  // input.
+  Status CompileSingleOp(const CompileOptions& options, string const& name,
+                         OpKernelContext* ctx,
+                         const std::vector<Argument>& args,
+                         CompilationResult* result);
+
   // Returns the shape of the XLA parameter for an argument 'arg'.
   // See the class comment for more details about the argument passing
   // convention.