Make eager functions runable on TPU
authorIgor Ganichev <iga@google.com>
Tue, 8 May 2018 23:43:54 +0000 (16:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 00:09:23 +0000 (17:09 -0700)
PiperOrigin-RevId: 195897321

13 files changed:
tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/create_xla_launch_op.cc
tensorflow/compiler/jit/create_xla_launch_op.h [new file with mode: 0644]
tensorflow/compiler/jit/create_xla_launch_op_test.cc [new file with mode: 0644]
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/kernels/xla_launch_op.h
tensorflow/compiler/jit/xla_compile_on_demand_op.cc
tensorflow/compiler/jit/xla_launch_util.cc
tensorflow/compiler/jit/xla_launch_util.h
tensorflow/compiler/tests/BUILD
tensorflow/compiler/tests/eager_test.py
tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
tensorflow/python/eager/function.py

index 07136d6..a6b3ce3 100644 (file)
@@ -261,6 +261,7 @@ cc_library(
     name = "create_xla_launch_op",
     srcs = [
         "create_xla_launch_op.cc",
+        "create_xla_launch_op.h",
     ],
     deps = [
         ":common",
@@ -270,6 +271,29 @@ cc_library(
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/memory",
+    ],
+    alwayslink = 1,
+)
+
+tf_cc_test(
+    name = "create_xla_launch_op_test",
+    srcs = [
+        "create_xla_launch_op.h",
+        "create_xla_launch_op_test.cc",
+    ],
+    deps = [
+        ":create_xla_launch_op",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:session_options",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "@com_google_absl//absl/memory",
     ],
 )
 
index 18d9013..f35e916 100644 (file)
@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
 
+#include "absl/memory/memory.h"
 #include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -25,78 +27,189 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
-// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
-// 'ndef' is a call to a compilable function defined in 'flr', returns OK
-// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
-// node. Otherwise, returns a non-OK.
+// Utility which searches for values in a sorted list by scanning over it once.
+// No matter how many times ScanForValue is called, the list is scanned at most
+// once. However, if a call to ScanForValue skips over a value, that value is
+// not revisited in future calls to ScanForValue, so callers must take
+// care to order their calls.
 //
-// This routine is here so that FunctionLibraryRuntime can jit a
-// specific function call as requested.
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
-                         std::unique_ptr<OpKernel>* kernel) {
-  bool xla_compile = false;
-  if (!flr->GetFunctionLibraryDefinition()
-           ->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
-           .ok() ||
-      !xla_compile) {
-    // Not marked as _XlaCompile=true.
-    return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
+// Useful for merging multiple sorted lists in O(n) time.
+class SinglePassSearch {
+ public:
+  // Creates a SinglePassSearch object that can be used to search in `values`.
+  // Does not take ownership of `values`. `values` must outlive this.
+  // `values` must be sorted.
+  explicit SinglePassSearch(const std::vector<int>* values)
+      : current_index_(0), values_(values) {}
+
+  // Scans forward in the vector looking for "value", updating the internal
+  // position in to the vector.
+  // Returns true iff the vector contains the given value at or after current
+  // position.
+  // Not thread-safe.
+  bool ScanForValue(int value) {
+    while (current_index_ < values_->size() &&
+           (*values_)[current_index_] <= value) {
+      if ((*values_)[current_index_] == value) {
+        current_index_++;
+        return true;
+      }
+      current_index_++;
+    }
+    return false;
   }
-  // Make sure that kernels have been registered on the JIT device.
-  XlaOpRegistry::RegisterCompilationKernels();
-  if (!IsCompilable(flr, ndef)) {
-    // ndef is calling a function that XLA can't compile.
-    return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
+
+ private:
+  int current_index_;
+  const std::vector<int>* values_;
+};
+
+Status CompilationRequested(const FunctionLibraryRuntime& flr,
+                            const NodeDef& node_def) {
+  bool xla_compile = false;
+  // Check if op is marked _XlaCompile=true.
+  Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
+      node_def, kXlaCompileAttr, &xla_compile);
+  if (!status.ok() || !xla_compile) {
+    if (VLOG_IS_ON(3)) {
+      if (!status.ok()) {
+        VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
+                << node_def.op() << ". status=" << status.ToString();
+      } else {
+        VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
+      }
+    }
+    return Status(error::INVALID_ARGUMENT, "");
   }
+  return Status::OK();
+}
+
+// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
+// runtime, returns this function's body in `fbody` as well as the indices
+// of its constant and resource arguments.
+// `fbody` is owned by `flr`.
+// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
+// They are sorted in ascending order on this function's return.
+Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
+                                       const NodeDef& node_def,
+                                       const FunctionBody** fbody,
+                                       std::vector<int>* constant_arg_indices,
+                                       std::vector<int>* resource_arg_indices) {
   FunctionLibraryRuntime::Handle handle;
-  // If ndef is not instantiable, e.g., the function does not exist,
+  // If node_def is not instantiable, e.g., the function does not exist,
   // simply bail out.
   TF_RETURN_IF_ERROR(
-      flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
-  const FunctionBody* fbody = flr->GetFunctionBody(handle);
-  CHECK(fbody);  // Can't be nullptr since we just instantiated it.
-  std::vector<bool> const_args(fbody->arg_types.size());
+      flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
+  *fbody = flr->GetFunctionBody(handle);
+  CHECK(*fbody);  // Can't be nullptr since we just instantiated it.
+  const DataTypeVector& arg_types = (*fbody)->arg_types;
+  std::vector<bool> const_args(arg_types.size());
   // If we can't analyze the const args. Bail out.
-  TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
+  TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
 
   for (int i = 0; i < const_args.size(); ++i) {
     if (const_args[i]) {
-      // There is a const arg. Bail out.
-      return errors::InvalidArgument("Const arg: ", i, " in ",
-                                     DebugString(fbody->fdef));
+      constant_arg_indices->push_back(i);
+    }
+  }
+
+  // There can be hundreds of resource variables. Reserve the space for them.
+  // We don't reserve for constants above as they are usually few.
+  resource_arg_indices->reserve(arg_types.size());
+  for (int i = 0; i < arg_types.size(); ++i) {
+    if (arg_types[i] == DT_RESOURCE) {
+      resource_arg_indices->push_back(i);
     }
   }
 
-  NodeDef launch_def;
-  launch_def.set_name(ndef.name());
-  launch_def.set_op("_XlaLaunch");
-  launch_def.set_device(flr->device()->name());
-  AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
-  AddNodeAttr("Nresources", 0, &launch_def);
-  AddNodeAttr("Targs", fbody->arg_types, &launch_def);
-  AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
-  NameAttrList func;
-  func.set_name(ndef.op());
-  *(func.mutable_attr()) = ndef.attr();
-  AddNodeAttr("function", func, &launch_def);
-
-  // TODO(b/32387911): Handles the host memory types across function
-  // calls properly. For now, we assume all inputs and outputs are on
-  // the device memory.
+  return Status::OK();
+}
+
+}  // namespace
+
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+                         std::unique_ptr<OpKernel>* kernel) {
+  TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
+
+  VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
+
+  // Make sure that kernels have been registered on the JIT device.
+  XlaOpRegistry::RegisterCompilationKernels();
+  if (!IsCompilable(flr, node_def)) {
+    // node_def is calling a function that XLA can't compile.
+    return errors::InvalidArgument("Not compilable: ",
+                                   node_def.ShortDebugString());
+  }
+
+  // Get function body, constant args, and resource args.
+  const FunctionBody* fbody = nullptr;
+  std::vector<int> constant_arg_indices;
+  std::vector<int> resource_arg_indices;
+  TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
+      flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
+
+  // Set input and output memory types.
   MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
+  // These indices are used only for optimization purposes. They allow us
+  // to loop over constant_arg_indices and resource_arg_indices only once
+  // while iterating over all the function arguments checking if it is a
+  // resource or a constant.
+  // The reason we optimized this code is because functions can have a lot of
+  // captured arguments. For example, the backward pass of ResNet50 takes in all
+  // 214 variables and a similar number of activations.
+  SinglePassSearch constants_search(&constant_arg_indices);
+  SinglePassSearch resources_search(&resource_arg_indices);
+  for (int i = 0; i < fbody->arg_types.size(); ++i) {
+    if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
+      // Compile-time constants and resource handles are expected to be in
+      // host memory.
+      input_memory_types[i] = HOST_MEMORY;
+    }
+  }
+  // One might wonder, about the case where a compile-time constant argument
+  // (which must be in host memory) is also used as an input into an op,
+  // e.g. Add, that expects its inputs in device memory. Here is how it
+  // works now.
+  // First, what do we mean by "op expects an input in XYZ memory"?
+  // There are two types of "ops" here: the tf2xla kernel and the HLO
+  // computation it builds. The tf2xla kernel needs to retrieve the actual
+  // numeric value of the compile-time constant tensors, so it really expects
+  // them to be on in host memory. However, for other inputs, it refers to them
+  // using xla::ComputationDataHandle, which is just a symbolic handle that
+  // xla::ComputationBuilder assigns. How does this handle gets assigned for
+  // constant arguments? Even constant arguments get an _Arg node in the graph
+  // instatiated for Function compilation. The tf2xla kernel for constant _Arg
+  // nodes takes the constant value, converts it to XlaLiteral, and feeds it
+  // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
+  // constant XlaLiteral is included in the HLO graph, and subsequently, in
+  // the actual executable, which is copied to the device before being
+  // executed. Thus, when this executable runs, the constant is available in
+  // device memory.
+
+  // XlaLaunch kernel keeps all outputs (including constants, which it copies),
+  // in device memory
   MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
 
+  // Create the kernel.
+  NameAttrList function;
+  function.set_name(node_def.op());
+  *(function.mutable_attr()) = node_def.attr();
+
   Device* dev = flr->device();
   Status s;
   OpKernelConstruction construction(
       DeviceType(dev->device_type()), dev,
-      dev->GetAllocator(AllocatorAttributes()), &launch_def,
+      dev->GetAllocator(AllocatorAttributes()), &node_def,
       &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
       fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
-  kernel->reset(new XlaLocalLaunchOp(&construction));
+
+  *kernel = absl::make_unique<XlaLocalLaunchBase>(
+      &construction, constant_arg_indices, resource_arg_indices, function);
   return s;
 }
 
+namespace {
+
 bool RegisterLaunchOpCreator() {
   RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
   return true;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h
new file mode 100644 (file)
index 0000000..98a22e3
--- /dev/null
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class FunctionLibraryRuntime;
+class OpKernel;
+
+// Given a NodeDef 'node_def' and the function library runtime 'flr', if
+// 'node_def' is a call to a compilable function defined in 'flr', returns OK
+// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
+// node. Otherwise, returns a non-OK.
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+                         std::unique_ptr<OpKernel>* kernel);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
new file mode 100644 (file)
index 0000000..bcd5e75
--- /dev/null
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+NodeDef ToNodeDef(const string& text) {
+  NodeDef node_def;
+  EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+  return node_def;
+}
+
+// Create a FunctionDef that takes one resource and one regular param
+FunctionDef XTimesY() {
+  return FunctionDefHelper::Define(
+      // Name
+      "XTimesY",
+      // Args
+      {"x: float", "y: resource"},
+      // Return values
+      {"z: float"},
+      // Attr def
+      {},
+      // Nodes
+      {
+          {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
+          {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
+      });
+}
+
+class CreateXlaLaunchOpTest : public ::testing::Test {
+ protected:
+  void Init(const std::vector<FunctionDef>& flib) {
+    SessionOptions options;
+    auto* device_count = options.config.mutable_device_count();
+    device_count->insert({"CPU", 1});
+    TF_CHECK_OK(DeviceFactory::AddDevices(
+        options, "/job:localhost/replica:0/task:0", &devices_));
+
+    FunctionDefLibrary proto;
+    for (const auto& fdef : flib) {
+      *(proto.add_function()) = fdef;
+    }
+    lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
+        OpRegistry::Global(), proto);
+    OptimizerOptions opts;
+    device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
+    pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
+        device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
+        opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
+    flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
+  }
+
+  FunctionLibraryRuntime* flr_;
+  std::vector<Device*> devices_;
+  std::unique_ptr<DeviceMgr> device_mgr_;
+  std::unique_ptr<FunctionLibraryDefinition> lib_def_;
+  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+
+  std::unique_ptr<OpKernel> kernel_;
+};
+
+AttrValue BoolAttr(bool b) {
+  AttrValue v;
+  v.set_b(b);
+  return v;
+}
+
+TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) {
+  FunctionDef fdef = XTimesY();
+  (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
+  Init({fdef});
+
+  Status status = CreateXlaLaunchOp(
+      flr_, ToNodeDef(R"pb(
+        name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
+      )pb"), &kernel_);
+  ASSERT_TRUE(status.ok()) << status.ToString();
+
+  EXPECT_EQ("XTimesY", kernel_->name());
+  EXPECT_EQ("XTimesY", kernel_->type_string());
+
+  EXPECT_EQ(2, kernel_->num_inputs());
+  EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
+  EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
+  EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
+  EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
+
+  EXPECT_EQ(1, kernel_->num_outputs());
+  EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
+  EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) {
+  FunctionDef fdef = XTimesY();
+  Init({fdef});
+
+  Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+                                      name: 'XTimesY'
+                                      op: 'XTimesY'
+                                      input: 'a'
+                                      input: 'b'
+                                    )proto"), &kernel_);
+  EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) {
+  FunctionDef fdef = XTimesY();
+  (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
+  Init({fdef});
+
+  Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+                                      name: 'XTimesY'
+                                      op: 'XTimesY'
+                                      input: 'a'
+                                      input: 'b'
+                                    )proto"), &kernel_);
+  EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+}  // namespace tensorflow
index 049d170..86a9fd3 100644 (file)
@@ -39,15 +39,15 @@ limitations under the License.
 
 namespace tensorflow {
 
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
-    : OpKernel(ctx), device_type_(ctx->device_type()) {
-  const NameAttrList* func;
-  OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
-  function_ = *func;
-  DataTypeVector constant_types;
-  OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
-  num_constant_args_ = constant_types.size();
-  OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+                                       const std::vector<int>& constants,
+                                       const std::vector<int>& resources,
+                                       const NameAttrList& function)
+    : OpKernel(ctx),
+      constants_(constants),
+      resources_(resources),
+      device_type_(ctx->device_type()),
+      function_(function) {
   if (device_type_ == DeviceType(DEVICE_CPU)) {
     platform_id_ = se::host::kHostPlatformId;
   } else if (device_type_ == DeviceType(DEVICE_GPU)) {
@@ -57,8 +57,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
   }
 }
 
-Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
-                                               XlaCompilationCache** cache) {
+Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
+                                                 XlaCompilationCache** cache) {
   const XlaDevice::Metadata* metadata;
   Status s = XlaDevice::GetMetadata(ctx, &metadata);
   if (s.ok()) {
@@ -90,8 +90,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
   return Status::OK();
 }
 
-void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
-  VLOG(1) << "XlaLocalLaunchOp::Compute "
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "XlaLocalLaunchOpBase::Compute "
           << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
   // We store information about the JIT-compiled XLA computation
   // in the ResourceMgr.
@@ -124,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   }
 
   std::map<int, OptionalTensor> variables =
-      SnapshotResourceVariables(ctx, num_resource_args_);
+      SnapshotResourceVariables(ctx, resources_);
 
   xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
 
@@ -161,7 +161,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   xla::LocalExecutable* executable;
 
   std::map<int, Tensor> constant_args;
-  for (int i = 0; i < num_constant_args_; ++i) {
+  for (int i : constants_) {
     constant_args.insert({i, ctx->input(i)});
   }
   OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
@@ -170,8 +170,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
 
   VLOG(1) << "Executing XLA Computation...";
 
-  XlaComputationLaunchContext launch_context(
-      num_resource_args_, client, xla_allocator, allocate_xla_tensors);
+  XlaComputationLaunchContext launch_context(client, xla_allocator,
+                                             allocate_xla_tensors);
   launch_context.PopulateInputs(ctx, kernel, variables);
 
   // Execute the computation.
@@ -194,6 +194,62 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   VLOG(1) << "Done";
 }
 
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
+  do {                                                      \
+    ::tensorflow::Status _s(__VA_ARGS__);                   \
+    if (!TF_PREDICT_TRUE(_s.ok())) {                        \
+      (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+      return RET;                                           \
+    }                                                       \
+  } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+  DataTypeVector constant_types;
+  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+                        ctx->GetAttr("Tconstants", &constant_types));
+  std::vector<int> constants(constant_types.size());
+  std::iota(constants.begin(), constants.end(), 0);
+  return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+  DataTypeVector constant_types;
+  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+                        ctx->GetAttr("Tconstants", &constant_types));
+
+  DataTypeVector arg_types;
+  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+                        ctx->GetAttr("Targs", &arg_types));
+
+  int num_resources;
+  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+                        ctx->GetAttr("Nresources", &num_resources));
+
+  std::vector<int> resources(num_resources);
+  std::iota(resources.begin(), resources.end(),
+            constant_types.size() + arg_types.size());
+  return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+  const NameAttrList* func;
+  OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+  return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+}  // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+    : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+                         FunctionAttr(ctx)) {}
+
 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
   VLOG(1) << "XlaLocalLaunchOp destroyed";
 }
