Automated g4 rollback of changelist 195748721
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 8 May 2018 09:11:52 +0000 (02:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 20:54:39 +0000 (13:54 -0700)
PiperOrigin-RevId: 195790581

13 files changed:
tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/create_xla_launch_op.cc
tensorflow/compiler/jit/create_xla_launch_op.h [deleted file]
tensorflow/compiler/jit/create_xla_launch_op_test.cc [deleted file]
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/kernels/xla_launch_op.h
tensorflow/compiler/jit/xla_compile_on_demand_op.cc
tensorflow/compiler/jit/xla_launch_util.cc
tensorflow/compiler/jit/xla_launch_util.h
tensorflow/compiler/tests/BUILD
tensorflow/compiler/tests/eager_test.py
tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
tensorflow/python/eager/function.py

index e942b46..07136d6 100644 (file)
@@ -261,7 +261,6 @@ cc_library(
     name = "create_xla_launch_op",
     srcs = [
         "create_xla_launch_op.cc",
-        "create_xla_launch_op.h",
     ],
     deps = [
         ":common",
@@ -271,27 +270,6 @@ cc_library(
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-    ],
-    alwayslink = 1,
-)
-
-tf_cc_test(
-    name = "create_xla_launch_op_test",
-    srcs = [
-        "create_xla_launch_op.h",
-        "create_xla_launch_op_test.cc",
-    ],
-    deps = [
-        ":create_xla_launch_op",
-        "//tensorflow/core:core_cpu_internal",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:session_options",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
     ],
 )
 
index 6ac84dc..18d9013 100644 (file)
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/compiler/jit/create_xla_launch_op.h"
 
 #include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
@@ -26,189 +25,78 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
-// Utility which searches for values in a sorted list by scanning over it once.
-// No matter how many times ScanForValue is called, the list is scanned at most
-// once. However, if a call to ScanForValue skips over a value, that value is
-// not revisited in future calls to ScanForValue, so callers must take
-// care to order their calls.
+// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
+// 'ndef' is a call to a compilable function defined in 'flr', returns OK
+// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
+// node. Otherwise, returns a non-OK.
 //
-// Useful for merging multiple sorted lists in O(n) time.
-class SinglePassSearch {
- public:
-  // Creates a SinglePassSearch object that can be used to search in `values`.
-  // Does not take ownership of `values`. `values` must outlive this.
-  // `values` must be sorted.
-  explicit SinglePassSearch(const std::vector<int>* values)
-      : current_index_(0), values_(values) {}
-
-  // Scans forward in the vector looking for "value", updating the internal
-  // position in to the vector.
-  // Returns true iff the vector contains the given value at or after current
-  // position.
-  // Not thread-safe.
-  bool ScanForValue(int value) {
-    while (current_index_ < values_->size() &&
-           (*values_)[current_index_] <= value) {
-      if ((*values_)[current_index_] == value) {
-        current_index_++;
-        return true;
-      }
-      current_index_++;
-    }
-    return false;
-  }
-
- private:
-  int current_index_;
-  const std::vector<int>* values_;
-};
-
-Status CompilationRequested(const FunctionLibraryRuntime& flr,
-                            const NodeDef& node_def) {
+// This routine is here so that FunctionLibraryRuntime can jit a
+// specific function call as requested.
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
+                         std::unique_ptr<OpKernel>* kernel) {
   bool xla_compile = false;
-  // Check if op is marked _XlaCompile=true.
-  Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
-      node_def, kXlaCompileAttr, &xla_compile);
-  if (!status.ok() || !xla_compile) {
-    if (VLOG_IS_ON(3)) {
-      if (!status.ok()) {
-        VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
-                << node_def.op() << ". status=" << status.ToString();
-      } else {
-        VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
-      }
-    }
-    return Status(error::INVALID_ARGUMENT, "");
+  if (!flr->GetFunctionLibraryDefinition()
+           ->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
+           .ok() ||
+      !xla_compile) {
+    // Not marked as _XlaCompile=true.
+    return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
+  }
+  // Make sure that kernels have been registered on the JIT device.
+  XlaOpRegistry::RegisterCompilationKernels();
+  if (!IsCompilable(flr, ndef)) {
+    // ndef is calling a function that XLA can't compile.
+    return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
   }
-  return Status::OK();
-}
-
-// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
-// runtime, returns this function's body in `fbody` as well as the indices
-// of its constant and resource arguments.
-// `fbody` is owned by `flr`.
-// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
-// They are sorted in ascending order on this function's return.
-Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
-                                       const NodeDef& node_def,
-                                       const FunctionBody** fbody,
-                                       std::vector<int>* constant_arg_indices,
-                                       std::vector<int>* resource_arg_indices) {
   FunctionLibraryRuntime::Handle handle;
-  // If node_def is not instantiable, e.g., the function does not exist,
+  // If ndef is not instantiable, e.g., the function does not exist,
   // simply bail out.
   TF_RETURN_IF_ERROR(
-      flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
-  *fbody = flr->GetFunctionBody(handle);
-  CHECK(*fbody);  // Can't be nullptr since we just instantiated it.
-  const DataTypeVector& arg_types = (*fbody)->arg_types;
-  std::vector<bool> const_args(arg_types.size());
+      flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
+  const FunctionBody* fbody = flr->GetFunctionBody(handle);
+  CHECK(fbody);  // Can't be nullptr since we just instantiated it.
+  std::vector<bool> const_args(fbody->arg_types.size());
   // If we can't analyze the const args. Bail out.
-  TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
+  TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
 
   for (int i = 0; i < const_args.size(); ++i) {
     if (const_args[i]) {
-      constant_arg_indices->push_back(i);
-    }
-  }
-
-  // There can be hundreds of resource variables. Reserve the space for them.
-  // We don't reserve for constants above as they are usually few.
-  resource_arg_indices->reserve(arg_types.size());
-  for (int i = 0; i < arg_types.size(); ++i) {
-    if (arg_types[i] == DT_RESOURCE) {
-      resource_arg_indices->push_back(i);
+      // There is a const arg. Bail out.
+      return errors::InvalidArgument("Const arg: ", i, " in ",
+                                     DebugString(fbody->fdef));
     }
   }
 
-  return Status::OK();
-}
-
-}  // namespace
-
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
-                         std::unique_ptr<OpKernel>* kernel) {
-  TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
-
-  VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
-
-  // Make sure that kernels have been registered on the JIT device.
-  XlaOpRegistry::RegisterCompilationKernels();
-  if (!IsCompilable(flr, node_def)) {
-    // node_def is calling a function that XLA can't compile.
-    return errors::InvalidArgument("Not compilable: ",
-                                   node_def.ShortDebugString());
-  }
-
-  // Get function body, constant args, and resource args.
-  const FunctionBody* fbody = nullptr;
-  std::vector<int> constant_arg_indices;
-  std::vector<int> resource_arg_indices;
-  TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
-      flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
-
-  // Set input and output memory types.
+  NodeDef launch_def;
+  launch_def.set_name(ndef.name());
+  launch_def.set_op("_XlaLaunch");
+  launch_def.set_device(flr->device()->name());
+  AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
+  AddNodeAttr("Nresources", 0, &launch_def);
+  AddNodeAttr("Targs", fbody->arg_types, &launch_def);
+  AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
+  NameAttrList func;
+  func.set_name(ndef.op());
+  *(func.mutable_attr()) = ndef.attr();
+  AddNodeAttr("function", func, &launch_def);
+
+  // TODO(b/32387911): Handles the host memory types across function
+  // calls properly. For now, we assume all inputs and outputs are on
+  // the device memory.
   MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
-  // These indices are used only for optimization purposes. They allow us
-  // to loop over constant_arg_indices and resource_arg_indices only once
-  // while iterating over all the function arguments checking if it is a
-  // resource or a constant.
-  // The reason we optimized this code is because functions can have a lot of
-  // captured arguments. For example, the backward pass of ResNet50 takes in all
-  // 214 variables and a similar number of activations.
-  SinglePassSearch constants_search(&constant_arg_indices);
-  SinglePassSearch resources_search(&resource_arg_indices);
-  for (int i = 0; i < fbody->arg_types.size(); ++i) {
-    if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
-      // Compile-time constants and resource handles are expected to be in
-      // host memory.
-      input_memory_types[i] = HOST_MEMORY;
-    }
-  }
-  // One might wonder, about the case where a compile-time constant argument
-  // (which must be in host memory) is also used as an input into an op,
-  // e.g. Add, that expects its inputs in device memory. Here is how it
-  // works now.
-  // First, what do we mean by "op expects an input in XYZ memory"?
-  // There are two types of "ops" here: the tf2xla kernel and the HLO
-  // computation it builds. The tf2xla kernel needs to retrieve the actual
-  // numeric value of the compile-time constant tensors, so it really expects
-  // them to be on in host memory. However, for other inputs, it refers to them
-  // using xla::ComputationDataHandle, which is just a symbolic handle that
-  // xla::ComputationBuilder assigns. How does this handle gets assigned for
-  // constant arguments? Even constant arguments get an _Arg node in the graph
-  // instatiated for Function compilation. The tf2xla kernel for constant _Arg
-  // nodes takes the constant value, converts it to XlaLiteral, and feeds it
-  // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
-  // constant XlaLiteral is included in the HLO graph, and subsequently, in
-  // the actual executable, which is copied to the device before being
-  // executed. Thus, when this executable runs, the constant is available in
-  // device memory.
-
-  // XlaLaunch kernel keeps all outputs (including constants, which it copies),
-  // in device memory
   MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
 
-  // Create the kernel.
-  NameAttrList function;
-  function.set_name(node_def.op());
-  *(function.mutable_attr()) = node_def.attr();
-
   Device* dev = flr->device();
   Status s;
   OpKernelConstruction construction(
       DeviceType(dev->device_type()), dev,
-      dev->GetAllocator(AllocatorAttributes()), &node_def,
+      dev->GetAllocator(AllocatorAttributes()), &launch_def,
       &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
       fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
-
-  *kernel = absl::make_unique<XlaLocalLaunchBase>(
-      &construction, constant_arg_indices, resource_arg_indices, function);
+  kernel->reset(new XlaLocalLaunchOp(&construction));
   return s;
 }
 
-namespace {
-
 bool RegisterLaunchOpCreator() {
   RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
   return true;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h
deleted file mode 100644 (file)
index 98a22e3..0000000
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
-
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/core/status.h"
-
-namespace tensorflow {
-
-class FunctionLibraryRuntime;
-class OpKernel;
-
-// Given a NodeDef 'node_def' and the function library runtime 'flr', if
-// 'node_def' is a call to a compilable function defined in 'flr', returns OK
-// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
-// node. Otherwise, returns a non-OK.
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
-                         std::unique_ptr<OpKernel>* kernel);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
deleted file mode 100644 (file)
index c222824..0000000
+++ /dev/null
@@ -1,144 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/create_xla_launch_op.h"
-
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/function_testlib.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/public/session_options.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-
-NodeDef ToNodeDef(const string& text) {
-  NodeDef node_def;
-  EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
-  return node_def;
-}
-
-// Create a FunctionDef that takes one resource and one regular param
-FunctionDef XTimesY() {
-  return FunctionDefHelper::Define(
-      // Name
-      "XTimesY",
-      // Args
-      {"x: float", "y: resource"},
-      // Return values
-      {"z: float"},
-      // Attr def
-      {},
-      // Nodes
-      {
-          {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
-          {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
-      });
-}
-
-class CreateXlaLaunchOpTest : public ::testing::Test {
- protected:
-  void Init(const std::vector<FunctionDef>& flib) {
-    SessionOptions options;
-    auto* device_count = options.config.mutable_device_count();
-    device_count->insert({"CPU", 1});
-    TF_CHECK_OK(DeviceFactory::AddDevices(
-        options, "/job:localhost/replica:0/task:0", &devices_));
-
-    FunctionDefLibrary proto;
-    for (const auto& fdef : flib) {
-      *(proto.add_function()) = fdef;
-    }
-    lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
-        OpRegistry::Global(), proto);
-    OptimizerOptions opts;
-    device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
-    pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
-        device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
-        opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
-    flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
-  }
-
-  FunctionLibraryRuntime* flr_;
-  std::vector<Device*> devices_;
-  std::unique_ptr<DeviceMgr> device_mgr_;
-  std::unique_ptr<FunctionLibraryDefinition> lib_def_;
-  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
-
-  std::unique_ptr<OpKernel> kernel_;
-};
-
-AttrValue BoolAttr(bool b) {
-  AttrValue v;
-  v.set_b(b);
-  return v;
-}
-
-TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) {
-  FunctionDef fdef = XTimesY();
-  (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
-  Init({fdef});
-
-  Status status = CreateXlaLaunchOp(
-      flr_, ToNodeDef(R"pb(
-        name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
-      )pb"), &kernel_);
-  ASSERT_TRUE(status.ok()) << status.ToString();
-
-  EXPECT_EQ("XTimesY", kernel_->name());
-  EXPECT_EQ("XTimesY", kernel_->type_string());
-
-  EXPECT_EQ(2, kernel_->num_inputs());
-  EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
-  EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
-  EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
-  EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
-
-  EXPECT_EQ(1, kernel_->num_outputs());
-  EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
-  EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
-}
-
-TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) {
-  FunctionDef fdef = XTimesY();
-  Init({fdef});
-
-  Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
-                                      name: 'XTimesY'
-                                      op: 'XTimesY'
-                                      input: 'a'
-                                      input: 'b'
-                                    )proto"), &kernel_);
-  EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
-}
-
-TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) {
-  FunctionDef fdef = XTimesY();
-  (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
-  Init({fdef});
-
-  Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
-                                      name: 'XTimesY'
-                                      op: 'XTimesY'
-                                      input: 'a'
-                                      input: 'b'
-                                    )proto"), &kernel_);
-  EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
-}
-
-}  // namespace tensorflow
index 86a9fd3..049d170 100644 (file)
@@ -39,15 +39,15 @@ limitations under the License.
 
 namespace tensorflow {
 
-XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
-                                       const std::vector<int>& constants,
-                                       const std::vector<int>& resources,
-                                       const NameAttrList& function)
-    : OpKernel(ctx),
-      constants_(constants),
-      resources_(resources),
-      device_type_(ctx->device_type()),
-      function_(function) {
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+    : OpKernel(ctx), device_type_(ctx->device_type()) {
+  const NameAttrList* func;
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
+  function_ = *func;
+  DataTypeVector constant_types;
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
+  num_constant_args_ = constant_types.size();
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
   if (device_type_ == DeviceType(DEVICE_CPU)) {
     platform_id_ = se::host::kHostPlatformId;
   } else if (device_type_ == DeviceType(DEVICE_GPU)) {
@@ -57,8 +57,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
   }
 }
 
-Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
-                                                 XlaCompilationCache** cache) {
+Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
+                                               XlaCompilationCache** cache) {
   const XlaDevice::Metadata* metadata;
   Status s = XlaDevice::GetMetadata(ctx, &metadata);
   if (s.ok()) {
@@ -90,8 +90,8 @@ Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
   return Status::OK();
 }
 
-void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
-  VLOG(1) << "XlaLocalLaunchOpBase::Compute "
+void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "XlaLocalLaunchOp::Compute "
           << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
   // We store information about the JIT-compiled XLA computation
   // in the ResourceMgr.
@@ -124,7 +124,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
   }
 
   std::map<int, OptionalTensor> variables =