index 8f8e646..8dfc4b3 100644 (file)
@@ -26,6 +26,41 @@ limitations under the License.
 
 namespace tensorflow {
 
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+  XlaLocalLaunchBase(OpKernelConstruction* ctx,
+                     const std::vector<int>& constants,
+                     const std::vector<int>& resources,
+                     const NameAttrList& function);
+  XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+  XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+  ~XlaLocalLaunchBase() override = default;
+
+  void Compute(OpKernelContext* ctx) override;
+
+ protected:
+  // Builds a XlaCompilationCache class suitable for the current device.
+  Status BuildCompilationCache(OpKernelContext* ctx,
+                               XlaCompilationCache** cache);
+
+  // Indexes of compile-time constant inputs
+  std::vector<int> constants_;
+  // Indexes of resource inputs
+  std::vector<int> resources_;
+
+  DeviceType device_type_;
+  NameAttrList function_;
+  se::Platform::Id platform_id_;
+};
+
 // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
 // which will be compiled and executed using XLA.  The XlaLocalLaunchOp is
 // responsible for handling interactions with the TensorFlow executor.
@@ -35,26 +70,12 @@ namespace tensorflow {
 // XlaLocalLaunchOp uses xla::LocalClient::Compile() and
 // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
 // memory.
-class XlaLocalLaunchOp : public OpKernel {
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
  public:
   explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
   ~XlaLocalLaunchOp() override;
 
-  void Compute(OpKernelContext* ctx) override;
-
  private:
-  // Builds a XlaCompilationCache class suitable for the current device.
-  Status BuildCompilationCache(OpKernelContext* ctx,
-                               XlaCompilationCache** compiler);
-
-  DeviceType device_type_;
-  NameAttrList function_;
-  int num_constant_args_;
-  // Number of resource variable arguments.
-  int num_resource_args_;
-
-  se::Platform::Id platform_id_;
-
   TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
 };
 
index 60458f6..6b83cf6 100644 (file)
@@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
                                  const XlaCompiler::CompilationResult* result,
                                  xla::LocalExecutable* executable) {
   std::map<int, OptionalTensor> variables = GetVariables(ctx);
-  int64 num_resource_args = variables.size();
 
   xla::LocalClient* client = metadata.client();
 
   // Builds an XLA allocator for the device.
   XlaComputationLaunchContext launch_context(
-      num_resource_args, client, client->backend().memory_allocator(), true);
+      client, client->backend().memory_allocator(), true);
 
   launch_context.PopulateInputs(ctx, result, variables);
 