-      SnapshotResourceVariables(ctx, resources_);
+      SnapshotResourceVariables(ctx, num_resource_args_);
 
   xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
 
@@ -161,7 +161,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
   xla::LocalExecutable* executable;
 
   std::map<int, Tensor> constant_args;
-  for (int i : constants_) {
+  for (int i = 0; i < num_constant_args_; ++i) {
     constant_args.insert({i, ctx->input(i)});
   }
   OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
@@ -170,8 +170,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
 
   VLOG(1) << "Executing XLA Computation...";
 
-  XlaComputationLaunchContext launch_context(client, xla_allocator,
-                                             allocate_xla_tensors);
+  XlaComputationLaunchContext launch_context(
+      num_resource_args_, client, xla_allocator, allocate_xla_tensors);
   launch_context.PopulateInputs(ctx, kernel, variables);
 
   // Execute the computation.
@@ -194,62 +194,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
   VLOG(1) << "Done";
 }
 
-namespace {
-
-// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
-// in error case, it returns RET instead of void.
-#define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
-  do {                                                      \
-    ::tensorflow::Status _s(__VA_ARGS__);                   \
-    if (!TF_PREDICT_TRUE(_s.ok())) {                        \
-      (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
-      return RET;                                           \
-    }                                                       \
-  } while (0)
-
-// Helper static functions to construct parameters for
-// XlaLocalLaunchBase constructor from OpKernelConstruction.
-std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
-  DataTypeVector constant_types;
-  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
-                        ctx->GetAttr("Tconstants", &constant_types));
-  std::vector<int> constants(constant_types.size());
-  std::iota(constants.begin(), constants.end(), 0);
-  return constants;
-}
-
-std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
-  DataTypeVector constant_types;
-  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
-                        ctx->GetAttr("Tconstants", &constant_types));
-
-  DataTypeVector arg_types;
-  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
-                        ctx->GetAttr("Targs", &arg_types));
-
-  int num_resources;
-  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
-                        ctx->GetAttr("Nresources", &num_resources));
-
-  std::vector<int> resources(num_resources);
-  std::iota(resources.begin(), resources.end(),
-            constant_types.size() + arg_types.size());
-  return resources;
-}
-
-NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
-  const NameAttrList* func;
-  OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
-  return *func;
-}
-
-#undef OP_REQUIRES_OK_RETURN
-}  // namespace
-
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
-    : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
-                         FunctionAttr(ctx)) {}
-
 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
   VLOG(1) << "XlaLocalLaunchOp destroyed";
 }