index 33e5361..0223f97 100644 (file)
@@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer;
 using xla::ShapedBuffer;
 }  // anonymous namespace
 
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
-                                                        int num_variables) {
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+    OpKernelContext* ctx, const std::vector<int>& variables) {
   std::map<int, OptionalTensor> snapshot;
-  int first_variable = ctx->num_inputs() - num_variables;
-  for (int i = 0; i < num_variables; ++i) {
+  for (int i : variables) {
     Var* variable = nullptr;
-    ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
-    OptionalTensor& tensor = snapshot[first_variable + i];
+    ResourceHandle handle = HandleFromInput(ctx, i);
+    OptionalTensor& tensor = snapshot[i];
     if (LookupResource(ctx, handle, &variable).ok()) {
       tf_shared_lock lock(*variable->mu());
       tensor.name = handle.name();
@@ -112,10 +111,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
 using internal::ExtractSubShapedBuffer;
 
 XlaComputationLaunchContext::XlaComputationLaunchContext(
-    int64 num_resource_args, xla::LocalClient* client,
-    xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
-    : num_resource_args_(num_resource_args),
-      client_(client),
+    xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
+    bool allocate_xla_tensors)
+    : client_(client),
       xla_allocator_(xla_allocator),
       allocate_xla_tensors_(allocate_xla_tensors) {}
 
index 38291b0..a243125 100644 (file)
@@ -31,15 +31,17 @@ limitations under the License.
 namespace tensorflow {
 class XlaAllocator;
 
-// Takes a snapshot of the values of resource variable arguments, which are
-// the last `num_variables` arguments. We snapshot tensors that back
+// Takes a snapshot of the values of resource variable arguments, whose
+// indices are specified in `variables` argument. We snapshot tensors that back
 // resource variables since concurrent updates may modify the shape, and it is
 // important that the shapes used for compilation match the true shapes of the
 // buffers.
 //
-// Returns a map of TensorFlow argument index to resource variable.
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
-                                                        int num_variables);
+// Returns a map of TensorFlow argument index to resource variable. If a
+// resource variable is not initialized, the corresponding OptionalTensor
+// will have its `present` field set to false.
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+    OpKernelContext* ctx, const std::vector<int>& variables);
 
 // Adapter class that wraps a Tensorflow allocator as an XLA allocator.
 // Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -72,7 +74,7 @@ class XlaComputationLaunchContext {
   // Create a new launch context. 'allocate_xla_tensors' is true if allocated
   // output tensors and variables are always XlaTensors. If false they are
   // assumed to be "normal" device pointers.
-  XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
+  XlaComputationLaunchContext(xla::LocalClient* client,
                               xla::DeviceMemoryAllocator* xla_allocator,
                               bool allocate_xla_tensors);
 
@@ -92,7 +94,6 @@ class XlaComputationLaunchContext {
   const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
 
  private:
-  int64 num_resource_args_;
   xla::LocalClient* client_;
   xla::DeviceMemoryAllocator* xla_allocator_;
   bool allocate_xla_tensors_;
index aaea83a..9791792 100644 (file)
@@ -327,7 +327,11 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:layers",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:nn",
         "//tensorflow/python:platform_test",
+        "//tensorflow/python/eager:function",
     ],
 )
 
index bdd0185..5ab1585 100644 (file)
@@ -24,10 +24,16 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
+from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.layers import convolutional
+from tensorflow.python.layers import pooling
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import googletest
 
@@ -43,7 +49,7 @@ class EagerTest(XLATestCase):
 
   def testExecuteListOutputLen0(self):
     with self.test_scope():
-      empty = constant_op.constant([], dtype=dtypes.int32)
+      empty = constant_op.constant([], dtype=dtypes.float32)
       result = array_ops.unstack(empty, 0)
       self.assertTrue(isinstance(result, list))
       self.assertEqual(0, len(result))
@@ -51,7 +57,7 @@ class EagerTest(XLATestCase):
   def testExecuteListOutputLen1(self):
     with self.test_scope():
       split_dim = constant_op.constant(1)
-      value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
       result = array_ops.split(value, 1, axis=split_dim)
       self.assertTrue(isinstance(result, list))
       self.assertEqual(1, len(result))
@@ -60,7 +66,7 @@ class EagerTest(XLATestCase):
   def testExecuteListOutputLen3(self):
     with self.test_scope():
       split_dim = constant_op.constant(1)
-      value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
       result = array_ops.split(value, 3, axis=split_dim)
       self.assertTrue(isinstance(result, list))
       self.assertEqual(3, len(result))
@@ -131,7 +137,105 @@ class EagerTest(XLATestCase):
     self.assertEqual(2., grads[0][0].numpy())
 
 
-if __name__ == "__main__":
+class EagerFunctionTest(XLATestCase):
+
+  def testBasic(self):
+    with self.test_scope():
+      matmul = function.defun(math_ops.matmul, compiled=True)
+      t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+      sq = matmul(t, t, transpose_a=True)
+      self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
+
+  def testConv(self):
+    if 'GPU' in self.device:
+      # TODO(b/32333178)
+      self.skipTest('Current implementation of RandomStandardNormal kernel '
+                    'is very slow on GPU, and has been blacklisted.')
+    with self.test_scope():
+      data_format = 'channels_last'
+      conv = convolutional.Conv2D(
+          filters=1, kernel_size=2, padding='VALID',
+          data_format=data_format, activation=nn_ops.relu,
+          kernel_initializer=init_ops.ones_initializer(),
+          bias_initializer=init_ops.zeros_initializer())
+      pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
+
+      def model(x):
+        x = conv(x)
+        return pool(x)
+      model = function.defun(model, compiled=True)
+
+      x = array_ops.ones([1, 4, 4, 1])
+      y = model(x)
+      self.assertAllEqual(y.numpy(), [[[[4.]]]])
+
+  def testReadVariable(self):
+    with self.test_scope():
+      v = resource_variable_ops.ResourceVariable(1.0)
+
+      @function.defun(compiled=True)
+      def f():
+        return v.read_value()
+
+      var = f()
+      self.assertEqual(1.0, var.numpy())
+
+  def testUpdateVariable(self):
+    with self.test_scope():
+      v = resource_variable_ops.ResourceVariable(1.0)
+
+      def f(v):
+        v.assign_add(1.0)
+        return v
+
+      f = function.defun(f, compiled=True)
+
+      var = f(v)
+      self.assertEqual(2.0, var.numpy())
+
+  def testAllArgumentKinds(self):
+    """Test a complex function that takes different argument kinds.
+
+    tf2xla machinery that translates, compiles, and runs defuns
+    classifies arguments into: compile-time constants, regular tensors,
+    and resources. This test creates a function with a mix of all these
+    kinds. Moreover, the order of function arguments is intentionally mixed up.
+
+    This also tests the case when the same argument is a compile-time constant
+    as well as used in an operation that normally expects its inputs to be
+    in device memory - addition in this case.
+    """
+    with self.test_scope():
+      def foo(c1, r1, v1, c2, v2, r2):
+        # c1 and c2 are compile-time constants
+        # r1 and r2 are regular tensors
+        # v1 and v2 are resource variables
+        a = c1 + r1
+        b = math_ops.cast(c2, dtypes.float32) + v2
+        c = array_ops.slice(v1, c1, c2)
+        d = r2 * v2
+        return a, b, c, d
+
+      foo = function.defun(foo, compiled=True)
+
+      c1 = [0, 0]
+      c2 = array_ops.ones([2], dtype=dtypes.int32)
+
+      r1 = array_ops.ones([2])
+      r2 = [[2., 2.], [3., 3.]]
+
+      v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
+      v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
+
+      a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
+
+      self.assertAllEqual([1, 1], a.numpy())
+      self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
+      self.assertAllEqual([[1.]], c.numpy())
+      self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
+
+
+if __name__ == '__main__':
   ops.enable_eager_execution(
       config=config_pb2.ConfigProto(log_device_placement=True))
   googletest.main()
index 8517a3b..b8f352d 100644 (file)
@@ -36,9 +36,7 @@ def device_and_data_format():
                                                               'channels_last')
 
 
-def random_batch(batch_size, device_and_format=None):
-  _, data_format = device_and_format or device_and_data_format()
-
+def random_batch(batch_size, data_format):
   shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
   shape = (batch_size,) + shape
 
@@ -70,7 +68,7 @@ class ResNet50Test(tf.test.TestCase):
     if defun:
       model.call = tfe.defun(model.call)
     with tf.device(device), tfe.execution_mode(execution_mode):
-      images, _ = random_batch(2)
+      images, _ = random_batch(2, data_format)
       output = model(images, training=False)
       tfe.async_wait()
     self.assertEqual((2, 1000), output.shape)
@@ -91,7 +89,7 @@ class ResNet50Test(tf.test.TestCase):
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format, include_top=False)
     with tf.device(device):