index 8dfc4b3..8f8e646 100644 (file)
@@ -26,41 +26,6 @@ limitations under the License.
 
 namespace tensorflow {
 
-// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
-// The only difference is that it does not require arguments to follow
-// the "constants, then regular args, then resources" order.
-// It takes vectors of constant and resource arguments explicitly.
-// It does not have corresponding OpDef because it is never present
-// in the GraphDef.
-// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
-// this kernel when asked to create a kernel for an XLA-compiled function.
-class XlaLocalLaunchBase : public OpKernel {
- public:
-  XlaLocalLaunchBase(OpKernelConstruction* ctx,
-                     const std::vector<int>& constants,
-                     const std::vector<int>& resources,
-                     const NameAttrList& function);
-  XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
-  XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
-  ~XlaLocalLaunchBase() override = default;
-
-  void Compute(OpKernelContext* ctx) override;
-
- protected:
-  // Builds a XlaCompilationCache class suitable for the current device.
-  Status BuildCompilationCache(OpKernelContext* ctx,
-                               XlaCompilationCache** cache);
-
-  // Indexes of compile-time constant inputs
-  std::vector<int> constants_;
-  // Indexes of resource inputs
-  std::vector<int> resources_;
-
-  DeviceType device_type_;
-  NameAttrList function_;
-  se::Platform::Id platform_id_;
-};
-
 // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
 // which will be compiled and executed using XLA.  The XlaLocalLaunchOp is
 // responsible for handling interactions with the TensorFlow executor.
@@ -70,12 +35,26 @@ class XlaLocalLaunchBase : public OpKernel {
 // XlaLocalLaunchOp uses xla::LocalClient::Compile() and
 // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
 // memory.
-class XlaLocalLaunchOp : public XlaLocalLaunchBase {
+class XlaLocalLaunchOp : public OpKernel {
  public:
   explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
   ~XlaLocalLaunchOp() override;
 
+  void Compute(OpKernelContext* ctx) override;
+
  private:
+  // Builds a XlaCompilationCache class suitable for the current device.
+  Status BuildCompilationCache(OpKernelContext* ctx,
+                               XlaCompilationCache** compiler);
+
+  DeviceType device_type_;
+  NameAttrList function_;
+  int num_constant_args_;
+  // Number of resource variable arguments.
+  int num_resource_args_;
+
+  se::Platform::Id platform_id_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
 };
 
index 6b83cf6..60458f6 100644 (file)
@@ -48,12 +48,13 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
                                  const XlaCompiler::CompilationResult* result,
                                  xla::LocalExecutable* executable) {
   std::map<int, OptionalTensor> variables = GetVariables(ctx);
+  int64 num_resource_args = variables.size();
 
   xla::LocalClient* client = metadata.client();
 
   // Builds an XLA allocator for the device.
   XlaComputationLaunchContext launch_context(
-      client, client->backend().memory_allocator(), true);
+      num_resource_args, client, client->backend().memory_allocator(), true);
 
   launch_context.PopulateInputs(ctx, result, variables);
 
index 0223f97..33e5361 100644 (file)
@@ -38,13 +38,14 @@ using xla::ScopedShapedBuffer;
 using xla::ShapedBuffer;
 }  // anonymous namespace
 
-std::map<int, OptionalTensor> SnapshotResourceVariables(
-    OpKernelContext* ctx, const std::vector<int>& variables) {
+std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
+                                                        int num_variables) {
   std::map<int, OptionalTensor> snapshot;
-  for (int i : variables) {
+  int first_variable = ctx->num_inputs() - num_variables;
+  for (int i = 0; i < num_variables; ++i) {
     Var* variable = nullptr;
-    ResourceHandle handle = HandleFromInput(ctx, i);
-    OptionalTensor& tensor = snapshot[i];
+    ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
+    OptionalTensor& tensor = snapshot[first_variable + i];
     if (LookupResource(ctx, handle, &variable).ok()) {
       tf_shared_lock lock(*variable->mu());
       tensor.name = handle.name();
@@ -111,9 +112,10 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
 using internal::ExtractSubShapedBuffer;
 
 XlaComputationLaunchContext::XlaComputationLaunchContext(
-    xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
-    bool allocate_xla_tensors)
-    : client_(client),
+    int64 num_resource_args, xla::LocalClient* client,
+    xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
+    : num_resource_args_(num_resource_args),
+      client_(client),
       xla_allocator_(xla_allocator),
       allocate_xla_tensors_(allocate_xla_tensors) {}
 
index a243125..38291b0 100644 (file)
@@ -31,17 +31,15 @@ limitations under the License.
 namespace tensorflow {
 class XlaAllocator;
 
-// Takes a snapshot of the values of resource variable arguments, whose
-// indices are specified in `variables` argument. We snapshot tensors that back
+// Takes a snapshot of the values of resource variable arguments, which are
+// the last `num_variables` arguments. We snapshot tensors that back
 // resource variables since concurrent updates may modify the shape, and it is
 // important that the shapes used for compilation match the true shapes of the
 // buffers.
 //
-// Returns a map of TensorFlow argument index to resource variable. If a
-// resource variable is not initialized, the corresponding OptionalTensor
-// will have its `present` field set to false.
-std::map<int, OptionalTensor> SnapshotResourceVariables(
-    OpKernelContext* ctx, const std::vector<int>& variables);
+// Returns a map of TensorFlow argument index to resource variable.
+std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
+                                                        int num_variables);
 
 // Adapter class that wraps a Tensorflow allocator as an XLA allocator.
 // Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -74,7 +72,7 @@ class XlaComputationLaunchContext {
   // Create a new launch context. 'allocate_xla_tensors' is true if allocated
   // output tensors and variables are always XlaTensors. If false they are
   // assumed to be "normal" device pointers.
-  XlaComputationLaunchContext(xla::LocalClient* client,
+  XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
                               xla::DeviceMemoryAllocator* xla_allocator,
                               bool allocate_xla_tensors);
 
@@ -94,6 +92,7 @@ class XlaComputationLaunchContext {
   const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
 
  private:
+  int64 num_resource_args_;
   xla::LocalClient* client_;
   xla::DeviceMemoryAllocator* xla_allocator_;
   bool allocate_xla_tensors_;
index 9791792..aaea83a 100644 (file)
@@ -327,11 +327,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:framework_for_generated_wrappers",
-        "//tensorflow/python:layers",
-        "//tensorflow/python:math_ops",
-        "//tensorflow/python:nn",
         "//tensorflow/python:platform_test",
-        "//tensorflow/python/eager:function",
     ],
 )
 
index 5ab1585..bdd0185 100644 (file)
@@ -24,16 +24,10 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
-from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.layers import convolutional
-from tensorflow.python.layers import pooling
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import googletest
 
@@ -49,7 +43,7 @@ class EagerTest(XLATestCase):
 
   def testExecuteListOutputLen0(self):
     with self.test_scope():
-      empty = constant_op.constant([], dtype=dtypes.float32)
+      empty = constant_op.constant([], dtype=dtypes.int32)
       result = array_ops.unstack(empty, 0)
       self.assertTrue(isinstance(result, list))
       self.assertEqual(0, len(result))
@@ -57,7 +51,7 @@ class EagerTest(XLATestCase):
   def testExecuteListOutputLen1(self):
     with self.test_scope():
       split_dim = constant_op.constant(1)
-      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
+      value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
       result = array_ops.split(value, 1, axis=split_dim)
       self.assertTrue(isinstance(result, list))
       self.assertEqual(1, len(result))
@@ -66,7 +60,7 @@ class EagerTest(XLATestCase):
   def testExecuteListOutputLen3(self):
     with self.test_scope():
       split_dim = constant_op.constant(1)
-      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
+      value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
       result = array_ops.split(value, 3, axis=split_dim)
       self.assertTrue(isinstance(result, list))
       self.assertEqual(3, len(result))
@@ -137,105 +131,7 @@ class EagerTest(XLATestCase):
     self.assertEqual(2., grads[0][0].numpy())
 
 
-class EagerFunctionTest(XLATestCase):
-
-  def testBasic(self):
-    with self.test_scope():
-      matmul = function.defun(math_ops.matmul, compiled=True)
-      t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
-      sq = matmul(t, t, transpose_a=True)
-      self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
-
-  def testConv(self):
-    if 'GPU' in self.device:
-      # TODO(b/32333178)
-      self.skipTest('Current implementation of RandomStandardNormal kernel '
-                    'is very slow on GPU, and has been blacklisted.')
-    with self.test_scope():
-      data_format = 'channels_last'
-      conv = convolutional.Conv2D(
-          filters=1, kernel_size=2, padding='VALID',
-          data_format=data_format, activation=nn_ops.relu,
-          kernel_initializer=init_ops.ones_initializer(),
-          bias_initializer=init_ops.zeros_initializer())
-      pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
-
-      def model(x):
-        x = conv(x)
-        return pool(x)
-      model = function.defun(model, compiled=True)
-
-      x = array_ops.ones([1, 4, 4, 1])
-      y = model(x)
-      self.assertAllEqual(y.numpy(), [[[[4.]]]])
-
-  def testReadVariable(self):
-    with self.test_scope():
-      v = resource_variable_ops.ResourceVariable(1.0)
-
-      @function.defun(compiled=True)
-      def f():
-        return v.read_value()
-
-      var = f()
-      self.assertEqual(1.0, var.numpy())
-
-  def testUpdateVariable(self):
-    with self.test_scope():
-      v = resource_variable_ops.ResourceVariable(1.0)
-
-      def f(v):
-        v.assign_add(1.0)
-        return v
-
-      f = function.defun(f, compiled=True)
-
-      var = f(v)
-      self.assertEqual(2.0, var.numpy())
-
-  def testAllArgumentKinds(self):
-    """Test a complex function that takes different argument kinds.
-
-    tf2xla machinery that translates, compiles, and runs defuns
-    classifies arguments into: compile-time constants, regular tensors,
-    and resources. This test creates a function with a mix of all these
-    kinds. Moreover, the order of function arguments is intentionally mixed up.
-
-    This also tests the case when the same argument is a compile-time constant
-    as well as used in an operation that normally expects its inputs to be
-    in device memory - addition in this case.
-    """
-    with self.test_scope():
-      def foo(c1, r1, v1, c2, v2, r2):
-        # c1 and c2 are compile-time constants
-        # r1 and r2 are regular tensors
-        # v1 and v2 are resource variables
-        a = c1 + r1
-        b = math_ops.cast(c2, dtypes.float32) + v2
-        c = array_ops.slice(v1, c1, c2)
-        d = r2 * v2
-        return a, b, c, d
-
-      foo = function.defun(foo, compiled=True)
-
-      c1 = [0, 0]
-      c2 = array_ops.ones([2], dtype=dtypes.int32)
-
-      r1 = array_ops.ones([2])
-      r2 = [[2., 2.], [3., 3.]]
-
-      v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
-      v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
-
-      a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
-
-      self.assertAllEqual([1, 1], a.numpy())
-      self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
-      self.assertAllEqual([[1.]], c.numpy())
-      self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
-
-
-if __name__ == '__main__':
+if __name__ == "__main__":
   ops.enable_eager_execution(
       config=config_pb2.ConfigProto(log_device_placement=True))
   googletest.main()
index b8f352d..8517a3b 100644 (file)
@@ -36,7 +36,9 @@ def device_and_data_format():
                                                               'channels_last')
 
 
-def random_batch(batch_size, data_format):
+def random_batch(batch_size, device_and_format=None):
+  _, data_format = device_and_format or device_and_data_format()
+
   shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
   shape = (batch_size,) + shape
 
@@ -68,7 +70,7 @@ class ResNet50Test(tf.test.TestCase):
     if defun:
       model.call = tfe.defun(model.call)
     with tf.device(device), tfe.execution_mode(execution_mode):
-      images, _ = random_batch(2, data_format)
+      images, _ = random_batch(2)
       output = model(images, training=False)
       tfe.async_wait()
     self.assertEqual((2, 1000), output.shape)
@@ -89,7 +91,7 @@ class ResNet50Test(tf.test.TestCase):
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format, include_top=False)
     with tf.device(device):
-      images, _ = random_batch(2, data_format)
+      images, _ = random_batch(2)
       output = model(images, training=False)
     output_shape = ((2, 2048, 1, 1)
                     if data_format == 'channels_first' else (2, 1, 1, 2048))
@@ -99,7 +101,7 @@ class ResNet50Test(tf.test.TestCase):
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
     with tf.device(device):
-      images, _ = random_batch(2, data_format)
+      images, _ = random_batch(2)
       output = model(images, training=False)
     self.assertEqual((2, 2048), output.shape)
 
@@ -113,7 +115,7 @@ class ResNet50Test(tf.test.TestCase):
         name='t0').as_default(), tf.contrib.summary.always_record_summaries():
       with tf.device(device), tfe.execution_mode(execution_mode):
         optimizer = tf.train.GradientDescentOptimizer(0.1)
-        images, labels = random_batch(2, data_format)
+        images, labels = random_batch(2)
         train_one_step(model, images, labels, optimizer)
         self.assertEqual(320, len(model.variables))
         tfe.async_wait()
@@ -132,7 +134,7 @@ class ResNet50Test(tf.test.TestCase):
     model = resnet50.ResNet50(data_format)
     optimizer = tf.train.GradientDescentOptimizer(0.1)
     with tf.device(device):
-      images, labels = random_batch(2, data_format)
+      images, labels = random_batch(2)
       gc.disable()
       # Warm up. Note that this first run does create significant amounts of
       # garbage to be collected. The hope is that this is a build-only effect,
@@ -200,18 +202,18 @@ class ResNet50Benchmarks(tf.test.Benchmark):
     # which forces a sync. This is a roundabout way, yes.
     tf.constant(1.).cpu()
 
-  def _benchmark_eager_apply(self, label, device_and_format, defun=False,
-                             execution_mode=None, compiled=False):
+  def _benchmark_eager_apply(self, label, defun=False, execution_mode=None,
+                             device_and_format=None):
     with tfe.execution_mode(execution_mode):
-      device, data_format = device_and_format
+      device, data_format = device_and_format or device_and_data_format()
       model = resnet50.ResNet50(data_format)
       if defun:
-        model.call = tfe.defun(model.call, compiled=compiled)
+        model.call = tfe.defun(model.call)
       batch_size = 64
       num_burn = 5
       num_iters = 30
       with tf.device(device):
-        images, _ = random_batch(batch_size, data_format)
+        images, _ = random_batch(batch_size, device_and_format)
         for _ in xrange(num_burn):
           model(images, training=False).cpu()
         if execution_mode:
@@ -225,34 +227,30 @@ class ResNet50Benchmarks(tf.test.Benchmark):
         self._report(label, start, num_iters, device, batch_size, data_format)
 
   def benchmark_eager_apply_sync(self):
-    self._benchmark_eager_apply('eager_apply', device_and_data_format(),
-                                defun=False)
+    self._benchmark_eager_apply('eager_apply', defun=False)
 
   def benchmark_eager_apply_async(self):
     self._benchmark_eager_apply(
-        'eager_apply_async', device_and_data_format(), defun=False,
-        execution_mode=tfe.ASYNC)
+        'eager_apply_async', defun=False, execution_mode=tfe.ASYNC)
 
   def benchmark_eager_apply_with_defun(self):
-    self._benchmark_eager_apply('eager_apply_with_defun',
-                                device_and_data_format(), defun=True)
+    self._benchmark_eager_apply('eager_apply_with_defun', defun=True)
 
   def _benchmark_eager_train(self,
                              label,
                              make_iterator,
-                             device_and_format,
                              defun=False,
                              execution_mode=None,
-                             compiled=False):
+                             device_and_format=None):
     with tfe.execution_mode(execution_mode):
-      device, data_format = device_and_format
+      device, data_format = device_and_format or device_and_data_format()
       for batch_size in self._train_batch_sizes():
-        (images, labels) = random_batch(batch_size, data_format)
+        (images, labels) = random_batch(batch_size, device_and_format)
         num_burn = 3
         num_iters = 10
         model = resnet50.ResNet50(data_format)
         if defun:
-          model.call = tfe.defun(model.call, compiled=compiled)
+          model.call = tfe.defun(model.call)
         optimizer = tf.train.GradientDescentOptimizer(0.1)
 
         with tf.device(device):
@@ -275,21 +273,18 @@ class ResNet50Benchmarks(tf.test.Benchmark):
           self._report(label, start, num_iters, device, batch_size, data_format)
 
   def benchmark_eager_train_sync(self):
-    self._benchmark_eager_train('eager_train', MockIterator,
-                                device_and_data_format(), defun=False)
+    self._benchmark_eager_train('eager_train', MockIterator, defun=False)
 
   def benchmark_eager_train_async(self):
     self._benchmark_eager_train(
         'eager_train_async',
         MockIterator,
-        device_and_data_format(),
         defun=False,
         execution_mode=tfe.ASYNC)
 
   def benchmark_eager_train_with_defun(self):
     self._benchmark_eager_train(
-        'eager_train_with_defun', MockIterator,
-        device_and_data_format(), defun=True)
+        'eager_train_with_defun', MockIterator, defun=True)
 
   def benchmark_eager_train_datasets(self):
 