-      images, _ = random_batch(2)
+      images, _ = random_batch(2, data_format)
       output = model(images, training=False)
     output_shape = ((2, 2048, 1, 1)
                     if data_format == 'channels_first' else (2, 1, 1, 2048))
@@ -101,7 +99,7 @@ class ResNet50Test(tf.test.TestCase):
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
     with tf.device(device):
-      images, _ = random_batch(2)
+      images, _ = random_batch(2, data_format)
       output = model(images, training=False)
     self.assertEqual((2, 2048), output.shape)
 
@@ -115,7 +113,7 @@ class ResNet50Test(tf.test.TestCase):
         name='t0').as_default(), tf.contrib.summary.always_record_summaries():
       with tf.device(device), tfe.execution_mode(execution_mode):
         optimizer = tf.train.GradientDescentOptimizer(0.1)
-        images, labels = random_batch(2)
+        images, labels = random_batch(2, data_format)
         train_one_step(model, images, labels, optimizer)
         self.assertEqual(320, len(model.variables))
         tfe.async_wait()
@@ -134,7 +132,7 @@ class ResNet50Test(tf.test.TestCase):
     model = resnet50.ResNet50(data_format)
     optimizer = tf.train.GradientDescentOptimizer(0.1)
     with tf.device(device):
-      images, labels = random_batch(2)
+      images, labels = random_batch(2, data_format)
       gc.disable()
       # Warm up. Note that this first run does create significant amounts of
       # garbage to be collected. The hope is that this is a build-only effect,