@@ -299,8 +294,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
       return tfe.Iterator(ds)
 
     self._benchmark_eager_train(
-        'eager_train_dataset', make_iterator,
-        device_and_data_format(), defun=False)
+        'eager_train_dataset', make_iterator, defun=False)
 
   def benchmark_eager_train_datasets_with_defun(self):
 
@@ -310,8 +304,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
       return tfe.Iterator(ds)
 
     self._benchmark_eager_train(
-        'eager_train_dataset_with_defun', make_iterator,
-        device_and_data_format(), defun=True)
+        'eager_train_dataset_with_defun', make_iterator, defun=True)
 
 
 if __name__ == '__main__':
index 60cfacc..741bd2a 100644 (file)
@@ -23,7 +23,6 @@ import collections
 
 import numpy as np
 
-from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import function_pb2
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import context
@@ -226,7 +225,7 @@ def _inference_name(n):
 class _EagerDefinedFunction(object):
   """Function object with the interface of tf _DefinedFunction."""
 
-  def __init__(self, name, graph, operations, inputs, outputs, attrs):
+  def __init__(self, name, graph, operations, inputs, outputs):
     """Initializes an eager defined function.
 
     Args:
@@ -236,7 +235,6 @@ class _EagerDefinedFunction(object):
         which will be in the function
       inputs: the tensors in the graph to be used as inputs to the function
       outputs: the tensors in the graph which will be outputs to the function
-      attrs: dict mapping names of attributes to their AttrValue values
     """
     fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
         graph._c_graph,  # pylint: disable=protected-access
@@ -248,14 +246,6 @@ class _EagerDefinedFunction(object):
         [],
         None,
         compat.as_str(""))
-
-    for name, attr_value in attrs.items():
-      serialized = attr_value.SerializeToString()
-      # TODO(iga): this creates and deletes a new TF_Status for every attr.
-      # It might be worth creating a convenient way to re-use status.
-      pywrap_tensorflow.TF_FunctionSetAttrValueProto(
-          fn, compat.as_str(name), serialized)
-
     # TODO(apassos) avoid creating a FunctionDef (specially to grab the
     # signature, but also in general it's nice not to depend on it.
     with c_api_util.tf_buffer() as buffer_:
@@ -297,6 +287,25 @@ def _flatten(sequence):
 
 class GraphModeFunction(object):
   """Callable object representing a graph-mode function.
+
+  Args:
+    name: str the name of the created function
+    input_placeholders: list of placeholder values (tensors) to feed when
+      calling the wrapped function.
+    extra_inputs: Tensor inputs this function definition closed over which
+      are passed as arguments. Need to track so gradients are supported
+      correctly.
+    graph: the Graph from which the operations will be pulled. Used as
+      a context when computing gradients.
+    operations: the subset of Operations in the graph used in the function
+      definition.
+    outputs: a flat list of the Tensors in the graph used as outputs to the
+      function
+    func_outputs: a possibly nested python object which will be returned by
+      this function. The Tensors in this structure will be replaced by their
+      corresponding values in outputs.
+    output_shapes: List of shapes of all tensors in outputs
+    variables: (optional) List of variables to watch during function execution.
   """
 
   def __init__(self,
@@ -308,36 +317,9 @@ class GraphModeFunction(object):
                outputs,
                func_outputs,
                output_shapes,
-               variables=None,
-               attrs=None):
-    """Initialize a GraphModeFunction.
-
-    Args:
-      name: str the name of the created function
-      input_placeholders: list of placeholder values (tensors) to feed when
-        calling the wrapped function.
-      extra_inputs: Tensor inputs this function definition closed over which
-        are passed as arguments. Need to track so gradients are supported
-        correctly.
-      graph: the Graph from which the operations will be pulled. Used as
-        a context when computing gradients.
-      operations: the subset of Operations in the graph used in the function
-        definition.
-      outputs: a flat list of the Tensors in the graph used as outputs to the
-        function
-      func_outputs: a possibly nested python object which will be returned by
-        this function. The Tensors in this structure will be replaced by their
-        corresponding values in outputs.
-      output_shapes: List of shapes of all tensors in outputs
-      variables: (optional) List of variables to watch during function
-        execution.
-      attrs: (optional) dict mapping names of attributes to their AttrValue
-        values. Attributes in `attrs` will be included in this function's
-        definition.
-    """
-    self._attrs = attrs or {}
+               variables=None):
     defined_function = _EagerDefinedFunction(
-        name, graph, operations, input_placeholders, outputs, self._attrs)
+        name, graph, operations, input_placeholders, outputs)
     if len(input_placeholders) != len(defined_function.signature.input_arg):
       raise ValueError("Internal error: invalid lengths. %s %s" % (
           len(input_placeholders), len(defined_function.signature.input_arg)))