@@ -202,18 +200,18 @@ class ResNet50Benchmarks(tf.test.Benchmark):
     # which forces a sync. This is a roundabout way, yes.
     tf.constant(1.).cpu()
 
-  def _benchmark_eager_apply(self, label, defun=False, execution_mode=None,
-                             device_and_format=None):
+  def _benchmark_eager_apply(self, label, device_and_format, defun=False,
+                             execution_mode=None, compiled=False):
     with tfe.execution_mode(execution_mode):
-      device, data_format = device_and_format or device_and_data_format()
+      device, data_format = device_and_format
       model = resnet50.ResNet50(data_format)
       if defun:
-        model.call = tfe.defun(model.call)
+        model.call = tfe.defun(model.call, compiled=compiled)
       batch_size = 64
       num_burn = 5
       num_iters = 30
       with tf.device(device):
-        images, _ = random_batch(batch_size, device_and_format)
+        images, _ = random_batch(batch_size, data_format)
         for _ in xrange(num_burn):
           model(images, training=False).cpu()
         if execution_mode:
@@ -227,30 +225,34 @@ class ResNet50Benchmarks(tf.test.Benchmark):
         self._report(label, start, num_iters, device, batch_size, data_format)
 
   def benchmark_eager_apply_sync(self):
-    self._benchmark_eager_apply('eager_apply', defun=False)
+    self._benchmark_eager_apply('eager_apply', device_and_data_format(),
+                                defun=False)
 
   def benchmark_eager_apply_async(self):
     self._benchmark_eager_apply(
-        'eager_apply_async', defun=False, execution_mode=tfe.ASYNC)
+        'eager_apply_async', device_and_data_format(), defun=False,
+        execution_mode=tfe.ASYNC)
 
   def benchmark_eager_apply_with_defun(self):