@@ -390,7 +372,7 @@ class GraphModeFunction(object):
     forward_name = _forward_name(self._func_name)
     self._forward_fdef = _EagerDefinedFunction(
         forward_name, self._graph, self._ops, self._input_placeholders,
-        filtered_outputs + captures, self._attrs)
+        filtered_outputs + captures)
     all_inputs = self._out_grad_placeholders + captures
     # Excluding input ops from the body as we do not intend to execute these
     # operations when the function is executed.
@@ -404,7 +386,7 @@ class GraphModeFunction(object):
     bname = _backward_name(self._func_name)
     self._backward_function = GraphModeFunction(
         bname, all_inputs, [], self._graph, function_def_ops,
-        backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
+        backward_outputs, in_gradients, output_shapes)
 
   def _backprop_call(self, args):
     """Calls the wrapped function and records the result on a tape."""
@@ -578,7 +560,7 @@ def _get_defun_inputs(args):
   return nest.pack_sequence_as(args, ret)
 
 
-def _defun_internal(name, func, compiled, args, kwds):
+def _defun_internal(name, func, args, kwds):
   """Defines and returns graph-mode version of func."""
   graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
   with context.graph_mode():
@@ -643,14 +625,9 @@ def _defun_internal(name, func, compiled, args, kwds):
     for f in tmp_graph._functions.values():  # pylint: disable=protected-access
       # TODO(ashankar): What about the gradient registry?
       _register(f._c_func.func)  # pylint: disable=protected-access
-
-  attrs = {}
-  if compiled:
-    attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True)
-
   return GraphModeFunction(
       fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
-      func_outputs, output_shapes, variables, attrs)
+      func_outputs, output_shapes, variables)
 
 
 # Defun uses this instead of Tensor as a cache key. Using dtype because
@@ -692,7 +669,7 @@ def _register(fn):
 
 
 # TODO(apassos): better error messages for non-hashable arguments.
-def named_defun(func, name, compiled=False):
+def named_defun(func, name):
   """Defines a function with a given name.
 
   See the documentation for `defun` for more information on the semantics of the
@@ -701,7 +678,6 @@ def named_defun(func, name, compiled=False):
   Args:
     func: the function to be wrapped.
     name: the name given to it.
-    compiled: if true, the framework will attempt to compile func with XLA.
 
   Returns:
     the wrapped function.
@@ -718,13 +694,13 @@ def named_defun(func, name, compiled=False):
 
     if cache_key not in arguments_to_functions:
       arguments_to_functions[cache_key] = _defun_internal(
-          name, func, compiled, args, kwds)
+          name, func, args, kwds)
     return arguments_to_functions[cache_key](*args)
 
   return decorated
 
 
-def defun(func=None, compiled=False):
+def defun(func):
   """Decorator to compile func into graph_mode.
 
   `defun` converts a function that constructs a TensorFlow graph into a function
@@ -767,45 +743,18 @@ def defun(func=None, compiled=False):
   ```
 
   Args:
-    func: function to be compiled. If `func` is None, returns a
-      decorator that can be invoked with a single argument - `func`. The
-      end result is equivalent to providing all the arguments up front.
-      In other words, defun(compiled=True)(func) is equivalent to
-      defun(func, compiled=True). The former allows the following use case:
-        @tfe.defun(compiled=True)
-        def foo(...):
-          ...
-    compiled: If True, an attempt to compile `func` with XLA will be made.
-      If it fails, function will be run normally. Experimental.
-      Currently, supported only for execution on TPUs.
+    func: function to be compiled.
 
   Returns:
-     If `func` is not None, returns callable that will execute the compiled
-     function (and return zero or more `tf.Tensor` objects).
-     If `func` is None, returns a decorator that, when invoked with a single
-     `func` argument, returns a callable equivalent to the case above.
+     A callable that will execute the compiled function (and return zero
+     or more `tf.Tensor` objects).
   """
   # TODO(apassos): deal with captured global state. Deal with control flow.
-  def decorated(function):
-    try:
-      name = function.__name__
-    except AttributeError:
-      name = "function"
-    return tf_decorator.make_decorator(
-        function, named_defun(function, name, compiled=compiled))
-
-  # This code path is for the `foo = tfe.defun(foo, ...)` use case
-  if func is not None:
-    return decorated(func)
-
-  # This code path is for the
-  #
-  # @tfe.defun(...)
-  # def foo(...):
-  #    ...
-  #
-  # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
-  return decorated
+  try:
+    name = func.__name__
+  except AttributeError:
+    name = "function"
+  return tf_decorator.make_decorator(func, named_defun(func, name))
 
 
 def make_defun_op(func, *args, **kwds):
@@ -857,7 +806,7 @@ def make_defun_op(func, *args, **kwds):
   name = func.__name__
   if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
     raise ValueError("Tensor keyword arguments are not supported.")
-  return _defun_internal(name, func, False, args, kwds)
+  return _defun_internal(name, func, args, kwds)
 
 
 class AutomaticControlDependencies(object):