-    self._benchmark_eager_apply('eager_apply_with_defun', defun=True)
+    self._benchmark_eager_apply('eager_apply_with_defun',
+                                device_and_data_format(), defun=True)
 
   def _benchmark_eager_train(self,
                              label,
                              make_iterator,
+                             device_and_format,
                              defun=False,
                              execution_mode=None,
-                             device_and_format=None):
+                             compiled=False):
     with tfe.execution_mode(execution_mode):
-      device, data_format = device_and_format or device_and_data_format()
+      device, data_format = device_and_format
       for batch_size in self._train_batch_sizes():
-        (images, labels) = random_batch(batch_size, device_and_format)
+        (images, labels) = random_batch(batch_size, data_format)
         num_burn = 3
         num_iters = 10
         model = resnet50.ResNet50(data_format)
         if defun:
-          model.call = tfe.defun(model.call)
+          model.call = tfe.defun(model.call, compiled=compiled)
         optimizer = tf.train.GradientDescentOptimizer(0.1)
 
         with tf.device(device):
@@ -273,18 +275,21 @@ class ResNet50Benchmarks(tf.test.Benchmark):
           self._report(label, start, num_iters, device, batch_size, data_format)
 
   def benchmark_eager_train_sync(self):
-    self._benchmark_eager_train('eager_train', MockIterator, defun=False)
+    self._benchmark_eager_train('eager_train', MockIterator,
+                                device_and_data_format(), defun=False)
 
   def benchmark_eager_train_async(self):
     self._benchmark_eager_train(
         'eager_train_async',
         MockIterator,
+        device_and_data_format(),
         defun=False,
         execution_mode=tfe.ASYNC)
 
   def benchmark_eager_train_with_defun(self):
     self._benchmark_eager_train(
-        'eager_train_with_defun', MockIterator, defun=True)
+        'eager_train_with_defun', MockIterator,
+        device_and_data_format(), defun=True)
 
   def benchmark_eager_train_datasets(self):
 
@@ -294,7 +299,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
       return tfe.Iterator(ds)
 
     self._benchmark_eager_train(
-        'eager_train_dataset', make_iterator, defun=False)
+        'eager_train_dataset', make_iterator,
+        device_and_data_format(), defun=False)
 
   def benchmark_eager_train_datasets_with_defun(self):
 
@@ -304,7 +310,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
       return tfe.Iterator(ds)
 
     self._benchmark_eager_train(
-        'eager_train_dataset_with_defun', make_iterator, defun=True)
+        'eager_train_dataset_with_defun', make_iterator,
+        device_and_data_format(), defun=True)
 
 
 if __name__ == '__main__':
index 89257bb..b478b6b 100644 (file)
@@ -23,6 +23,7 @@ import collections
 
 import numpy as np
 
+from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import function_pb2
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import context
@@ -227,7 +228,7 @@ def _inference_name(n):
 class _EagerDefinedFunction(object):
   """Function object with the interface of tf _DefinedFunction."""
 
-  def __init__(self, name, graph, operations, inputs, outputs):
+  def __init__(self, name, graph, operations, inputs, outputs, attrs):
     """Initializes an eager defined function.
 
     Args:
@@ -237,6 +238,7 @@ class _EagerDefinedFunction(object):
         which will be in the function
       inputs: the tensors in the graph to be used as inputs to the function
       outputs: the tensors in the graph which will be outputs to the function
+      attrs: dict mapping names of attributes to their AttrValue values
     """
     fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
         graph._c_graph,  # pylint: disable=protected-access
@@ -248,6 +250,14 @@ class _EagerDefinedFunction(object):
         [],
         None,
         compat.as_str(""))
+
+    for name, attr_value in attrs.items():
+      serialized = attr_value.SerializeToString()
+      # TODO(iga): this creates and deletes a new TF_Status for every attr.
+      # It might be worth creating a convenient way to re-use status.
+      pywrap_tensorflow.TF_FunctionSetAttrValueProto(
+          fn, compat.as_str(name), serialized)
+
     # TODO(apassos) avoid creating a FunctionDef (specially to grab the
     # signature, but also in general it's nice not to depend on it.
     with c_api_util.tf_buffer() as buffer_:
@@ -289,25 +299,6 @@ def _flatten(sequence):
 
 class GraphModeFunction(object):
   """Callable object representing a graph-mode function.
-
-  Args:
-    name: str the name of the created function
-    input_placeholders: list of placeholder values (tensors) to feed when
-      calling the wrapped function.
-    extra_inputs: Tensor inputs this function definition closed over which
-      are passed as arguments. Need to track so gradients are supported
-      correctly.
-    graph: the Graph from which the operations will be pulled. Used as
-      a context when computing gradients.
-    operations: the subset of Operations in the graph used in the function
-      definition.
-    outputs: a flat list of the Tensors in the graph used as outputs to the
-      function
-    func_outputs: a possibly nested python object which will be returned by
-      this function. The Tensors in this structure will be replaced by their
-      corresponding values in outputs.
-    output_shapes: List of shapes of all tensors in outputs
-    variables: (optional) List of variables to watch during function execution.
   """
 
   def __init__(self,
@@ -319,9 +310,36 @@ class GraphModeFunction(object):
                outputs,
                func_outputs,
                output_shapes,
-               variables=None):
+               variables=None,
+               attrs=None):
+    """Initialize a GraphModeFunction.
+
+    Args:
+      name: str the name of the created function
+      input_placeholders: list of placeholder values (tensors) to feed when
+        calling the wrapped function.
+      extra_inputs: Tensor inputs this function definition closed over which
+        are passed as arguments. Need to track so gradients are supported
+        correctly.
+      graph: the Graph from which the operations will be pulled. Used as
+        a context when computing gradients.
+      operations: the subset of Operations in the graph used in the function
+        definition.
+      outputs: a flat list of the Tensors in the graph used as outputs to the
+        function
+      func_outputs: a possibly nested python object which will be returned by
+        this function. The Tensors in this structure will be replaced by their
+        corresponding values in outputs.
+      output_shapes: List of shapes of all tensors in outputs
+      variables: (optional) List of variables to watch during function
+        execution.
+      attrs: (optional) dict mapping names of attributes to their AttrValue
+        values. Attributes in `attrs` will be included in this function's
+        definition.
+    """
+    self._attrs = attrs or {}
     defined_function = _EagerDefinedFunction(
-        name, graph, operations, input_placeholders, outputs)
+        name, graph, operations, input_placeholders, outputs, self._attrs)
     if len(input_placeholders) != len(defined_function.signature.input_arg):
       raise ValueError("Internal error: invalid lengths. %s %s" % (
           len(input_placeholders), len(defined_function.signature.input_arg)))
@@ -374,7 +392,7 @@ class GraphModeFunction(object):
     forward_name = _forward_name(self._func_name)
     self._forward_fdef = _EagerDefinedFunction(
         forward_name, self._graph, self._ops, self._input_placeholders,
-        filtered_outputs + captures)
+        filtered_outputs + captures, self._attrs)
     all_inputs = self._out_grad_placeholders + captures
     # Excluding input ops from the body as we do not intend to execute these
     # operations when the function is executed.
@@ -388,7 +406,7 @@ class GraphModeFunction(object):
     bname = _backward_name(self._func_name)
     self._backward_function = GraphModeFunction(
         bname, all_inputs, [], self._graph, function_def_ops,
-        backward_outputs, in_gradients, output_shapes)
+        backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
 
   def _backprop_call(self, args):
     """Calls the wrapped function and records the result on a tape."""
@@ -562,7 +580,7 @@ def _get_defun_inputs(args):
   return nest.pack_sequence_as(args, ret)
 
 
-def _defun_internal(name, func, args, kwds):
+def _defun_internal(name, func, compiled, args, kwds):
   """Defines and returns graph-mode version of func."""
   graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
   with context.graph_mode():
@@ -627,9 +645,14 @@ def _defun_internal(name, func, args, kwds):
     for f in tmp_graph._functions.values():  # pylint: disable=protected-access
       # TODO(ashankar): What about the gradient registry?
       _register(f._c_func.func)  # pylint: disable=protected-access
+
+  attrs = {}
+  if compiled:
+    attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True)
+
   return GraphModeFunction(
       fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
-      func_outputs, output_shapes, variables)
+      func_outputs, output_shapes, variables, attrs)
 
 
 # Defun uses this instead of Tensor as a cache key. Using dtype because
@@ -671,7 +694,7 @@ def _register(fn):
 
 
 # TODO(apassos): better error messages for non-hashable arguments.
-def named_defun(func, name):
+def named_defun(func, name, compiled=False):
   """Defines a function with a given name.
 
   See the documentation for `defun` for more information on the semantics of the
@@ -680,6 +703,7 @@ def named_defun(func, name):
   Args:
     func: the function to be wrapped.
     name: the name given to it.
+    compiled: if true, the framework will attempt to compile func with XLA.
 
   Returns:
     the wrapped function.
@@ -696,13 +720,13 @@ def named_defun(func, name):
 
     if cache_key not in arguments_to_functions:
       arguments_to_functions[cache_key] = _defun_internal(
-          name, func, args, kwds)
+          name, func, compiled, args, kwds)
     return arguments_to_functions[cache_key](*args)
 
   return decorated
 
 
-def defun(func):
+def defun(func=None, compiled=False):
   """Decorator to compile func into graph_mode.
 
   `defun` converts a function that constructs a TensorFlow graph into a function
@@ -745,18 +769,45 @@ def defun(func):
   ```
 
   Args:
-    func: function to be compiled.
+    func: function to be compiled. If `func` is None, returns a
+      decorator that can be invoked with a single argument - `func`. The
+      end result is equivalent to providing all the arguments up front.
+      In other words, defun(compiled=True)(func) is equivalent to
+      defun(func, compiled=True). The former allows the following use case:
+        @tfe.defun(compiled=True)
+        def foo(...):
+          ...
+    compiled: If True, an attempt to compile `func` with XLA will be made.
+      If it fails, function will be run normally. Experimental.
+      Currently, supported only for execution on TPUs.
 
   Returns:
-     A callable that will execute the compiled function (and return zero
-     or more `tf.Tensor` objects).
+     If `func` is not None, returns callable that will execute the compiled
+     function (and return zero or more `tf.Tensor` objects).
+     If `func` is None, returns a decorator that, when invoked with a single
+     `func` argument, returns a callable equivalent to the case above.
   """
   # TODO(apassos): deal with captured global state. Deal with control flow.
-  try:
-    name = func.__name__
-  except AttributeError:
-    name = "function"
-  return tf_decorator.make_decorator(func, named_defun(func, name))
+  def decorated(function):
+    try:
+      name = function.__name__
+    except AttributeError:
+      name = "function"
+    return tf_decorator.make_decorator(
+        function, named_defun(function, name, compiled=compiled))
+
+  # This code path is for the `foo = tfe.defun(foo, ...)` use case
+  if func is not None:
+    return decorated(func)
+
+  # This code path is for the
+  #
+  # @tfe.defun(...)
+  # def foo(...):
+  #    ...
+  #
+  # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
+  return decorated
 
 
 def make_defun_op(func, *args, **kwds):
@@ -808,7 +859,7 @@ def make_defun_op(func, *args, **kwds):
   name = func.__name__
   if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
     raise ValueError("Tensor keyword arguments are not supported.")
-  return _defun_internal(name, func, args, kwds)
+  return _defun_internal(name, func, False, args, kwds)
 
 
 class AutomaticControlDependencies(object):